Подписывание изображений с использованием внимания

В этом примере показано, как обучить модель глубокого обучения для подписывания изображений с использованием внимания.

Большинство предварительно обученных нейронных сетей для глубокого обучения сконфигурированы для однометковой классификации. Для примера, учитывая изображение типового офисного стола, сеть может предсказать один класс «клавиатура» или «мышь». Напротив, модель подписывания изображений комбинирует сверточные и повторяющиеся операции, чтобы получить текстовое описание того, что находится в изображении, а не одну метку.

Эта модель, обученная в этом примере, использует архитектуру энкодера-декодера. Энкодер является предварительно обученной Inception-v3 сетью, используемой в качестве извлечения функций. Декодер является рекуррентной нейронной сетью (RNN), которая принимает извлеченные функции как входные и генерирует подпись. Декодер содержит механизм внимания, который позволяет декодеру фокусироваться на частях кодированного входа при генерации подписи.

Модель энкодера является предварительно обученной Inception-v3 моделью, которая извлекает функции из "mixed10" слой с последующим полным подключением и операциями ReLU.

Модель декодера состоит из встраивания слова, механизма внимания, стробируемого рекуррентного модуля (GRU) и двух полностью связанных операций.

Загрузка предварительно обученной сети

Загрузка предварительно обученной Inception-v3 сети. Для этого шага требуется пакет Deep Learning Toolbox™ Model for Inception-v3 Network поддержки. Если у вас нет установленного необходимого пакета поддержки, то программное обеспечение предоставляет ссылку на загрузку.

net = inceptionv3;
inputSizeNet = net.Layers(1).InputSize;

Преобразуйте сеть в dlnetwork объект для редукции данных и удаления последних четырех слоев, оставив "mixed10" слой как последний слой.

lgraph = layerGraph(net);
lgraph = removeLayers(lgraph,["avg_pool" "predictions" "predictions_softmax" "ClassificationLayer_predictions"]);

Просмотрите вход слой сети. В Inception-v3 сети используется нормализация симметричного преобразования с минимальным значением 0 и максимальным значением 255.

lgraph.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'input_1'
                 InputSize: [299 299 3]

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'rescale-symmetric'
    NormalizationDimension: 'auto'
                       Max: 255
                       Min: 0

Пользовательское обучение не поддерживает эту нормализацию, поэтому необходимо отключить нормализацию в сети и выполнить нормализацию в пользовательском цикле обучения. Сохраните минимальное и максимальное значения как двойные в переменных с именем inputMin и inputMax, соответственно, и заменить слой входа на слой входа изображений без нормализации.

inputMin = double(lgraph.Layers(1).Min);
inputMax = double(lgraph.Layers(1).Max);
layer = imageInputLayer(inputSizeNet,"Normalization","none",'Name','input');
lgraph = replaceLayer(lgraph,'input_1',layer);

Определите выходной размер сети. Используйте analyzeNetwork функция для просмотра размеров активации последнего слоя. Анализатор нейронной сети для глубокого обучения показывает некоторые проблемы с сетью, которые можно безопасно игнорировать для пользовательских рабочих процессов обучения.

analyzeNetwork(lgraph)

Создайте переменную с именем outputSizeNet содержит размер выходного сигнала сети.

outputSizeNet = [8 8 2048];

Преобразуйте график слоев в dlnetwork объект и просмотр слоя выхода. Слой выхода является "mixed10" слой Inception-v3 сети.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [311×1 nnet.cnn.layer.Layer]
    Connections: [345×2 table]
     Learnables: [376×3 table]
          State: [188×3 table]
     InputNames: {'input'}
    OutputNames: {'mixed10'}

Импорт набора данных COCO

Загрузите изображения и аннотации из наборов данных «2014 Train image» и «2014 Train/val annotations», соответственно, из https://cocodataset.org/#download. Извлеките изображения и аннотации в папку с именем "coco". Набор данных COCO 2014 был собран Coco Consortium.

Извлеките подписи из файла "captions_train2014.json" использование jsondecode функция.

dataFolder = fullfile(tempdir,"coco");
filename = fullfile(dataFolder,"annotations_trainval2014","annotations","captions_train2014.json");
str = fileread(filename);
data = jsondecode(str)
data = struct with fields:
           info: [1×1 struct]
         images: [82783×1 struct]
       licenses: [8×1 struct]
    annotations: [414113×1 struct]

The annotations поле struct содержит данные, необходимые для подписывания изображений.

data.annotations
ans=414113×1 struct array with fields:
    image_id
    id
    caption

Набор данных содержит несколько подписей для каждого изображения. Чтобы гарантировать, что одни и те же изображения не появятся как в наборах обучения, так и в наборах валидации, идентифицируйте уникальные изображения в наборе данных с помощью unique функция при помощи идентификаторов в image_id поле поля аннотаций данных, затем просмотрите количество уникальных изображений.

numObservationsAll = numel(data.annotations)
numObservationsAll = 414113
imageIDs = [data.annotations.image_id];
imageIDsUnique = unique(imageIDs);
numUniqueImages = numel(imageIDsUnique)
numUniqueImages = 82783

Каждое изображение имеет по меньшей мере пять подписей. Создайте struct annotationsAll с этими полями:

  • ImageID ⁠ - Идентификатор изображения

  • Filename ⁠ - Имя файла изображения

  • Captions ⁠ - Строковые массивы необработанных подписей

  • CaptionIDs ⁠ - Вектор индексов соответствующих подписей в data.annotations

Чтобы упростить слияние, отсортируйте аннотации по идентификаторам изображений.

[~,idx] = sort([data.annotations.image_id]);
data.annotations = data.annotations(idx);

Наведите циклы на аннотации и при необходимости объедините несколько аннотаций.

i = 0;
j = 0;
imageIDPrev = 0;
while i < numel(data.annotations)
    i = i + 1;
    
    imageID = data.annotations(i).image_id;
    caption = string(data.annotations(i).caption);
    
    if imageID ~= imageIDPrev
        % Create new entry
        j = j + 1;
        annotationsAll(j).ImageID = imageID;
        annotationsAll(j).Filename = fullfile(dataFolder,"train2014","COCO_train2014_" + pad(string(imageID),12,'left','0') + ".jpg");
        annotationsAll(j).Captions = caption;
        annotationsAll(j).CaptionIDs = i;
    else
        % Append captions
        annotationsAll(j).Captions = [annotationsAll(j).Captions; caption];
        annotationsAll(j).CaptionIDs = [annotationsAll(j).CaptionIDs; i];
    end
    
    imageIDPrev = imageID;
end

Разделите данные на наборы для обучения и валидации. Продержитесь 5% наблюдений для проверки.

cvp = cvpartition(numel(annotationsAll),'HoldOut',0.05);
idxTrain = training(cvp);
idxTest = test(cvp);
annotationsTrain = annotationsAll(idxTrain);
annotationsTest = annotationsAll(idxTest);

Этот struct содержит три поля:

  • id - Уникальный идентификатор подписи

  • caption - Image caption Заданный в виде символьного вектора

  • image_id - Уникальный идентификатор изображения, соответствующего подписи

Чтобы просмотреть изображение и соответствующий заголовок, найдите файл изображения с именем файла "train2014\COCO_train2014_XXXXXXXXXXXX.jpg", где "XXXXXXXXXXXX" соответствует идентификатору изображения, заполненному нулями, чтобы иметь длину 12.

imageID = annotationsTrain(1).ImageID;
captions = annotationsTrain(1).Captions;
filename = annotationsTrain(1).Filename;

Чтобы просмотреть изображение, используйте imread и imshow функций.

img = imread(filename);
figure
imshow(img)
title(captions)

Подготовка данных к обучению

Подготовьте подписи для обучения и проверки. Извлеките текст из Captions поле struct, содержащее как обучающие, так и тестовые данные (annotationsAll), стереть пунктуацию и преобразовать текст в строчный.

captionsAll = cat(1,annotationsAll.Captions);
captionsAll = erasePunctuation(captionsAll);
captionsAll = lower(captionsAll);

В порядок для генерации подписей декодер RNN требует специальных лексем запуска и остановки, чтобы указать, когда начать и остановить генерацию текста, соответственно. Добавьте пользовательские лексемы "<start>" и "<stop>" к началам и концам подписей, соответственно.

captionsAll = "<start>" + captionsAll + "<stop>";

Токенизация подписей с помощью tokenizedDocument и задайте стартовые и стоповые лексемы используя 'CustomTokens' опция.

documentsAll = tokenizedDocument(captionsAll,'CustomTokens',["<start>" "<stop>"]);

Создайте wordEncoding объект, который преобразует слова в числовые индексы и назад. Уменьшите требования к памяти путем определения размера словаря 5000, соответствующего наиболее часто наблюдаемым словам в обучающих данных. Чтобы избежать смещения, используйте только документы, соответствующие наборы обучающих данных.

enc = wordEncoding(documentsAll(idxTrain),'MaxNumWords',5000,'Order','frequency');

Создайте хранилище данных дополненных изображений, содержащее изображения, соответствующие подписям. Установите размер выходного сигнала, так чтобы он совпадал с размером входа сверточной сети. Чтобы сохранить изображения синхронизированными с подписями, задайте таблицу имен файлов для datastore путем восстановления имен файлов с помощью идентификатора изображения. Чтобы вернуть полутоновые изображения как 3-канальные изображения RGB, установите 'ColorPreprocessing' опция для 'gray2rgb'.

tblFilenames = table(cat(1,annotationsTrain.Filename));
augimdsTrain = augmentedImageDatastore(inputSizeNet,tblFilenames,'ColorPreprocessing','gray2rgb')
augimdsTrain = 
  augmentedImageDatastore with properties:

         NumObservations: 78644
           MiniBatchSize: 1
        DataAugmentation: 'none'
      ColorPreprocessing: 'gray2rgb'
              OutputSize: [299 299]
          OutputSizeMode: 'resize'
    DispatchInBackground: 0

Инициализация параметров модели

Инициализируйте параметры модели. Задайте 512 скрытые модули с размерностью встраивания слов 256.

embeddingDimension = 256;
numHiddenUnits = 512;

Инициализируйте struct, содержащую параметры для модели энкодера.

  • Инициализируйте веса полносвязных операций с помощью инициализатора Glorot, заданного initializeGlorot функции, перечисленной в конце примера. Задайте размер выхода, чтобы соответствовать размерности встраивания декодера (256) и размеру входа, чтобы соответствовать количеству выхода каналов предварительно обученной сети. The 'mixed10' слой Inception-v3 сети выводит данные с 2048 каналами.

numFeatures = outputSizeNet(1) * outputSizeNet(2);
inputSizeEncoder = outputSizeNet(3);
parametersEncoder = struct;

% Fully connect
parametersEncoder.fc.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeEncoder));
parametersEncoder.fc.Bias = dlarray(zeros([embeddingDimension 1],'single'));

Инициализируйте struct, содержащую параметры для модели декодера.

  • Инициализируйте веса встраивания слов с размером, заданным размерностью встраивания и размером словаря плюс один, где дополнительная запись соответствует значению дополнения.

  • Инициализируйте веса и смещения для механизма внимания Бахданау с размерами, соответствующими количеству скрытых модулей операции ГРУ.

  • Инициализируйте веса и смещение операции GRU.

  • Инициализируйте веса и смещения двух полносвязных операций.

Для параметров декодера модели инициализируйте каждое из взвешиваний и смещений с помощью инициализатора Glorot и нули, соответственно.

inputSizeDecoder = enc.NumWords + 1;
parametersDecoder = struct;

% Word embedding
parametersDecoder.emb.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeDecoder));

% Attention
parametersDecoder.attention.Weights1 = dlarray(initializeGlorot(numHiddenUnits,embeddingDimension));
parametersDecoder.attention.Bias1 = dlarray(zeros([numHiddenUnits 1],'single'));
parametersDecoder.attention.Weights2 = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits));
parametersDecoder.attention.Bias2 = dlarray(zeros([numHiddenUnits 1],'single'));
parametersDecoder.attention.WeightsV = dlarray(initializeGlorot(1,numHiddenUnits));
parametersDecoder.attention.BiasV = dlarray(zeros(1,1,'single'));

% GRU
parametersDecoder.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,2*embeddingDimension));
parametersDecoder.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits));
parametersDecoder.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single'));

% Fully connect
parametersDecoder.fc1.Weights = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits));
parametersDecoder.fc1.Bias = dlarray(zeros([numHiddenUnits 1],'single'));

% Fully connect
parametersDecoder.fc2.Weights = dlarray(initializeGlorot(enc.NumWords+1,numHiddenUnits));
parametersDecoder.fc2.Bias = dlarray(zeros([enc.NumWords+1 1],'single'));

Задайте функции модели

Создайте функции modelEncoder и modelDecoder, перечисленных в конце примера, которые вычисляют выходы моделей энкодера и декодера, соответственно.

The modelEncoder функция, перечисленная в разделе Модель Function примера, принимает за вход массив активаций dlX с выхода предварительно обученной сети и пропускает ее через полностью подключенную операцию и операцию ReLU. Поскольку предварительно обученная сеть не должна быть прослежена для автоматической дифференциации, извлечение признаков вне функции модели энкодера является более в вычислительном отношении эффективным.

The modelDecoder функция, перечисленная в разделе Decoder Model Function примера, принимает за вход один входной временной шаг, соответствующий входному слову, параметрам модели декодера, функциям от энкодера и состоянию сети, и возвращает предсказания для следующего временного шага, обновленное состояние сети и веса внимания.

Настройка опций обучения

Укажите опции обучения. Train на 30 эпох с мини-пакетом размером 128 и отображением процесса обучения на участке.

miniBatchSize = 128;
numEpochs = 30;
plots = "training-progress";

Обучите на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

executionEnvironment = "auto";

Обучите сеть

Обучите сеть с помощью пользовательского цикла обучения.

В начале каждой эпохи тасуйте входные данные. Чтобы сохранить изображения в дополненном datastore изображений и подписях синхронизированными, создайте массив перетасованных индексов, который индексируется в оба набора данных.

Для каждого мини-пакета:

  • Измените размер изображений, ожидаемый предварительно обученной сетью.

  • Для каждого изображения выберите случайную подпись.

  • Преобразуйте заголовки в последовательности индексов слов. Задайте правое заполнение последовательностей со значением заполнения, соответствующим индексу лексемы заполнения.

  • Преобразуйте данные в dlarray объекты. Для изображений задайте метки размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Для обучения графический процессор преобразуйте данные в gpuArray объекты.

  • Извлеките функции изображений с помощью предварительно обученной сети и измените их на размер, ожидаемый энкодером.

  • Оцените градиенты модели и потери с помощью dlfeval и modelGradients функций.

  • Обновите параметры модели энкодера и декодера с помощью adamupdate функция.

  • Отображение процесса обучения на графике.

Инициализируйте параметры для оптимизатора Адама.

trailingAvgEncoder = [];
trailingAvgSqEncoder = [];

trailingAvgDecoder = [];
trailingAvgSqDecoder = [];

Инициализируйте график процесса обучения. Создайте анимированную линию, которая строит график потерь относительно соответствующей итерации.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    xlabel("Iteration")
    ylabel("Loss")
    ylim([0 inf])
    grid on
end

Обучите модель.

iteration = 0;
numObservationsTrain = numel(annotationsTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Shuffle data.
    idxShuffle = randperm(numObservationsTrain);
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Determine mini-batch indices.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        idxMiniBatch = idxShuffle(idx);
        
        % Read mini-batch of data.
        tbl = readByIndex(augimdsTrain,idxMiniBatch);
        X = cat(4,tbl.input{:});
        annotations = annotationsTrain(idxMiniBatch);
        
        % For each image, select random caption.
        idx = cellfun(@(captionIDs) randsample(captionIDs,1),{annotations.CaptionIDs});
        documents = documentsAll(idx);
        
        % Create batch of data.
        [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment);
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ...
            parametersDecoder, dlX, dlT);
        
        % Update encoder using adamupdate.
        [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ...
            gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration);
        
        % Update decoder using adamupdate.
        [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ...
            gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            
            drawnow
        end
    end
end

Предсказание новых подписей

Процесс генерации заголовка отличается от процесса обучения. Во время обучения на каждом временном шаге декодер использует в качестве входных данных истинное значение предыдущего временного шага. Это известно как «принуждение учителя». При выполнении предсказаний на новых данных декодер использует предыдущие предсказанные значения вместо истинных значений.

Предсказание наиболее вероятного слова для каждого шага в последовательности может привести к неоптимальным результатам. Например, если декодер предсказывает первое слово подписи «а» при выдаче изображения слона, то вероятность предсказания «слона» для следующего слова становится гораздо более маловероятной из-за крайне низкой вероятности появления фразы «слон» в английском тексте.

Чтобы решить эту проблему, можно использовать алгоритм поиска луча: вместо того, чтобы делать наиболее вероятное предсказание для каждого шага в последовательности, берите верхние k предсказаний (индекс луча) и для каждого следующего шага, сохраняйте верхние k предсказанных последовательностей до сих пор в соответствии с общим счетом.

Сгенерируйте заголовок нового изображения путем извлечения функций изображения, ввода их в энкодер, а затем с помощью beamSearch функция, перечисленная в разделе «Функция поиска луча» примера.

img = imread("laika_sitting.jpg");
dlX = extractImageFeatures(dlnet,img,inputMin,inputMax,executionEnvironment);

beamIndex = 3;
maxNumWords = 20;
[words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
caption = join(words)
caption = 
"a dog is standing on a tile floor"

Отобразите изображение с подписью.

figure
imshow(img)
title(caption)

Предсказание подписей для набора данных

Чтобы предсказать подписи для набора изображений, закольцовывайте мини-пакеты данных в datastore и извлеките функции из изображений с помощью extractImageFeatures функция. Затем закольцовывайте изображения в мини-пакете и сгенерируйте подписи, используя beamSearch функция.

Создайте хранилище datastore дополненного изображения и установите размер выхода, соответствующий размеру входа сверточной сети. Чтобы выводить полутоновые изображения как 3-канальные изображения RGB, установите 'ColorPreprocessing' опция для 'gray2rgb'.

tblFilenamesTest = table(cat(1,annotationsTest.Filename));
augimdsTest = augmentedImageDatastore(inputSizeNet,tblFilenamesTest,'ColorPreprocessing','gray2rgb')
augimdsTest = 
  augmentedImageDatastore with properties:

         NumObservations: 4139
           MiniBatchSize: 1
        DataAugmentation: 'none'
      ColorPreprocessing: 'gray2rgb'
              OutputSize: [299 299]
          OutputSizeMode: 'resize'
    DispatchInBackground: 0

Сгенерируйте подписи для тестовых данных. Предсказание подписей на большом наборе данных может занять некоторое время. Если у вас есть Parallel Computing Toolbox™, то можно делать предсказания параллельно, генерируя подписи внутри parfor смотри. Если у вас нет Parallel Computing Toolbox. затем parfor цикл запускается последовательно.

beamIndex = 2;
maxNumWords = 20;

numObservationsTest = numel(annotationsTest);
numIterationsTest = ceil(numObservationsTest/miniBatchSize);

captionsTestPred = strings(1,numObservationsTest);
documentsTestPred = tokenizedDocument(strings(1,numObservationsTest));

for i = 1:numIterationsTest
    % Mini-batch indices.
    idxStart = (i-1)*miniBatchSize+1;
    idxEnd = min(i*miniBatchSize,numObservationsTest);
    idx = idxStart:idxEnd;
    
    sz = numel(idx);
    
    % Read images.
    tbl = readByIndex(augimdsTest,idx);
    
    % Extract image features.
    X = cat(4,tbl.input{:});
    dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment);
    
    % Generate captions.
    captionsPredMiniBatch = strings(1,sz);
    documentsPredMiniBatch = tokenizedDocument(strings(1,sz));
    
    parfor j = 1:sz
        words = beamSearch(dlX(:,:,j),beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
        captionsPredMiniBatch(j) = join(words);
        documentsPredMiniBatch(j) = tokenizedDocument(words,'TokenizeMethod','none');
    end
    
    captionsTestPred(idx) = captionsPredMiniBatch;
    documentsTestPred(idx) = documentsPredMiniBatch;
end
Analyzing and transferring files to the workers ...done.

Чтобы просмотреть тестовое изображение с соответствующей подписью, используйте imshow и установите заголовок предсказанной подписи.

idx = 1;
tbl = readByIndex(augimdsTest,idx);
img = tbl.input{1};
figure
imshow(img)
title(captionsTestPred(idx))

Оцените точность модели

Чтобы оценить точность подписей с помощью счета BLEU, вычислите счет BLEU для каждой подписи (кандидата) против соответствующих подписей в тестовом наборе (ссылках), используя bleuEvaluationScore функция. Использование bleuEvaluationScore можно сравнить один документ-кандидат с несколькими ссылочными документами.

The bleuEvaluationScore функция по умолчанию оценивает подобие, используя n-граммы длины от одного до четырех. Поскольку подписи коротки, это поведение может привести к неинформативным результатам, так как большинство счетов близки к нулю. Установите длину n-грамма от одного до двух путем установки 'NgramWeights' опция для двухэлементного вектора с равными весами.

ngramWeights = [0.5 0.5];

for i = 1:numObservationsTest
    annotation = annotationsTest(i);
    
    captionIDs = annotation.CaptionIDs;
    candidate = documentsTestPred(i);
    references = documentsAll(captionIDs);
    
    score = bleuEvaluationScore(candidate,references,'NgramWeights',ngramWeights);
    
    scores(i) = score;
end

Просмотрите средний счет BLEU.

scoreMean = mean(scores)
scoreMean = 0.4224

Визуализируйте счета в гистограмме.

figure
histogram(scores)
xlabel("BLEU Score")
ylabel("Frequency")

Функция внимания

The attention функция вычисляет вектор контекста и веса внимания, используя внимание Бахданау.

function [contextVector, attentionWeights] = attention(hidden,features,weights1, ...
    bias1,weights2,bias2,weightsV,biasV)

% Model dimensions.
[embeddingDimension,numFeatures,miniBatchSize] = size(features);
numHiddenUnits = size(weights1,1);

% Fully connect.
dlY1 = reshape(features,embeddingDimension, numFeatures*miniBatchSize);
dlY1 = fullyconnect(dlY1,weights1,bias1,'DataFormat','CB');
dlY1 = reshape(dlY1,numHiddenUnits,numFeatures,miniBatchSize);

% Fully connect.
dlY2 = fullyconnect(hidden,weights2,bias2,'DataFormat','CB');
dlY2 = reshape(dlY2,numHiddenUnits,1,miniBatchSize);

% Addition, tanh.
scores = tanh(dlY1 + dlY2);
scores = reshape(scores, numHiddenUnits, numFeatures*miniBatchSize);

% Fully connect, softmax.
attentionWeights = fullyconnect(scores,weightsV,biasV,'DataFormat','CB');
attentionWeights = reshape(attentionWeights,1,numFeatures,miniBatchSize);
attentionWeights = softmax(attentionWeights,'DataFormat','SCB');

% Context.
contextVector = attentionWeights .* features;
contextVector = squeeze(sum(contextVector,2));

end

Функция встраивания

The embedding функция преобразует массив индексов в последовательность векторов встраивания.

function Z = embedding(X, weights)

% Reshape inputs into a vector
[N, T] = size(X, 1:2);
X = reshape(X, N*T, 1);

% Index into embedding matrix
Z = weights(:, X);

% Reshape outputs by separating out batch and sequence dimensions
Z = reshape(Z, [], N, T);

end

Функция редукции данных

The extractImageFeatures функция принимает как вход обученную dlnetwork объект, входное изображение, статистика для сканирования изображений и среда выполнения и возвращает dlarray содержащие функции, извлеченный из предварительно обученной сети.

function dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment)

% Resize and rescale.
inputSize = dlnet.Layers(1).InputSize(1:2);
X = imresize(X,inputSize);
X = rescale(X,-1,1,'InputMin',inputMin,'InputMax',inputMax);

% Convert to dlarray.
dlX = dlarray(X,'SSCB');

% Convert to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

% Extract features and reshape.
dlX = predict(dlnet,dlX);
sz = size(dlX);
numFeatures = sz(1) * sz(2);
inputSizeEncoder = sz(3);
miniBatchSize = sz(4);
dlX = reshape(dlX,[numFeatures inputSizeEncoder miniBatchSize]);

end

Функция создания пакетов

The createBatch функция принимает за вход мини-пакет данных, токенизированные подписи, предварительно обученную сеть, статистику для перемасштабирования изображений, кодирование слов и окружение выполнения и возвращает мини-пакет данных, соответствующих извлеченным функциям изображения и подписям для обучения.

function [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment)

dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment);

% Convert documents to sequences of word indices.
T = doc2sequence(enc,documents,'PaddingDirection','right','PaddingValue',enc.NumWords+1);
T = cat(1,T{:});

% Convert mini-batch of data to dlarray.
dlT = dlarray(T);

% If training on a GPU, then convert data to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlT = gpuArray(dlT);
end

end

Функция модели энкодера

The modelEncoder функция принимает за вход массив активаций dlX и передает его через полносвязную операцию и операцию ReLU. Для полностью подключенной операции работайте только по размерности канала. Чтобы применить операцию, полностью связанную только с размерностью канала, выровняйте другие каналы в одну размерность и задайте эту размерность как размерность пакета, используя 'DataFormat' опция fullyconnect функция.

function dlY = modelEncoder(dlX,parametersEncoder)

[numFeatures,inputSizeEncoder,miniBatchSize] = size(dlX);

% Fully connect
weights = parametersEncoder.fc.Weights;
bias = parametersEncoder.fc.Bias;
embeddingDimension = size(weights,1);

dlX = permute(dlX,[2 1 3]);
dlX = reshape(dlX,inputSizeEncoder,numFeatures*miniBatchSize);
dlY = fullyconnect(dlX,weights,bias,'DataFormat','CB');
dlY = reshape(dlY,embeddingDimension,numFeatures,miniBatchSize);

% ReLU
dlY = relu(dlY);

end

Функция модели декодера

The modelDecoder функция принимает как вход один временной шаг dlX, параметры модели декодера, функции от энкодера и состояние сети и возвращает предсказания для следующего временного шага, обновленное состояние сети и веса внимания.

function [dlY,state,attentionWeights] = modelDecoder(dlX,parametersDecoder,features,state)

hiddenState = state.gru.HiddenState;

% Attention
weights1 = parametersDecoder.attention.Weights1;
bias1 = parametersDecoder.attention.Bias1;
weights2 = parametersDecoder.attention.Weights2;
bias2 = parametersDecoder.attention.Bias2;
weightsV = parametersDecoder.attention.WeightsV;
biasV = parametersDecoder.attention.BiasV;
[contextVector, attentionWeights] = attention(hiddenState,features,weights1,bias1,weights2,bias2,weightsV,biasV);

% Embedding
weights = parametersDecoder.emb.Weights;
dlX = embedding(dlX,weights);

% Concatenate
dlY = cat(1,contextVector,dlX);

% GRU
inputWeights = parametersDecoder.gru.InputWeights;
recurrentWeights = parametersDecoder.gru.RecurrentWeights;
bias = parametersDecoder.gru.Bias;
[dlY, hiddenState] = gru(dlY, hiddenState, inputWeights, recurrentWeights, bias, 'DataFormat','CBT');

% Update state
state.gru.HiddenState = hiddenState;

% Fully connect
weights = parametersDecoder.fc1.Weights;
bias = parametersDecoder.fc1.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

% Fully connect
weights = parametersDecoder.fc2.Weights;
bias = parametersDecoder.fc2.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

end

Моделирование градиентов

The modelGradients функция принимает за вход параметры энкодера и декодера, функции энкодера dlX, и целевой заголовок dlTи возвращает градиенты параметров энкодера и декодера относительно потерь, потерь и предсказаний.

function [gradientsEncoder,gradientsDecoder,loss,dlYPred] = ...
    modelGradients(parametersEncoder,parametersDecoder,dlX,dlT)

miniBatchSize = size(dlX,3);
sequenceLength = size(dlT,2) - 1;
vocabSize = size(parametersDecoder.emb.Weights,2);

% Model encoder
features = modelEncoder(dlX,parametersEncoder);

% Initialize state
numHiddenUnits = size(parametersDecoder.attention.Weights1,1);
state = struct;
state.gru.HiddenState = dlarray(zeros([numHiddenUnits miniBatchSize],'single'));

dlYPred = dlarray(zeros([vocabSize miniBatchSize sequenceLength],'like',dlX));
loss = dlarray(single(0));

padToken = vocabSize;

for t = 1:sequenceLength
    decoderInput = dlT(:,t);
    
    dlYReal = dlT(:,t+1);
    
    [dlYPred(:,:,t),state] = modelDecoder(decoderInput,parametersDecoder,features,state);
    
    mask = dlYReal ~= padToken;
    
    loss = loss + sparseCrossEntropyAndSoftmax(dlYPred(:,:,t),dlYReal,mask);
end

% Calculate gradients
[gradientsEncoder,gradientsDecoder] = dlgradient(loss, parametersEncoder,parametersDecoder);

end

Разреженная перекрестная энтропия и функция Softmax Loss

The sparseCrossEntropyAndSoftmax принимает за вход предсказания dlY, соответствующие цели dlT, и последовательность заполнения маски, и применяет softmax выполняет функции и возвращает потери перекрестной энтропии.

function loss = sparseCrossEntropyAndSoftmax(dlY, dlT, mask)

miniBatchSize = size(dlY, 2);

% Softmax.
dlY = softmax(dlY,'DataFormat','CB');

% Find rows corresponding to the target words.
idx = sub2ind(size(dlY), dlT', 1:miniBatchSize);
dlY = dlY(idx);

% Bound away from zero.
dlY = max(dlY, single(1e-8));

% Masked loss.
loss = log(dlY) .* mask';
loss = -sum(loss,'all') ./ miniBatchSize;

end

Функция поиска луча

The beamSearch функция принимает как вход изображение функций dlX, индекс луча, параметры для сетей энкодера и декодера, кодирование слова и максимальная длина последовательности, и возвращает слова заголовка для изображения, используя алгоритм поиска луча.

function [words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder, ...
    enc,maxNumWords)

% Model dimensions
numFeatures = size(dlX,1);
numHiddenUnits = size(parametersDecoder.attention.Weights1,1);

% Extract features
features = modelEncoder(dlX,parametersEncoder);

% Initialize state
state = struct;
state.gru.HiddenState = dlarray(zeros([numHiddenUnits 1],'like',dlX));

% Initialize candidates
candidates = struct;
candidates.State = state;
candidates.Words = "<start>";
candidates.Score = 0;
candidates.AttentionScores = dlarray(zeros([numFeatures maxNumWords],'like',dlX));
candidates.StopFlag = false;

t = 0;

% Loop over words
while t < maxNumWords
    t = t + 1;
    
    candidatesNew = [];
    
    % Loop over candidates
    for i = 1:numel(candidates)
        
        % Stop generating when stop token is predicted
        if candidates(i).StopFlag
            continue
        end
        
        % Candidate details
        state = candidates(i).State;
        words = candidates(i).Words;
        score = candidates(i).Score;
        attentionScores = candidates(i).AttentionScores;
        
        % Predict next token
        decoderInput = word2ind(enc,words(end));
        [dlYPred,state,attentionScores(:,t)] = modelDecoder(decoderInput,parametersDecoder,features,state);
        
        dlYPred = softmax(dlYPred,'DataFormat','CB');
        [scoresTop,idxTop] = maxk(extractdata(dlYPred),beamIndex);
        idxTop = gather(idxTop);
        
        % Loop over top predictions
        for j = 1:beamIndex
            candidate = struct;
            
            candidateWord = ind2word(enc,idxTop(j));
            candidateScore = scoresTop(j);
            
            if candidateWord == "<stop>"
                candidate.StopFlag = true;
                attentionScores(:,t+1:end) = [];
            else
                candidate.StopFlag = false;
            end
            
            candidate.State = state;
            candidate.Words = [words candidateWord];
            candidate.Score = score + log(candidateScore);
            candidate.AttentionScores = attentionScores;
            
            candidatesNew = [candidatesNew candidate];
        end
    end
    
    % Get top candidates
    [~,idx] = maxk([candidatesNew.Score],beamIndex);
    candidates = candidatesNew(idx);
    
    % Stop predicting when all candidates have stop token
    if all([candidates.StopFlag])
        break
    end
end

% Get top candidate
words = candidates(1).Words(2:end-1);
attentionScores = candidates(1).AttentionScores;

end

Функция инициализации веса Глорота

The initializeGlorot функция генерирует массив весов согласно инициализации Glorot.

function weights = initializeGlorot(numOut, numIn)

varWeights = sqrt( 6 / (numIn + numOut) );
weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1);

end

См. также

| | | | | | | | | (Symbolic Math Toolbox) | (Symbolic Math Toolbox) | (Symbolic Math Toolbox) | (Symbolic Math Toolbox)

Похожие темы