exponenta event banner

Субтитры изображения с помощью внимания

В этом примере показано, как обучить модель глубокого обучения субтитрам изображений с использованием внимания.

Большинство предварительно подготовленных сетей глубокого обучения сконфигурированы для классификации с одной меткой. Например, учитывая изображение типичного офисного стола, сеть может предсказать один класс «клавиатура» или «мышь». Напротив, модель субтитров изображения объединяет сверточные и повторяющиеся операции для получения текстового описания того, что находится в изображении, а не одной метки.

Эта модель, обученная в этом примере, использует архитектуру кодера-декодера. Кодер представляет собой предварительно подготовленную сеть Inception-v3, используемую в качестве экстрактора признаков. Декодер представляет собой повторяющуюся нейронную сеть (RNN), которая принимает извлеченные признаки в качестве входных данных и генерирует заголовок. Декодер включает в себя механизм внимания, который позволяет декодеру фокусироваться на частях кодированного входа при генерации подписи.

Модель кодера является предварительно подготовленной моделью Inception-v3, которая извлекает элементы из "mixed10" уровень с последующим полным подключением и операциями ReLU.

Модель декодера состоит из встраивания слова, механизма внимания, стробируемого повторяющегося блока (ГРУ) и двух полностью соединенных операций.

Загрузить предварительно обученную сеть

Загрузите предварительно подготовленную сеть Inception-v3. Для выполнения этого шага необходим пакет поддержки Deep Learning Toolbox™ Model for Inception-v3 Network. Если необходимый пакет поддержки не установлен, программа предоставляет ссылку для загрузки.

net = inceptionv3;
inputSizeNet = net.Layers(1).InputSize;

Преобразование сети в dlnetwork для извлечения элемента и удаления последних четырех слоев, оставив "mixed10" слой как последний слой.

lgraph = layerGraph(net);
lgraph = removeLayers(lgraph,["avg_pool" "predictions" "predictions_softmax" "ClassificationLayer_predictions"]);

Просмотр входного уровня сети. Сеть Inception-v3 использует нормализацию симметричного масштаба с минимальным значением 0 и максимальным значением 255.

lgraph.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'input_1'
                 InputSize: [299 299 3]

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'rescale-symmetric'
    NormalizationDimension: 'auto'
                       Max: 255
                       Min: 0

Пользовательское обучение не поддерживает эту нормализацию, поэтому необходимо отключить нормализацию в сети и выполнить нормализацию в пользовательском цикле обучения. Сохранить минимальное и максимальное значения как двойные в переменных с именем inputMin и inputMax, соответственно, и заменить входной слой на входной слой изображения без нормализации.

inputMin = double(lgraph.Layers(1).Min);
inputMax = double(lgraph.Layers(1).Max);
layer = imageInputLayer(inputSizeNet,"Normalization","none",'Name','input');
lgraph = replaceLayer(lgraph,'input_1',layer);

Определите размер выходного сигнала сети. Используйте analyzeNetwork для просмотра размеров активации последнего слоя. Анализатор сети Deep Learning Network Analyzer показывает некоторые проблемы, связанные с сетью, которые можно безопасно игнорировать для пользовательских рабочих процессов обучения.

analyzeNetwork(lgraph)

Создание переменной с именем outputSizeNet содержит размер сетевого вывода.

outputSizeNet = [8 8 2048];

Преобразование графика слоев в dlnetwork и просмотрите выходной слой. Выходной уровень - "mixed10" уровень сети Inception-v3.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [311×1 nnet.cnn.layer.Layer]
    Connections: [345×2 table]
     Learnables: [376×3 table]
          State: [188×3 table]
     InputNames: {'input'}
    OutputNames: {'mixed10'}

Импорт набора данных COCO

Загружайте изображения и аннотации из наборов данных «Изображения поезда 2014 года» и «Аннотации поезда/вала 2014 года» соответственно из https://cocodataset.org/#download. Извлечение изображений и аннотаций в папку с именем "coco". Набор данных COCO 2014 был собран Coco Consortium.

Извлеките подписи из файла "captions_train2014.json" с использованием jsondecode функция.

dataFolder = fullfile(tempdir,"coco");
filename = fullfile(dataFolder,"annotations_trainval2014","annotations","captions_train2014.json");
str = fileread(filename);
data = jsondecode(str)
data = struct with fields:
           info: [1×1 struct]
         images: [82783×1 struct]
       licenses: [8×1 struct]
    annotations: [414113×1 struct]

annotations поле структуры содержит данные, необходимые для субтитров изображения.

data.annotations
ans=414113×1 struct array with fields:
    image_id
    id
    caption

Набор данных содержит несколько заголовков для каждого изображения. Чтобы одни и те же изображения не появлялись как в обучающих наборах, так и в наборах проверки, определите уникальные изображения в наборе данных с помощью unique используя идентификаторы в image_id поле поля аннотаций данных, затем просмотрите количество уникальных изображений.

numObservationsAll = numel(data.annotations)
numObservationsAll = 414113
imageIDs = [data.annotations.image_id];
imageIDsUnique = unique(imageIDs);
numUniqueImages = numel(imageIDsUnique)
numUniqueImages = 82783

Каждое изображение имеет не менее пяти титров. Создание структуры annotationsAll с этими полями:

  • ImageID ⁠ - Идентификатор изображения

  • Filename ⁠ - Имя файла изображения

  • Captions ⁠ - Строковый массив необработанных субтитров

  • CaptionIDs ⁠ - Вектор индексов соответствующих заголовков вdata.annotations

Чтобы упростить объединение, сортируйте аннотации по идентификаторам изображений.

[~,idx] = sort([data.annotations.image_id]);
data.annotations = data.annotations(idx);

Закольцовывайте аннотации и при необходимости объединяйте несколько аннотаций.

i = 0;
j = 0;
imageIDPrev = 0;
while i < numel(data.annotations)
    i = i + 1;
    
    imageID = data.annotations(i).image_id;
    caption = string(data.annotations(i).caption);
    
    if imageID ~= imageIDPrev
        % Create new entry
        j = j + 1;
        annotationsAll(j).ImageID = imageID;
        annotationsAll(j).Filename = fullfile(dataFolder,"train2014","COCO_train2014_" + pad(string(imageID),12,'left','0') + ".jpg");
        annotationsAll(j).Captions = caption;
        annotationsAll(j).CaptionIDs = i;
    else
        % Append captions
        annotationsAll(j).Captions = [annotationsAll(j).Captions; caption];
        annotationsAll(j).CaptionIDs = [annotationsAll(j).CaptionIDs; i];
    end
    
    imageIDPrev = imageID;
end

Разбейте данные на обучающие и проверочные наборы. Выдержать 5% наблюдений для тестирования.

cvp = cvpartition(numel(annotationsAll),'HoldOut',0.05);
idxTrain = training(cvp);
idxTest = test(cvp);
annotationsTrain = annotationsAll(idxTrain);
annotationsTest = annotationsAll(idxTest);

Структура содержит три поля:

  • id - Уникальный идентификатор подписи

  • caption - Подпись к изображению, заданная как вектор символа

  • image_id - Уникальный идентификатор изображения, соответствующего подписи

Для просмотра изображения и соответствующей подписи найдите файл изображения с именем файла "train2014\COCO_train2014_XXXXXXXXXXXX.jpg", где "XXXXXXXXXXXX" соответствует идентификатору изображения, заполненному влево нулями и имеющему длину 12.

imageID = annotationsTrain(1).ImageID;
captions = annotationsTrain(1).Captions;
filename = annotationsTrain(1).Filename;

Для просмотра изображения используйте imread и imshow функции.

img = imread(filename);
figure
imshow(img)
title(captions)

Подготовка данных для обучения

Подготовьте подписи для обучения и тестирования. Извлечь текст из Captions поле структуры, содержащее как учебные, так и тестовые данные (annotationsAll), стереть пунктуацию и преобразовать текст в строчный.

captionsAll = cat(1,annotationsAll.Captions);
captionsAll = erasePunctuation(captionsAll);
captionsAll = lower(captionsAll);

Чтобы генерировать подписи, декодер RNN требует специальных маркеров запуска и остановки, чтобы указать, когда начинать и прекращать генерировать текст, соответственно. Добавить пользовательские маркеры "<start>" и "<stop>" до начала и конца титров соответственно.

captionsAll = "<start>" + captionsAll + "<stop>";

Маркировать подписи с помощью tokenizedDocument и укажите маркеры запуска и остановки с помощью 'CustomTokens' вариант.

documentsAll = tokenizedDocument(captionsAll,'CustomTokens',["<start>" "<stop>"]);

Создать wordEncoding объект, сопоставляющий слова числовым индексам и обратно. Уменьшите требования к памяти, указав размер словаря 5000, соответствующий наиболее часто встречающимся словам в данных обучения. Во избежание предвзятости используйте только документы, соответствующие обучающему набору.

enc = wordEncoding(documentsAll(idxTrain),'MaxNumWords',5000,'Order','frequency');

Создайте хранилище данных дополненного изображения, содержащее изображения, соответствующие заголовкам. Задайте размер выходного сигнала, соответствующий размеру входного сигнала сверточной сети. Чтобы синхронизировать изображения с подписями, укажите таблицу имен файлов для хранилища данных, восстановив имена файлов с помощью идентификатора изображения. Для возврата изображений в оттенках серого в виде 3-канальных изображений RGB установите 'ColorPreprocessing' опция для 'gray2rgb'.

tblFilenames = table(cat(1,annotationsTrain.Filename));
augimdsTrain = augmentedImageDatastore(inputSizeNet,tblFilenames,'ColorPreprocessing','gray2rgb')
augimdsTrain = 
  augmentedImageDatastore with properties:

         NumObservations: 78644
           MiniBatchSize: 1
        DataAugmentation: 'none'
      ColorPreprocessing: 'gray2rgb'
              OutputSize: [299 299]
          OutputSizeMode: 'resize'
    DispatchInBackground: 0

Инициализация параметров модели

Инициализируйте параметры модели. Укажите 512 скрытых единиц с размером вставки слова 256.

embeddingDimension = 256;
numHiddenUnits = 512;

Инициализируйте структуру, содержащую параметры для модели кодировщика.

  • Инициализируйте веса полностью подключенных операций с помощью инициализатора Glorot, указанного initializeGlorot функция, перечисленная в конце примера. Укажите размер выходного сигнала, соответствующий размеру внедрения декодера (256), и размер входного сигнала, соответствующий количеству выходных каналов предварительно обученной сети. 'mixed10' уровень сети Inception-v3 выводит данные по 2048 каналам.

numFeatures = outputSizeNet(1) * outputSizeNet(2);
inputSizeEncoder = outputSizeNet(3);
parametersEncoder = struct;

% Fully connect
parametersEncoder.fc.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeEncoder));
parametersEncoder.fc.Bias = dlarray(zeros([embeddingDimension 1],'single'));

Инициализируйте структуру, содержащую параметры для модели декодера.

  • Инициализируйте весы встраивания слов с размером, заданным размерностью встраивания, и словарным размером плюс один, где дополнительная запись соответствует значению заполнения.

  • Инициализируйте веса и смещения для механизма внимания Бахданау с размерами, соответствующими количеству скрытых единиц операции ГРУ.

  • Инициализируйте вес и смещение операции GRU.

  • Инициализируйте веса и смещения двух полностью связанных операций.

Для параметров декодера модели инициализируйте каждый из весов и смещений с помощью инициализатора Глорота и нулей соответственно.

inputSizeDecoder = enc.NumWords + 1;
parametersDecoder = struct;

% Word embedding
parametersDecoder.emb.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeDecoder));

% Attention
parametersDecoder.attention.Weights1 = dlarray(initializeGlorot(numHiddenUnits,embeddingDimension));
parametersDecoder.attention.Bias1 = dlarray(zeros([numHiddenUnits 1],'single'));
parametersDecoder.attention.Weights2 = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits));
parametersDecoder.attention.Bias2 = dlarray(zeros([numHiddenUnits 1],'single'));
parametersDecoder.attention.WeightsV = dlarray(initializeGlorot(1,numHiddenUnits));
parametersDecoder.attention.BiasV = dlarray(zeros(1,1,'single'));

% GRU
parametersDecoder.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,2*embeddingDimension));
parametersDecoder.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits));
parametersDecoder.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single'));

% Fully connect
parametersDecoder.fc1.Weights = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits));
parametersDecoder.fc1.Bias = dlarray(zeros([numHiddenUnits 1],'single'));

% Fully connect
parametersDecoder.fc2.Weights = dlarray(initializeGlorot(enc.NumWords+1,numHiddenUnits));
parametersDecoder.fc2.Bias = dlarray(zeros([enc.NumWords+1 1],'single'));

Определение функций модели

Создание функций modelEncoder и modelDecoder, перечисленных в конце примера, которые вычисляют выходные сигналы моделей кодера и декодера соответственно.

modelEncoder функция, перечисленная в разделе Encoder Model Function примера, принимает в качестве входных данных массив активаций dlX с выхода предварительно обученной сети и пропускает ее через полностью подключенную операцию и операцию ReLU. Поскольку предварительно обученную сеть не нужно отслеживать для автоматического дифференцирования, извлечение признаков вне функции модели кодера является более эффективным с точки зрения вычислений.

modelDecoder функция, перечисленная в разделе «Функция модели декодера» примера, принимает в качестве входных данных один входной временной шаг, соответствующий входному слову, параметрам модели декодера, признакам от кодера и состоянию сети, и возвращает прогнозы для следующего временного шага, обновленного состояния сети и весов внимания.

Укажите параметры обучения

Укажите параметры обучения. Потренироваться на 30 эпох с размером мини-партии 128 и отобразить ход обучения на графике.

miniBatchSize = 128;
numEpochs = 30;
plots = "training-progress";

Обучение на GPU, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

executionEnvironment = "auto";

Железнодорожная сеть

Обучение сети с использованием пользовательского цикла обучения.

В начале каждой эпохи перетасуйте входные данные. Чтобы синхронизировать изображения в хранилище данных дополненного изображения и подписи, создайте массив перетасованных индексов, который индексирует в оба набора данных.

Для каждой мини-партии:

  • Масштабируйте изображения до размера, ожидаемого предварительно подготовленной сетью.

  • Для каждого изображения выберите случайную подпись.

  • Преобразуйте подписи в последовательности индексов слов. Укажите правое заполнение последовательностей со значением заполнения, соответствующим индексу маркера заполнения.

  • Преобразовать данные в dlarray объекты. Для изображений укажите метки размеров 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Для обучения GPU преобразуйте данные в gpuArray объекты.

  • Извлеките элементы изображения, используя предварительно подготовленную сеть, и измените их размер до ожидаемого кодировщиком размера.

  • Оцените градиенты модели и потери с помощью dlfeval и modelGradients функции.

  • Обновление параметров модели кодера и декодера с помощью adamupdate функция.

  • Отображение хода обучения на графике.

Инициализируйте параметры оптимизатора Adam.

trailingAvgEncoder = [];
trailingAvgSqEncoder = [];

trailingAvgDecoder = [];
trailingAvgSqDecoder = [];

Инициализируйте график хода обучения. Создайте анимированную линию, отображающую потери в соответствии с соответствующей итерацией.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    xlabel("Iteration")
    ylabel("Loss")
    ylim([0 inf])
    grid on
end

Тренируйте модель.

iteration = 0;
numObservationsTrain = numel(annotationsTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Shuffle data.
    idxShuffle = randperm(numObservationsTrain);
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Determine mini-batch indices.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        idxMiniBatch = idxShuffle(idx);
        
        % Read mini-batch of data.
        tbl = readByIndex(augimdsTrain,idxMiniBatch);
        X = cat(4,tbl.input{:});
        annotations = annotationsTrain(idxMiniBatch);
        
        % For each image, select random caption.
        idx = cellfun(@(captionIDs) randsample(captionIDs,1),{annotations.CaptionIDs});
        documents = documentsAll(idx);
        
        % Create batch of data.
        [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment);
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ...
            parametersDecoder, dlX, dlT);
        
        % Update encoder using adamupdate.
        [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ...
            gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration);
        
        % Update decoder using adamupdate.
        [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ...
            gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            
            drawnow
        end
    end
end

Прогнозировать новые подписи

Процесс создания подписи отличается от процесса обучения. Во время обучения на каждом временном шаге декодер использует истинное значение предыдущего временного шага в качестве входного сигнала. Это известно как «принуждение учителя». При прогнозировании новых данных декодер использует предыдущие предсказанные значения вместо истинных значений.

Прогнозирование наиболее вероятного слова для каждого шага в последовательности может привести к неоптимальным результатам. Например, если декодер предсказывает, что первое слово подписи является «a», когда дано изображение слона, то вероятность предсказания «elephant» для следующего слова становится гораздо более маловероятной из-за крайне низкой вероятности появления фразы «a elephant» в английском тексте.

Чтобы решить эту проблему, можно использовать алгоритм поиска луча: вместо того, чтобы делать наиболее вероятное предсказание для каждого шага в последовательности, делайте верхние k предсказаний (индекс луча) и для каждого последующего шага держите верхние k предсказанных последовательностей до сих пор в соответствии с общей оценкой.

Создайте подпись нового изображения, извлекая элементы изображения, вводя их в кодировщик, а затем используя beamSearch функция, перечисленная в разделе «Функция поиска балки» примера.

img = imread("laika_sitting.jpg");
dlX = extractImageFeatures(dlnet,img,inputMin,inputMax,executionEnvironment);

beamIndex = 3;
maxNumWords = 20;
[words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
caption = join(words)
caption = 
"a dog is standing on a tile floor"

Отображение изображения с подписью.

figure
imshow(img)
title(caption)

Прогнозировать подписи для набора данных

Чтобы предсказать подписи для коллекции изображений, выполните цикл над мини-пакетами данных в хранилище данных и извлеките элементы из изображений с помощью extractImageFeatures функция. Затем закольцовывайте изображения в мини-пакете и создавайте подписи с помощью beamSearch функция.

Создайте хранилище данных дополненного изображения и задайте размер вывода в соответствии с размером ввода сверточной сети. Для вывода изображений в оттенках серого в виде 3-канальных изображений RGB установите 'ColorPreprocessing' опция для 'gray2rgb'.

tblFilenamesTest = table(cat(1,annotationsTest.Filename));
augimdsTest = augmentedImageDatastore(inputSizeNet,tblFilenamesTest,'ColorPreprocessing','gray2rgb')
augimdsTest = 
  augmentedImageDatastore with properties:

         NumObservations: 4139
           MiniBatchSize: 1
        DataAugmentation: 'none'
      ColorPreprocessing: 'gray2rgb'
              OutputSize: [299 299]
          OutputSizeMode: 'resize'
    DispatchInBackground: 0

Создание титров для тестовых данных. Прогнозирование титров на большом наборе данных может занять некоторое время. Если у вас есть Toolbox™ Parallel Computing, то вы можете делать прогнозы параллельно, создавая подписи внутри parfor смотри. Если у вас нет панели инструментов параллельных вычислений. затем parfor цикл выполняется в последовательном формате.

beamIndex = 2;
maxNumWords = 20;

numObservationsTest = numel(annotationsTest);
numIterationsTest = ceil(numObservationsTest/miniBatchSize);

captionsTestPred = strings(1,numObservationsTest);
documentsTestPred = tokenizedDocument(strings(1,numObservationsTest));

for i = 1:numIterationsTest
    % Mini-batch indices.
    idxStart = (i-1)*miniBatchSize+1;
    idxEnd = min(i*miniBatchSize,numObservationsTest);
    idx = idxStart:idxEnd;
    
    sz = numel(idx);
    
    % Read images.
    tbl = readByIndex(augimdsTest,idx);
    
    % Extract image features.
    X = cat(4,tbl.input{:});
    dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment);
    
    % Generate captions.
    captionsPredMiniBatch = strings(1,sz);
    documentsPredMiniBatch = tokenizedDocument(strings(1,sz));
    
    parfor j = 1:sz
        words = beamSearch(dlX(:,:,j),beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
        captionsPredMiniBatch(j) = join(words);
        documentsPredMiniBatch(j) = tokenizedDocument(words,'TokenizeMethod','none');
    end
    
    captionsTestPred(idx) = captionsPredMiniBatch;
    documentsTestPred(idx) = documentsPredMiniBatch;
end
Analyzing and transferring files to the workers ...done.

Для просмотра тестового изображения с соответствующей подписью используйте imshow и установите заголовок на прогнозируемый заголовок.

idx = 1;
tbl = readByIndex(augimdsTest,idx);
img = tbl.input{1};
figure
imshow(img)
title(captionsTestPred(idx))

Оценка точности модели

Чтобы оценить точность титров с помощью показателя BLEU, вычислите балл BLEU для каждого титра (кандидата) по сравнению с соответствующими титрами в тестовом наборе (ссылки) с помощью bleuEvaluationScore функция. Использование bleuEvaluationScore можно сравнить один документ-кандидат с несколькими ссылочными документами.

bleuEvaluationScore функция по умолчанию оценивает сходство с использованием n-граммов длины от одного до четырех. Поскольку подписи короткие, такое поведение может привести к неинформативным результатам, поскольку большинство баллов близки к нулю. Задайте длину n-грамма от одного до двух, установив значение 'NgramWeights' опция для двухэлементного вектора с равными весами.

ngramWeights = [0.5 0.5];

for i = 1:numObservationsTest
    annotation = annotationsTest(i);
    
    captionIDs = annotation.CaptionIDs;
    candidate = documentsTestPred(i);
    references = documentsAll(captionIDs);
    
    score = bleuEvaluationScore(candidate,references,'NgramWeights',ngramWeights);
    
    scores(i) = score;
end

Просмотрите средний балл BLEU.

scoreMean = mean(scores)
scoreMean = 0.4224

Визуализируйте оценки в гистограмме.

figure
histogram(scores)
xlabel("BLEU Score")
ylabel("Frequency")

Функция внимания

attention функция вычисляет вектор контекста и веса внимания, используя внимание Бахданау.

function [contextVector, attentionWeights] = attention(hidden,features,weights1, ...
    bias1,weights2,bias2,weightsV,biasV)

% Model dimensions.
[embeddingDimension,numFeatures,miniBatchSize] = size(features);
numHiddenUnits = size(weights1,1);

% Fully connect.
dlY1 = reshape(features,embeddingDimension, numFeatures*miniBatchSize);
dlY1 = fullyconnect(dlY1,weights1,bias1,'DataFormat','CB');
dlY1 = reshape(dlY1,numHiddenUnits,numFeatures,miniBatchSize);

% Fully connect.
dlY2 = fullyconnect(hidden,weights2,bias2,'DataFormat','CB');
dlY2 = reshape(dlY2,numHiddenUnits,1,miniBatchSize);

% Addition, tanh.
scores = tanh(dlY1 + dlY2);
scores = reshape(scores, numHiddenUnits, numFeatures*miniBatchSize);

% Fully connect, softmax.
attentionWeights = fullyconnect(scores,weightsV,biasV,'DataFormat','CB');
attentionWeights = reshape(attentionWeights,1,numFeatures,miniBatchSize);
attentionWeights = softmax(attentionWeights,'DataFormat','SCB');

% Context.
contextVector = attentionWeights .* features;
contextVector = squeeze(sum(contextVector,2));

end

Функция встраивания

embedding функция отображает массив индексов в последовательность вложенных векторов.

function Z = embedding(X, weights)

% Reshape inputs into a vector
[N, T] = size(X, 1:2);
X = reshape(X, N*T, 1);

% Index into embedding matrix
Z = weights(:, X);

% Reshape outputs by separating out batch and sequence dimensions
Z = reshape(Z, [], N, T);

end

Функция извлечения элементов

extractImageFeatures функция принимает в качестве входных данных обученный dlnetwork объект, входное изображение, статистика для масштабирования изображения и среда выполнения, и возвращает dlarray содержит элементы, извлеченные из предварительно обученной сети.

function dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment)

% Resize and rescale.
inputSize = dlnet.Layers(1).InputSize(1:2);
X = imresize(X,inputSize);
X = rescale(X,-1,1,'InputMin',inputMin,'InputMax',inputMax);

% Convert to dlarray.
dlX = dlarray(X,'SSCB');

% Convert to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

% Extract features and reshape.
dlX = predict(dlnet,dlX);
sz = size(dlX);
numFeatures = sz(1) * sz(2);
inputSizeEncoder = sz(3);
miniBatchSize = sz(4);
dlX = reshape(dlX,[numFeatures inputSizeEncoder miniBatchSize]);

end

Функция создания пакета

createBatch функция принимает в качестве входных данных мини-пакет данных, маркированные подписи, предварительно обученную сеть, статистику для масштабирования изображения, кодирование слова и среду выполнения, и возвращает мини-пакет данных, соответствующий извлеченным функциям изображения и подписям для обучения.

function [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment)

dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment);

% Convert documents to sequences of word indices.
T = doc2sequence(enc,documents,'PaddingDirection','right','PaddingValue',enc.NumWords+1);
T = cat(1,T{:});

% Convert mini-batch of data to dlarray.
dlT = dlarray(T);

% If training on a GPU, then convert data to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlT = gpuArray(dlT);
end

end

Функция модели кодировщика

modelEncoder функция принимает в качестве входных данных массив активаций dlX и пропускает его через полностью подключенную операцию и операцию ReLU. Для полностью подключенной операции используйте только размер канала. Чтобы применить полностью подключенную операцию только к размеру канала, выровняйте другие каналы в один размер и укажите этот размер как пакетное измерение с помощью 'DataFormat' вариант fullyconnect функция.

function dlY = modelEncoder(dlX,parametersEncoder)

[numFeatures,inputSizeEncoder,miniBatchSize] = size(dlX);

% Fully connect
weights = parametersEncoder.fc.Weights;
bias = parametersEncoder.fc.Bias;
embeddingDimension = size(weights,1);

dlX = permute(dlX,[2 1 3]);
dlX = reshape(dlX,inputSizeEncoder,numFeatures*miniBatchSize);
dlY = fullyconnect(dlX,weights,bias,'DataFormat','CB');
dlY = reshape(dlY,embeddingDimension,numFeatures,miniBatchSize);

% ReLU
dlY = relu(dlY);

end

Функция модели декодера

modelDecoder функция принимает в качестве ввода один шаг времени dlX, параметры модели декодера, признаки от кодера и состояние сети и возвращает прогнозы для следующего временного шага, обновленного состояния сети и весов внимания.

function [dlY,state,attentionWeights] = modelDecoder(dlX,parametersDecoder,features,state)

hiddenState = state.gru.HiddenState;

% Attention
weights1 = parametersDecoder.attention.Weights1;
bias1 = parametersDecoder.attention.Bias1;
weights2 = parametersDecoder.attention.Weights2;
bias2 = parametersDecoder.attention.Bias2;
weightsV = parametersDecoder.attention.WeightsV;
biasV = parametersDecoder.attention.BiasV;
[contextVector, attentionWeights] = attention(hiddenState,features,weights1,bias1,weights2,bias2,weightsV,biasV);

% Embedding
weights = parametersDecoder.emb.Weights;
dlX = embedding(dlX,weights);

% Concatenate
dlY = cat(1,contextVector,dlX);

% GRU
inputWeights = parametersDecoder.gru.InputWeights;
recurrentWeights = parametersDecoder.gru.RecurrentWeights;
bias = parametersDecoder.gru.Bias;
[dlY, hiddenState] = gru(dlY, hiddenState, inputWeights, recurrentWeights, bias, 'DataFormat','CBT');

% Update state
state.gru.HiddenState = hiddenState;

% Fully connect
weights = parametersDecoder.fc1.Weights;
bias = parametersDecoder.fc1.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

% Fully connect
weights = parametersDecoder.fc2.Weights;
bias = parametersDecoder.fc2.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

end

Градиенты модели

modelGradients функция принимает в качестве входных параметры кодера и декодера, кодер имеет dlXи целевой заголовок dlTи возвращает градиенты параметров кодера и декодера относительно потерь, потерь и прогнозов.

function [gradientsEncoder,gradientsDecoder,loss,dlYPred] = ...
    modelGradients(parametersEncoder,parametersDecoder,dlX,dlT)

miniBatchSize = size(dlX,3);
sequenceLength = size(dlT,2) - 1;
vocabSize = size(parametersDecoder.emb.Weights,2);

% Model encoder
features = modelEncoder(dlX,parametersEncoder);

% Initialize state
numHiddenUnits = size(parametersDecoder.attention.Weights1,1);
state = struct;
state.gru.HiddenState = dlarray(zeros([numHiddenUnits miniBatchSize],'single'));

dlYPred = dlarray(zeros([vocabSize miniBatchSize sequenceLength],'like',dlX));
loss = dlarray(single(0));

padToken = vocabSize;

for t = 1:sequenceLength
    decoderInput = dlT(:,t);
    
    dlYReal = dlT(:,t+1);
    
    [dlYPred(:,:,t),state] = modelDecoder(decoderInput,parametersDecoder,features,state);
    
    mask = dlYReal ~= padToken;
    
    loss = loss + sparseCrossEntropyAndSoftmax(dlYPred(:,:,t),dlYReal,mask);
end

% Calculate gradients
[gradientsEncoder,gradientsDecoder] = dlgradient(loss, parametersEncoder,parametersDecoder);

end

Функция разреженной перекрестной энтропии и потерь Softmax

sparseCrossEntropyAndSoftmax принимает за входные данные прогнозы dlY, соответствующие цели dlTи маску заполнения последовательности, и применяет softmax и возвращает потери перекрестной энтропии.

function loss = sparseCrossEntropyAndSoftmax(dlY, dlT, mask)

miniBatchSize = size(dlY, 2);

% Softmax.
dlY = softmax(dlY,'DataFormat','CB');

% Find rows corresponding to the target words.
idx = sub2ind(size(dlY), dlT', 1:miniBatchSize);
dlY = dlY(idx);

% Bound away from zero.
dlY = max(dlY, single(1e-8));

% Masked loss.
loss = log(dlY) .* mask';
loss = -sum(loss,'all') ./ miniBatchSize;

end

Функция поиска балки

beamSearch функция принимает в качестве входных признаков изображения dlX, индекс луча, параметры для сетей кодера и декодера, кодирование слова и максимальная длина последовательности, и возвращает титульные слова для изображения с использованием алгоритма поиска луча.

function [words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder, ...
    enc,maxNumWords)

% Model dimensions
numFeatures = size(dlX,1);
numHiddenUnits = size(parametersDecoder.attention.Weights1,1);

% Extract features
features = modelEncoder(dlX,parametersEncoder);

% Initialize state
state = struct;
state.gru.HiddenState = dlarray(zeros([numHiddenUnits 1],'like',dlX));

% Initialize candidates
candidates = struct;
candidates.State = state;
candidates.Words = "<start>";
candidates.Score = 0;
candidates.AttentionScores = dlarray(zeros([numFeatures maxNumWords],'like',dlX));
candidates.StopFlag = false;

t = 0;

% Loop over words
while t < maxNumWords
    t = t + 1;
    
    candidatesNew = [];
    
    % Loop over candidates
    for i = 1:numel(candidates)
        
        % Stop generating when stop token is predicted
        if candidates(i).StopFlag
            continue
        end
        
        % Candidate details
        state = candidates(i).State;
        words = candidates(i).Words;
        score = candidates(i).Score;
        attentionScores = candidates(i).AttentionScores;
        
        % Predict next token
        decoderInput = word2ind(enc,words(end));
        [dlYPred,state,attentionScores(:,t)] = modelDecoder(decoderInput,parametersDecoder,features,state);
        
        dlYPred = softmax(dlYPred,'DataFormat','CB');
        [scoresTop,idxTop] = maxk(extractdata(dlYPred),beamIndex);
        idxTop = gather(idxTop);
        
        % Loop over top predictions
        for j = 1:beamIndex
            candidate = struct;
            
            candidateWord = ind2word(enc,idxTop(j));
            candidateScore = scoresTop(j);
            
            if candidateWord == "<stop>"
                candidate.StopFlag = true;
                attentionScores(:,t+1:end) = [];
            else
                candidate.StopFlag = false;
            end
            
            candidate.State = state;
            candidate.Words = [words candidateWord];
            candidate.Score = score + log(candidateScore);
            candidate.AttentionScores = attentionScores;
            
            candidatesNew = [candidatesNew candidate];
        end
    end
    
    % Get top candidates
    [~,idx] = maxk([candidatesNew.Score],beamIndex);
    candidates = candidatesNew(idx);
    
    % Stop predicting when all candidates have stop token
    if all([candidates.StopFlag])
        break
    end
end

% Get top candidate
words = candidates(1).Words(2:end-1);
attentionScores = candidates(1).AttentionScores;

end

Функция инициализации веса Glorot

initializeGlorot генерирует массив весов в соответствии с инициализацией Glorot.

function weights = initializeGlorot(numOut, numIn)

varWeights = sqrt( 6 / (numIn + numOut) );
weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1);

end

См. также

| | | | | | | | | (инструментарий для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста)

Связанные темы