Отобразите ввод субтитров Используя внимание

В этом примере показано, как обучить модель глубокого обучения вводу субтитров изображений с помощью внимания.

Большинство предварительно обученных нейронных сетей для глубокого обучения сконфигурировано для классификации одно меток. Например, учитывая изображение типичного офисного стола, сетевая сила предсказывает единый класс "клавиатура" или "мышь". В отличие от этого модель ввода субтитров изображений комбинирует сверточные и текущие операции, чтобы произвести текстовое описание того, что находится в изображении, а не одной метке.

Эта модель, обученная в этом примере, использует архитектуру декодера энкодера. Энкодер является предварительно обученной сетью Inception-v3, используемой в качестве экстрактора функции. Декодер является рекуррентной нейронной сетью (RNN), которая берет извлеченные функции, как введено и генерирует заголовок. Декодер включает механизм внимания, который позволяет декодеру фокусироваться на частях закодированного входа при генерации заголовка.

Модель энкодера является предварительно обученной моделью Inception-v3, которая извлекает функции из "mixed10" слой, сопровождаемый полностью связанным и операции ReLU.

Модель декодера состоит из встраивания слова, механизма внимания, закрытого текущего модуля (GRU) и двух полностью связанных операций.

Загрузите предварительно обученную сеть

Загрузите предварительно обученную сеть Inception-v3. Этот шаг требует Модели Deep Learning Toolbox™ для пакета Сетевой поддержки Inception-v3. Если вам не установили необходимый пакет поддержки, то программное обеспечение обеспечивает ссылку на загрузку.

net = inceptionv3;
inputSizeNet = net.Layers(1).InputSize;

Преобразуйте сеть в dlnetwork объект для извлечения признаков и удаляет последние четыре слоя, оставляя "mixed10" слой как последний слой.

lgraph = layerGraph(net);
lgraph = removeLayers(lgraph,["avg_pool" "predictions" "predictions_softmax" "ClassificationLayer_predictions"]);

Просмотрите входной слой сети. Симметричное использование сети Inception-v3 - перемасштабирует нормализацию с минимальным значением 0 и максимальным значением 255.

lgraph.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'input_1'
                 InputSize: [299 299 3]

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'rescale-symmetric'
    NormalizationDimension: 'auto'
                       Max: 255
                       Min: 0

Пользовательское обучение не поддерживает эту нормализацию, таким образом, необходимо отключить нормализацию в сети и выполнить нормализацию в пользовательском учебном цикле вместо этого. Сохраните минимальные и максимальные значения, как удваивается в переменных под названием inputMin и inputMax, соответственно, и замена входной слой с изображением ввела слой без нормализации.

inputMin = double(lgraph.Layers(1).Min);
inputMax = double(lgraph.Layers(1).Max);
layer = imageInputLayer(inputSizeNet,'Normalization','none','Name','input');
lgraph = replaceLayer(lgraph,'input_1',layer);

Определите выходной размер сети. Используйте analyzeNetwork функция, чтобы видеть размеры активации последнего слоя. Чтобы анализировать сеть для пользовательских учебных рабочих процессов цикла, установите TargetUsage опция к 'dlnetwork'.

analyzeNetwork(lgraph,'TargetUsage','dlnetwork')

Создайте переменную под названием outputSizeNet содержа сетевой выходной размер.

outputSizeNet = [8 8 2048];

Преобразуйте график слоев в dlnetwork возразите и просмотрите выходной слой. Выходным слоем является "mixed10" слой сети Inception-v3.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [311×1 nnet.cnn.layer.Layer]
    Connections: [345×2 table]
     Learnables: [376×3 table]
          State: [188×3 table]
     InputNames: {'input'}
    OutputNames: {'mixed10'}

Импортируйте набор данных COCO

Образы загрузки и аннотации от наборов данных "2014 Обучают изображения", и "2014 Обучают/val аннотации", соответственно, под эгидой https://cocodataset.org/#download. Извлеките изображения и аннотации в папку под названием "coco". Набор данных COCO 2014 был собран Кокосовым Консорциумом.

Извлеките заголовки из файла "captions_train2014.json" использование jsondecode функция.

dataFolder = fullfile(tempdir,"coco");
filename = fullfile(dataFolder,"annotations_trainval2014","annotations","captions_train2014.json");
str = fileread(filename);
data = jsondecode(str)
data = struct with fields:
           info: [1×1 struct]
         images: [82783×1 struct]
       licenses: [8×1 struct]
    annotations: [414113×1 struct]

annotations поле struct содержит данные, требуемые для ввода субтитров изображений.

data.annotations
ans=414113×1 struct array with fields:
    image_id
    id
    caption

Набор данных содержит несколько заголовков для каждого изображения. Чтобы гарантировать те же изображения не появляются и в наборах обучения и в валидации, идентифицируют уникальные изображения в наборе данных с помощью unique функция при помощи идентификаторов в image_id поле поля аннотаций данных, затем просмотрите количество уникальных изображений.

numObservationsAll = numel(data.annotations)
numObservationsAll = 414113
imageIDs = [data.annotations.image_id];
imageIDsUnique = unique(imageIDs);
numUniqueImages = numel(imageIDsUnique)
numUniqueImages = 82783

Каждое изображение имеет по крайней мере пять заголовков. Создайте struct annotationsAll с этими полями:

  • ImageID ⁠ — Отображают ID

  • Filename ⁠ — Имя файла изображения

  • Captions ⁠ — Массив строк необработанных заголовков

  • CaptionIDs ⁠ — Вектор из индексов соответствующих заголовков в data.annotations

Чтобы сделать слияние легче, отсортируйте аннотации по идентификаторам изображений.

[~,idx] = sort([data.annotations.image_id]);
data.annotations = data.annotations(idx);

Цикл по аннотациям и слиянию несколько аннотаций при необходимости.

i = 0;
j = 0;
imageIDPrev = 0;
while i < numel(data.annotations)
    i = i + 1;
    
    imageID = data.annotations(i).image_id;
    caption = string(data.annotations(i).caption);
    
    if imageID ~= imageIDPrev
        % Create new entry
        j = j + 1;
        annotationsAll(j).ImageID = imageID;
        annotationsAll(j).Filename = fullfile(dataFolder,"train2014","COCO_train2014_" + pad(string(imageID),12,'left','0') + ".jpg");
        annotationsAll(j).Captions = caption;
        annotationsAll(j).CaptionIDs = i;
    else
        % Append captions
        annotationsAll(j).Captions = [annotationsAll(j).Captions; caption];
        annotationsAll(j).CaptionIDs = [annotationsAll(j).CaptionIDs; i];
    end
    
    imageIDPrev = imageID;
end

Разделите данные в наборы обучения и валидации. Протяните 5% наблюдений для тестирования.

cvp = cvpartition(numel(annotationsAll),'HoldOut',0.05);
idxTrain = training(cvp);
idxTest = test(cvp);
annotationsTrain = annotationsAll(idxTrain);
annotationsTest = annotationsAll(idxTest);

Struct содержит три поля:

  • id — Уникальный идентификатор для заголовка

  • caption — Отобразите заголовок в виде вектора символов

  • image_id — Уникальный идентификатор изображения, соответствующего заголовку

Чтобы просмотреть изображение и соответствующий заголовок, найдите файл изображения с именем файла "train2014\COCO_train2014_XXXXXXXXXXXX.jpg", где "XXXXXXXXXXXX" соответствует ID изображений, лево-дополненному нулями, чтобы иметь длину 12.

imageID = annotationsTrain(1).ImageID;
captions = annotationsTrain(1).Captions;
filename = annotationsTrain(1).Filename;

Чтобы просмотреть изображение, используйте imread и imshow функции.

img = imread(filename);
figure
imshow(img)
title(captions)

Подготовка данных для обучения

Подготовьте заголовки к обучению и тестированию. Извлеките текст из Captions поле struct, содержащего и обучение и тестовые данные (annotationsAll), сотрите пунктуацию и преобразуйте текст в нижний регистр.

captionsAll = cat(1,annotationsAll.Captions);
captionsAll = erasePunctuation(captionsAll);
captionsAll = lower(captionsAll);

Для того, чтобы сгенерировать заголовки, декодер RNN требует, чтобы специальный запуск и лексемы остановки указали, когда запустить и прекратить генерировать текст, соответственно. Добавьте пользовательские лексемы "<start>" и "<stop>" к началу и концам заголовков, соответственно.

captionsAll = "<start>" + captionsAll + "<stop>";

Маркируйте заголовки с помощью tokenizedDocument функционируйте и задайте запуск и лексемы остановки с помощью 'CustomTokens' опция.

documentsAll = tokenizedDocument(captionsAll,'CustomTokens',["<start>" "<stop>"]);

Создайте wordEncoding возразите что слова карт против числовых индексов и назад. Уменьшайте требования к памяти путем определения размера словаря 5 000 соответствий наиболее часто наблюдаемым словам в обучающих данных. Чтобы избежать смещения, используйте только документы, соответствующие набору обучающих данных.

enc = wordEncoding(documentsAll(idxTrain),'MaxNumWords',5000,'Order','frequency');

Создайте увеличенный datastore изображений, содержащий изображения, соответствующие заголовкам. Установите выходной размер совпадать с входным размером сверточной сети. Чтобы сохранить изображения синхронизируемыми с заголовками, задайте таблицу имен файлов для datastore путем восстановления имен файлов с помощью ID изображений. Чтобы возвратить полутоновые изображения как RGB, с 3 каналами отображает, установите 'ColorPreprocessing' опция к 'gray2rgb'.

tblFilenames = table(cat(1,annotationsTrain.Filename));
augimdsTrain = augmentedImageDatastore(inputSizeNet,tblFilenames,'ColorPreprocessing','gray2rgb')
augimdsTrain = 
  augmentedImageDatastore with properties:

         NumObservations: 78644
           MiniBatchSize: 1
        DataAugmentation: 'none'
      ColorPreprocessing: 'gray2rgb'
              OutputSize: [299 299]
          OutputSizeMode: 'resize'
    DispatchInBackground: 0

Инициализируйте параметры модели

Инициализируйте параметры модели. Задайте 512 скрытых модулей с размерностью встраивания слова 256.

embeddingDimension = 256;
numHiddenUnits = 512;

Инициализируйте struct, содержащий параметры для модели энкодера.

  • Инициализируйте веса полностью связанных операций с помощью инициализатора Glorot, заданного initializeGlorot функция, перечисленная в конце примера. Задайте выходной размер, чтобы совпадать с размерностью встраивания декодера (256) и входной размер, чтобы совпадать с количеством выходных каналов предварительно обученной сети. 'mixed10' слой сети Inception-v3 выходные данные с 2 048 каналами.

numFeatures = outputSizeNet(1) * outputSizeNet(2);
inputSizeEncoder = outputSizeNet(3);
parametersEncoder = struct;

% Fully connect
parametersEncoder.fc.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeEncoder));
parametersEncoder.fc.Bias = dlarray(zeros([embeddingDimension 1],'single'));

Инициализируйте struct, содержащий параметры для модели декодера.

  • Инициализируйте веса встраивания слова размером, данным размерностью встраивания и размером словаря плюс один, где дополнительная запись соответствует дополнительному значению.

  • Инициализируйте веса и смещения для механизма внимания Bahdanau с размерами, соответствующими количеству скрытых модулей операции ГРУ.

  • Инициализируйте веса и смещение операции ГРУ.

  • Инициализируйте веса и смещения двух полностью связанных операций.

Для параметров декодера модели инициализируйте каждый то, чтобы взвешивать и смещения с инициализатором Glorot и нулями, соответственно.

inputSizeDecoder = enc.NumWords + 1;
parametersDecoder = struct;

% Word embedding
parametersDecoder.emb.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeDecoder));

% Attention
parametersDecoder.attention.Weights1 = dlarray(initializeGlorot(numHiddenUnits,embeddingDimension));
parametersDecoder.attention.Bias1 = dlarray(zeros([numHiddenUnits 1],'single'));
parametersDecoder.attention.Weights2 = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits));
parametersDecoder.attention.Bias2 = dlarray(zeros([numHiddenUnits 1],'single'));
parametersDecoder.attention.WeightsV = dlarray(initializeGlorot(1,numHiddenUnits));
parametersDecoder.attention.BiasV = dlarray(zeros(1,1,'single'));

% GRU
parametersDecoder.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,2*embeddingDimension));
parametersDecoder.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits));
parametersDecoder.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single'));

% Fully connect
parametersDecoder.fc1.Weights = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits));
parametersDecoder.fc1.Bias = dlarray(zeros([numHiddenUnits 1],'single'));

% Fully connect
parametersDecoder.fc2.Weights = dlarray(initializeGlorot(enc.NumWords+1,numHiddenUnits));
parametersDecoder.fc2.Bias = dlarray(zeros([enc.NumWords+1 1],'single'));

Функции модели Define

Создайте функции modelEncoder и modelDecoder, перечисленный в конце примера, которые вычисляют выходные параметры моделей энкодера и декодера, соответственно.

modelEncoder функция, перечисленная в разделе Encoder Model Function примера, берет в качестве входа массив активаций dlX от выхода предварительно обученной сети и передач это посредством полностью связанной операции и операции ReLU. Поскольку предварительно обученная сеть не должна быть прослежена для автоматического дифференцирования, извлечение функций вне функции модели энкодера более в вычислительном отношении эффективно.

modelDecoder функция, перечисленная в разделе Decoder Model Function примера, занимает в качестве входа один входной такт, соответствуя входному слову, параметрам модели декодера, функциям от энкодера и сетевому состоянию, и возвращает предсказания для следующего временного шага, обновленного сетевого состояния и весов внимания.

Задайте опции обучения

Задайте опции для обучения. Обучайтесь в течение 30 эпох с мини-пакетным размером 128 и отобразите прогресс обучения в графике.

miniBatchSize = 128;
numEpochs = 30;
plots = "training-progress";

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

executionEnvironment = "auto";

Обучение сети

Обучите сеть с помощью пользовательского учебного цикла.

В начале каждой эпохи переставьте входные данные. Чтобы сохранить изображения в увеличенном datastore изображений и заголовках синхронизируемыми, создайте массив переставленных индексов, который индексирует в оба набора данных.

Для каждого мини-пакета:

  • Перемасштабируйте изображения к размеру, который ожидает предварительно обученная сеть.

  • Для каждого изображения выберите случайный заголовок.

  • Преобразуйте заголовки в последовательности словарей. Задайте дополнение права последовательностей с дополнительным значением, соответствующим индексу дополнительной лексемы.

  • Преобразуйте данные в dlarray объекты. Для изображений укажите, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Для обучения графического процессора преобразуйте данные в gpuArray объекты.

  • Извлеките функции изображений с помощью предварительно обученной сети и измените их к размеру, который ожидает энкодер.

  • Оцените градиенты модели и потерю с помощью dlfeval и modelGradients функции.

  • Обновите параметры модели энкодера и декодера с помощью adamupdate функция.

  • Отобразите прогресс обучения в графике.

Инициализируйте параметры для оптимизатора Адама.

trailingAvgEncoder = [];
trailingAvgSqEncoder = [];

trailingAvgDecoder = [];
trailingAvgSqDecoder = [];

Инициализируйте график процесса обучения. Создайте анимированную линию, которая строит потерю против соответствующей итерации.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    xlabel("Iteration")
    ylabel("Loss")
    ylim([0 inf])
    grid on
end

Обучите модель.

iteration = 0;
numObservationsTrain = numel(annotationsTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Shuffle data.
    idxShuffle = randperm(numObservationsTrain);
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Determine mini-batch indices.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        idxMiniBatch = idxShuffle(idx);
        
        % Read mini-batch of data.
        tbl = readByIndex(augimdsTrain,idxMiniBatch);
        X = cat(4,tbl.input{:});
        annotations = annotationsTrain(idxMiniBatch);
        
        % For each image, select random caption.
        idx = cellfun(@(captionIDs) randsample(captionIDs,1),{annotations.CaptionIDs});
        documents = documentsAll(idx);
        
        % Create batch of data.
        [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment);
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ...
            parametersDecoder, dlX, dlT);
        
        % Update encoder using adamupdate.
        [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ...
            gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration);
        
        % Update decoder using adamupdate.
        [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ...
            gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            
            drawnow
        end
    end
end

Предскажите новые заголовки

Процесс генерации заголовка отличается от процесса для обучения. Во время обучения, на каждом временном шаге, декодер использует истинное значение предыдущего временного шага, как введено. Это известно как "учителя, обеспечивающего". При создании предсказаний на новых данных декодер использует предыдущие ожидаемые значения вместо истинных значений.

Предсказание наиболее вероятного слова для каждого шага в последовательности может привести к субоптимальным результатам. Например, если декодер предсказывает, что первое слово заголовка является "a", когда дали изображение слона, то вероятность предсказания "слона" для следующего слова становится намного более маловероятной из-за чрезвычайно низкой вероятности фразы "слон", появляющийся в английском тексте.

Чтобы решить эту проблему, можно использовать алгоритм поиска луча: вместо того, чтобы брать наиболее вероятное предсказание для каждого шага в последовательности, возьмите верхнюю часть k предсказания (индекс луча) и для каждого следующего шага, сохраните верхнюю часть k предсказанными последовательностями до сих пор согласно общей оценке.

Сгенерируйте заголовок нового изображения путем извлечения функций изображений, введения их в энкодер, и затем использования beamSearch функция, перечисленная в разделе Beam Search Function примера.

img = imread("laika_sitting.jpg");
dlX = extractImageFeatures(dlnet,img,inputMin,inputMax,executionEnvironment);

beamIndex = 3;
maxNumWords = 20;
[words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
caption = join(words)
caption = 
"a dog is standing on a tile floor"

Отобразите изображение с заголовком.

figure
imshow(img)
title(caption)

Предскажите заголовки для набора данных

Предсказать заголовки для набора изображений, цикла по мини-пакетам данных в datastore и извлечь функции из изображений с помощью extractImageFeatures функция. Затем цикл по изображениям в мини-пакете и генерирует заголовки с помощью beamSearch функция.

Создайте увеличенный datastore изображений и установите выходной размер совпадать с входным размером сверточной сети. Чтобы вывести полутоновые изображения как RGB, с 3 каналами отображает, установите 'ColorPreprocessing' опция к 'gray2rgb'.

tblFilenamesTest = table(cat(1,annotationsTest.Filename));
augimdsTest = augmentedImageDatastore(inputSizeNet,tblFilenamesTest,'ColorPreprocessing','gray2rgb')
augimdsTest = 
  augmentedImageDatastore with properties:

         NumObservations: 4139
           MiniBatchSize: 1
        DataAugmentation: 'none'
      ColorPreprocessing: 'gray2rgb'
              OutputSize: [299 299]
          OutputSizeMode: 'resize'
    DispatchInBackground: 0

Сгенерируйте заголовки для тестовых данных. Предсказание заголовков на большом наборе данных может занять время. Если у вас есть Parallel Computing Toolbox™, то можно сделать предсказания параллельно путем генерации заголовков в parfor посмотреть. Если у вас нет Parallel Computing Toolbox. затем parfor цикл запускается в сериале.

beamIndex = 2;
maxNumWords = 20;

numObservationsTest = numel(annotationsTest);
numIterationsTest = ceil(numObservationsTest/miniBatchSize);

captionsTestPred = strings(1,numObservationsTest);
documentsTestPred = tokenizedDocument(strings(1,numObservationsTest));

for i = 1:numIterationsTest
    % Mini-batch indices.
    idxStart = (i-1)*miniBatchSize+1;
    idxEnd = min(i*miniBatchSize,numObservationsTest);
    idx = idxStart:idxEnd;
    
    sz = numel(idx);
    
    % Read images.
    tbl = readByIndex(augimdsTest,idx);
    
    % Extract image features.
    X = cat(4,tbl.input{:});
    dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment);
    
    % Generate captions.
    captionsPredMiniBatch = strings(1,sz);
    documentsPredMiniBatch = tokenizedDocument(strings(1,sz));
    
    parfor j = 1:sz
        words = beamSearch(dlX(:,:,j),beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
        captionsPredMiniBatch(j) = join(words);
        documentsPredMiniBatch(j) = tokenizedDocument(words,'TokenizeMethod','none');
    end
    
    captionsTestPred(idx) = captionsPredMiniBatch;
    documentsTestPred(idx) = documentsPredMiniBatch;
end
Analyzing and transferring files to the workers ...done.

Чтобы просмотреть тестовое изображение с соответствующим заголовком, используйте imshow функция и набор заголовок на предсказанный заголовок.

idx = 1;
tbl = readByIndex(augimdsTest,idx);
img = tbl.input{1};
figure
imshow(img)
title(captionsTestPred(idx))

Оцените точность модели

Чтобы оценить точность заголовков с помощью BLEU score, вычислите BLEU score для каждого заголовка (кандидат) против соответствующих заголовков в наборе тестов (ссылки) использование bleuEvaluationScore функция. Используя bleuEvaluationScore функция, можно сравнить один документ кандидата нескольким справочным документам.

bleuEvaluationScore функция, по умолчанию, подобие баллов с помощью N-грамм длины один - четыре. Когда заголовки коротки, это поведение может привести к неинформативным результатам, как большинство баллов близко к нулю. Установите длину n-граммы на один - два путем установки 'NgramWeights' опция к двухэлементному вектору с равными весами.

ngramWeights = [0.5 0.5];

for i = 1:numObservationsTest
    annotation = annotationsTest(i);
    
    captionIDs = annotation.CaptionIDs;
    candidate = documentsTestPred(i);
    references = documentsAll(captionIDs);
    
    score = bleuEvaluationScore(candidate,references,'NgramWeights',ngramWeights);
    
    scores(i) = score;
end

Просмотрите средний BLEU score.

scoreMean = mean(scores)
scoreMean = 0.4224

Визуализируйте баллы в гистограмме.

figure
histogram(scores)
xlabel("BLEU Score")
ylabel("Frequency")

Функция внимания

attention функция вычисляет вектор контекста и использование весов внимания внимание Bahdanau.

function [contextVector, attentionWeights] = attention(hidden,features,weights1, ...
    bias1,weights2,bias2,weightsV,biasV)

% Model dimensions.
[embeddingDimension,numFeatures,miniBatchSize] = size(features);
numHiddenUnits = size(weights1,1);

% Fully connect.
dlY1 = reshape(features,embeddingDimension, numFeatures*miniBatchSize);
dlY1 = fullyconnect(dlY1,weights1,bias1,'DataFormat','CB');
dlY1 = reshape(dlY1,numHiddenUnits,numFeatures,miniBatchSize);

% Fully connect.
dlY2 = fullyconnect(hidden,weights2,bias2,'DataFormat','CB');
dlY2 = reshape(dlY2,numHiddenUnits,1,miniBatchSize);

% Addition, tanh.
scores = tanh(dlY1 + dlY2);
scores = reshape(scores, numHiddenUnits, numFeatures*miniBatchSize);

% Fully connect, softmax.
attentionWeights = fullyconnect(scores,weightsV,biasV,'DataFormat','CB');
attentionWeights = reshape(attentionWeights,1,numFeatures,miniBatchSize);
attentionWeights = softmax(attentionWeights,'DataFormat','SCB');

% Context.
contextVector = attentionWeights .* features;
contextVector = squeeze(sum(contextVector,2));

end

Встраивание функции

embedding функционируйте сопоставляет массив индексов к последовательности встраивания векторов.

function Z = embedding(X, weights)

% Reshape inputs into a vector
[N, T] = size(X, 1:2);
X = reshape(X, N*T, 1);

% Index into embedding matrix
Z = weights(:, X);

% Reshape outputs by separating out batch and sequence dimensions
Z = reshape(Z, [], N, T);

end

Функция извлечения признаков

extractImageFeatures функционируйте берет в качестве входа обученный dlnetwork объект, входное изображение, статистика для перемасштабирующего изображения, и среда выполнения, и возвращают dlarray содержание функций извлечено из предварительно обученной сети.

function dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment)

% Resize and rescale.
inputSize = dlnet.Layers(1).InputSize(1:2);
X = imresize(X,inputSize);
X = rescale(X,-1,1,'InputMin',inputMin,'InputMax',inputMax);

% Convert to dlarray.
dlX = dlarray(X,'SSCB');

% Convert to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

% Extract features and reshape.
dlX = predict(dlnet,dlX);
sz = size(dlX);
numFeatures = sz(1) * sz(2);
inputSizeEncoder = sz(3);
miniBatchSize = sz(4);
dlX = reshape(dlX,[numFeatures inputSizeEncoder miniBatchSize]);

end

Пакетная функция создания

createBatch функционируйте берет в качестве входа мини-пакет данных, маркируемых заголовков, предварительно обученной сети, статистики для перемасштабирующего изображения, кодирование слова и среда выполнения, и возвращает мини-пакет данных, соответствующих извлеченным функциям изображений и заголовкам для обучения.

function [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment)

dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment);

% Convert documents to sequences of word indices.
T = doc2sequence(enc,documents,'PaddingDirection','right','PaddingValue',enc.NumWords+1);
T = cat(1,T{:});

% Convert mini-batch of data to dlarray.
dlT = dlarray(T);

% If training on a GPU, then convert data to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlT = gpuArray(dlT);
end

end

Функция модели энкодера

modelEncoder функционируйте берет в качестве входа массив активаций dlX и передачи это посредством полностью связанной операции и операции ReLU. Для полностью связанной операции работайте с размерностью канала только. Чтобы применить полностью связанную операцию через размерность канала только, сгладьте другие каналы в одну размерность и задайте эту размерность как пакетную размерность с помощью 'DataFormat' опция fullyconnect функция.

function dlY = modelEncoder(dlX,parametersEncoder)

[numFeatures,inputSizeEncoder,miniBatchSize] = size(dlX);

% Fully connect
weights = parametersEncoder.fc.Weights;
bias = parametersEncoder.fc.Bias;
embeddingDimension = size(weights,1);

dlX = permute(dlX,[2 1 3]);
dlX = reshape(dlX,inputSizeEncoder,numFeatures*miniBatchSize);
dlY = fullyconnect(dlX,weights,bias,'DataFormat','CB');
dlY = reshape(dlY,embeddingDimension,numFeatures,miniBatchSize);

% ReLU
dlY = relu(dlY);

end

Функция модели декодера

modelDecoder функционируйте занимает в качестве входа один такт dlX, параметры модели декодера, функции от энкодера и сетевое состояние, и возвращают предсказания для следующего временного шага, обновленного сетевого состояния и весов внимания.

function [dlY,state,attentionWeights] = modelDecoder(dlX,parametersDecoder,features,state)

hiddenState = state.gru.HiddenState;

% Attention
weights1 = parametersDecoder.attention.Weights1;
bias1 = parametersDecoder.attention.Bias1;
weights2 = parametersDecoder.attention.Weights2;
bias2 = parametersDecoder.attention.Bias2;
weightsV = parametersDecoder.attention.WeightsV;
biasV = parametersDecoder.attention.BiasV;
[contextVector, attentionWeights] = attention(hiddenState,features,weights1,bias1,weights2,bias2,weightsV,biasV);

% Embedding
weights = parametersDecoder.emb.Weights;
dlX = embedding(dlX,weights);

% Concatenate
dlY = cat(1,contextVector,dlX);

% GRU
inputWeights = parametersDecoder.gru.InputWeights;
recurrentWeights = parametersDecoder.gru.RecurrentWeights;
bias = parametersDecoder.gru.Bias;
[dlY, hiddenState] = gru(dlY, hiddenState, inputWeights, recurrentWeights, bias, 'DataFormat','CBT');

% Update state
state.gru.HiddenState = hiddenState;

% Fully connect
weights = parametersDecoder.fc1.Weights;
bias = parametersDecoder.fc1.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

% Fully connect
weights = parametersDecoder.fc2.Weights;
bias = parametersDecoder.fc2.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

end

Градиенты модели

modelGradients функционируйте берет в качестве входа параметры энкодера и декодера, энкодер показывает dlX, и целевой заголовок dlT, и возвращает градиенты параметров энкодера и декодера относительно потери, потери и предсказаний.

function [gradientsEncoder,gradientsDecoder,loss,dlYPred] = ...
    modelGradients(parametersEncoder,parametersDecoder,dlX,dlT)

miniBatchSize = size(dlX,3);
sequenceLength = size(dlT,2) - 1;
vocabSize = size(parametersDecoder.emb.Weights,2);

% Model encoder
features = modelEncoder(dlX,parametersEncoder);

% Initialize state
numHiddenUnits = size(parametersDecoder.attention.Weights1,1);
state = struct;
state.gru.HiddenState = dlarray(zeros([numHiddenUnits miniBatchSize],'single'));

dlYPred = dlarray(zeros([vocabSize miniBatchSize sequenceLength],'like',dlX));
loss = dlarray(single(0));

padToken = vocabSize;

for t = 1:sequenceLength
    decoderInput = dlT(:,t);
    
    dlYReal = dlT(:,t+1);
    
    [dlYPred(:,:,t),state] = modelDecoder(decoderInput,parametersDecoder,features,state);
    
    mask = dlYReal ~= padToken;
    
    loss = loss + sparseCrossEntropyAndSoftmax(dlYPred(:,:,t),dlYReal,mask);
end

% Calculate gradients
[gradientsEncoder,gradientsDecoder] = dlgradient(loss, parametersEncoder,parametersDecoder);

end

Разреженная перекрестная энтропия и функция потерь Softmax

sparseCrossEntropyAndSoftmax берет в качестве входа предсказания dlY, соответствующие цели dlT, и дополнительная маска последовательности, и применяет softmax функции и возвращают потерю перекрестной энтропии.

function loss = sparseCrossEntropyAndSoftmax(dlY, dlT, mask)

miniBatchSize = size(dlY, 2);

% Softmax.
dlY = softmax(dlY,'DataFormat','CB');

% Find rows corresponding to the target words.
idx = sub2ind(size(dlY), dlT', 1:miniBatchSize);
dlY = dlY(idx);

% Bound away from zero.
dlY = max(dlY, single(1e-8));

% Masked loss.
loss = log(dlY) .* mask';
loss = -sum(loss,'all') ./ miniBatchSize;

end

Излучите поисковую функцию

beamSearch функционируйте берет в качестве входа, изображение показывает dlX, индекс луча, параметры для сетей энкодера и декодера, кодирования слова и максимальной длины последовательности, и возвращают слова заголовка для изображения с помощью алгоритма поиска луча.

function [words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder, ...
    enc,maxNumWords)

% Model dimensions
numFeatures = size(dlX,1);
numHiddenUnits = size(parametersDecoder.attention.Weights1,1);

% Extract features
features = modelEncoder(dlX,parametersEncoder);

% Initialize state
state = struct;
state.gru.HiddenState = dlarray(zeros([numHiddenUnits 1],'like',dlX));

% Initialize candidates
candidates = struct;
candidates.State = state;
candidates.Words = "<start>";
candidates.Score = 0;
candidates.AttentionScores = dlarray(zeros([numFeatures maxNumWords],'like',dlX));
candidates.StopFlag = false;

t = 0;

% Loop over words
while t < maxNumWords
    t = t + 1;
    
    candidatesNew = [];
    
    % Loop over candidates
    for i = 1:numel(candidates)
        
        % Stop generating when stop token is predicted
        if candidates(i).StopFlag
            continue
        end
        
        % Candidate details
        state = candidates(i).State;
        words = candidates(i).Words;
        score = candidates(i).Score;
        attentionScores = candidates(i).AttentionScores;
        
        % Predict next token
        decoderInput = word2ind(enc,words(end));
        [dlYPred,state,attentionScores(:,t)] = modelDecoder(decoderInput,parametersDecoder,features,state);
        
        dlYPred = softmax(dlYPred,'DataFormat','CB');
        [scoresTop,idxTop] = maxk(extractdata(dlYPred),beamIndex);
        idxTop = gather(idxTop);
        
        % Loop over top predictions
        for j = 1:beamIndex
            candidate = struct;
            
            candidateWord = ind2word(enc,idxTop(j));
            candidateScore = scoresTop(j);
            
            if candidateWord == "<stop>"
                candidate.StopFlag = true;
                attentionScores(:,t+1:end) = [];
            else
                candidate.StopFlag = false;
            end
            
            candidate.State = state;
            candidate.Words = [words candidateWord];
            candidate.Score = score + log(candidateScore);
            candidate.AttentionScores = attentionScores;
            
            candidatesNew = [candidatesNew candidate];
        end
    end
    
    % Get top candidates
    [~,idx] = maxk([candidatesNew.Score],beamIndex);
    candidates = candidatesNew(idx);
    
    % Stop predicting when all candidates have stop token
    if all([candidates.StopFlag])
        break
    end
end

% Get top candidate
words = candidates(1).Words(2:end-1);
attentionScores = candidates(1).AttentionScores;

end

Функция инициализации веса Glorot

initializeGlorot функция генерирует массив весов согласно инициализации Glorot.

function weights = initializeGlorot(numOut, numIn)

varWeights = sqrt( 6 / (numIn + numOut) );
weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1);

end

Смотрите также

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | | | | (Text Analytics Toolbox) |

Похожие темы