Классификация узлов с использованием графовой сверточной сети

В этом примере показано, как классифицировать узлы в графике с помощью Graph Convolutional Network (GCN).

Задача классификации узлов является такой, где алгоритм, в этом примере, GCN [1], должен предсказать метки немаркированных узлов в графике. В этом примере графиков представлен молекулой. Атомы в молекуле представляют узлы в графике, а химические связи между атомами представляют ребра в графике. Метки узлов являются типами атомов, например, Carbon. Таким образом, входы GCN являются молекулами, и выходы являются предсказаниями типа атома каждого немеченого атома в молекуле.

Чтобы назначить категориальную метку каждому узлу графика, GCN моделирует функцию f(X,A) на графике G=(V,E), где V обозначает набор узлов и E обозначает набор ребер, таких чтоf(X,A) принимает за вход:

  • X: A функции матрица размерности N×C, где N=|V| - число узлов в G и C количество входа каналов/функций на узел.

  • A: Матрица смежности размерности N×N представление E и описание структуры G.

и возвращает выход:

  • Z: Матрица встраивания или функций размерности N×F, где F - количество выхода функций на узел. Другими словами, Z является предсказаниями сети и F количество классов.

Модель f(X,A) основан на спектральной свертке графика, с весами/фильтром параметров разделенными по всем местоположениям в G. Модель может быть представлена как слоистая модель распространения, такая что выход слоя l+1 выражается как

Zl+1=σ(Dˆ-1/2AˆDˆ-1/2ZlWl),

где

  • σ является функцией активации.

  • Zl - матрица активации слоя l, с Z1=X.

  • Wl - весовая матрица слоя l.

  • Aˆ=A+IN - матрица смежности графика G с добавленными самосоединениями. IN - матрица тождеств.

  • Dˆ - матрица степеней Aˆ.

Выражение Dˆ-1/2AˆDˆ-1/2 может быть названа нормированной матрицей смежности графика.

Модель GCN в этом примере является вариантом стандартной модели GCN, описанной выше. В варианте используются остаточные соединения между слоями [1]. Остаточные соединения позволяют модели переносить информацию с входа предыдущего слоя. Поэтому выход слоя l+1модели GCN в этом примере является

Zl+1=σ(Dˆ-1/2AˆDˆ-1/2ZlWl)+Zl,

Для получения дополнительной информации о модели GCN см. раздел [1].

Этот пример использует набор данных QM7 [2] [3], который представляет собой набор молекулярных данных, состоящий из 7165 молекул, состоящих из до 23 атомов. То есть молекула с наибольшим количеством атомов имеет 23 атома. В целом набор данных состоит из 5 уникальных атомов: углерода, водорода, азота, кислорода и серы.

Загрузка и загрузка QM7 данных

Загрузите QM7 набор данных со следующего URL-адреса:

dataURL = 'http://quantum-machine.org/data/qm7.mat';
outputFolder = fullfile(tempdir,'qm7Data');
dataFile = fullfile(outputFolder,'qm7.mat');

if ~exist(dataFile, 'file')
    mkdir(outputFolder);
    fprintf('Downloading file ''%s'' ...\n', dataFile);
    websave(dataFile, dataURL);
end

Загрузка QM7 данных.

data = load(dataFile)
data = struct with fields:
    X: [7165×23×23 single]
    R: [7165×23×3 single]
    Z: [7165×23 single]
    T: [1×7165 single]
    P: [5×1433 int64]

Данные состоят из пяти различных массивов. Этот пример использует массивы в полях X и Z от struct data. Массив в X представляет представление Кулоновской матрицы [3] каждой молекулы, насчитывающее 7165 молекул, и массива в Z представляет атомарный заряд/число каждого атома в молекулах. Матрицы смежности графиков, представляющие молекулы, и матрицы функций графиков извлекаются из матриц Кулона. Категориальный массив меток извлекается из массива в Z.

Обратите внимание, что данные для любой молекулы, которая не имеет до 23 атомов, содержат заполненные нули. Для примера данные, представляющие атомарные числа атомов в молекуле с индексом 1, являются

data.Z(1,:)
ans = 1×23 single row vector

     6     1     1     1     1     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0

Это показывает, что эта молекула состоит из пяти атомов; один атом с атомарным числом 6 и четыре атома с атомным числом 1, и данные дополнены 18 нулями.

Извлечение и предварительная обработка данных графика

Чтобы извлечь данные графика, получите матрицы Кулона и атомарные числа. Транспозиция данных, представляющих матрицы Кулона, и изменение типа данных на double. Отсортируйте данные, представляющие атомарные заряды, так, чтобы они совпадали с данными, представляющими матрицы Кулона.

coulombData = double(permute(data.X, [2 3 1]));
atomicNumber = sort(data.Z,2,'descend'); 

Переформатируйте представление матрицы Кулона молекул к двоичным матрицам смежности с помощью coloumb2Adjacency функция, присоединенная к этому примеру как вспомогательный файл.

adjacencyData = coloumb2Adjacency(coulombData, atomicNumber);

Обратите внимание, что coloumb2Adjacency функция не удаляет заполненные нули из данных. Они остаются намеренно, чтобы облегчить разделение данных на отдельные молекулы для обучения, валидации и вывода. Поэтому, игнорируя заполненные нули, матрица смежности графика, представляющего молекулу с индексом 1, которая состоит из 5 атомов, является

adjacencyData(1:5,1:5,1)
ans = 5×5

     0     1     1     1     1
     1     0     0     0     0
     1     0     0     0     0
     1     0     0     0     0
     1     0     0     0     0

Перед предварительной обработкой данных используйте функцию splitData, представленную в конце примера, чтобы случайным образом выбрать и разделить данные на обучающие, валидационные и тестовые данные. Функция использует отношение 80:10:10, чтобы разделить данные.

The adjacencyDataSplit выход splitData функция является adjacencyData входные данные разделены на три различных массива. Аналогично, coulombDataSplit и atomicNumberSplit выходы coulombData и atomicNumber входные данные разделены на три различных массива, соответственно.

[adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber);

Используйте функцию preprocessData, предоставленную в конце примера, чтобы обработать adjacencyDataSplit, coulombDataSplit, и atomicNumberSplit и возвращает матрицу смежности adjacency, матрица признаков features, и категориальный массив labels.

The preprocessData функция создает разреженную блок-диагональную матрицу матриц смежности различных образцов графика, так что каждый блок в матрице соответствует матрице смежности одного образца графика. Эта предварительная обработка требуется, потому что GCN принимает одну матрицу смежности как вход, в то время как этот пример касается нескольких образцов графика. Функция принимает ненулевые диагональные элементы матриц Кулона и присваивает их как функции. Поэтому количество входа функций на узел в примере составляет 1.

[adjacency, features, labels] = cellfun(@preprocessData, adjacencyDataSplit, coulombDataSplit, atomicNumberSplit, 'UniformOutput', false);

Просмотрите матрицы смежности обучения, валидации и тестовых данных.

adjacency
adjacency=1×3 cell array
    {88722×88722 double}    {10942×10942 double}    {10986×10986 double}

Это показывает, что в обучающих данных 88722 узла, 10942 узла в данных валидации и 10986 узлов в тестовых данных.

Нормализуйте массив признаков с помощью функции normalizeFeatures, представленной в конце примера.

features = normalizeFeatures(features);

Получите обучающие и валидационные данные.

featureTrain = features{1};
adjacencyTrain = adjacency{1};
targetTrain = labels{1};

featureValidation = features{2};
adjacencyValidation = adjacency{2};
targetValidation = labels{2};

Визуализация статистики данных и данных

Выборка и определение индексов молекул для визуализации.

Для каждого заданного индекса

  • Удалите заполненные нули из данных, представляющих необработанные атомарные числа atomicNumber и необработанная матрица смежности adjacencyData отобранной молекулы. Необработанные данные используются здесь для легкой выборки.

  • Преобразуйте матрицу смежности в график с помощью graph функция.

  • Преобразуйте атомарные числа в символы.

  • Постройте график, используя атомарные символы в качестве меток узлов.

idx = [1 5 300 1159];
for j = 1:numel(idx)
    % Remove padded zeros from the data
    atomicNum = nonzeros(atomicNumber(idx(j),:));
    numOfNodes = numel(atomicNum);
    adj = adjacencyData(1:numOfNodes,1:numOfNodes,idx(j));
    
    % Convert adjacency matrix to graph
    compound = graph(adj);
    
    % Convert atomic numbers to symbols
    symbols = cell(numOfNodes, 1);
    for i = 1:numOfNodes
        if atomicNum(i) == 1
            symbols{i} = 'H';
        elseif atomicNum(i) == 6
            symbols{i} = 'C';
        elseif atomicNum(i) == 7
            symbols{i} = 'N';
        elseif atomicNum(i) == 8
            symbols{i} = 'O';
        else
            symbols{i} = 'S';
        end
    end
    
    % Plot graph
    subplot(2,2,j)
    plot(compound, 'NodeLabel', symbols, 'LineWidth', 0.75, ...
    'Layout', 'force')
    title("Molecule " + idx(j))
end

Получите все метки и классы.

labelsAll = cat(1,labels{:});
classes = categories(labelsAll)
classes = 5×1 cell
    {'Hydrogen'}
    {'Carbon'  }
    {'Nitrogen'}
    {'Oxygen'  }
    {'Sulphur' }

Визуализируйте частоту каждой категории меток с помощью гистограммы.

figure
histogram(labelsAll)
xlabel('Category')
ylabel('Frequency')
title('Label Counts')

Задайте функцию модели

Создайте модель функции, представленную в конце примера, которая принимает данные о функции dlX, матрица смежности Aи параметры модели parameters как входные и возвращающие предсказания для метки.

Инициализация параметров модели

Установите количество входа функций на узел. Это длина столбца матрицы функций.

numInputFeatures = size(featureTrain,2)
numInputFeatures = 1

Установите количество карт функций для скрытых слоев.

numHiddenFeatureMaps = 32;

Установите количество выхода функций как количество категорий.

numOutputFeatures = numel(classes)
numOutputFeatures = 5

Создайте struct parameters содержащие веса модели. Инициализируйте веса, используя initializeGlorot функция, присоединенная к этому примеру как вспомогательный файл.

sz = [numInputFeatures numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numInputFeatures;
parameters.W1 = initializeGlorot(sz,numOut,numIn,'double');

sz = [numHiddenFeatureMaps numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numHiddenFeatureMaps;
parameters.W2 = initializeGlorot(sz,numOut,numIn,'double');

sz = [numHiddenFeatureMaps numOutputFeatures];
numOut = numOutputFeatures;
numIn = numHiddenFeatureMaps;
parameters.W3 = initializeGlorot(sz,numOut,numIn,'double');
parameters
parameters = struct with fields:
    W1: [1×32 dlarray]
    W2: [32×32 dlarray]
    W3: [32×5 dlarray]

Задайте функцию градиентов модели

Создайте модель функции Gradients, представленную в конце примера, которая принимает данные о элементе dlX, матрица смежности adjacencyTrain, одноразовые закодированные целевые объекты T меток и параметров модели parameters как вход и возвращает градиенты потерь относительно параметров, соответствующих потерь и предсказаний.

Настройка опций обучения

Обучите на 1500 эпох и установите скорость обучения для решателя Адама равной 0,01.

numEpochs = 1500;
learnRate = 0.01;

Проверяйте сеть через каждые 300 эпох.

validationFrequency = 300;

Визуализируйте процесс обучения на графике.

plots = "training-progress";

Чтобы обучиться на графическом процессоре, если он доступен, задайте окружение выполнения "auto". Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox) (Parallel Computing Toolbox).

executionEnvironment = "auto";

Обучите модель

Обучите модель с помощью пользовательского цикла обучения. В обучении используется полный пакетный градиентный спуск.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    
    % Accuracy.
    subplot(2,1,1)
    lineAccuracyTrain = animatedline('Color',[0 0.447 0.741]);
    lineAccuracyValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 1])
    xlabel("Epoch")
    ylabel("Accuracy")
    grid on
    
    % Loss.
    subplot(2,1,2)
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    lineLossValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 inf])
    xlabel("Epoch")
    ylabel("Loss")
    grid on
end

Инициализируйте параметры для Adam.

trailingAvg = [];
trailingAvgSq = [];

Преобразуйте данные функций обучения и валидации в dlarray.

dlX = dlarray(featureTrain);
dlXValidation = dlarray(featureValidation);

Для обучения графический процессор преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Закодируйте данные о метках обучения и валидации с помощью onehotencode.

T = onehotencode(targetTrain, 2, 'ClassNames', classes);
TValidation = onehotencode(targetValidation, 2, 'ClassNames', classes);

Обучите модель.

Для каждой эпохи

  • Оцените градиенты модели и потери с помощью dlfeval и modelGradients функция.

  • Обновляйте параметры сети с помощью adamupdate.

  • Вычислите счет точности обучения с помощью функции точности, предоставленной в конце примера. Функция принимает сетевые предсказания, цель, содержащая метки и категории classes как вводит и возвращает счет точности.

  • При необходимости проверьте сеть, сделав предсказания с помощью model функция и вычисление потерь валидации и счета точности валидации с помощью crossentropy и accuracyфункциональные .

  • Обновите график обучения.

start = tic;
% Loop over epochs.
for epoch = 1:numEpochs
    
    % Evaluate the model gradients and loss using dlfeval and the
    % modelGradients function.
    [gradients, loss, dlYPred] = dlfeval(@modelGradients, dlX, adjacencyTrain, T, parameters);
    
    % Update the network parameters using the Adam optimizer.
    [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
        trailingAvg,trailingAvgSq,epoch,learnRate);
    
    % Display the training progress.
    if plots == "training-progress"
        subplot(2,1,1)
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title("Epoch: " + epoch + ", Elapsed: " + string(D))

        % Loss.
        addpoints(lineLossTrain,epoch,double(gather(extractdata(loss))))

        % Accuracy score.
        score = accuracy(dlYPred, targetTrain, classes);
        addpoints(lineAccuracyTrain,epoch,double(gather(score)))

        drawnow

        % Display validation metrics.
        if epoch == 1 || mod(epoch,validationFrequency) == 0
            % Loss.
            dlYPredValidation = model(dlXValidation, adjacencyValidation, parameters);
            lossValidation = crossentropy(dlYPredValidation, TValidation, 'DataFormat', 'BC');
            addpoints(lineLossValidation,epoch,double(gather(extractdata(lossValidation))))

            % Accuracy score.
            scoreValidation = accuracy(dlYPredValidation, targetValidation, classes);
            addpoints(lineAccuracyValidation,epoch,double(gather(scoreValidation)))

            drawnow
        end
    end
end

Экспериментальная модель

Протестируйте модель с помощью тестовых данных.

featureTest = features{3};
adjacencyTest = adjacency{3};
targetTest = labels{3};

Преобразуйте данные тестовых функций в dlarray.

dlXTest = dlarray(featureTest);

Делайте предсказания по данным.

dlYPredTest = model(dlXTest, adjacencyTest, parameters);

Вычислите счет точности с помощью accuracy функция. The accuracy функция также возвращает декодированные сетевые предсказания predTest как метки классов. Сетевые предсказания декодируются с помощью onehotdecode.

[scoreTest, predTest] = accuracy(dlYPredTest, targetTest, classes);

Просмотрите счет точности.

scoreTest
scoreTest = 0.9053

Визуализация предсказаний

Чтобы визуализировать счет точности для каждой категории, вычислите счета точности по классам и визуализируйте их с помощью гистограммы.

numOfSamples = numel(targetTest);
classTarget = zeros(numOfSamples, numOutputFeatures);
classPred = zeros(numOfSamples, numOutputFeatures);
for i = 1:numOutputFeatures
    classTarget(:,i) = targetTest==categorical(classes(i));
    classPred(:,i) = predTest==categorical(classes(i));
end

% Compute class-wise accuracy score
classAccuracy = sum(classPred == classTarget)./numOfSamples;

% Visualize class-wise accuracy score
figure
[~,idx] = sort(classAccuracy,'descend');
histogram('Categories',classes(idx), ...
    'BinCounts',classAccuracy(idx), ...
    'Barwidth',0.8)
xlabel("Category")
ylabel("Accuracy")
title("Class Accuracy Score")

Классовая точность счетов показа, как модель делает правильные предсказания, используя как истинные положительные, так и истинные отрицательные значения. Истинный позитив является результатом, где модель правильно предсказывает класс как присутствующий в наблюдении. Истинный минус является результатом, где модель правильно предсказывает класс как отсутствующий в наблюдении.

Чтобы визуализировать, как модель делает неправильные предсказания и оценивает модель на основе классовой точности и классового отзыва, вычислите матрицу неточностей, используя confusionmat и визуализируйте результаты с помощью confusionchart.

Классовая точность - это отношение истинных срабатываний к общему количеству положительных предсказаний для класса. Общие положительные предсказания включают истинные срабатывания и ложные срабатывания. Ложный позитив является результатом, где модель неправильно предсказывает класс как присутствующий в наблюдении.

Классово-мудрый отзыв, также известный как истинные положительные скорости, является отношением истинных положительных результатов к общим положительным наблюдениям для класса. Общее положительное наблюдение включает истинные срабатывания и ложные срабатывания. Ложный минус является результатом, где модель неправильно предсказывает класс как отсутствующий в наблюдении.

[confusionMatrix, order] = confusionmat(targetTest, predTest);
figure
cm = confusionchart(confusionMatrix, classes, ...
    'ColumnSummary','column-normalized', ...
    'RowSummary','row-normalized', ...
    'Title', 'GCN QM7 Confusion Chart');

Классовая точность - эти счета в первой строке ' column summary ' графика, а классовая ретрансляция - эти счета в первом столбце ' row summary ' графика.

Функция разделения данных

The splitData функция принимает adjacencyData, coloumbData, и atomicNumber данные и случайным образом разделяют их на обучающие, валидационные и тестовые данные в соотношении 80:10:10. Функция возвращает соответствующие данные разделения adjacencyDataSplit, coulombDataSplit, atomicNumberSplit как массивы ячеек.

function [adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber)

adjacencyDataSplit = cell(1,3);
coulombDataSplit = cell(1,3);
atomicNumberSplit = cell(1,3);

numMolecules = size(adjacencyData, 3);

% Set initial random state for example reproducibility.
rng(0);

% Get training data
idx = randperm(size(adjacencyData, 3), floor(0.8*numMolecules));
adjacencyDataSplit{1} = adjacencyData(:,:,idx);
coulombDataSplit{1} = coulombData(:,:,idx);
atomicNumberSplit{1} = atomicNumber(idx,:);
adjacencyData(:,:,idx) = [];
coulombData(:,:,idx) = [];
atomicNumber(idx,:) = [];

% Get validation data
idx = randperm(size(adjacencyData, 3), floor(0.1*numMolecules));
adjacencyDataSplit{2} = adjacencyData(:,:,idx);
coulombDataSplit{2} = coulombData(:,:,idx);
atomicNumberSplit{2} = atomicNumber(idx,:);
adjacencyData(:,:,idx) = [];
coulombData(:,:,idx) = [];
atomicNumber(idx,:) = [];

% Get test data
adjacencyDataSplit{3} = adjacencyData;
coulombDataSplit{3} = coulombData;
atomicNumberSplit{3} = atomicNumber;

end

Функция предварительной обработки данных

The preprocessData функция предварительно обрабатывает входные данные следующим образом:

Для каждого графа/молекулы

  • Удалите заполненные нули из atomicNumber.

  • Сцепите атомарные числовые данные с атомарными числовыми данными других образцов графика. Необходимо конкатенировать данные, поскольку пример касается нескольких образцов графика.

  • Удалите заполненные нули из adjacencyData.

  • Создайте разреженную блочно-диагональную матрицу матриц смежности различных образцов графика. Каждый блок в матрице соответствует матрице смежности одного образца графика. Этот шаг также необходим, потому что в примере есть несколько образцы графика.

  • Извлечение массива функций из coulombData. Массив функции является ненулевыми диагональными элементами матрицы Кулона в coulombData.

  • Сцепите массив объектов с массивами функций других образцов графика.

Затем функция преобразует атомарные числовые данные в категориальные массивы.

function [adjacency, features, labels] = preprocessData(adjacencyData, coulombData, atomicNumber)

adjacency = sparse([]);
features = [];
labels = [];
for i = 1:size(adjacencyData, 3)
    % Remove padded zeros from atomicNumber
    tmpLabels = nonzeros(atomicNumber(i,:));
    labels = [labels; tmpLabels];
    
    % Get the indices of the un-padded data
    validIdx = 1:numel(tmpLabels);
    
    % Use the indices for un-padded data to remove padded zeros
    % from the adjacency data
    tmpAdjacency = adjacencyData(validIdx, validIdx, i);
    
    % Build the adjacency matrix into a block diagonal matrix
    adjacency = blkdiag(adjacency, tmpAdjacency);
    
    % Remove padded zeros from coulombData and extract the
    % feature array
    tmpFeatures = diag(coulombData(validIdx, validIdx, i));
    features = [features; tmpFeatures];
end

% Convert labels to categorical array
atomicNumbers = unique(labels);
atomNames = ["Hydrogen","Carbon","Nitrogen","Oxygen","Sulphur"];
labels = categorical(labels, atomicNumbers, atomNames);

end

Функция нормализации функций

The normalizeFeatures функция стандартизирует данные о входе обучении, валидации и тестовых функциях features использование среднего значения и отклонения обучающих данных.

function features = normalizeFeatures(features)

% Get the mean and variance from the training data
meanFeatures = mean(features{1});
varFeatures = var(features{1}, 1);

% Standardize training, validation and test data
for i = 1:3
    features{i} = (features{i} - meanFeatures)./sqrt(varFeatures);
end

end

Моделируйте функцию

The model функция принимает матрицу признаков dlX, матрица смежности Aи параметры модели parameters и возвращает сетевые предсказания. На этапе предварительной обработки model функция вычисляет нормированную матрицу смежности, описанную ранее, используя normalizeAdjacency функция, представленная ниже.

function dlY = model(dlX, A, parameters)

% Normalize adjacency matrix
L = normalizeAdjacency(A);

Z1 = dlX;

Z2 = L * Z1 * parameters.W1;
Z2 = relu(Z2) + Z1;

Z3 = L * Z2 * parameters.W2;
Z3 = relu(Z3) + Z2;

Z4 = L * Z3 * parameters.W3;
dlY = softmax(Z4, 'DataFormat', 'BC');

end

Нормализуйте функцию Ajacency

The normalizeAdjacency функция вычисляет и возвращает нормированную матрицу смежности normAdjacency матрицы входа смежности adjacency.

function normAdjacency = normalizeAdjacency(adjacency)

% Add self connections to adjacency matrix
adjacency = adjacency + speye(size(adjacency));

% Compute degree of nodes
degree = sum(adjacency, 2);

% Compute inverse square root of degree
degreeInvSqrt = sparse(sqrt(1./degree));

% Normalize adjacency matrix
normAdjacency = diag(degreeInvSqrt) * adjacency * diag(degreeInvSqrt);

end

Функция градиентов модели

The modelGradients функция принимает матрицу признаков dlX, матрица смежности adjacencyTrain, одноразовые закодированные целевые данные Tи параметры модели parameters, и возвращает градиенты потерь относительно параметров модели, соответствующих потерь и предсказаний.

function [gradients, loss, dlYPred] = modelGradients(dlX, adjacencyTrain, T, parameters)

dlYPred = model(dlX, adjacencyTrain, parameters);

loss = crossentropy(dlYPred, T, 'DataFormat', 'BC');

gradients = dlgradient(loss, parameters);

end

Функция точности

Функция точности декодирует сетевые предсказания YPredи вычисляет точность, используя декодированные предсказания и целевые данные target. Функция возвращает вычисленный счет точности и декодированные предсказания prediction.

function [score, prediction] = accuracy(YPred, target, classes)

% Decode probability vectors into class labels
prediction = onehotdecode(YPred, classes, 2);
score = sum(prediction == target)/numel(target);

end

Ссылки

  1. Т. Н. Кипф и М. Веллинг. Полууправляемая классификация с графовыми сверточными сетями. В ICLR, 2016.

  2. Л. К. Блюм, Дж. -Л. Реймонд, 970 млн. Druglike Small Molecules for Virtual Screening in the Chemical Universe Database GDB-13, J. Am. Chem. Soc., 131:8732, 2009.

  3. М. Рупп, А. Ткатченко, К.-Р. Мюллер О. А. фон Лилиенфельд: быстрое и точное моделирование энергий молекулярной атомизации с помощью машинного обучения, Букв физического обзора, 108 (5): 058301, 2012.

См. также

| | |

Похожие темы