Обучите сиамскую сеть уменьшению размерности

Этот пример показывает, как обучить сиамскую сеть сравнивать рукописные цифры с помощью уменьшения размерности.

Сиамская сеть - это тип нейронной сети для глубокого обучения, которая использует две или более идентичных подсетей, которые имеют одинаковую архитектуру и имеют одинаковые параметры и веса. Сиамские сети обычно используются в задачах, которые включают нахождение связи между двумя сопоставимыми вещами. Некоторые распространенные приложения для сиамских сетей включают распознавание лиц, верификацию подписей [1] или идентификацию перефразирования [2]. Сиамские сети хорошо работают в этих задачах, потому что их общие веса означают, что во время обучения меньше параметров, и они могут дать хорошие результаты с относительно небольшим объемом обучающих данных.

Сиамские сети особенно полезны в тех случаях, когда существует большое количество классов с небольшим количеством наблюдений за каждым. В таких случаях недостаточно данных, чтобы обучить глубокую сверточную нейронную сеть классифицировать изображения в эти классы. Вместо этого сиамская сеть может определить, находятся ли два изображения в одном классе. Сеть делает это, уменьшая размерность обучающих данных и используя дистанционную функцию затрат, чтобы дифференцировать классы.

Этот пример использует сиамскую сеть для уменьшения размерности набора изображений рукописных цифр. Сиамская архитектура уменьшает размерность путем отображения изображений с тем же классом на близлежащие точки в низкомерном пространстве. Представление уменьшенных признаков затем используется, чтобы извлечь изображения из набора данных, которые наиболее похожи на тестовое изображение. Обучающие данные в этом примере - изображения размера 28 на 28 на 1, давая начальную размерность функции 784. Сиамская сеть уменьшает размерность входных изображений до двух признаков и обучена выводить аналогичные уменьшенные функции для изображений с той же меткой.

Можно также использовать сиамские сети для идентификации аналогичных изображений путем их прямого сравнения. Для получения примера смотрите Train сиамской сети для сравнения изображений.

Загрузка и предварительная обработка обучающих данных

Загрузите обучающие данные, которые состоят из изображений рукописных цифр. Функция digitTrain4DArrayData загружает изображения цифр и их метки.

[XTrain,YTrain] = digitTrain4DArrayData;

XTrain массив 28 на 28 на 1 на 5000, содержащий 5 000 одноканальных изображений, каждый размер 28 на 28. Значения каждого пикселя находятся между 0 и 1. YTrain - категориальный вектор, содержащий метки для каждого наблюдения, которые являются числами от 0 до 9, соответствующими значению записанной цифры.

Отобразите случайный выбор изображений.

perm = randperm(numel(YTrain), 9);
imshow(imtile(XTrain(:,:,:,perm),"ThumbnailSize",[100 100]));

Создайте пары похожих и непохожих изображений

Чтобы обучить сеть, данные должны быть сгруппированы в пары изображений, которые либо похожи, либо отличаются друг от друга. Здесь аналогичные изображения определяются как имеющие одну и ту же метку, в то время как различные изображения имеют различные метки. Функция getSiameseBatch (определено в разделе Support Functions этого примера) создает рандомизированные пары подобных или разнородных изображений, pairImage1 и pairImage2. Функция также возвращает метку pairLabel, который определяет, похожа ли пара изображений или отличается друг от друга. Подобные пары изображений имеют pairLabel = 1, в то время как разнородные пары имеют pairLabel = 0.

В качестве примера создайте небольшой репрезентативный набор из пяти пар изображений

batchSize = 10;
[pairImage1,pairImage2,pairLabel] = getSiameseBatch(XTrain,YTrain,batchSize);

Отображение сгенерированных пар изображений.

for i = 1:batchSize
subplot(2,5,i)
imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]);
if pairLabel(i) == 1
    s = "similar";
else
    s = "dissimilar";
end
title(s)
end

В этом примере создается новый пакет из 180 парных изображений для каждой итерации цикла обучения. Это гарантирует, что сеть обучена на большом количестве случайных пар изображений с приблизительно равными пропорциями аналогичных и разнородных пар.

Определение сетевой архитектуры

Сиамская сетевая архитектура проиллюстрирована на следующей схеме.

В этом примере две идентичные подсети определяются как серия полносвязных слоев со слоями ReLU. Создайте сеть, которая принимает изображения 28 на 28 на 1 и производит два вектора функции, используемые для уменьшенного представления функции. Сеть уменьшает размерность входных изображений до двух, значение, которое легче построить и визуализировать, чем начальная размерность 784.

Для первых двух полносвязных слоев задайте размер выхода 1024 и используйте инициализатор веса He.

Для последнего полносвязного слоя задайте выход два и используйте инициализатор весов He.

layers = [
    imageInputLayer([28 28],'Name','input1','Normalization','none')
    fullyConnectedLayer(1024,'Name','fc1','WeightsInitializer','he')
    reluLayer('Name','relu1')
    fullyConnectedLayer(1024,'Name','fc2','WeightsInitializer','he')
    reluLayer('Name','relu2')
    fullyConnectedLayer(2,'Name','fc3','WeightsInitializer','he')];

lgraph = layerGraph(layers);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.

dlnet = dlnetwork(lgraph);

Задайте функцию градиентов модели

Создайте функцию modelGradients (определено в разделе Вспомогательные функции этого примера). The modelGradients функция принимает сиамский dlnetwork dlnet объекта and мини-пакет входных данных dlX1 и dlX2 со своими метками pairLabels. Функция возвращает значения потерь и градиенты потерь относительно настраиваемых параметров сети.

Цель сиамской сети состоит в том, чтобы вывести вектор функций для каждого изображения таким образом, чтобы векторы функций были одинаковыми для аналогичных изображений и, в частности, отличались для разнородных изображений. Таким образом, сеть может различать эти два входа.

Найдите контрастные потери между выходами последнего полносвязного слоя, векторов функций features1 и features1 от pairImage1 и pairImage2, соответственно. Контрастные потери для пары даются [3]

loss=12yd2+12(1-y)max(margin-d,0)2,

где y - значение метки пары (y=1 для аналогичных изображений;y=0 для разнородных изображений), и d - евклидово расстояние между двумя векторами функций f1 и f2: d=f1-f22.

margin параметр используется для ограничения: если два изображения в паре отличаются друг от друга, то их расстояние должно быть как минимум marginили будет понесен убыток.

Контрастные потери имеют два условия, но только один всегда является ненулевым для заданной пары изображений. В случае аналогичных изображений первый член может быть ненулевым и минимизируется путем уменьшения расстояния между функциями изображения f1 и f2. В случае разнородных изображений второй член может быть ненулевым и минимизируется путем увеличения расстояния между функциями изображения, по меньшей мере, до расстояния, равного margin. Тем меньше значение marginтем меньше ограничивает то, насколько близка может быть разнородная пара до возникновения убытка.

Настройка опций обучения

Задайте значение margin использовать во время обучения.

margin = 0.3;

Задайте опции, которые будут использоваться во время обучения. Обучите для 3000 итераций.

numIterations = 3000;
miniBatchSize = 180;

Задайте опции для оптимизации ADAM:

  • Установите скорость обучения равной 0.0001.

  • Инициализируйте конечный средний градиент и конечный средний градиент-квадратные скорости распада с [].

  • Установите коэффициент градиентного распада равным 0.9 и квадратный коэффициент градиента распада, чтобы 0.99.

learningRate = 1e-4;
trailingAvg = [];
trailingAvgSq = [];
gradDecay = 0.9;
gradDecaySq = 0.99;

Обучите на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Чтобы автоматически определить, доступен ли графический процессор, и поместить соответствующие данные на графический процессор, установите значение executionEnvironment на "auto". Если у вас нет графический процессор или вы не хотите использовать его для обучения, задайте значение executionEnvironment на "cpu". Чтобы убедиться, что вы используете графический процессор для обучения, задайте значение executionEnvironment на "gpu".

executionEnvironment = "auto";

Чтобы контролировать процесс обучения, можно построить график потерь обучения после каждой итерации. Создайте переменную plots который содержит "training-progress". Если вы не хотите строить график процесса обучения, задайте это значение "none".

plots = "training-progress";

Инициализируйте параметры графика для графика прогресса потерь обучения.

plotRatio = 16/9;

if plots == "training-progress"
    trainingPlot = figure;
    trainingPlot.Position(3) = plotRatio*trainingPlot.Position(4);
    trainingPlot.Visible = 'on';
    
    trainingPlotAxes = gca;
    
    lineLossTrain = animatedline(trainingPlotAxes);
    xlabel(trainingPlotAxes,"Iteration")
    ylabel(trainingPlotAxes,"Loss")
    title(trainingPlotAxes,"Loss During Training")
end

Чтобы оценить, насколько хорошо сеть работает при уменьшении размерности, вычислите и постройте графики уменьшенных функций набора тестовых данных после каждой итерации. Загрузите тестовые данные, которые состоят из изображений рукописных цифр, аналогичных обучающим данным. Преобразуйте тестовые данные в dlarray и задайте метки размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный). Если вы используете графический процессор, преобразуйте тестовые данные в gpuArray.

[XTest,YTest] = digitTest4DArrayData;
dlXTest = dlarray(single(XTest),'SSCB');

% If training on a GPU, then convert data to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);           
end 

Инициализируйте параметры графика для сокращенного графика тестовых данных.

dimensionPlot = figure;
dimensionPlot.Position(3) = plotRatio*dimensionPlot.Position(4);
dimensionPlot.Visible = 'on';

dimensionPlotAxes = gca;

uniqueGroups = unique(YTest);
colors = hsv(length(uniqueGroups));

Инициализируйте счетчик, чтобы отслеживать общее количество итераций.

iteration = 1;

Обучите модель

Обучите модель с помощью пользовательского цикла обучения. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации.

Для каждой итерации:

  • Извлечение пакета пар изображений и меток с помощью getSiameseBatch функция, определенная в разделе «Создание пакетов пар изображений».

  • Преобразуйте данные изображения в dlarray объекты с базовым типом single и задайте метки размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Для обучения графический процессор преобразуйте данные изображения в gpuArray объекты.

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция.

  • Обновляйте параметры сети с помощью adamupdate функция.

% Loop over mini-batches.
for iteration = 1:numIterations
    
    % Extract mini-batch of image pairs and pair labels
    [X1,X2,pairLabels] = getSiameseBatch(XTrain,YTrain,miniBatchSize);
    
    % Convert mini-batch of data to dlarray. Specify the dimension labels
    % 'SSCB' (spatial, spatial, channel, batch) for image data
    dlX1 = dlarray(single(X1),'SSCB');
    dlX2 = dlarray(single(X2),'SSCB');
    
    % If training on a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        dlX1 = gpuArray(dlX1);
        dlX2 = gpuArray(dlX2);
    end       
    
    % Evaluate the model gradients and the generator state using
    % dlfeval and the modelGradients function listed at the end of the
    % example.
    [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX1,dlX2,pairLabels,margin);
    lossValue = double(gather(extractdata(loss)));
    
    % Update the Siamese network parameters.
    [dlnet.Learnables,trailingAvg,trailingAvgSq] = ...
        adamupdate(dlnet.Learnables,gradients, ...
        trailingAvg,trailingAvgSq,iteration,learningRate,gradDecay,gradDecaySq);
    
    % Update the training loss progress plot.
    if plots == "training-progress"
        addpoints(lineLossTrain,iteration,lossValue);
    end
            
    % Update the reduced-feature plot of the test data.        
    % Compute reduced features of the test data:
    dlFTest = predict(dlnet,dlXTest);
    FTest = extractdata(dlFTest);
       
    figure(dimensionPlot);
    for k = 1:length(uniqueGroups)
        % Get indices of each image in test data with the same numeric 
        % label (defined by the unique group):
        ind = YTest==uniqueGroups(k);
        % Plot this group:
        plot(dimensionPlotAxes,gather(FTest(1,ind)'),gather(FTest(2,ind)'),'.','color',...
            colors(k,:));
        hold on
    end
    
    legend(uniqueGroups)
    
    % Update title of reduced-feature plot with training progress information.
    title(dimensionPlotAxes,"2-D Feature Representation of Digits Images. Iteration = " +...
        iteration);
    legend(dimensionPlotAxes,'Location','eastoutside');
    xlabel(dimensionPlotAxes,"Feature 1")
    ylabel(dimensionPlotAxes,"Feature 2")
    
    hold off    
    drawnow    
end

Теперь сеть научилась представлять каждое изображение как 2-D векторы. Как видно из графика сокращенных признаков тестовых данных, изображения похожих цифр кластеризуются близко друг к другу в этом 2-D представлении.

Используйте обученную сеть для поиска похожих изображений

Можно использовать обученную сеть, чтобы найти набор изображений, которые похожи друг на друга из группы. В этом случае используйте тестовые данные как группу изображений. Преобразуйте группу изображений в dlarray объекты и gpuArray объекты, если вы используете графический процессор.

groupX = XTest;

dlGroupX = dlarray(single(groupX),'SSCB');

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlGroupX = gpuArray(dlGroupX);           
end 

Извлеките из группы одно тестовое изображение и отобразите его. Удалите тестовое изображение из группы, чтобы оно не появилось в наборе аналогичных изображений.

testIdx = randi(5000);
testImg = dlGroupX(:,:,:,testIdx);

trialImgDisp = extractdata(testImg);

figure
imshow(trialImgDisp, 'InitialMagnification', 500);

dlGroupX(:,:,:,testIdx) = [];

Найдите уменьшенные функции тестового изображения, используя predict.

trialF = predict(dlnet,testImg);

Найдите 2-D представление сокращённых функций каждого из изображений в группе с помощью обученной сети.

FGroupX = predict(dlnet,dlGroupX);

Используйте представление сокращённой функции, чтобы найти девять изображений в группе, которые наиболее близки к тестовому изображению, с помощью метрики Евклидова расстояния. Отобразите изображения.

distances = vecnorm(extractdata(trialF - FGroupX));
[~,idx] = sort(distances);
sortedImages = groupX(:,:,:,idx);

figure
imshow(imtile(sortedImages(:,:,:,1:9)), 'InitialMagnification', 500);

Путем сокращения изображений до более низкой размерности сеть способна идентифицировать изображения, которые аналогичны пробному изображению. Представление уменьшенных функций позволяет сети различать изображения, которые похожи и отличаются друг от друга. Сиамские сети часто используются в контексте распознавания лица или подписи. Например, можно обучить сиамскую сеть принимать изображение лица в качестве входов и возвращать набор наиболее похожих граней из базы данных.

Вспомогательные функции

Функция градиентов модели

Функция modelGradients принимает сиамский dlnetwork dlnet объекта, пару мини-пакетных входных данных X1 и X2, и метки pairLabels. Функция возвращает градиенты потерь относительно настраиваемых параметров в сети, а также контрастные потери между функциями уменьшенной размерности парных изображений. В этом примере функция modelGradients введено в раздел Define Model Gradients Function.

function [gradients, loss] = modelGradients(net,X1,X2,pairLabel,margin)
% The modelGradients function calculates the contrastive loss between the
% paired images and returns the loss and the gradients of the loss with 
% respect to the network learnable parameters

    % Pass first half of image pairs forward through the network
    F1 = forward(net,X1);
    % Pass second set of image pairs forward through the network
    F2 = forward(net,X2);
    
    % Calculate contrastive loss
    loss = contrastiveLoss(F1,F2,pairLabel,margin);
    
    % Calculate gradients of the loss with respect to the network learnable
    % parameters
    gradients = dlgradient(loss, net.Learnables);

end

function loss = contrastiveLoss(F1,F2,pairLabel,margin)
% The contrastiveLoss function calculates the contrastive loss between
% the reduced features of the paired images 
    
    % Define small value to prevent taking square root of 0
    delta = 1e-6;
    
    % Find Euclidean distance metric
    distances = sqrt(sum((F1 - F2).^2,1) + delta);
    
    % label(i) = 1 if features1(:,i) and features2(:,i) are features
    % for similar images, and 0 otherwise
    lossSimilar = pairLabel.*(distances.^2);
 
    lossDissimilar = (1 - pairLabel).*(max(margin - distances, 0).^2);
    
    loss = 0.5*sum(lossSimilar + lossDissimilar,'all');
end

Создание пакетов пар изображений

Следующие функции создают рандомизированные пары изображений, которые похожи или отличаются друг от друга, на основе их меток. В этом примере функция getSiameseBatch введено в раздел «Создание пар похожих и непохожих изображений».

function [X1,X2,pairLabels] = getSiameseBatch(X,Y,miniBatchSize)
% getSiameseBatch returns a randomly selected batch of paired images. 
% On average, this function produces a balanced set of similar and 
% dissimilar pairs.
    pairLabels = zeros(1, miniBatchSize);
    imgSize = size(X(:,:,:,1));
    X1 = zeros([imgSize 1 miniBatchSize]);
    X2 = zeros([imgSize 1 miniBatchSize]);
    
    for i = 1:miniBatchSize
        choice = rand(1);
        if choice < 0.5
            [pairIdx1, pairIdx2, pairLabels(i)] = getSimilarPair(Y);
        else
            [pairIdx1, pairIdx2, pairLabels(i)] = getDissimilarPair(Y);
        end
        X1(:,:,:,i) = X(:,:,:,pairIdx1);
        X2(:,:,:,i) = X(:,:,:,pairIdx2);
    end
    
end

function [pairIdx1,pairIdx2,pairLabel] = getSimilarPair(classLabel)
% getSimilarPair returns a random pair of indices for images
% that are in the same class and the similar pair label = 1.

    % Find all unique classes.
    classes = unique(classLabel);
    
    % Choose a class randomly which will be used to get a similar pair.
    classChoice = randi(numel(classes));
    
    % Find the indices of all the observations from the chosen class.
    idxs = find(classLabel==classes(classChoice));
    
    % Randomly choose two different images from the chosen class.
    pairIdxChoice = randperm(numel(idxs),2);
    pairIdx1 = idxs(pairIdxChoice(1));
    pairIdx2 = idxs(pairIdxChoice(2));
    pairLabel = 1;
end

function  [pairIdx1,pairIdx2,pairLabel] = getDissimilarPair(classLabel)
% getDissimilarPair returns a random pair of indices for images
% that are in different classes and the dissimilar pair label = 0.

    % Find all unique classes.
    classes = unique(classLabel);
    
    % Choose two different classes randomly which will be used to get a dissimilar pair.
    classesChoice = randperm(numel(classes), 2);
    
    % Find the indices of all the observations from the first and second classes.
    idxs1 = find(classLabel==classes(classesChoice(1)));
    idxs2 = find(classLabel==classes(classesChoice(2)));
    
    % Randomly choose one image from each class.
    pairIdx1Choice = randi(numel(idxs1));
    pairIdx2Choice = randi(numel(idxs2));
    pairIdx1 = idxs1(pairIdx1Choice);
    pairIdx2 = idxs2(pairIdx2Choice);
    pairLabel = 0;
end

Ссылки

[1] Bromley, J., I. Guyon, Y. LeCunn, E. Säckinger, and R. Shah. «Проверка подписи с использованием» сиамской «нейронной сети с задержкой по времени». В трудах 6-й Международной конференции по нейронным системам обработки информации (NIPS 1993), 1994, pp737-744. Доступно при верификации подписи с помощью «сиамской» нейронной сети задержки на веб-сайте NIPS Processions.

[2] Wenpeg, Y., and H Schütze. Сверточная нейронная сеть для идентификации парафразов. В трудах 2015 года Конференция североамериканского Cahapter ACL, 2015, pp901-911. Доступно в сверточной нейронной сети для идентификации парафразов на веб-сайте ACL Anthology.

[3] Hadsell, R., S. Chopra, and Y. LeCunn. «Уменьшение размерности путем изучения инвариантного отображения». В материалах Конференции компьютерного общества IEEE 2006 по компьютерному зрению и распознаванию шаблонов (CVPR 2006), 2006, pp1735-1742.

См. также

| | | |

Похожие темы