exponenta event banner

Классификация узлов с использованием сверточной сети графов

В этом примере показано, как классифицировать узлы в графике с помощью сверточной сети Graph (GCN).

Задача классификации узлов является задачей, в которой алгоритм, в этом примере GCN [1], должен предсказывать метки немаркированных узлов в графе. В этом примере граф представлен молекулой. Атомы в молекуле представляют узлы в графе, а химические связи между атомами представляют рёбра в графе. Метки узлов - это типы атомов, например, Carbon. Таким образом, вход в GCN представляет собой молекулы, а выходы представляют собой предсказания типа атома каждого немеченого атома в молекуле.

Чтобы назначить категориальную метку каждому узлу графа, GCN моделирует функцию f (X, A) на графе G = (V, E), где V обозначает набор узлов, а E обозначает набор рёбер, такой thatf (X, A) принимает за вход:

  • X: Матрица признаков размерности N × C, где N = | V | - количество узлов в G и C - количество входных каналов/признаков на узел.

  • A: Матрица смежности размерности N × N, представляющая E и описывающая структуру G.

и возвращает выходные данные:

  • Z: Вложение или матрица элементов размерности N × F, где F - количество выходных элементов на узел. Другими словами, Z - предсказания сети, а F - количество классов.

Модель f (X, A) основана на свертке спектрального графа, с весами/параметрами фильтра, общими для всех местоположений G. Модель может быть представлена в виде модели распространения по слоям, так что выходной сигнал слоя l + 1 выражается как

Zl + 1 = λ (Dˆ-1/2AˆDˆ-1/2ZlWl),

где

  • λ - функция активации.

  • Z1 - матрица активации слоя 1 с Z1 = X.

  • W1 - весовая матрица слоя 1.

  • Aˆ=A+IN - матрица смежности графа G с добавленными самосоединениями. IN является единичной матрицей.

  • - матрица степеней .

Выражение Dˆ-1/2AˆDˆ-1/2 можно назвать нормализованной матрицей смежности графа.

Модель GCN в этом примере является вариантом стандартной модели GCN, описанной выше. Вариант использует остаточные соединения между слоями [1]. Остаточные соединения позволяют модели переносить информацию из входных данных предыдущего уровня. Поэтому выход уровня l + 1 модели GCN в этом примере равен

Zl + 1 = λ (Dˆ-1/2AˆDˆ-1/2ZlWl) + Zl,

Для получения дополнительной информации о модели GCN см. [1].

В этом примере используется набор данных QM7 [2] [3], который является молекулярным набором данных, состоящим из 7165 молекул, состоящих из до 23 атомов. То есть молекула с наибольшим числом атомов имеет 23 атома. В целом набор данных состоит из 5 уникальных атомов: углерода, водорода, азота, кислорода и серы.

Загрузка и загрузка данных QM7

Загрузите набор данных QM7 со следующего URL-адреса:

dataURL = 'http://quantum-machine.org/data/qm7.mat';
outputFolder = fullfile(tempdir,'qm7Data');
dataFile = fullfile(outputFolder,'qm7.mat');

if ~exist(dataFile, 'file')
    mkdir(outputFolder);
    fprintf('Downloading file ''%s'' ...\n', dataFile);
    websave(dataFile, dataURL);
end

Загрузить данные QM7.

data = load(dataFile)
data = struct with fields:
    X: [7165×23×23 single]
    R: [7165×23×3 single]
    Z: [7165×23 single]
    T: [1×7165 single]
    P: [5×1433 int64]

Данные состоят из пяти различных массивов. В этом примере используются массивы в полях X и Z структуры data. Массив в X представляет собой представление матрицы Кулона [3] каждой молекулы, в общей сложности 7165 молекул, и массив в Z представляет атомный заряд/число каждого атома в молекулах. Матрицы смежности графов, представляющих молекулы, и матрицы признаков графов извлекаются из матриц Кулона. Категориальный массив меток извлекается из массива в Z.

Следует отметить, что данные для любой молекулы, которая не имеет до 23 атомов, содержат заполненные нули. Например, данные, представляющие атомные номера атомов в молекуле с индексом 1, равны

data.Z(1,:)
ans = 1×23 single row vector

     6     1     1     1     1     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0

Это показывает, что эта молекула состоит из пяти атомов; один атом с атомным номером 6 и четыре атома с атомным номером 1, и данные заполняются 18 нулями.

Извлечение и предварительная обработка данных графика

Чтобы извлечь данные графа, получите матрицы Кулона и атомные числа. Переставьте данные, представляющие матрицы Кулона, и измените тип данных на двойной. Сортируйте данные, представляющие атомные заряды, так, чтобы они совпадали с данными, представляющими матрицы Кулона.

coulombData = double(permute(data.X, [2 3 1]));
atomicNumber = sort(data.Z,2,'descend'); 

Переформатируйте матричное представление Кулона молекул в двоичные матрицы смежности, используя coloumb2Adjacency функция, присоединенная к этому примеру в качестве вспомогательного файла.

adjacencyData = coloumb2Adjacency(coulombData, atomicNumber);

Обратите внимание, что coloumb2Adjacency функция не удаляет заполненные нули из данных. Они оставлены намеренно, чтобы облегчить разделение данных на отдельные молекулы для обучения, проверки и вывода. Поэтому, игнорируя заполненные нули, матрица смежности графа, представляющего молекулу с индексом 1, которая состоит из 5 атомов, равна

adjacencyData(1:5,1:5,1)
ans = 5×5

     0     1     1     1     1
     1     0     0     0     0
     1     0     0     0     0
     1     0     0     0     0
     1     0     0     0     0

Перед предварительной обработкой данных используйте функцию splitData, предоставленную в конце примера, для случайного выбора и разделения данных на обучающие, проверочные и тестовые данные. Функция использует отношение 80:10:10 для разделения данных.

adjacencyDataSplit выходные данные splitData функция - adjacencyData входные данные разделены на три различных массива. Аналогично, coulombDataSplit и atomicNumberSplit выходами являются coulombData и atomicNumber входные данные разбиты на три различных массива соответственно.

[adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber);

Используйте функцию preprocessData, предоставленную в конце примера, для обработки adjacencyDataSplit, coulombDataSplit, и atomicNumberSplit и вернуть матрицу смежности adjacency, матрица элементов featuresи категориальный массив labels.

preprocessData функция строит разреженную блок-диагональную матрицу матриц смежности различных экземпляров графа, так что каждый блок в матрице соответствует матрице смежности одного экземпляра графа. Эта предварительная обработка необходима, поскольку GCN принимает одну матрицу смежности в качестве входных данных, тогда как в этом примере рассматривается несколько экземпляров графа. Функция принимает ненулевые диагональные элементы кулоновских матриц и назначает их в качестве признаков. Поэтому количество входных элементов на узел в примере равно 1.

[adjacency, features, labels] = cellfun(@preprocessData, adjacencyDataSplit, coulombDataSplit, atomicNumberSplit, 'UniformOutput', false);

Просмотрите матрицы смежности обучающих, валидационных и тестовых данных.

adjacency
adjacency=1×3 cell array
    {88722×88722 double}    {10942×10942 double}    {10986×10986 double}

Это показывает, что в обучающих данных имеется 88722 узла, в валидационных данных - 10942 узла и в тестовых данных - 10986 узлов.

Нормализуйте массив элементов с помощью функции normalityFeatures, предоставленной в конце примера.

features = normalizeFeatures(features);

Получите данные обучения и проверки.

featureTrain = features{1};
adjacencyTrain = adjacency{1};
targetTrain = labels{1};

featureValidation = features{2};
adjacencyValidation = adjacency{2};
targetValidation = labels{2};

Визуализация данных и статистики данных

Отобрать и указать индексы молекул для визуализации.

Для каждого указанного индекса

  • Удаление заполненных нулей из данных, представляющих необработанные атомные номера atomicNumber и необработанная матрица смежности adjacencyData отобранной молекулы. Необработанные данные используются здесь для простой выборки.

  • Преобразование матрицы смежности в график с помощью graph функция.

  • Преобразуйте атомарные номера в символы.

  • Постройте график, используя атомарные символы в качестве меток узлов.

idx = [1 5 300 1159];
for j = 1:numel(idx)
    % Remove padded zeros from the data
    atomicNum = nonzeros(atomicNumber(idx(j),:));
    numOfNodes = numel(atomicNum);
    adj = adjacencyData(1:numOfNodes,1:numOfNodes,idx(j));
    
    % Convert adjacency matrix to graph
    compound = graph(adj);
    
    % Convert atomic numbers to symbols
    symbols = cell(numOfNodes, 1);
    for i = 1:numOfNodes
        if atomicNum(i) == 1
            symbols{i} = 'H';
        elseif atomicNum(i) == 6
            symbols{i} = 'C';
        elseif atomicNum(i) == 7
            symbols{i} = 'N';
        elseif atomicNum(i) == 8
            symbols{i} = 'O';
        else
            symbols{i} = 'S';
        end
    end
    
    % Plot graph
    subplot(2,2,j)
    plot(compound, 'NodeLabel', symbols, 'LineWidth', 0.75, ...
    'Layout', 'force')
    title("Molecule " + idx(j))
end

Получите все этикетки и классы.

labelsAll = cat(1,labels{:});
classes = categories(labelsAll)
classes = 5×1 cell
    {'Hydrogen'}
    {'Carbon'  }
    {'Nitrogen'}
    {'Oxygen'  }
    {'Sulphur' }

Визуализация частоты каждой категории меток с помощью гистограммы.

figure
histogram(labelsAll)
xlabel('Category')
ylabel('Frequency')
title('Label Counts')

Определение функции модели

Создайте функциональную модель, предоставленную в конце примера, которая принимает данные элемента. dlX, матрица смежности Aи параметры модели parameters в качестве входных данных и возвращает прогнозы для метки.

Инициализация параметров модели

Задайте количество входных элементов на узел. Это длина столбца матрицы элементов.

numInputFeatures = size(featureTrain,2)
numInputFeatures = 1

Задайте количество сопоставлений элементов для скрытых слоев.

numHiddenFeatureMaps = 32;

Задайте количество выходных элементов в качестве количества категорий.

numOutputFeatures = numel(classes)
numOutputFeatures = 5

Создание структуры parameters содержащий веса модели. Инициализируйте веса с помощью initializeGlorot функция, присоединенная к этому примеру в качестве вспомогательного файла.

sz = [numInputFeatures numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numInputFeatures;
parameters.W1 = initializeGlorot(sz,numOut,numIn,'double');

sz = [numHiddenFeatureMaps numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numHiddenFeatureMaps;
parameters.W2 = initializeGlorot(sz,numOut,numIn,'double');

sz = [numHiddenFeatureMaps numOutputFeatures];
numOut = numOutputFeatures;
numIn = numHiddenFeatureMaps;
parameters.W3 = initializeGlorot(sz,numOut,numIn,'double');
parameters
parameters = struct with fields:
    W1: [1×32 dlarray]
    W2: [32×32 dlarray]
    W3: [32×5 dlarray]

Определение функции градиентов модели

Создайте функцию ModelGradients, предоставленную в конце примера, которая принимает данные элемента dlX, матрица смежности adjacencyTrain, однонаправленные закодированные цели T меток и параметров модели parameters в качестве входных данных и возвращает градиенты потерь относительно параметров, соответствующих потерь и сетевых прогнозов.

Укажите параметры обучения

Потренироваться на 1500 эпох и установить скорость обучения для решателя Адама 0,01.

numEpochs = 1500;
learnRate = 0.01;

Проверка сети после каждых 300 периодов.

validationFrequency = 300;

Визуализация хода обучения на графике.

plots = "training-progress";

Для обучения на графическом процессоре, если он доступен, укажите среду выполнения "auto". Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox) (Parallel Computing Toolbox).

executionEnvironment = "auto";

Модель поезда

Обучение модели с помощью пользовательского цикла обучения. В тренировке используется полнопериодный градиентный спуск.

Инициализируйте график хода обучения.

if plots == "training-progress"
    figure
    
    % Accuracy.
    subplot(2,1,1)
    lineAccuracyTrain = animatedline('Color',[0 0.447 0.741]);
    lineAccuracyValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 1])
    xlabel("Epoch")
    ylabel("Accuracy")
    grid on
    
    % Loss.
    subplot(2,1,2)
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    lineLossValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 inf])
    xlabel("Epoch")
    ylabel("Loss")
    grid on
end

Инициализация параметров для Adam.

trailingAvg = [];
trailingAvgSq = [];

Преобразование данных обучающих и проверочных функций в dlarray.

dlX = dlarray(featureTrain);
dlXValidation = dlarray(featureValidation);

Для обучения GPU преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Кодирование данных меток обучения и проверки с помощью onehotencode.

T = onehotencode(targetTrain, 2, 'ClassNames', classes);
TValidation = onehotencode(targetValidation, 2, 'ClassNames', classes);

Тренируйте модель.

Для каждой эпохи

  • Оценка градиентов и потерь модели с помощью dlfeval и modelGradients функция.

  • Обновление параметров сети с помощью adamupdate.

  • Вычислите оценку точности обучения, используя функцию точности, предоставленную в конце примера. Функция принимает сетевые прогнозы, цель, содержащую метки, и категории classes в качестве входных данных и возвращает оценку точности.

  • При необходимости проверьте сеть, сделав прогнозы с помощью model функция и вычисление потери проверки и показателя точности проверки с использованием crossentropy и accuracy функция.

  • Обновите график обучения.

start = tic;
% Loop over epochs.
for epoch = 1:numEpochs
    
    % Evaluate the model gradients and loss using dlfeval and the
    % modelGradients function.
    [gradients, loss, dlYPred] = dlfeval(@modelGradients, dlX, adjacencyTrain, T, parameters);
    
    % Update the network parameters using the Adam optimizer.
    [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
        trailingAvg,trailingAvgSq,epoch,learnRate);
    
    % Display the training progress.
    if plots == "training-progress"
        subplot(2,1,1)
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title("Epoch: " + epoch + ", Elapsed: " + string(D))

        % Loss.
        addpoints(lineLossTrain,epoch,double(gather(extractdata(loss))))

        % Accuracy score.
        score = accuracy(dlYPred, targetTrain, classes);
        addpoints(lineAccuracyTrain,epoch,double(gather(score)))

        drawnow

        % Display validation metrics.
        if epoch == 1 || mod(epoch,validationFrequency) == 0
            % Loss.
            dlYPredValidation = model(dlXValidation, adjacencyValidation, parameters);
            lossValidation = crossentropy(dlYPredValidation, TValidation, 'DataFormat', 'BC');
            addpoints(lineLossValidation,epoch,double(gather(extractdata(lossValidation))))

            % Accuracy score.
            scoreValidation = accuracy(dlYPredValidation, targetValidation, classes);
            addpoints(lineAccuracyValidation,epoch,double(gather(scoreValidation)))

            drawnow
        end
    end
end

Тестовая модель

Протестируйте модель, используя данные теста.

featureTest = features{3};
adjacencyTest = adjacency{3};
targetTest = labels{3};

Преобразуйте данные тестового элемента в массив dlarray.

dlXTest = dlarray(featureTest);

Делать прогнозы по данным.

dlYPredTest = model(dlXTest, adjacencyTest, parameters);

Вычислите оценку точности с помощью accuracy функция. accuracy функция также возвращает декодированные предсказания сети predTest в качестве меток класса. Предсказания сети декодируются с использованием onehotdecode.

[scoreTest, predTest] = accuracy(dlYPredTest, targetTest, classes);

Просмотр оценки точности.

scoreTest
scoreTest = 0.9053

Визуализация прогнозов

Чтобы визуализировать оценку точности для каждой категории, вычислите оценки точности по классу и визуализируйте их с помощью гистограммы.

numOfSamples = numel(targetTest);
classTarget = zeros(numOfSamples, numOutputFeatures);
classPred = zeros(numOfSamples, numOutputFeatures);
for i = 1:numOutputFeatures
    classTarget(:,i) = targetTest==categorical(classes(i));
    classPred(:,i) = predTest==categorical(classes(i));
end

% Compute class-wise accuracy score
classAccuracy = sum(classPred == classTarget)./numOfSamples;

% Visualize class-wise accuracy score
figure
[~,idx] = sort(classAccuracy,'descend');
histogram('Categories',classes(idx), ...
    'BinCounts',classAccuracy(idx), ...
    'Barwidth',0.8)
xlabel("Category")
ylabel("Accuracy")
title("Class Accuracy Score")

Оценки точности по классу показывают, как модель делает правильные прогнозы, используя как истинные положительные, так и истинные отрицательные. Истинным положительным результатом является результат, в котором модель правильно предсказывает класс как присутствующий в наблюдении. Истинный негатив - это результат, в котором модель правильно предсказывает, что класс отсутствует в наблюдении.

Для визуализации того, как модель делает неверные прогнозы, и оценки модели на основе точности по классу и отзыва по классу, вычислите матрицу путаницы с помощью confusionmat и визуализировать результаты с помощью confusionchart.

Точность по классу - это отношение истинных положительных значений к суммарным положительным прогнозам для класса. Общие положительные прогнозы включают в себя истинные положительные и ложные положительные. Ложноположительный результат - это результат, когда модель неверно предсказывает класс как присутствующий в наблюдении.

Классовый отзыв, также известный как истинные положительные показатели, - это отношение истинных положительных результатов к общим положительным наблюдениям для класса. Общее положительное наблюдение включает в себя истинные положительные и ложные отрицательные. Ложноотрицательный результат, когда модель неверно предсказывает класс как отсутствующий в наблюдении.

[confusionMatrix, order] = confusionmat(targetTest, predTest);
figure
cm = confusionchart(confusionMatrix, classes, ...
    'ColumnSummary','column-normalized', ...
    'RowSummary','row-normalized', ...
    'Title', 'GCN QM7 Confusion Chart');

Точность по классу - это баллы в первой строке «сводки столбцов» диаграммы, а выборка по классу - это баллы в первом столбце «сводки строк» диаграммы.

Функция разделения данных

splitData функция принимает adjacencyData, coloumbData, и atomicNumber данные и случайным образом разбивает их на обучающие, валидационные и тестовые данные в соотношении 80:10:10. Функция возвращает соответствующие данные разделения adjacencyDataSplit, coulombDataSplit, atomicNumberSplit в виде массивов ячеек.

function [adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber)

adjacencyDataSplit = cell(1,3);
coulombDataSplit = cell(1,3);
atomicNumberSplit = cell(1,3);

numMolecules = size(adjacencyData, 3);

% Set initial random state for example reproducibility.
rng(0);

% Get training data
idx = randperm(size(adjacencyData, 3), floor(0.8*numMolecules));
adjacencyDataSplit{1} = adjacencyData(:,:,idx);
coulombDataSplit{1} = coulombData(:,:,idx);
atomicNumberSplit{1} = atomicNumber(idx,:);
adjacencyData(:,:,idx) = [];
coulombData(:,:,idx) = [];
atomicNumber(idx,:) = [];

% Get validation data
idx = randperm(size(adjacencyData, 3), floor(0.1*numMolecules));
adjacencyDataSplit{2} = adjacencyData(:,:,idx);
coulombDataSplit{2} = coulombData(:,:,idx);
atomicNumberSplit{2} = atomicNumber(idx,:);
adjacencyData(:,:,idx) = [];
coulombData(:,:,idx) = [];
atomicNumber(idx,:) = [];

% Get test data
adjacencyDataSplit{3} = adjacencyData;
coulombDataSplit{3} = coulombData;
atomicNumberSplit{3} = atomicNumber;

end

Функция данных предварительной обработки

preprocessData функция предварительно обрабатывает входные данные следующим образом:

Для каждого графа/молекулы

  • Удалить заполненные нули из atomicNumber.

  • Объедините данные атомного номера с данными атомного номера других экземпляров графа. Необходимо объединить данные, поскольку в примере рассматривается несколько экземпляров графа.

  • Удалить заполненные нули из adjacencyData.

  • Построение разреженной блок-диагональной матрицы матриц смежности различных экземпляров графа. Каждый блок в матрице соответствует матрице смежности одного экземпляра графа. Этот шаг также необходим, поскольку в примере имеется несколько экземпляров графа.

  • Извлечение массива элементов из coulombData. Массив элементов представляет собой ненулевые диагональные элементы матрицы Кулона в coulombData.

  • Соедините массив элементов с массивами элементов других экземпляров графика.

Затем функция преобразует данные атомного номера в категориальные массивы.

function [adjacency, features, labels] = preprocessData(adjacencyData, coulombData, atomicNumber)

adjacency = sparse([]);
features = [];
labels = [];
for i = 1:size(adjacencyData, 3)
    % Remove padded zeros from atomicNumber
    tmpLabels = nonzeros(atomicNumber(i,:));
    labels = [labels; tmpLabels];
    
    % Get the indices of the un-padded data
    validIdx = 1:numel(tmpLabels);
    
    % Use the indices for un-padded data to remove padded zeros
    % from the adjacency data
    tmpAdjacency = adjacencyData(validIdx, validIdx, i);
    
    % Build the adjacency matrix into a block diagonal matrix
    adjacency = blkdiag(adjacency, tmpAdjacency);
    
    % Remove padded zeros from coulombData and extract the
    % feature array
    tmpFeatures = diag(coulombData(validIdx, validIdx, i));
    features = [features; tmpFeatures];
end

% Convert labels to categorical array
atomicNumbers = unique(labels);
atomNames = ["Hydrogen","Carbon","Nitrogen","Oxygen","Sulphur"];
labels = categorical(labels, atomicNumbers, atomNames);

end

Нормализация функций

normalizeFeatures функция стандартизирует входные данные обучения, проверки и тестирования features используя среднее значение и дисперсию данных обучения.

function features = normalizeFeatures(features)

% Get the mean and variance from the training data
meanFeatures = mean(features{1});
varFeatures = var(features{1}, 1);

% Standardize training, validation and test data
for i = 1:3
    features{i} = (features{i} - meanFeatures)./sqrt(varFeatures);
end

end

Функция модели

model функция принимает матрицу признаков dlX, матрица смежности Aи параметры модели parameters и возвращает предсказания сети. На этапе предварительной обработки model функция вычисляет нормализованную матрицу смежности, описанную ранее, с помощью normalizeAdjacency функция, представленная ниже.

function dlY = model(dlX, A, parameters)

% Normalize adjacency matrix
L = normalizeAdjacency(A);

Z1 = dlX;

Z2 = L * Z1 * parameters.W1;
Z2 = relu(Z2) + Z1;

Z3 = L * Z2 * parameters.W2;
Z3 = relu(Z3) + Z2;

Z4 = L * Z3 * parameters.W3;
dlY = softmax(Z4, 'DataFormat', 'BC');

end

Нормализация функции Ajacency

normalizeAdjacency функция вычисляет и возвращает нормализованную матрицу смежности normAdjacency входной матрицы смежности adjacency.

function normAdjacency = normalizeAdjacency(adjacency)

% Add self connections to adjacency matrix
adjacency = adjacency + speye(size(adjacency));

% Compute degree of nodes
degree = sum(adjacency, 2);

% Compute inverse square root of degree
degreeInvSqrt = sparse(sqrt(1./degree));

% Normalize adjacency matrix
normAdjacency = diag(degreeInvSqrt) * adjacency * diag(degreeInvSqrt);

end

Функция градиентов модели

modelGradients функция принимает матрицу признаков dlX, матрица смежности adjacencyTrain, однокодированные кодированные целевые данные Tи параметры модели parameters, и возвращает градиенты потерь относительно параметров модели, соответствующих потерь и сетевых прогнозов.

function [gradients, loss, dlYPred] = modelGradients(dlX, adjacencyTrain, T, parameters)

dlYPred = model(dlX, adjacencyTrain, parameters);

loss = crossentropy(dlYPred, T, 'DataFormat', 'BC');

gradients = dlgradient(loss, parameters);

end

Функция точности

Функция точности декодирует предсказания сети YPredи вычисляет точность с использованием декодированных прогнозов и целевых данных target. Функция возвращает вычисленную оценку точности и декодированные предсказания. prediction.

function [score, prediction] = accuracy(YPred, target, classes)

% Decode probability vectors into class labels
prediction = onehotdecode(YPred, classes, 2);
score = sum(prediction == target)/numel(target);

end

Ссылки

  1. Т. Н. Кипф и М. Веллинг. Полууправляемая классификация с графовыми сверточными сетями. В ICLR, 2016.

  2. Л. К. Блюм, Дж. -Л. Реймонд, 970 миллионов маленьких молекул, подобных Druglike, для виртуального скрининга в базе данных химической вселенной GDB-13, J. Am. Chem. Soc., 131:8732, 2009.

  3. М. Рупп, А. Ткатченко, К.-Р. Мюллер, О. А. фон Лилиенфельд: Быстрое и точное моделирование энергий молекулярного атомизации с машинным обучением, физические обзорные письма, 108 (5): 058301, 2012.

См. также

| | |

Связанные темы