В этом примере показано, как классифицировать узлы в графике с помощью Graph Convolutional Network (GCN).
Задача классификации узлов является такой, где алгоритм, в этом примере, GCN [1], должен предсказать метки немаркированных узлов в графике. В этом примере графиков представлен молекулой. Атомы в молекуле представляют узлы в графике, а химические связи между атомами представляют ребра в графике. Метки узлов являются типами атомов, например, Carbon. Таким образом, входы GCN являются молекулами, и выходы являются предсказаниями типа атома каждого немеченого атома в молекуле.
Чтобы назначить категориальную метку каждому узлу графика, GCN моделирует функцию на графике , где обозначает набор узлов и обозначает набор ребер, таких что принимает за вход:
: A функции матрица размерности , где - число узлов в и количество входа каналов/функций на узел.
: Матрица смежности размерности представление и описание структуры .
и возвращает выход:
: Матрица встраивания или функций размерности , где - количество выхода функций на узел. Другими словами, является предсказаниями сети и количество классов.
Модель основан на спектральной свертке графика, с весами/фильтром параметров разделенными по всем местоположениям в . Модель может быть представлена как слоистая модель распространения, такая что выход слоя выражается как
,
где
является функцией активации.
- матрица активации слоя , с .
- весовая матрица слоя .
- матрица смежности графика с добавленными самосоединениями. - матрица тождеств.
- матрица степеней .
Выражение может быть названа нормированной матрицей смежности графика.
Модель GCN в этом примере является вариантом стандартной модели GCN, описанной выше. В варианте используются остаточные соединения между слоями [1]. Остаточные соединения позволяют модели переносить информацию с входа предыдущего слоя. Поэтому выход слоя модели GCN в этом примере является
,
Для получения дополнительной информации о модели GCN см. раздел [1].
Этот пример использует набор данных QM7 [2] [3], который представляет собой набор молекулярных данных, состоящий из 7165 молекул, состоящих из до 23 атомов. То есть молекула с наибольшим количеством атомов имеет 23 атома. В целом набор данных состоит из 5 уникальных атомов: углерода, водорода, азота, кислорода и серы.
Загрузите QM7 набор данных со следующего URL-адреса:
dataURL = 'http://quantum-machine.org/data/qm7.mat'; outputFolder = fullfile(tempdir,'qm7Data'); dataFile = fullfile(outputFolder,'qm7.mat'); if ~exist(dataFile, 'file') mkdir(outputFolder); fprintf('Downloading file ''%s'' ...\n', dataFile); websave(dataFile, dataURL); end
Загрузка QM7 данных.
data = load(dataFile)
data = struct with fields:
X: [7165×23×23 single]
R: [7165×23×3 single]
Z: [7165×23 single]
T: [1×7165 single]
P: [5×1433 int64]
Данные состоят из пяти различных массивов. Этот пример использует массивы в полях X
и Z
от struct data
. Массив в X
представляет представление Кулоновской матрицы [3] каждой молекулы, насчитывающее 7165 молекул, и массива в Z
представляет атомарный заряд/число каждого атома в молекулах. Матрицы смежности графиков, представляющие молекулы, и матрицы функций графиков извлекаются из матриц Кулона. Категориальный массив меток извлекается из массива в Z
.
Обратите внимание, что данные для любой молекулы, которая не имеет до 23 атомов, содержат заполненные нули. Для примера данные, представляющие атомарные числа атомов в молекуле с индексом 1, являются
data.Z(1,:)
ans = 1×23 single row vector
6 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
Это показывает, что эта молекула состоит из пяти атомов; один атом с атомарным числом 6 и четыре атома с атомным числом 1, и данные дополнены 18 нулями.
Чтобы извлечь данные графика, получите матрицы Кулона и атомарные числа. Транспозиция данных, представляющих матрицы Кулона, и изменение типа данных на double. Отсортируйте данные, представляющие атомарные заряды, так, чтобы они совпадали с данными, представляющими матрицы Кулона.
coulombData = double(permute(data.X, [2 3 1]));
atomicNumber = sort(data.Z,2,'descend');
Переформатируйте представление матрицы Кулона молекул к двоичным матрицам смежности с помощью coloumb2Adjacency
функция, присоединенная к этому примеру как вспомогательный файл.
adjacencyData = coloumb2Adjacency(coulombData, atomicNumber);
Обратите внимание, что coloumb2Adjacency
функция не удаляет заполненные нули из данных. Они остаются намеренно, чтобы облегчить разделение данных на отдельные молекулы для обучения, валидации и вывода. Поэтому, игнорируя заполненные нули, матрица смежности графика, представляющего молекулу с индексом 1, которая состоит из 5 атомов, является
adjacencyData(1:5,1:5,1)
ans = 5×5
0 1 1 1 1
1 0 0 0 0
1 0 0 0 0
1 0 0 0 0
1 0 0 0 0
Перед предварительной обработкой данных используйте функцию splitData, представленную в конце примера, чтобы случайным образом выбрать и разделить данные на обучающие, валидационные и тестовые данные. Функция использует отношение 80:10:10, чтобы разделить данные.
The adjacencyDataSplit
выход splitData
функция является adjacencyData
входные данные разделены на три различных массива. Аналогично, coulombDataSplit
и atomicNumberSplit
выходы coulombData
и atomicNumber
входные данные разделены на три различных массива, соответственно.
[adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber);
Используйте функцию preprocessData, предоставленную в конце примера, чтобы обработать adjacencyDataSplit
, coulombDataSplit,
и atomicNumberSplit
и возвращает матрицу смежности adjacency
, матрица признаков features
, и категориальный массив labels
.
The preprocessData
функция создает разреженную блок-диагональную матрицу матриц смежности различных образцов графика, так что каждый блок в матрице соответствует матрице смежности одного образца графика. Эта предварительная обработка требуется, потому что GCN принимает одну матрицу смежности как вход, в то время как этот пример касается нескольких образцов графика. Функция принимает ненулевые диагональные элементы матриц Кулона и присваивает их как функции. Поэтому количество входа функций на узел в примере составляет 1.
[adjacency, features, labels] = cellfun(@preprocessData, adjacencyDataSplit, coulombDataSplit, atomicNumberSplit, 'UniformOutput', false);
Просмотрите матрицы смежности обучения, валидации и тестовых данных.
adjacency
adjacency=1×3 cell array
{88722×88722 double} {10942×10942 double} {10986×10986 double}
Это показывает, что в обучающих данных 88722 узла, 10942 узла в данных валидации и 10986 узлов в тестовых данных.
Нормализуйте массив признаков с помощью функции normalizeFeatures, представленной в конце примера.
features = normalizeFeatures(features);
Получите обучающие и валидационные данные.
featureTrain = features{1}; adjacencyTrain = adjacency{1}; targetTrain = labels{1}; featureValidation = features{2}; adjacencyValidation = adjacency{2}; targetValidation = labels{2};
Выборка и определение индексов молекул для визуализации.
Для каждого заданного индекса
Удалите заполненные нули из данных, представляющих необработанные атомарные числа atomicNumber
и необработанная матрица смежности adjacencyData
отобранной молекулы. Необработанные данные используются здесь для легкой выборки.
Преобразуйте матрицу смежности в график с помощью graph
функция.
Преобразуйте атомарные числа в символы.
Постройте график, используя атомарные символы в качестве меток узлов.
idx = [1 5 300 1159]; for j = 1:numel(idx) % Remove padded zeros from the data atomicNum = nonzeros(atomicNumber(idx(j),:)); numOfNodes = numel(atomicNum); adj = adjacencyData(1:numOfNodes,1:numOfNodes,idx(j)); % Convert adjacency matrix to graph compound = graph(adj); % Convert atomic numbers to symbols symbols = cell(numOfNodes, 1); for i = 1:numOfNodes if atomicNum(i) == 1 symbols{i} = 'H'; elseif atomicNum(i) == 6 symbols{i} = 'C'; elseif atomicNum(i) == 7 symbols{i} = 'N'; elseif atomicNum(i) == 8 symbols{i} = 'O'; else symbols{i} = 'S'; end end % Plot graph subplot(2,2,j) plot(compound, 'NodeLabel', symbols, 'LineWidth', 0.75, ... 'Layout', 'force') title("Molecule " + idx(j)) end
Получите все метки и классы.
labelsAll = cat(1,labels{:}); classes = categories(labelsAll)
classes = 5×1 cell
{'Hydrogen'}
{'Carbon' }
{'Nitrogen'}
{'Oxygen' }
{'Sulphur' }
Визуализируйте частоту каждой категории меток с помощью гистограммы.
figure histogram(labelsAll) xlabel('Category') ylabel('Frequency') title('Label Counts')
Создайте модель функции, представленную в конце примера, которая принимает данные о функции dlX
, матрица смежности A
и параметры модели parameters
как входные и возвращающие предсказания для метки.
Установите количество входа функций на узел. Это длина столбца матрицы функций.
numInputFeatures = size(featureTrain,2)
numInputFeatures = 1
Установите количество карт функций для скрытых слоев.
numHiddenFeatureMaps = 32;
Установите количество выхода функций как количество категорий.
numOutputFeatures = numel(classes)
numOutputFeatures = 5
Создайте struct parameters
содержащие веса модели. Инициализируйте веса, используя initializeGlorot
функция, присоединенная к этому примеру как вспомогательный файл.
sz = [numInputFeatures numHiddenFeatureMaps]; numOut = numHiddenFeatureMaps; numIn = numInputFeatures; parameters.W1 = initializeGlorot(sz,numOut,numIn,'double'); sz = [numHiddenFeatureMaps numHiddenFeatureMaps]; numOut = numHiddenFeatureMaps; numIn = numHiddenFeatureMaps; parameters.W2 = initializeGlorot(sz,numOut,numIn,'double'); sz = [numHiddenFeatureMaps numOutputFeatures]; numOut = numOutputFeatures; numIn = numHiddenFeatureMaps; parameters.W3 = initializeGlorot(sz,numOut,numIn,'double'); parameters
parameters = struct with fields:
W1: [1×32 dlarray]
W2: [32×32 dlarray]
W3: [32×5 dlarray]
Создайте модель функции Gradients, представленную в конце примера, которая принимает данные о элементе dlX
, матрица смежности adjacencyTrain
, одноразовые закодированные целевые объекты T
меток и параметров модели parameters
как вход и возвращает градиенты потерь относительно параметров, соответствующих потерь и предсказаний.
Обучите на 1500 эпох и установите скорость обучения для решателя Адама равной 0,01.
numEpochs = 1500; learnRate = 0.01;
Проверяйте сеть через каждые 300 эпох.
validationFrequency = 300;
Визуализируйте процесс обучения на графике.
plots = "training-progress";
Чтобы обучиться на графическом процессоре, если он доступен, задайте окружение выполнения "auto"
. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox) (Parallel Computing Toolbox).
executionEnvironment = "auto";
Обучите модель с помощью пользовательского цикла обучения. В обучении используется полный пакетный градиентный спуск.
Инициализируйте график процесса обучения.
if plots == "training-progress" figure % Accuracy. subplot(2,1,1) lineAccuracyTrain = animatedline('Color',[0 0.447 0.741]); lineAccuracyValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 1]) xlabel("Epoch") ylabel("Accuracy") grid on % Loss. subplot(2,1,2) lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); lineLossValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 inf]) xlabel("Epoch") ylabel("Loss") grid on end
Инициализируйте параметры для Adam.
trailingAvg = []; trailingAvgSq = [];
Преобразуйте данные функций обучения и валидации в dlarray.
dlX = dlarray(featureTrain); dlXValidation = dlarray(featureValidation);
Для обучения графический процессор преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end
Закодируйте данные о метках обучения и валидации с помощью onehotencode
.
T = onehotencode(targetTrain, 2, 'ClassNames', classes); TValidation = onehotencode(targetValidation, 2, 'ClassNames', classes);
Обучите модель.
Для каждой эпохи
Оцените градиенты модели и потери с помощью dlfeval
и modelGradients
функция.
Обновляйте параметры сети с помощью adamupdate
.
Вычислите счет точности обучения с помощью функции точности, предоставленной в конце примера. Функция принимает сетевые предсказания, цель, содержащая метки и категории classes
как вводит и возвращает счет точности.
При необходимости проверьте сеть, сделав предсказания с помощью model
функция и вычисление потерь валидации и счета точности валидации с помощью crossentropy
и accuracy
функциональные .
Обновите график обучения.
start = tic; % Loop over epochs. for epoch = 1:numEpochs % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients, loss, dlYPred] = dlfeval(@modelGradients, dlX, adjacencyTrain, T, parameters); % Update the network parameters using the Adam optimizer. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAvgSq,epoch,learnRate); % Display the training progress. if plots == "training-progress" subplot(2,1,1) D = duration(0,0,toc(start),'Format','hh:mm:ss'); title("Epoch: " + epoch + ", Elapsed: " + string(D)) % Loss. addpoints(lineLossTrain,epoch,double(gather(extractdata(loss)))) % Accuracy score. score = accuracy(dlYPred, targetTrain, classes); addpoints(lineAccuracyTrain,epoch,double(gather(score))) drawnow % Display validation metrics. if epoch == 1 || mod(epoch,validationFrequency) == 0 % Loss. dlYPredValidation = model(dlXValidation, adjacencyValidation, parameters); lossValidation = crossentropy(dlYPredValidation, TValidation, 'DataFormat', 'BC'); addpoints(lineLossValidation,epoch,double(gather(extractdata(lossValidation)))) % Accuracy score. scoreValidation = accuracy(dlYPredValidation, targetValidation, classes); addpoints(lineAccuracyValidation,epoch,double(gather(scoreValidation))) drawnow end end end
Протестируйте модель с помощью тестовых данных.
featureTest = features{3}; adjacencyTest = adjacency{3}; targetTest = labels{3};
Преобразуйте данные тестовых функций в dlarray.
dlXTest = dlarray(featureTest);
Делайте предсказания по данным.
dlYPredTest = model(dlXTest, adjacencyTest, parameters);
Вычислите счет точности с помощью accuracy
функция. The accuracy
функция также возвращает декодированные сетевые предсказания predTest
как метки классов. Сетевые предсказания декодируются с помощью onehotdecode
.
[scoreTest, predTest] = accuracy(dlYPredTest, targetTest, classes);
Просмотрите счет точности.
scoreTest
scoreTest = 0.9053
Чтобы визуализировать счет точности для каждой категории, вычислите счета точности по классам и визуализируйте их с помощью гистограммы.
numOfSamples = numel(targetTest); classTarget = zeros(numOfSamples, numOutputFeatures); classPred = zeros(numOfSamples, numOutputFeatures); for i = 1:numOutputFeatures classTarget(:,i) = targetTest==categorical(classes(i)); classPred(:,i) = predTest==categorical(classes(i)); end % Compute class-wise accuracy score classAccuracy = sum(classPred == classTarget)./numOfSamples; % Visualize class-wise accuracy score figure [~,idx] = sort(classAccuracy,'descend'); histogram('Categories',classes(idx), ... 'BinCounts',classAccuracy(idx), ... 'Barwidth',0.8) xlabel("Category") ylabel("Accuracy") title("Class Accuracy Score")
Классовая точность счетов показа, как модель делает правильные предсказания, используя как истинные положительные, так и истинные отрицательные значения. Истинный позитив является результатом, где модель правильно предсказывает класс как присутствующий в наблюдении. Истинный минус является результатом, где модель правильно предсказывает класс как отсутствующий в наблюдении.
Чтобы визуализировать, как модель делает неправильные предсказания и оценивает модель на основе классовой точности и классового отзыва, вычислите матрицу неточностей, используя confusionmat
и визуализируйте результаты с помощью confusionchart
.
Классовая точность - это отношение истинных срабатываний к общему количеству положительных предсказаний для класса. Общие положительные предсказания включают истинные срабатывания и ложные срабатывания. Ложный позитив является результатом, где модель неправильно предсказывает класс как присутствующий в наблюдении.
Классово-мудрый отзыв, также известный как истинные положительные скорости, является отношением истинных положительных результатов к общим положительным наблюдениям для класса. Общее положительное наблюдение включает истинные срабатывания и ложные срабатывания. Ложный минус является результатом, где модель неправильно предсказывает класс как отсутствующий в наблюдении.
[confusionMatrix, order] = confusionmat(targetTest, predTest); figure cm = confusionchart(confusionMatrix, classes, ... 'ColumnSummary','column-normalized', ... 'RowSummary','row-normalized', ... 'Title', 'GCN QM7 Confusion Chart');
Классовая точность - эти счета в первой строке ' column summary ' графика, а классовая ретрансляция - эти счета в первом столбце ' row summary ' графика.
The splitData
функция принимает adjacencyData
, coloumbData
, и atomicNumber
данные и случайным образом разделяют их на обучающие, валидационные и тестовые данные в соотношении 80:10:10. Функция возвращает соответствующие данные разделения adjacencyDataSplit,
coulombDataSplit, atomicNumberSplit
как массивы ячеек.
function [adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber) adjacencyDataSplit = cell(1,3); coulombDataSplit = cell(1,3); atomicNumberSplit = cell(1,3); numMolecules = size(adjacencyData, 3); % Set initial random state for example reproducibility. rng(0); % Get training data idx = randperm(size(adjacencyData, 3), floor(0.8*numMolecules)); adjacencyDataSplit{1} = adjacencyData(:,:,idx); coulombDataSplit{1} = coulombData(:,:,idx); atomicNumberSplit{1} = atomicNumber(idx,:); adjacencyData(:,:,idx) = []; coulombData(:,:,idx) = []; atomicNumber(idx,:) = []; % Get validation data idx = randperm(size(adjacencyData, 3), floor(0.1*numMolecules)); adjacencyDataSplit{2} = adjacencyData(:,:,idx); coulombDataSplit{2} = coulombData(:,:,idx); atomicNumberSplit{2} = atomicNumber(idx,:); adjacencyData(:,:,idx) = []; coulombData(:,:,idx) = []; atomicNumber(idx,:) = []; % Get test data adjacencyDataSplit{3} = adjacencyData; coulombDataSplit{3} = coulombData; atomicNumberSplit{3} = atomicNumber; end
The preprocessData
функция предварительно обрабатывает входные данные следующим образом:
Для каждого графа/молекулы
Удалите заполненные нули из atomicNumber
.
Сцепите атомарные числовые данные с атомарными числовыми данными других образцов графика. Необходимо конкатенировать данные, поскольку пример касается нескольких образцов графика.
Удалите заполненные нули из adjacencyData
.
Создайте разреженную блочно-диагональную матрицу матриц смежности различных образцов графика. Каждый блок в матрице соответствует матрице смежности одного образца графика. Этот шаг также необходим, потому что в примере есть несколько образцы графика.
Извлечение массива функций из coulombData
. Массив функции является ненулевыми диагональными элементами матрицы Кулона в coulombData
.
Сцепите массив объектов с массивами функций других образцов графика.
Затем функция преобразует атомарные числовые данные в категориальные массивы.
function [adjacency, features, labels] = preprocessData(adjacencyData, coulombData, atomicNumber) adjacency = sparse([]); features = []; labels = []; for i = 1:size(adjacencyData, 3) % Remove padded zeros from atomicNumber tmpLabels = nonzeros(atomicNumber(i,:)); labels = [labels; tmpLabels]; % Get the indices of the un-padded data validIdx = 1:numel(tmpLabels); % Use the indices for un-padded data to remove padded zeros % from the adjacency data tmpAdjacency = adjacencyData(validIdx, validIdx, i); % Build the adjacency matrix into a block diagonal matrix adjacency = blkdiag(adjacency, tmpAdjacency); % Remove padded zeros from coulombData and extract the % feature array tmpFeatures = diag(coulombData(validIdx, validIdx, i)); features = [features; tmpFeatures]; end % Convert labels to categorical array atomicNumbers = unique(labels); atomNames = ["Hydrogen","Carbon","Nitrogen","Oxygen","Sulphur"]; labels = categorical(labels, atomicNumbers, atomNames); end
The normalizeFeatures
функция стандартизирует данные о входе обучении, валидации и тестовых функциях features
использование среднего значения и отклонения обучающих данных.
function features = normalizeFeatures(features) % Get the mean and variance from the training data meanFeatures = mean(features{1}); varFeatures = var(features{1}, 1); % Standardize training, validation and test data for i = 1:3 features{i} = (features{i} - meanFeatures)./sqrt(varFeatures); end end
The model
функция принимает матрицу признаков dlX
, матрица смежности A
и параметры модели parameters
и возвращает сетевые предсказания. На этапе предварительной обработки model
функция вычисляет нормированную матрицу смежности, описанную ранее, используя normalizeAdjacency
функция, представленная ниже.
function dlY = model(dlX, A, parameters) % Normalize adjacency matrix L = normalizeAdjacency(A); Z1 = dlX; Z2 = L * Z1 * parameters.W1; Z2 = relu(Z2) + Z1; Z3 = L * Z2 * parameters.W2; Z3 = relu(Z3) + Z2; Z4 = L * Z3 * parameters.W3; dlY = softmax(Z4, 'DataFormat', 'BC'); end
The normalizeAdjacency
функция вычисляет и возвращает нормированную матрицу смежности normAdjacency
матрицы входа смежности adjacency
.
function normAdjacency = normalizeAdjacency(adjacency) % Add self connections to adjacency matrix adjacency = adjacency + speye(size(adjacency)); % Compute degree of nodes degree = sum(adjacency, 2); % Compute inverse square root of degree degreeInvSqrt = sparse(sqrt(1./degree)); % Normalize adjacency matrix normAdjacency = diag(degreeInvSqrt) * adjacency * diag(degreeInvSqrt); end
The modelGradients
функция принимает матрицу признаков dlX
, матрица смежности adjacencyTrain
, одноразовые закодированные целевые данные T
и параметры модели parameters,
и возвращает градиенты потерь относительно параметров модели, соответствующих потерь и предсказаний.
function [gradients, loss, dlYPred] = modelGradients(dlX, adjacencyTrain, T, parameters) dlYPred = model(dlX, adjacencyTrain, parameters); loss = crossentropy(dlYPred, T, 'DataFormat', 'BC'); gradients = dlgradient(loss, parameters); end
Функция точности декодирует сетевые предсказания YPred
и вычисляет точность, используя декодированные предсказания и целевые данные target.
Функция возвращает вычисленный счет точности и декодированные предсказания prediction.
function [score, prediction] = accuracy(YPred, target, classes) % Decode probability vectors into class labels prediction = onehotdecode(YPred, classes, 2); score = sum(prediction == target)/numel(target); end
Т. Н. Кипф и М. Веллинг. Полууправляемая классификация с графовыми сверточными сетями. В ICLR, 2016.
Л. К. Блюм, Дж. -Л. Реймонд, 970 млн. Druglike Small Molecules for Virtual Screening in the Chemical Universe Database GDB-13, J. Am. Chem. Soc., 131:8732, 2009.
М. Рупп, А. Ткатченко, К.-Р. Мюллер О. А. фон Лилиенфельд: быстрое и точное моделирование энергий молекулярной атомизации с помощью машинного обучения, Букв физического обзора, 108 (5): 058301, 2012.
dlarray
| dlfeval
| dlgradient
| minibatchqueue