Классификация узлов Используя график сверточная сеть

В этом примере показано, как классифицировать узлы на график с помощью Графика сверточной сети (GCN).

Задачей классификации узлов является та, где алгоритм, в этом примере, GCN [1], должен предсказать метки непомеченных узлов в графике. В этом примере график представлен молекулой. Атомы в молекуле представляют узлы в графике, и химические связи между атомами представляют ребра в графике. Метки узла являются типами атома, например, Углерода. По сути, вход к GCN молекулы, и выходные параметры являются предсказаниями типа атома каждого непомеченного атома в молекуле.

Присваивать категориальную метку каждой вершине графика, моделей GCN функция f(X,A) на графике G=(V,E), где V обозначает набор узлов и E обозначает набор ребер, таких чтоf(X,A) берет в качестве входа:

  • X: Матрица функции размерности N×C, где N=|V| количество узлов в G и C количество входных каналов/функций на узел.

  • A: Матрица смежности размерности N×N представление E и описывая структуру G.

и возвращает выходной параметр:

  • Z: Встраивание или матрица функции размерности N×F, где F количество выходных функций на узел. Другими словами, Z предсказания сети и F количество классов.

Модель f(X,A) основан на спектральной свертке графика, параметрами весов/фильтра, совместно использованными по всем местоположениям в G. Модель может быть представлена как мудрая слоем модель распространения, такая что выход слоя l+1 описывается как

Zl+1=σ(Dˆ-1/2AˆDˆ-1/2ZlWl),

где

  • σ функция активации.

  • Zl матрица активации слоя l, с Z1=X.

  • Wl матрица веса слоя l.

  • Aˆ=A+IN матрица смежности графика G с добавленными самосвязями. IN единичная матрица.

  • Dˆ матрица степени Aˆ.

Выражение Dˆ-1/2AˆDˆ-1/2 может упоминаться как нормированная матрица смежности графика.

Модель GCN в этом примере является вариантом стандартной модели GCN, описанной выше. Вариант использует остаточные связи между слоями [1]. Остаточные связи позволяют модели перенести информацию от входа предыдущего слоя. Поэтому выход слоя l+1из модели GCN в этом примере

Zl+1=σ(Dˆ-1/2AˆDˆ-1/2ZlWl)+Zl,

См. [1] для получения дополнительной информации о модели GCN.

Этот пример использует набор данных QM7 [2] [3], который является молекулярным набором данных, состоящим из 7 165 молекул, состоявших максимум из 23 атомов. Таким образом, молекула с самым большим количеством атомов имеет 23 атома. В целом, набор данных состоит из 5 уникальных атомов: Углерод, Водород, Азот, Кислород и Сера.

Загрузите и Загрузка данные QM7

Загрузите набор данных QM7 со следующего URL:

dataURL = 'http://quantum-machine.org/data/qm7.mat';
outputFolder = fullfile(tempdir,'qm7Data');
dataFile = fullfile(outputFolder,'qm7.mat');

if ~exist(dataFile, 'file')
    mkdir(outputFolder);
    fprintf('Downloading file ''%s'' ...\n', dataFile);
    websave(dataFile, dataURL);
end

Загрузите данные QM7.

data = load(dataFile)
data = struct with fields:
    X: [7165×23×23 single]
    R: [7165×23×3 single]
    Z: [7165×23 single]
    T: [1×7165 single]
    P: [5×1433 int64]

Данные состоят из пяти различных массивов. Этот пример использует массивы в полях X и Z из struct data. Массив в X представляет кулоново матричное [3] представление каждой молекулы, в общей сложности 7 165 молекул и массив в Z представляет атомный заряд / количество каждого атома в молекулах. Матрицы смежности графиков, представляющих молекулы и матрицы функции графиков, извлечены из кулоновых матриц. Категориальный массив меток извлечен из массива в Z.

Обратите внимание на то, что данные, для любой молекулы, которая не имеет до 23 атомов, содержат дополненные нули. Например, данные, представляющие атомные числа атомов в молекуле в индексе 1,

data.Z(1,:)
ans = 1×23 single row vector

     6     1     1     1     1     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0

Это показывает, что эта молекула состоит из пяти атомов; один атом с атомным числом 6 и четыре атома с атомным числом 1, и данные дополнен 18 нулями.

Извлеките и предварительно обработайте данные о графике

Чтобы извлечь данные о графике, получите кулоновы матрицы и атомные числа. Переставьте данные, представляющие кулоновы матрицы, и измените тип данных, чтобы удвоиться. Сортировка данных, представляющих атомные заряды так, чтобы это совпадало с данными, представляющими кулоновы матрицы.

coulombData = double(permute(data.X, [2 3 1]));
atomicNumber = sort(data.Z,2,'descend'); 

Переформатируйте кулоново матричное представление молекул к бинарным матрицам смежности с помощью coloumb2Adjacency функция присоединяется к этому примеру как вспомогательный файл.

adjacencyData = coloumb2Adjacency(coulombData, atomicNumber);

Обратите внимание на то, что coloumb2Adjacency функция не удаляет дополненные нули из данных. Их оставляют намеренно облегчить разделять данные в отдельные молекулы для обучения, валидации и вывода. Поэтому игнорируя заполненные нули, матрица смежности графика, представляющего молекулу в индексе 1, который состоит из 5 атомов,

adjacencyData(1:5,1:5,1)
ans = 5×5

     0     1     1     1     1
     1     0     0     0     0
     1     0     0     0     0
     1     0     0     0     0
     1     0     0     0     0

Прежде, чем предварительно обработать данные, используйте splitData функция, обеспеченная в конце примера, чтобы случайным образом выбрать и разделить данные в обучение, валидацию и тестовые данные. Функция использует отношение 80:10:10, чтобы разделить данные.

adjacencyDataSplit выход splitData функцией является adjacencyData разделение входных данных в три различных массива. Аналогично, coulombDataSplit и atomicNumberSplit выходными параметрами является coulombData и atomicNumber разделение входных данных в три различных массива соответственно.

[adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber);

Используйте preprocessData функция, обеспеченная в конце примера, чтобы обработать adjacencyDataSplit, coulombDataSplit, and atomicNumberSplit и возвратите матрицу смежности adjacency, матрица функции features, и категориальный массив labels.

preprocessData функционируйте создает разреженную блочно диагональную матрицу матриц смежности различных экземпляров графика, таких, что, каждый блок в матрице соответствует матрице смежности одного экземпляра графика. Эта предварительная обработка требуется, потому что GCN принимает одну матрицу смежности, как введено, тогда как этот пример имеет дело с несколькими экземплярами графика. Функция берет ненулевые диагональные элементы кулоновых матриц и присваивает их как функции. Поэтому количество входных функций на узел в примере равняется 1.

[adjacency, features, labels] = cellfun(@preprocessData, adjacencyDataSplit, coulombDataSplit, atomicNumberSplit, 'UniformOutput', false);

Просмотрите матрицы смежности обучения, валидации и тестовых данных.

adjacency
adjacency=1×3 cell array
    {88722×88722 double}    {10942×10942 double}    {10986×10986 double}

Это показывает, что существует 88 722 узла в обучающих данных, 10 942 узла в данных о валидации и 10 986 узлов в тестовых данных.

Нормируйте массив функции с помощью normalizeFeatures функция обеспечивается в конце примера.

features = normalizeFeatures(features);

Получите обучение и данные о валидации.

featureTrain = features{1};
adjacencyTrain = adjacency{1};
targetTrain = labels{1};

featureValidation = features{2};
adjacencyValidation = adjacency{2};
targetValidation = labels{2};

Визуализируйте статистику данных и данных

Выборка и задает индексы молекул, чтобы визуализировать.

Для каждого заданного индекса

  • Удалите дополненные нули из данных, представляющих необработанные атомные числа atomicNumber и необработанная матрица смежности adjacencyData из произведенной молекулы. Необработанные данные используются здесь для легкой выборки.

  • Преобразуйте матрицу смежности в график с помощью graph функция.

  • Преобразуйте атомные числа в символы.

  • Постройте график с помощью атомарных символов в качестве меток узла.

idx = [1 5 300 1159];
for j = 1:numel(idx)
    % Remove padded zeros from the data
    atomicNum = nonzeros(atomicNumber(idx(j),:));
    numOfNodes = numel(atomicNum);
    adj = adjacencyData(1:numOfNodes,1:numOfNodes,idx(j));
    
    % Convert adjacency matrix to graph
    compound = graph(adj);
    
    % Convert atomic numbers to symbols
    symbols = cell(numOfNodes, 1);
    for i = 1:numOfNodes
        if atomicNum(i) == 1
            symbols{i} = 'H';
        elseif atomicNum(i) == 6
            symbols{i} = 'C';
        elseif atomicNum(i) == 7
            symbols{i} = 'N';
        elseif atomicNum(i) == 8
            symbols{i} = 'O';
        else
            symbols{i} = 'S';
        end
    end
    
    % Plot graph
    subplot(2,2,j)
    plot(compound, 'NodeLabel', symbols, 'LineWidth', 0.75, ...
    'Layout', 'force')
    title("Molecule " + idx(j))
end

Получите все метки и классы.

labelsAll = cat(1,labels{:});
classes = categories(labelsAll)
classes = 5×1 cell
    {'Hydrogen'}
    {'Carbon'  }
    {'Nitrogen'}
    {'Oxygen'  }
    {'Sulphur' }

Визуализируйте частоту каждой категории меток с помощью гистограммы.

figure
histogram(labelsAll)
xlabel('Category')
ylabel('Frequency')
title('Label Counts')

Функция модели Define

Создайте модель функции, при условии в конце примера, который берет данные о функции dlX, матрица смежности A, и параметры модели parameters как введено и возвращает предсказания для метки.

Инициализируйте параметры модели

Определите номер входных функций на узел. Это - длина столбца матрицы функции.

numInputFeatures = size(featureTrain,2)
numInputFeatures = 1

Определите номер карт функции для скрытых слоев.

numHiddenFeatureMaps = 32;

Определите номер выходных функций как количество категорий.

numOutputFeatures = numel(classes)
numOutputFeatures = 5

Создайте struct parameters содержа веса модели. Инициализируйте веса с помощью initializeGlorot функция присоединяется к этому примеру как вспомогательный файл.

sz = [numInputFeatures numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numInputFeatures;
parameters.W1 = initializeGlorot(sz,numOut,numIn,'double');

sz = [numHiddenFeatureMaps numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numHiddenFeatureMaps;
parameters.W2 = initializeGlorot(sz,numOut,numIn,'double');

sz = [numHiddenFeatureMaps numOutputFeatures];
numOut = numOutputFeatures;
numIn = numHiddenFeatureMaps;
parameters.W3 = initializeGlorot(sz,numOut,numIn,'double');
parameters
parameters = struct with fields:
    W1: [1×32 dlarray]
    W2: [32×32 dlarray]
    W3: [32×5 dlarray]

Функция градиентов модели Define

Создайте функцию modelGradients, при условии в конце примера, который берет данные о функции dlX, матрица смежности adjacencyTrain, одногорячие закодированные цели T из меток и параметров модели parameters как введено и возвращает градиенты потери относительно параметров, соответствующей потери и сетевых предсказаний.

Задайте опции обучения

Обучайтесь в течение 1 500 эпох и установите изучить уровень для решателя Адама к 0,01.

numEpochs = 1500;
learnRate = 0.01;

Проверьте сеть после каждых 300 эпох.

validationFrequency = 300;

Визуализируйте процесс обучения в графике.

plots = "training-progress";

Чтобы обучаться на графическом процессоре, если вы доступны, задайте среду выполнения "auto". Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox) (Parallel Computing Toolbox).

executionEnvironment = "auto";

Обучите модель

Обучите модель с помощью пользовательского учебного цикла. Обучение использует градиентный спуск полного пакета.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    
    % Accuracy.
    subplot(2,1,1)
    lineAccuracyTrain = animatedline('Color',[0 0.447 0.741]);
    lineAccuracyValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 1])
    xlabel("Epoch")
    ylabel("Accuracy")
    grid on
    
    % Loss.
    subplot(2,1,2)
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    lineLossValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 inf])
    xlabel("Epoch")
    ylabel("Loss")
    grid on
end

Инициализируйте параметры для Адама.

trailingAvg = [];
trailingAvgSq = [];

Преобразуйте данные о функции обучения и валидации в dlarray.

dlX = dlarray(featureTrain);
dlXValidation = dlarray(featureValidation);

Для обучения графического процессора преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Закодируйте данные о метке обучения и валидации с помощью onehotencode.

T = onehotencode(targetTrain, 2, 'ClassNames', classes);
TValidation = onehotencode(targetValidation, 2, 'ClassNames', classes);

Обучите модель.

В течение каждой эпохи

  • Оцените градиенты модели и потерю с помощью dlfeval и modelGradients функция.

  • Обновите сетевые параметры с помощью adamupdate.

  • Вычислите учебный счет точности с помощью функции точности, обеспеченной в конце примера. Функция берет сетевые предсказания, цель, содержащая метки и категории classes как вводит и возвращает счет точности.

  • При необходимости проверьте сеть путем создания предсказаний с помощью model функция и вычисление потери валидации и счета точности валидации с помощью crossentropy и the accuracy функция.

  • Обновите учебный график.

start = tic;
% Loop over epochs.
for epoch = 1:numEpochs
    
    % Evaluate the model gradients and loss using dlfeval and the
    % modelGradients function.
    [gradients, loss, dlYPred] = dlfeval(@modelGradients, dlX, adjacencyTrain, T, parameters);
    
    % Update the network parameters using the Adam optimizer.
    [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
        trailingAvg,trailingAvgSq,epoch,learnRate);
    
    % Display the training progress.
    if plots == "training-progress"
        subplot(2,1,1)
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title("Epoch: " + epoch + ", Elapsed: " + string(D))

        % Loss.
        addpoints(lineLossTrain,epoch,double(gather(extractdata(loss))))

        % Accuracy score.
        score = accuracy(dlYPred, targetTrain, classes);
        addpoints(lineAccuracyTrain,epoch,double(gather(score)))

        drawnow

        % Display validation metrics.
        if epoch == 1 || mod(epoch,validationFrequency) == 0
            % Loss.
            dlYPredValidation = model(dlXValidation, adjacencyValidation, parameters);
            lossValidation = crossentropy(dlYPredValidation, TValidation, 'DataFormat', 'BC');
            addpoints(lineLossValidation,epoch,double(gather(extractdata(lossValidation))))

            % Accuracy score.
            scoreValidation = accuracy(dlYPredValidation, targetValidation, classes);
            addpoints(lineAccuracyValidation,epoch,double(gather(scoreValidation)))

            drawnow
        end
    end
end

Тестовая модель

Протестируйте модель с помощью тестовых данных.

featureTest = features{3};
adjacencyTest = adjacency{3};
targetTest = labels{3};

Преобразуйте тестовые данные о функции в dlarray.

dlXTest = dlarray(featureTest);

Сделайте предсказания на данных.

dlYPredTest = model(dlXTest, adjacencyTest, parameters);

Вычислите счет точности с помощью accuracy функция. accuracy функционируйте также возвращает декодируемые сетевые предсказания predTest как метки класса. Сетевые предсказания декодируются с помощью onehotdecode.

[scoreTest, predTest] = accuracy(dlYPredTest, targetTest, classes);

Просмотрите счет точности.

scoreTest
scoreTest = 0.9053

Визуализируйте предсказания

Чтобы визуализировать счет точности к каждой категории, вычислите мудрые классом баллы точности и визуализируйте их использующий гистограмму.

numOfSamples = numel(targetTest);
classTarget = zeros(numOfSamples, numOutputFeatures);
classPred = zeros(numOfSamples, numOutputFeatures);
for i = 1:numOutputFeatures
    classTarget(:,i) = targetTest==categorical(classes(i));
    classPred(:,i) = predTest==categorical(classes(i));
end

% Compute class-wise accuracy score
classAccuracy = sum(classPred == classTarget)./numOfSamples;

% Visualize class-wise accuracy score
figure
[~,idx] = sort(classAccuracy,'descend');
histogram('Categories',classes(idx), ...
    'BinCounts',classAccuracy(idx), ...
    'Barwidth',0.8)
xlabel("Category")
ylabel("Accuracy")
title("Class Accuracy Score")

Мудрые классом баллы точности показывают, как модель делает правильные предсказания с помощью и истинных положительных сторон и истинных отрицательных сторон. Положительная истина является результатом, где модель правильно предсказывает класс как существующий в наблюдении. Истинное отрицание является результатом, где модель правильно предсказывает класс как отсутствующий в наблюдении.

Чтобы визуализировать, как модель делает неправильные предсказания и оценивает основанное на модели на мудрой классом точности и мудром классом отзыве, вычислите матрицу беспорядка использование confusionmat и визуализируйте результаты с помощью confusionchart.

Мудрая классом точность является отношением истинных положительных сторон, чтобы составить положительные предсказания для класса. Общие положительные предсказания включают истинные положительные стороны и ложные положительные стороны. Положительная ложь является результатом, где модель неправильно предсказывает класс как существующий в наблюдении.

Мудрый классом отзыв, также известный как истинные положительные уровни, является отношением истинных положительных сторон, чтобы составить положительные наблюдения для класса. Общее положительное наблюдение включает истинные положительные стороны и ложные отрицательные стороны. Ложное отрицание является результатом, где модель неправильно предсказывает класс как отсутствующий в наблюдении.

[confusionMatrix, order] = confusionmat(targetTest, predTest);
figure
cm = confusionchart(confusionMatrix, classes, ...
    'ColumnSummary','column-normalized', ...
    'RowSummary','row-normalized', ...
    'Title', 'GCN QM7 Confusion Chart');

Мудрая классом точность является баллами в первой строке 'сводных данных столбца' графика, и мудрый классом отзыв баллы в первом столбце 'сводных данных строки' графика.

Функция Разделения данных

splitData функционируйте берет adjacencyData, coloumbData, и atomicNumber данные и случайным образом разделяют их в обучение, валидацию и тестовые данные в отношении 80:10:10. Функция возвращает соответствующие данные о разделении adjacencyDataSplit, coulombDataSplit, atomicNumberSplit как массивы ячеек.

function [adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber)

adjacencyDataSplit = cell(1,3);
coulombDataSplit = cell(1,3);
atomicNumberSplit = cell(1,3);

numMolecules = size(adjacencyData, 3);

% Set initial random state for example reproducibility.
rng(0);

% Get training data
idx = randperm(size(adjacencyData, 3), floor(0.8*numMolecules));
adjacencyDataSplit{1} = adjacencyData(:,:,idx);
coulombDataSplit{1} = coulombData(:,:,idx);
atomicNumberSplit{1} = atomicNumber(idx,:);
adjacencyData(:,:,idx) = [];
coulombData(:,:,idx) = [];
atomicNumber(idx,:) = [];

% Get validation data
idx = randperm(size(adjacencyData, 3), floor(0.1*numMolecules));
adjacencyDataSplit{2} = adjacencyData(:,:,idx);
coulombDataSplit{2} = coulombData(:,:,idx);
atomicNumberSplit{2} = atomicNumber(idx,:);
adjacencyData(:,:,idx) = [];
coulombData(:,:,idx) = [];
atomicNumber(idx,:) = [];

% Get test data
adjacencyDataSplit{3} = adjacencyData;
coulombDataSplit{3} = coulombData;
atomicNumberSplit{3} = atomicNumber;

end

Предварительно обработайте функцию данных

preprocessData функция предварительно обрабатывает входные данные можно следующим образом:

Для каждого графика/молекулы

  • Удалите дополненные нули из atomicNumber.

  • Конкатенация данных об атомном числе с данными об атомном числе других экземпляров графика. Необходимо конкатенировать данные начиная с соглашений в качестве примера с несколькими экземплярами графика.

  • Удалите дополненные нули из adjacencyData.

  • Создайте разреженную блочно диагональную матрицу матриц смежности различных экземпляров графика. Каждый блок в матрице соответствует матрице смежности одного экземпляра графика. Этот шаг также необходим, потому что существует несколько экземпляров графика в примере.

  • Извлеките массив функции из coulombData. Массив функции является ненулевыми диагональными элементами кулоновой матрицы в coulombData.

  • Конкатенация массива функции с массивами функции других экземпляров графика.

Функция затем преобразует данные об атомном числе в категориальные массивы.

function [adjacency, features, labels] = preprocessData(adjacencyData, coulombData, atomicNumber)

adjacency = sparse([]);
features = [];
labels = [];
for i = 1:size(adjacencyData, 3)
    % Remove padded zeros from atomicNumber
    tmpLabels = nonzeros(atomicNumber(i,:));
    labels = [labels; tmpLabels];
    
    % Get the indices of the un-padded data
    validIdx = 1:numel(tmpLabels);
    
    % Use the indices for un-padded data to remove padded zeros
    % from the adjacency data
    tmpAdjacency = adjacencyData(validIdx, validIdx, i);
    
    % Build the adjacency matrix into a block diagonal matrix
    adjacency = blkdiag(adjacency, tmpAdjacency);
    
    % Remove padded zeros from coulombData and extract the
    % feature array
    tmpFeatures = diag(coulombData(validIdx, validIdx, i));
    features = [features; tmpFeatures];
end

% Convert labels to categorical array
atomicNumbers = unique(labels);
atomNames = ["Hydrogen","Carbon","Nitrogen","Oxygen","Sulphur"];
labels = categorical(labels, atomicNumbers, atomNames);

end

Нормируйте функцию функций

normalizeFeatures функция стандартизирует входное обучение, валидацию и тестовые данные о функции features использование среднего значения и отклонения обучающих данных.

function features = normalizeFeatures(features)

% Get the mean and variance from the training data
meanFeatures = mean(features{1});
varFeatures = var(features{1}, 1);

% Standardize training, validation and test data
for i = 1:3
    features{i} = (features{i} - meanFeatures)./sqrt(varFeatures);
end

end

Функция модели

model функционируйте берет матрицу функции dlX, матрица смежности A, и параметры модели parameters и возвращает сетевые предсказания. На шаге предварительной обработки, model функция вычисляет, нормированная матрица смежности описала более раннее использование normalizeAdjacency функция обеспечивается.

function dlY = model(dlX, A, parameters)

% Normalize adjacency matrix
L = normalizeAdjacency(A);

Z1 = dlX;

Z2 = L * Z1 * parameters.W1;
Z2 = relu(Z2) + Z1;

Z3 = L * Z2 * parameters.W2;
Z3 = relu(Z3) + Z2;

Z4 = L * Z3 * parameters.W3;
dlY = softmax(Z4, 'DataFormat', 'BC');

end

Нормируйте функцию смежности

normalizeAdjacency функция вычисляет и возвращает нормированную матрицу смежности normAdjacency из входной матрицы смежности adjacency.

function normAdjacency = normalizeAdjacency(adjacency)

% Add self connections to adjacency matrix
adjacency = adjacency + speye(size(adjacency));

% Compute degree of nodes
degree = sum(adjacency, 2);

% Compute inverse square root of degree
degreeInvSqrt = sparse(sqrt(1./degree));

% Normalize adjacency matrix
normAdjacency = diag(degreeInvSqrt) * adjacency * diag(degreeInvSqrt);

end

Функция градиентов модели

modelGradients функционируйте берет матрицу функции dlX, матрица смежности adjacencyTrain, одногорячие закодированные целевые данные T, и параметры модели parameters, и возвращает градиенты потери относительно параметров модели, соответствующей потери и сетевых предсказаний.

function [gradients, loss, dlYPred] = modelGradients(dlX, adjacencyTrain, T, parameters)

dlYPred = model(dlX, adjacencyTrain, parameters);

loss = crossentropy(dlYPred, T, 'DataFormat', 'BC');

gradients = dlgradient(loss, parameters);

end

Функция точности

Функция точности декодирует сетевые предсказания YPred и вычисляет точность с помощью декодируемых предсказаний и целевых данных target. Функция возвращает вычисленный счет точности и декодируемые предсказания prediction.

function [score, prediction] = accuracy(YPred, target, classes)

% Decode probability vectors into class labels
prediction = onehotdecode(YPred, classes, 2);
score = sum(prediction == target)/numel(target);

end

Ссылки

  1. Т. Н. Кипф и М. Веллинг. Полуконтролируемая классификация с графиком сверточные сети. В ICLR, 2016.

  2. Л. К. Блум, J.-L. Реймонд, 970 миллионов подобных препарату маленьких молекул для виртуального экранирования в химической базе данных вселенной GDB-13, J. Chem. Soc., 131:8732, 2009.

  3. М. Рупп, А. Ткаченко, K.-R. Мюллер, О. А. фон Лилинфельд: Быстрое и Точное Моделирование Молекулярных энергий Распыления с Машинным обучением, Physical Review Letters, 108 (5):058301, 2012.

Copyright 2021, The MathWorks, Inc.

Смотрите также

| | |

Похожие темы