Обучите сиамскую сеть, чтобы сравнить изображения

В этом примере показано, как обучить сиамскую сеть, чтобы идентифицировать подобные изображения рукописных символов.

Сиамская сеть является типом нейронной сети для глубокого обучения, которая использует две или больше идентичных подсети, которые имеют ту же архитектуру и совместно используют те же параметры и веса. Сиамские сети обычно используются в задачах, которые включают нахождение отношения между двумя сопоставимыми вещами. Некоторые распространенные приложения для сиамских сетей включают распознавание лиц, верификацию подписи [1], или перефразируют идентификацию [2]. Сиамские сети выполняют хорошо в этих задачах, потому что их разделяемые веса означают, что существует меньше параметров, чтобы учиться во время обучения, и они могут привести к хорошим результатам с относительно небольшим количеством обучающих данных.

Сиамские сети особенно полезны в случаях, где существуют большие количества классов с небольшими числами наблюдений за каждым. В таких случаях существует недостаточно данных, чтобы обучить глубокую сверточную нейронную сеть классифицировать изображения в эти классы. Вместо этого сиамская сеть может определить, находятся ли два изображения в том же классе.

Это использование в качестве примера набор данных Omniglot [3], чтобы обучить сиамскую сеть, чтобы сравнить изображения рукописных символов [4]. Набор данных Omniglot содержит наборы символов для 50 алфавитов, разделенных на 30 используемых для обучения и 20 для тестирования. Каждый алфавит содержит много символов от 14 для Ojibwe (абориген Canadia Саллэбикс) к 55 для Tifinagh. Наконец, каждый символ имеет 20 рукописных наблюдений. Этот пример обучает сеть, чтобы идентифицировать, являются ли два рукописных наблюдения различными экземплярами того же символа.

Можно также использовать сиамские сети, чтобы идентифицировать подобные изображения с помощью сокращения размерности. Для примера смотрите, Обучают сиамскую Сеть для Сокращения Размерности.

Загрузите и предварительно обработайте обучающие данные

Загрузите и извлеките обучающий набор данных Omniglot.

url = "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"images_background.zip");

dataFolderTrain = fullfile(downloadFolder,'images_background');
if ~exist(dataFolderTrain,"dir")
    disp("Downloading Omniglot training data (4.5 MB)...")
    websave(filename,url);
    unzip(filename,downloadFolder);
end
disp("Training data downloaded.")
Training data downloaded.

Загрузите обучающие данные как datastore изображений с помощью imageDatastore функция. Задайте метки вручную путем извлечения меток из имен файлов и установки Labels свойство.

imdsTrain = imageDatastore(dataFolderTrain, ...
    'IncludeSubfolders',true, ...
    'LabelSource','none');

files = imdsTrain.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),'_');
imdsTrain.Labels = categorical(labels);

Обучающий набор данных Omniglot состоит из черных и белых рукописных символов от 30 алфавитов с 20 наблюдениями за каждым символом. Изображения имеют размер 105 105 1, и значения каждого пикселя между 0 и 1.

Отобразите случайный выбор изображений.

idxs = randperm(numel(imdsTrain.Files),8);

for i = 1:numel(idxs)
    subplot(4,2,i)
    imshow(readimage(imdsTrain,idxs(i)))
    title(imdsTrain.Labels(idxs(i)), "Interpreter","none");
end

Создайте пары подобных и отличающихся изображений

Чтобы обучить сеть, данные должны быть сгруппированы в пары изображений, которые или подобны или отличаются. Здесь, подобные изображения являются различными рукописными экземплярами того же символа, которые имеют ту же метку, в то время как отличающиеся изображения различных символов имеют различные метки. Функциональный getSiameseBatch (заданный в разделе Supporting Functions этого примера), создает рандомизированные пары подобных или отличающихся изображений, pairImage1 и pairImage2. Функция также возвращает метку pairLabel, который идентифицирует, подобна ли пара изображений или отличается друг от друга. Подобные пары изображений имеют pairLabel = 1, в то время как отличающиеся пары имеют pairLabel = 0.

Как пример, создайте маленький представительный набор пяти пар изображений

batchSize = 10;
[pairImage1,pairImage2,pairLabel] = getSiameseBatch(imdsTrain,batchSize);

Отобразите сгенерированные пары изображений.

for i = 1:batchSize    
    if pairLabel(i) == 1
        s = "similar";
    else
        s = "dissimilar";
    end
    subplot(2,5,i)   
    imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]);
    title(s)
end

В этом примере новый пакет 180 парных изображений создается для каждой итерации учебного цикла. Это гарантирует, что сеть обучена на большом количестве случайных пар изображений приблизительно с равными пропорциями подобных и отличающихся пар.

Архитектура сети Define

Сиамская сетевая архитектура проиллюстрирована в следующей схеме.

Чтобы сравнить два изображения, каждое изображение передается через одну из двух идентичных подсетей та доля веса. Подсети преобразуют каждого 105 105 1 изображением к 4096-мерному характеристическому вектору. Изображения того же класса имеют подобные 4096-мерные представления. Выходные характеристические векторы от каждой подсети объединены посредством вычитания, и результат передается через fullyconnect операция с одним выходом. sigmoid операция преобразует это значение в вероятность между 0 и 1, указание на предсказание сети того, подобны ли изображения или отличаются. Бинарная потеря перекрестной энтропии между сетевым предсказанием и истинной меткой используется, чтобы обновить сеть во время обучения.

В этом примере две идентичных подсети заданы как dlnetwork объект. Итоговый fullyconnect и sigmoid операции выполняются как функциональные операции на подсети выходные параметры.

Создайте подсеть как серию слоев, которая принимает 105 105 1 изображением и выводит характеристический вектор размера 4096.

Для convolution2dLayer объекты, используйте узкое нормальное распределение, чтобы инициализировать веса и смещение.

Для maxPooling2dLayer объекты, набор шаг к 2.

Для итогового fullyConnectedLayer возразите, задайте выходной размер 4 096 и используйте узкое нормальное распределение, чтобы инициализировать веса и смещение.

layers = [
    imageInputLayer([105 105 1],'Name','input1','Normalization','none')
    convolution2dLayer(10,64,'Name','conv1','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    reluLayer('Name','relu1')
    maxPooling2dLayer(2,'Stride',2,'Name','maxpool1')
    convolution2dLayer(7,128,'Name','conv2','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    reluLayer('Name','relu2')
    maxPooling2dLayer(2,'Stride',2,'Name','maxpool2')
    convolution2dLayer(4,128,'Name','conv3','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    reluLayer('Name','relu3')
    maxPooling2dLayer(2,'Stride',2,'Name','maxpool3')
    convolution2dLayer(5,256,'Name','conv4','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    reluLayer('Name','relu4')
    fullyConnectedLayer(4096,'Name','fc1','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')];

lgraph = layerGraph(layers);

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnet = dlnetwork(lgraph);

Создайте веса для итогового fullyconnect операция. Инициализируйте веса путем выборки случайного выбора от узкого нормального распределения со стандартным отклонением 0,01.

fcWeights = dlarray(0.01*randn(1,4096));
fcBias = dlarray(0.01*randn(1,1));

fcParams = struct(...
    "FcWeights",fcWeights,...
    "FcBias",fcBias);

Чтобы использовать сеть, создайте функциональный forwardSiamese (заданный в разделе Supporting Functions этого примера), который задает как две подсети и вычитание, fullyconnect, и sigmoid операции объединены. Функциональный forwardSiamese принимает сеть, структура, содержащая параметры для fullyconnect операция и два учебных изображения. forwardSiamese функционируйте выводит предсказание о подобии двух изображений.

Функция градиентов модели Define

Создайте функциональный modelGradients (заданный в разделе Supporting Functions этого примера). modelGradients функционируйте берет сиамскую подсеть dlnet, структура параметра для fullyconnect операция и мини-пакет входных данных X1 и X2 с их маркирует pairLabels. Функция возвращает значения потерь и градиенты потери относительно настраиваемых параметров сети.

Цель сиамской сети состоит в том, чтобы отличить между двумя входными параметрами X1 и X2. Выход сети является вероятностью между 0 и 1, где значение ближе к 0 указывает на предсказание, что изображения отличаются, и значение ближе к 1 то, что изображения подобны. Потеря дана бинарной перекрестной энтропией между предсказанным счетом и истинным значением метки:

loss=-tlog(y)-(1-t)log(1-y),

где истинная метка t может быть 0 или 1 и y предсказанная метка.

Задайте опции обучения

Задайте опции, чтобы использовать во время обучения. Обучайтесь для 10 000 итераций.

numIterations = 10000;
miniBatchSize = 180;

Задайте опции для оптимизации ADAM:

  • Установите скорость обучения на 0.00006.

  • Инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с [] для обоих dlnet и fcParams.

  • Установитесь коэффициент затухания градиента на 0.9 и градиент в квадрате затухает фактор к 0.99.

learningRate = 6e-5;
trailingAvgSubnet = [];
trailingAvgSqSubnet = [];
trailingAvgParams = [];
trailingAvgSqParams = [];
gradDecay = 0.9;
gradDecaySq = 0.99;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Автоматически обнаружить, если вы имеете графический процессор в наличии и помещаете соответствующие данные по графическому процессору, устанавливаете значение executionEnvironment к "auto". Если вы не имеете графического процессора или не хотите использовать один для обучения, устанавливать значение executionEnvironment к "cpu". Чтобы убедиться вы используете графический процессор для обучения, устанавливаете значение executionEnvironment к "gpu".

executionEnvironment = "auto";

Чтобы контролировать процесс обучения, можно построить учебную потерю после каждой итерации. Создайте переменную plots это содержит "training-progress". Если вы не хотите строить процесс обучения, установите это значение к "none".

plots = "training-progress";

Инициализируйте параметры графика для учебного графика прогресса потерь.

plotRatio = 16/9;

if plots == "training-progress"
    trainingPlot = figure;
    trainingPlot.Position(3) = plotRatio*trainingPlot.Position(4);
    trainingPlot.Visible = 'on';
    
    trainingPlotAxes = gca;
    
    lineLossTrain = animatedline(trainingPlotAxes);
    xlabel(trainingPlotAxes,"Iteration")
    ylabel(trainingPlotAxes,"Loss")
    title(trainingPlotAxes,"Loss During Training")
end

Обучите модель

Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации.

Для каждой итерации:

  • Извлеките пакет пар изображений и меток с помощью getSiameseBatch функция, определяемая в разделе Create Batches of Image Pairs.

  • Преобразуйте данные в dlarray объекты с базовым типом single и укажите, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет) для данных изображения и 'CB' (образуйте канал, пакет) для меток.

  • Для обучения графического процессора преобразуйте данные в gpuArray объекты.

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция.

  • Обновите сетевые параметры с помощью adamupdate функция.

% Loop over mini-batches.
for iteration = 1:numIterations
    
    % Extract mini-batch of image pairs and pair labels
    [X1,X2,pairLabels] = getSiameseBatch(imdsTrain,miniBatchSize);
    
    % Convert mini-batch of data to dlarray. Specify the dimension labels
    % 'SSCB' (spatial, spatial, channel, batch) for image data
    dlX1 = dlarray(single(X1),'SSCB');
    dlX2 = dlarray(single(X2),'SSCB');
    
    % If training on a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        dlX1 = gpuArray(dlX1);
        dlX2 = gpuArray(dlX2);
    end  
    
    % Evaluate the model gradients and the generator state using
    % dlfeval and the modelGradients function listed at the end of the
    % example.
    [gradientsSubnet, gradientsParams,loss] = dlfeval(@modelGradients,dlnet,fcParams,dlX1,dlX2,pairLabels);
    lossValue = double(gather(extractdata(loss)));
    
    % Update the Siamese subnetwork parameters.
    [dlnet,trailingAvgSubnet,trailingAvgSqSubnet] = ...
        adamupdate(dlnet,gradientsSubnet, ...
        trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq);
    
    % Update the fullyconnect parameters.
    [fcParams,trailingAvgParams,trailingAvgSqParams] = ...
        adamupdate(fcParams,gradientsParams, ...
        trailingAvgParams,trailingAvgSqParams,iteration,learningRate,gradDecay,gradDecaySq);
      
    % Update the training loss progress plot.
    if plots == "training-progress"
        addpoints(lineLossTrain,iteration,lossValue);
    end
    drawnow;
end

Оцените точность сети

Загрузите и извлеките тестовый набор данных Omniglot.

url = 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'images_evaluation.zip');

dataFolderTest = fullfile(downloadFolder,'images_evaluation');
if ~exist(dataFolderTest,'dir')
    disp('Downloading Omniglot test data (3.2 MB)...')
    websave(filename,url);
    unzip(filename,downloadFolder);
end
disp("Test data downloaded.")
Test data downloaded.

Загрузите тестовые данные как datastore изображений с помощью imageDatastore функция. Задайте метки вручную путем извлечения меток из имен файлов и установки Labels свойство.

imdsTest = imageDatastore(dataFolderTest, ...
    'IncludeSubfolders',true, ...
    'LabelSource','none');    

files = imdsTest.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),'_');
imdsTest.Labels = categorical(labels);

Тестовый набор данных содержит 20 алфавитов, которые отличаются от тех, на которых была обучена сеть. Всего, там 659 различных классов в тестовом наборе данных.

numClasses = numel(unique(imdsTest.Labels))
numClasses = 659

Чтобы вычислить точность сети, создайте набор пяти случайных мини-пакетов тестовых пар. Используйте predictSiamese функция (заданный в разделе Supporting Functions этого примера), чтобы оценить сетевые предсказания и вычислить среднюю точность по мини-пакетам.

accuracy = zeros(1,5);
accuracyBatchSize = 150;

for i = 1:5
    
    % Extract mini-batch of image pairs and pair labels
    [XAcc1,XAcc2,pairLabelsAcc] = getSiameseBatch(imdsTest,accuracyBatchSize);
    
    % Convert mini-batch of data to dlarray. Specify the dimension labels
    % 'SSCB' (spatial, spatial, channel, batch) for image data.
    dlXAcc1 = dlarray(single(XAcc1),'SSCB');
    dlXAcc2 = dlarray(single(XAcc2),'SSCB');
    
    % If using a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
       dlXAcc1 = gpuArray(dlXAcc1);
       dlXAcc2 = gpuArray(dlXAcc2);
    end    
    
    % Evaluate predictions using trained network
    dlY = predictSiamese(dlnet,fcParams,dlXAcc1,dlXAcc2);
   
    % Convert predictions to binary 0 or 1
    Y = gather(extractdata(dlY));
    Y = round(Y);
    
    % Compute average accuracy for the minibatch
    accuracy(i) = sum(Y == pairLabelsAcc)/accuracyBatchSize;
end

% Compute accuracy over all minibatches
averageAccuracy = mean(accuracy)*100
averageAccuracy = 88.6667

Отобразите набор тестов изображений с предсказаниями

Чтобы визуально проверять, идентифицирует ли сеть правильно подобные и отличающиеся пары, создайте маленький пакет пар изображений, чтобы протестировать. Используйте функцию predictSiamese, чтобы получить предсказание для каждой тестовой пары. Отобразите пару изображений с предсказанием, счетом вероятности и меткой, указывающей, было ли предсказание правильным или неправильным.

testBatchSize = 10;

[XTest1,XTest2,pairLabelsTest] = getSiameseBatch(imdsTest,testBatchSize);
    
% Convert test batch of data to dlarray. Specify the dimension labels
% 'SSCB' (spatial, spatial, channel, batch) for image data and 'CB' 
% (channel, batch) for labels
dlXTest1 = dlarray(single(XTest1),'SSCB');
dlXTest2 = dlarray(single(XTest2),'SSCB');

% If using a GPU, then convert data to gpuArray
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
   dlXTest1 = gpuArray(dlXTest1);
   dlXTest2 = gpuArray(dlXTest2);
end

% Calculate the predicted probability
dlYScore = predictSiamese(dlnet,fcParams,dlXTest1,dlXTest2);
YScore = gather(extractdata(dlYScore));

% Convert predictions to binary 0 or 1
YPred = round(YScore);    

% Extract data to plot
XTest1 = extractdata(dlXTest1);
XTest2 = extractdata(dlXTest2);

% Plot images with predicted label and predicted score
testingPlot = figure;
testingPlot.Position(3) = plotRatio*testingPlot.Position(4);
testingPlot.Visible = 'on';
    
for i = 1:numel(pairLabelsTest)
     
    if YPred(i) == 1
        predLabel = "similar";
    else
        predLabel = "dissimilar" ;
    end
    
    if pairLabelsTest(i) == YPred(i)
        testStr = "\bf\color{darkgreen}Correct\rm\newline";
        
    else
        testStr = "\bf\color{red}Incorrect\rm\newline";
    end
    
    subplot(2,5,i)        
    imshow([XTest1(:,:,:,i) XTest2(:,:,:,i)]);        
    
    title(testStr + "\color{black}Predicted: " + predLabel + "\newlineScore: " + YScore(i)); 
end

Сеть может сравнить тестовые изображения, чтобы определить их подобие, даже при том, что ни одно из этих изображений не было в обучающем наборе данных.

Вспомогательные Функции

Функции модели для обучения и предсказания

Функциональный forwardSiamese используется во время сетевого обучения. Функция задает как подсети и fullyconnect и sigmoid операции объединяются, чтобы сформировать полную сиамскую сеть. forwardSiamese принимает структуру сети и два учебных изображения и выводит предсказание о подобии двух изображений. В этом примере, функциональном forwardSiamese введен в разделе Define Network Architecture.

function Y = forwardSiamese(dlnet,fcParams,dlX1,dlX2)
% forwardSiamese accepts the network and pair of training images, and returns a
% prediction of the probability of the pair being similar (closer to 1) or 
% dissimilar (closer to 0). Use forwardSiamese during training.

    % Pass the first image through the twin subnetwork
    F1 = forward(dlnet,dlX1);
    F1 = sigmoid(F1);
    
    % Pass the second image through the twin subnetwork
    F2 = forward(dlnet,dlX2);
    F2 = sigmoid(F2);
    
    % Subtract the feature vectors
    Y = abs(F1 - F2);
    
    % Pass the result through a fullyconnect operation
    Y = fullyconnect(Y,fcParams.FcWeights,fcParams.FcBias);
    
    % Convert to probability between 0 and 1.
    Y = sigmoid(Y);
end

Функциональный predictSiamese использует обучивший сеть, чтобы сделать предсказания о подобии двух изображений. Функция похожа на функциональный forwardSiamese, заданный ранее. Однако predictSiamese использует predict функция с сетью вместо forward функция, потому что некоторые слои глубокого обучения ведут себя по-другому во время обучения и предсказания. В этом примере, функциональном predictSiamese введен в разделе Evaluate the Accuracy Сети.

function Y = predictSiamese(dlnet,fcParams,dlX1,dlX2)
% predictSiamese accepts the network and pair of images, and returns a
% prediction of the probability of the pair being similar (closer to 1)
% or dissimilar (closer to 0). Use predictSiamese during prediction.

    % Pass the first image through the twin subnetwork
    F1 = predict(dlnet,dlX1);
    F1 = sigmoid(F1);
    
    % Pass the second image through the twin subnetwork
    F2 = predict(dlnet,dlX2);
    F2 = sigmoid(F2);
    
    % Subtract the feature vectors
    Y = abs(F1 - F2);
    
    % Pass result through a fullyconnect operation
    Y = fullyconnect(Y,fcParams.FcWeights,fcParams.FcBias);
    
    % Convert to probability between 0 and 1.
    Y = sigmoid(Y);
end

Функция градиентов модели

Функциональный modelGradients берет сиамский dlnetwork объект net, пара мини-данных пакетного ввода X1 и X2, и метка, указывающая, подобны ли они или отличаются. Функция возвращает градиенты потери относительно настраиваемых параметров в сети и бинарной потере перекрестной энтропии между предсказанием и основной истиной. В этом примере, функциональном modelGradients введен в разделе Define Model Gradients Function.

function [gradientsSubnet,gradientsParams,loss] = modelGradients(dlnet,fcParams,dlX1,dlX2,pairLabels)
% The modelGradients function calculates the binary cross-entropy loss between the
% paired images and returns the loss and the gradients of the loss with respect to
% the network learnable parameters

    % Pass the image pair through the network 
    Y = forwardSiamese(dlnet,fcParams,dlX1,dlX2);
    
    % Calculate binary cross-entropy loss
    loss = binarycrossentropy(Y,pairLabels);
       
    % Calculate gradients of the loss with respect to the network learnable
    % parameters
    [gradientsSubnet,gradientsParams] = dlgradient(loss,dlnet.Learnables,fcParams);
end

function loss = binarycrossentropy(Y,pairLabels)
    % binarycrossentropy accepts the network's prediction Y, the true
    % label, and pairLabels, and returns the binary cross-entropy loss value.
    
    % Get precision of prediction to prevent errors due to floating
    % point precision    
    precision = underlyingType(Y);
      
    % Convert values less than floating point precision to eps.
    Y(Y < eps(precision)) = eps(precision);
    %convert values between 1-eps and 1 to 1-eps.
    Y(Y > 1 - eps(precision)) = 1 - eps(precision);
    
    % Calculate binary cross-entropy loss for each pair
    loss = -pairLabels.*log(Y) - (1 - pairLabels).*log(1 - Y);
    
    % Sum over all pairs in minibatch and normalize.
    loss = sum(loss)/numel(pairLabels);
end

Создайте пакеты пар изображений

Следующие функции создают рандомизированные пары изображений, которые подобны или отличаются, на основе их меток. В этом примере, функциональном getSiameseBatch введен в разделе Create Pairs of Similar и Dissimilar Images.

function [X1,X2,pairLabels] = getSiameseBatch(imds,miniBatchSize)
% getSiameseBatch returns a randomly selected batch or paired images. On
% average, this function produces a balanced set of similar and dissimilar
% pairs.

    pairLabels = zeros(1,miniBatchSize);
    imgSize = size(readimage(imds,1));
    X1 = zeros([imgSize 1 miniBatchSize]);
    X2 = zeros([imgSize 1 miniBatchSize]);

    for i = 1:miniBatchSize
        choice = rand(1);
        if choice < 0.5
            [pairIdx1,pairIdx2,pairLabels(i)] = getSimilarPair(imds.Labels);
        else
            [pairIdx1,pairIdx2,pairLabels(i)] = getDissimilarPair(imds.Labels);
        end
        X1(:,:,:,i) = imds.readimage(pairIdx1);
        X2(:,:,:,i) = imds.readimage(pairIdx2);
    end
end

function [pairIdx1,pairIdx2,pairLabel] = getSimilarPair(classLabel)
% getSimilarSiamesePair returns a random pair of indices for images
% that are in the same class and the similar pair label = 1.

    % Find all unique classes.
    classes = unique(classLabel);
    
    % Choose a class randomly which will be used to get a similar pair.
    classChoice = randi(numel(classes));
    
    % Find the indices of all the observations from the chosen class.
    idxs = find(classLabel==classes(classChoice));
    
    % Randomly choose two different images from the chosen class.
    pairIdxChoice = randperm(numel(idxs),2);
    pairIdx1 = idxs(pairIdxChoice(1));
    pairIdx2 = idxs(pairIdxChoice(2));
    pairLabel = 1;
end

function  [pairIdx1,pairIdx2,label] = getDissimilarPair(classLabel)
% getDissimilarSiamesePair returns a random pair of indices for images
% that are in different classes and the dissimilar pair label = 0.

    % Find all unique classes.
    classes = unique(classLabel);
    
    % Choose two different classes randomly which will be used to get a dissimilar pair.
    classesChoice = randperm(numel(classes),2);
    
    % Find the indices of all the observations from the first and second classes.
    idxs1 = find(classLabel==classes(classesChoice(1)));
    idxs2 = find(classLabel==classes(classesChoice(2)));
    
    % Randomly choose one image from each class.
    pairIdx1Choice = randi(numel(idxs1));
    pairIdx2Choice = randi(numel(idxs2));
    pairIdx1 = idxs1(pairIdx1Choice);
    pairIdx2 = idxs2(pairIdx2Choice);
    label = 0;
end

Ссылки

[1] Бромли, J. i. Guyon, И. Лекун, Э. Зекингер и Р. Шах. "Верификация подписи с помощью "сиамской" Нейронной сети С временной задержкой". В Продолжениях 6-й Международной конференции по вопросам Нейронных Систем обработки информации (NIPS 1993), 1994, pp737-744. Доступный при Верификации Подписи с помощью "сиамской" Нейронной сети С временной задержкой на веб-сайте Продолжений NIPS.

[2] Wenpeg, Y. и H Schütze. "Сверточная нейронная сеть для Идентификации Пересказа". В Продолжениях 2 015 Конференций североамериканской Главы ACL, 2015, pp901-911. Доступный в Сверточной нейронной сети для Идентификации Пересказа на веб-сайте Антологии ACL

[3] Озеро, B. M. Салахутдинов, R. и Tenenbaum, J. B. "Концепция человеческого уровня, учащаяся посредством вероятностной индукции программы". Наука, 350 (6266), (2015) pp1332-1338.

[4] Кох, G., Zemel, R. и Салахутдинов, R. (2015). "Сиамские нейронные сети для распознавания изображений с одним выстрелом". В Продолжениях 32-й Международной конференции по вопросам Машинного обучения, 37 (2015). Доступный в сиамских Нейронных сетях для Распознавания Изображений С одним выстрелом на ICML '15 веб-сайтов.

Смотрите также

| | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте