В этом примере показано, как обучить модель глубокого обучения субтитрам изображений с использованием внимания.
Большинство предварительно подготовленных сетей глубокого обучения сконфигурированы для классификации с одной меткой. Например, учитывая изображение типичного офисного стола, сеть может предсказать один класс «клавиатура» или «мышь». Напротив, модель субтитров изображения объединяет сверточные и повторяющиеся операции для получения текстового описания того, что находится в изображении, а не одной метки.
Эта модель, обученная в этом примере, использует архитектуру кодера-декодера. Кодер представляет собой предварительно подготовленную сеть Inception-v3, используемую в качестве экстрактора признаков. Декодер представляет собой повторяющуюся нейронную сеть (RNN), которая принимает извлеченные признаки в качестве входных данных и генерирует заголовок. Декодер включает в себя механизм внимания, который позволяет декодеру фокусироваться на частях кодированного входа при генерации подписи.

Модель кодера является предварительно подготовленной моделью Inception-v3, которая извлекает элементы из "mixed10" уровень с последующим полным подключением и операциями ReLU.

Модель декодера состоит из встраивания слова, механизма внимания, стробируемого повторяющегося блока (ГРУ) и двух полностью соединенных операций.

Загрузите предварительно подготовленную сеть Inception-v3. Для выполнения этого шага необходим пакет поддержки Deep Learning Toolbox™ Model for Inception-v3 Network. Если необходимый пакет поддержки не установлен, программа предоставляет ссылку для загрузки.
net = inceptionv3; inputSizeNet = net.Layers(1).InputSize;
Преобразование сети в dlnetwork для извлечения элемента и удаления последних четырех слоев, оставив "mixed10" слой как последний слой.
lgraph = layerGraph(net); lgraph = removeLayers(lgraph,["avg_pool" "predictions" "predictions_softmax" "ClassificationLayer_predictions"]);
Просмотр входного уровня сети. Сеть Inception-v3 использует нормализацию симметричного масштаба с минимальным значением 0 и максимальным значением 255.
lgraph.Layers(1)
ans =
ImageInputLayer with properties:
Name: 'input_1'
InputSize: [299 299 3]
Hyperparameters
DataAugmentation: 'none'
Normalization: 'rescale-symmetric'
NormalizationDimension: 'auto'
Max: 255
Min: 0
Пользовательское обучение не поддерживает эту нормализацию, поэтому необходимо отключить нормализацию в сети и выполнить нормализацию в пользовательском цикле обучения. Сохранить минимальное и максимальное значения как двойные в переменных с именем inputMin и inputMax, соответственно, и заменить входной слой на входной слой изображения без нормализации.
inputMin = double(lgraph.Layers(1).Min); inputMax = double(lgraph.Layers(1).Max); layer = imageInputLayer(inputSizeNet,"Normalization","none",'Name','input'); lgraph = replaceLayer(lgraph,'input_1',layer);
Определите размер выходного сигнала сети. Используйте analyzeNetwork для просмотра размеров активации последнего слоя. Анализатор сети Deep Learning Network Analyzer показывает некоторые проблемы, связанные с сетью, которые можно безопасно игнорировать для пользовательских рабочих процессов обучения.
analyzeNetwork(lgraph)

Создание переменной с именем outputSizeNet содержит размер сетевого вывода.
outputSizeNet = [8 8 2048];
Преобразование графика слоев в dlnetwork и просмотрите выходной слой. Выходной уровень - "mixed10" уровень сети Inception-v3.
dlnet = dlnetwork(lgraph)
dlnet =
dlnetwork with properties:
Layers: [311×1 nnet.cnn.layer.Layer]
Connections: [345×2 table]
Learnables: [376×3 table]
State: [188×3 table]
InputNames: {'input'}
OutputNames: {'mixed10'}
Загружайте изображения и аннотации из наборов данных «Изображения поезда 2014 года» и «Аннотации поезда/вала 2014 года» соответственно из https://cocodataset.org/#download. Извлечение изображений и аннотаций в папку с именем "coco". Набор данных COCO 2014 был собран Coco Consortium.
Извлеките подписи из файла "captions_train2014.json" с использованием jsondecode функция.
dataFolder = fullfile(tempdir,"coco"); filename = fullfile(dataFolder,"annotations_trainval2014","annotations","captions_train2014.json"); str = fileread(filename); data = jsondecode(str)
data = struct with fields:
info: [1×1 struct]
images: [82783×1 struct]
licenses: [8×1 struct]
annotations: [414113×1 struct]
annotations поле структуры содержит данные, необходимые для субтитров изображения.
data.annotations
ans=414113×1 struct array with fields:
image_id
id
caption
Набор данных содержит несколько заголовков для каждого изображения. Чтобы одни и те же изображения не появлялись как в обучающих наборах, так и в наборах проверки, определите уникальные изображения в наборе данных с помощью unique используя идентификаторы в image_id поле поля аннотаций данных, затем просмотрите количество уникальных изображений.
numObservationsAll = numel(data.annotations)
numObservationsAll = 414113
imageIDs = [data.annotations.image_id]; imageIDsUnique = unique(imageIDs); numUniqueImages = numel(imageIDsUnique)
numUniqueImages = 82783
Каждое изображение имеет не менее пяти титров. Создание структуры annotationsAll с этими полями:
ImageID - Идентификатор изображения
Filename - Имя файла изображения
Captions - Строковый массив необработанных субтитров
CaptionIDs - Вектор индексов соответствующих заголовков вdata.annotations
Чтобы упростить объединение, сортируйте аннотации по идентификаторам изображений.
[~,idx] = sort([data.annotations.image_id]); data.annotations = data.annotations(idx);
Закольцовывайте аннотации и при необходимости объединяйте несколько аннотаций.
i = 0; j = 0; imageIDPrev = 0; while i < numel(data.annotations) i = i + 1; imageID = data.annotations(i).image_id; caption = string(data.annotations(i).caption); if imageID ~= imageIDPrev % Create new entry j = j + 1; annotationsAll(j).ImageID = imageID; annotationsAll(j).Filename = fullfile(dataFolder,"train2014","COCO_train2014_" + pad(string(imageID),12,'left','0') + ".jpg"); annotationsAll(j).Captions = caption; annotationsAll(j).CaptionIDs = i; else % Append captions annotationsAll(j).Captions = [annotationsAll(j).Captions; caption]; annotationsAll(j).CaptionIDs = [annotationsAll(j).CaptionIDs; i]; end imageIDPrev = imageID; end
Разбейте данные на обучающие и проверочные наборы. Выдержать 5% наблюдений для тестирования.
cvp = cvpartition(numel(annotationsAll),'HoldOut',0.05);
idxTrain = training(cvp);
idxTest = test(cvp);
annotationsTrain = annotationsAll(idxTrain);
annotationsTest = annotationsAll(idxTest);Структура содержит три поля:
id - Уникальный идентификатор подписи
caption - Подпись к изображению, заданная как вектор символа
image_id - Уникальный идентификатор изображения, соответствующего подписи
Для просмотра изображения и соответствующей подписи найдите файл изображения с именем файла "train2014\COCO_train2014_XXXXXXXXXXXX.jpg", где "XXXXXXXXXXXX" соответствует идентификатору изображения, заполненному влево нулями и имеющему длину 12.
imageID = annotationsTrain(1).ImageID; captions = annotationsTrain(1).Captions; filename = annotationsTrain(1).Filename;
Для просмотра изображения используйте imread и imshow функции.
img = imread(filename); figure imshow(img) title(captions)
Подготовьте подписи для обучения и тестирования. Извлечь текст из Captions поле структуры, содержащее как учебные, так и тестовые данные (annotationsAll), стереть пунктуацию и преобразовать текст в строчный.
captionsAll = cat(1,annotationsAll.Captions); captionsAll = erasePunctuation(captionsAll); captionsAll = lower(captionsAll);
Чтобы генерировать подписи, декодер RNN требует специальных маркеров запуска и остановки, чтобы указать, когда начинать и прекращать генерировать текст, соответственно. Добавить пользовательские маркеры "<start>" и "<stop>" до начала и конца титров соответственно.
captionsAll = "<start>" + captionsAll + "<stop>";
Маркировать подписи с помощью tokenizedDocument и укажите маркеры запуска и остановки с помощью 'CustomTokens' вариант.
documentsAll = tokenizedDocument(captionsAll,'CustomTokens',["<start>" "<stop>"]);
Создать wordEncoding объект, сопоставляющий слова числовым индексам и обратно. Уменьшите требования к памяти, указав размер словаря 5000, соответствующий наиболее часто встречающимся словам в данных обучения. Во избежание предвзятости используйте только документы, соответствующие обучающему набору.
enc = wordEncoding(documentsAll(idxTrain),'MaxNumWords',5000,'Order','frequency');
Создайте хранилище данных дополненного изображения, содержащее изображения, соответствующие заголовкам. Задайте размер выходного сигнала, соответствующий размеру входного сигнала сверточной сети. Чтобы синхронизировать изображения с подписями, укажите таблицу имен файлов для хранилища данных, восстановив имена файлов с помощью идентификатора изображения. Для возврата изображений в оттенках серого в виде 3-канальных изображений RGB установите 'ColorPreprocessing' опция для 'gray2rgb'.
tblFilenames = table(cat(1,annotationsTrain.Filename)); augimdsTrain = augmentedImageDatastore(inputSizeNet,tblFilenames,'ColorPreprocessing','gray2rgb')
augimdsTrain =
augmentedImageDatastore with properties:
NumObservations: 78644
MiniBatchSize: 1
DataAugmentation: 'none'
ColorPreprocessing: 'gray2rgb'
OutputSize: [299 299]
OutputSizeMode: 'resize'
DispatchInBackground: 0
Инициализируйте параметры модели. Укажите 512 скрытых единиц с размером вставки слова 256.
embeddingDimension = 256; numHiddenUnits = 512;
Инициализируйте структуру, содержащую параметры для модели кодировщика.
Инициализируйте веса полностью подключенных операций с помощью инициализатора Glorot, указанного initializeGlorot функция, перечисленная в конце примера. Укажите размер выходного сигнала, соответствующий размеру внедрения декодера (256), и размер входного сигнала, соответствующий количеству выходных каналов предварительно обученной сети. 'mixed10' уровень сети Inception-v3 выводит данные по 2048 каналам.
numFeatures = outputSizeNet(1) * outputSizeNet(2); inputSizeEncoder = outputSizeNet(3); parametersEncoder = struct; % Fully connect parametersEncoder.fc.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeEncoder)); parametersEncoder.fc.Bias = dlarray(zeros([embeddingDimension 1],'single'));
Инициализируйте структуру, содержащую параметры для модели декодера.
Инициализируйте весы встраивания слов с размером, заданным размерностью встраивания, и словарным размером плюс один, где дополнительная запись соответствует значению заполнения.
Инициализируйте веса и смещения для механизма внимания Бахданау с размерами, соответствующими количеству скрытых единиц операции ГРУ.
Инициализируйте вес и смещение операции GRU.
Инициализируйте веса и смещения двух полностью связанных операций.
Для параметров декодера модели инициализируйте каждый из весов и смещений с помощью инициализатора Глорота и нулей соответственно.
inputSizeDecoder = enc.NumWords + 1; parametersDecoder = struct; % Word embedding parametersDecoder.emb.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeDecoder)); % Attention parametersDecoder.attention.Weights1 = dlarray(initializeGlorot(numHiddenUnits,embeddingDimension)); parametersDecoder.attention.Bias1 = dlarray(zeros([numHiddenUnits 1],'single')); parametersDecoder.attention.Weights2 = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits)); parametersDecoder.attention.Bias2 = dlarray(zeros([numHiddenUnits 1],'single')); parametersDecoder.attention.WeightsV = dlarray(initializeGlorot(1,numHiddenUnits)); parametersDecoder.attention.BiasV = dlarray(zeros(1,1,'single')); % GRU parametersDecoder.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,2*embeddingDimension)); parametersDecoder.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits)); parametersDecoder.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single')); % Fully connect parametersDecoder.fc1.Weights = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits)); parametersDecoder.fc1.Bias = dlarray(zeros([numHiddenUnits 1],'single')); % Fully connect parametersDecoder.fc2.Weights = dlarray(initializeGlorot(enc.NumWords+1,numHiddenUnits)); parametersDecoder.fc2.Bias = dlarray(zeros([enc.NumWords+1 1],'single'));
Создание функций modelEncoder и modelDecoder, перечисленных в конце примера, которые вычисляют выходные сигналы моделей кодера и декодера соответственно.
modelEncoder функция, перечисленная в разделе Encoder Model Function примера, принимает в качестве входных данных массив активаций dlX с выхода предварительно обученной сети и пропускает ее через полностью подключенную операцию и операцию ReLU. Поскольку предварительно обученную сеть не нужно отслеживать для автоматического дифференцирования, извлечение признаков вне функции модели кодера является более эффективным с точки зрения вычислений.
modelDecoder функция, перечисленная в разделе «Функция модели декодера» примера, принимает в качестве входных данных один входной временной шаг, соответствующий входному слову, параметрам модели декодера, признакам от кодера и состоянию сети, и возвращает прогнозы для следующего временного шага, обновленного состояния сети и весов внимания.
Укажите параметры обучения. Потренироваться на 30 эпох с размером мини-партии 128 и отобразить ход обучения на графике.
miniBatchSize = 128;
numEpochs = 30;
plots = "training-progress";Обучение на GPU, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).
executionEnvironment = "auto";Обучение сети с использованием пользовательского цикла обучения.
В начале каждой эпохи перетасуйте входные данные. Чтобы синхронизировать изображения в хранилище данных дополненного изображения и подписи, создайте массив перетасованных индексов, который индексирует в оба набора данных.
Для каждой мини-партии:
Масштабируйте изображения до размера, ожидаемого предварительно подготовленной сетью.
Для каждого изображения выберите случайную подпись.
Преобразуйте подписи в последовательности индексов слов. Укажите правое заполнение последовательностей со значением заполнения, соответствующим индексу маркера заполнения.
Преобразовать данные в dlarray объекты. Для изображений укажите метки размеров 'SSCB' (пространственный, пространственный, канальный, пакетный).
Для обучения GPU преобразуйте данные в gpuArray объекты.
Извлеките элементы изображения, используя предварительно подготовленную сеть, и измените их размер до ожидаемого кодировщиком размера.
Оцените градиенты модели и потери с помощью dlfeval и modelGradients функции.
Обновление параметров модели кодера и декодера с помощью adamupdate функция.
Отображение хода обучения на графике.
Инициализируйте параметры оптимизатора Adam.
trailingAvgEncoder = []; trailingAvgSqEncoder = []; trailingAvgDecoder = []; trailingAvgSqDecoder = [];
Инициализируйте график хода обучения. Создайте анимированную линию, отображающую потери в соответствии с соответствующей итерацией.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); xlabel("Iteration") ylabel("Loss") ylim([0 inf]) grid on end
Тренируйте модель.
iteration = 0; numObservationsTrain = numel(annotationsTrain); numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize); start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. idxShuffle = randperm(numObservationsTrain); % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Determine mini-batch indices. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; idxMiniBatch = idxShuffle(idx); % Read mini-batch of data. tbl = readByIndex(augimdsTrain,idxMiniBatch); X = cat(4,tbl.input{:}); annotations = annotationsTrain(idxMiniBatch); % For each image, select random caption. idx = cellfun(@(captionIDs) randsample(captionIDs,1),{annotations.CaptionIDs}); documents = documentsAll(idx); % Create batch of data. [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment); % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ... parametersDecoder, dlX, dlT); % Update encoder using adamupdate. [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ... gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration); % Update decoder using adamupdate. [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ... gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end

Процесс создания подписи отличается от процесса обучения. Во время обучения на каждом временном шаге декодер использует истинное значение предыдущего временного шага в качестве входного сигнала. Это известно как «принуждение учителя». При прогнозировании новых данных декодер использует предыдущие предсказанные значения вместо истинных значений.
Прогнозирование наиболее вероятного слова для каждого шага в последовательности может привести к неоптимальным результатам. Например, если декодер предсказывает, что первое слово подписи является «a», когда дано изображение слона, то вероятность предсказания «elephant» для следующего слова становится гораздо более маловероятной из-за крайне низкой вероятности появления фразы «a elephant» в английском тексте.
Чтобы решить эту проблему, можно использовать алгоритм поиска луча: вместо того, чтобы делать наиболее вероятное предсказание для каждого шага в последовательности, делайте верхние k предсказаний (индекс луча) и для каждого последующего шага держите верхние k предсказанных последовательностей до сих пор в соответствии с общей оценкой.
Создайте подпись нового изображения, извлекая элементы изображения, вводя их в кодировщик, а затем используя beamSearch функция, перечисленная в разделе «Функция поиска балки» примера.
img = imread("laika_sitting.jpg");
dlX = extractImageFeatures(dlnet,img,inputMin,inputMax,executionEnvironment);
beamIndex = 3;
maxNumWords = 20;
[words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
caption = join(words)caption = "a dog is standing on a tile floor"
Отображение изображения с подписью.
figure imshow(img) title(caption)

Чтобы предсказать подписи для коллекции изображений, выполните цикл над мини-пакетами данных в хранилище данных и извлеките элементы из изображений с помощью extractImageFeatures функция. Затем закольцовывайте изображения в мини-пакете и создавайте подписи с помощью beamSearch функция.
Создайте хранилище данных дополненного изображения и задайте размер вывода в соответствии с размером ввода сверточной сети. Для вывода изображений в оттенках серого в виде 3-канальных изображений RGB установите 'ColorPreprocessing' опция для 'gray2rgb'.
tblFilenamesTest = table(cat(1,annotationsTest.Filename)); augimdsTest = augmentedImageDatastore(inputSizeNet,tblFilenamesTest,'ColorPreprocessing','gray2rgb')
augimdsTest =
augmentedImageDatastore with properties:
NumObservations: 4139
MiniBatchSize: 1
DataAugmentation: 'none'
ColorPreprocessing: 'gray2rgb'
OutputSize: [299 299]
OutputSizeMode: 'resize'
DispatchInBackground: 0
Создание титров для тестовых данных. Прогнозирование титров на большом наборе данных может занять некоторое время. Если у вас есть Toolbox™ Parallel Computing, то вы можете делать прогнозы параллельно, создавая подписи внутри parfor смотри. Если у вас нет панели инструментов параллельных вычислений. затем parfor цикл выполняется в последовательном формате.
beamIndex = 2; maxNumWords = 20; numObservationsTest = numel(annotationsTest); numIterationsTest = ceil(numObservationsTest/miniBatchSize); captionsTestPred = strings(1,numObservationsTest); documentsTestPred = tokenizedDocument(strings(1,numObservationsTest)); for i = 1:numIterationsTest % Mini-batch indices. idxStart = (i-1)*miniBatchSize+1; idxEnd = min(i*miniBatchSize,numObservationsTest); idx = idxStart:idxEnd; sz = numel(idx); % Read images. tbl = readByIndex(augimdsTest,idx); % Extract image features. X = cat(4,tbl.input{:}); dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment); % Generate captions. captionsPredMiniBatch = strings(1,sz); documentsPredMiniBatch = tokenizedDocument(strings(1,sz)); parfor j = 1:sz words = beamSearch(dlX(:,:,j),beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords); captionsPredMiniBatch(j) = join(words); documentsPredMiniBatch(j) = tokenizedDocument(words,'TokenizeMethod','none'); end captionsTestPred(idx) = captionsPredMiniBatch; documentsTestPred(idx) = documentsPredMiniBatch; end
Analyzing and transferring files to the workers ...done.
Для просмотра тестового изображения с соответствующей подписью используйте imshow и установите заголовок на прогнозируемый заголовок.
idx = 1;
tbl = readByIndex(augimdsTest,idx);
img = tbl.input{1};
figure
imshow(img)
title(captionsTestPred(idx))
Чтобы оценить точность титров с помощью показателя BLEU, вычислите балл BLEU для каждого титра (кандидата) по сравнению с соответствующими титрами в тестовом наборе (ссылки) с помощью bleuEvaluationScore функция. Использование bleuEvaluationScore можно сравнить один документ-кандидат с несколькими ссылочными документами.
bleuEvaluationScore функция по умолчанию оценивает сходство с использованием n-граммов длины от одного до четырех. Поскольку подписи короткие, такое поведение может привести к неинформативным результатам, поскольку большинство баллов близки к нулю. Задайте длину n-грамма от одного до двух, установив значение 'NgramWeights' опция для двухэлементного вектора с равными весами.
ngramWeights = [0.5 0.5]; for i = 1:numObservationsTest annotation = annotationsTest(i); captionIDs = annotation.CaptionIDs; candidate = documentsTestPred(i); references = documentsAll(captionIDs); score = bleuEvaluationScore(candidate,references,'NgramWeights',ngramWeights); scores(i) = score; end
Просмотрите средний балл BLEU.
scoreMean = mean(scores)
scoreMean = 0.4224
Визуализируйте оценки в гистограмме.
figure histogram(scores) xlabel("BLEU Score") ylabel("Frequency")

attention функция вычисляет вектор контекста и веса внимания, используя внимание Бахданау.
function [contextVector, attentionWeights] = attention(hidden,features,weights1, ... bias1,weights2,bias2,weightsV,biasV) % Model dimensions. [embeddingDimension,numFeatures,miniBatchSize] = size(features); numHiddenUnits = size(weights1,1); % Fully connect. dlY1 = reshape(features,embeddingDimension, numFeatures*miniBatchSize); dlY1 = fullyconnect(dlY1,weights1,bias1,'DataFormat','CB'); dlY1 = reshape(dlY1,numHiddenUnits,numFeatures,miniBatchSize); % Fully connect. dlY2 = fullyconnect(hidden,weights2,bias2,'DataFormat','CB'); dlY2 = reshape(dlY2,numHiddenUnits,1,miniBatchSize); % Addition, tanh. scores = tanh(dlY1 + dlY2); scores = reshape(scores, numHiddenUnits, numFeatures*miniBatchSize); % Fully connect, softmax. attentionWeights = fullyconnect(scores,weightsV,biasV,'DataFormat','CB'); attentionWeights = reshape(attentionWeights,1,numFeatures,miniBatchSize); attentionWeights = softmax(attentionWeights,'DataFormat','SCB'); % Context. contextVector = attentionWeights .* features; contextVector = squeeze(sum(contextVector,2)); end
embedding функция отображает массив индексов в последовательность вложенных векторов.
function Z = embedding(X, weights) % Reshape inputs into a vector [N, T] = size(X, 1:2); X = reshape(X, N*T, 1); % Index into embedding matrix Z = weights(:, X); % Reshape outputs by separating out batch and sequence dimensions Z = reshape(Z, [], N, T); end
extractImageFeatures функция принимает в качестве входных данных обученный dlnetwork объект, входное изображение, статистика для масштабирования изображения и среда выполнения, и возвращает dlarray содержит элементы, извлеченные из предварительно обученной сети.
function dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment) % Resize and rescale. inputSize = dlnet.Layers(1).InputSize(1:2); X = imresize(X,inputSize); X = rescale(X,-1,1,'InputMin',inputMin,'InputMax',inputMax); % Convert to dlarray. dlX = dlarray(X,'SSCB'); % Convert to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Extract features and reshape. dlX = predict(dlnet,dlX); sz = size(dlX); numFeatures = sz(1) * sz(2); inputSizeEncoder = sz(3); miniBatchSize = sz(4); dlX = reshape(dlX,[numFeatures inputSizeEncoder miniBatchSize]); end
createBatch функция принимает в качестве входных данных мини-пакет данных, маркированные подписи, предварительно обученную сеть, статистику для масштабирования изображения, кодирование слова и среду выполнения, и возвращает мини-пакет данных, соответствующий извлеченным функциям изображения и подписям для обучения.
function [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment) dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment); % Convert documents to sequences of word indices. T = doc2sequence(enc,documents,'PaddingDirection','right','PaddingValue',enc.NumWords+1); T = cat(1,T{:}); % Convert mini-batch of data to dlarray. dlT = dlarray(T); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlT = gpuArray(dlT); end end
modelEncoder функция принимает в качестве входных данных массив активаций dlX и пропускает его через полностью подключенную операцию и операцию ReLU. Для полностью подключенной операции используйте только размер канала. Чтобы применить полностью подключенную операцию только к размеру канала, выровняйте другие каналы в один размер и укажите этот размер как пакетное измерение с помощью 'DataFormat' вариант fullyconnect функция.
function dlY = modelEncoder(dlX,parametersEncoder) [numFeatures,inputSizeEncoder,miniBatchSize] = size(dlX); % Fully connect weights = parametersEncoder.fc.Weights; bias = parametersEncoder.fc.Bias; embeddingDimension = size(weights,1); dlX = permute(dlX,[2 1 3]); dlX = reshape(dlX,inputSizeEncoder,numFeatures*miniBatchSize); dlY = fullyconnect(dlX,weights,bias,'DataFormat','CB'); dlY = reshape(dlY,embeddingDimension,numFeatures,miniBatchSize); % ReLU dlY = relu(dlY); end
modelDecoder функция принимает в качестве ввода один шаг времени dlX, параметры модели декодера, признаки от кодера и состояние сети и возвращает прогнозы для следующего временного шага, обновленного состояния сети и весов внимания.
function [dlY,state,attentionWeights] = modelDecoder(dlX,parametersDecoder,features,state) hiddenState = state.gru.HiddenState; % Attention weights1 = parametersDecoder.attention.Weights1; bias1 = parametersDecoder.attention.Bias1; weights2 = parametersDecoder.attention.Weights2; bias2 = parametersDecoder.attention.Bias2; weightsV = parametersDecoder.attention.WeightsV; biasV = parametersDecoder.attention.BiasV; [contextVector, attentionWeights] = attention(hiddenState,features,weights1,bias1,weights2,bias2,weightsV,biasV); % Embedding weights = parametersDecoder.emb.Weights; dlX = embedding(dlX,weights); % Concatenate dlY = cat(1,contextVector,dlX); % GRU inputWeights = parametersDecoder.gru.InputWeights; recurrentWeights = parametersDecoder.gru.RecurrentWeights; bias = parametersDecoder.gru.Bias; [dlY, hiddenState] = gru(dlY, hiddenState, inputWeights, recurrentWeights, bias, 'DataFormat','CBT'); % Update state state.gru.HiddenState = hiddenState; % Fully connect weights = parametersDecoder.fc1.Weights; bias = parametersDecoder.fc1.Bias; dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB'); % Fully connect weights = parametersDecoder.fc2.Weights; bias = parametersDecoder.fc2.Bias; dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB'); end
modelGradients функция принимает в качестве входных параметры кодера и декодера, кодер имеет dlXи целевой заголовок dlTи возвращает градиенты параметров кодера и декодера относительно потерь, потерь и прогнозов.
function [gradientsEncoder,gradientsDecoder,loss,dlYPred] = ... modelGradients(parametersEncoder,parametersDecoder,dlX,dlT) miniBatchSize = size(dlX,3); sequenceLength = size(dlT,2) - 1; vocabSize = size(parametersDecoder.emb.Weights,2); % Model encoder features = modelEncoder(dlX,parametersEncoder); % Initialize state numHiddenUnits = size(parametersDecoder.attention.Weights1,1); state = struct; state.gru.HiddenState = dlarray(zeros([numHiddenUnits miniBatchSize],'single')); dlYPred = dlarray(zeros([vocabSize miniBatchSize sequenceLength],'like',dlX)); loss = dlarray(single(0)); padToken = vocabSize; for t = 1:sequenceLength decoderInput = dlT(:,t); dlYReal = dlT(:,t+1); [dlYPred(:,:,t),state] = modelDecoder(decoderInput,parametersDecoder,features,state); mask = dlYReal ~= padToken; loss = loss + sparseCrossEntropyAndSoftmax(dlYPred(:,:,t),dlYReal,mask); end % Calculate gradients [gradientsEncoder,gradientsDecoder] = dlgradient(loss, parametersEncoder,parametersDecoder); end
sparseCrossEntropyAndSoftmax принимает за входные данные прогнозы dlY, соответствующие цели dlTи маску заполнения последовательности, и применяет softmax и возвращает потери перекрестной энтропии.
function loss = sparseCrossEntropyAndSoftmax(dlY, dlT, mask) miniBatchSize = size(dlY, 2); % Softmax. dlY = softmax(dlY,'DataFormat','CB'); % Find rows corresponding to the target words. idx = sub2ind(size(dlY), dlT', 1:miniBatchSize); dlY = dlY(idx); % Bound away from zero. dlY = max(dlY, single(1e-8)); % Masked loss. loss = log(dlY) .* mask'; loss = -sum(loss,'all') ./ miniBatchSize; end
beamSearch функция принимает в качестве входных признаков изображения dlX, индекс луча, параметры для сетей кодера и декодера, кодирование слова и максимальная длина последовательности, и возвращает титульные слова для изображения с использованием алгоритма поиска луча.
function [words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder, ... enc,maxNumWords) % Model dimensions numFeatures = size(dlX,1); numHiddenUnits = size(parametersDecoder.attention.Weights1,1); % Extract features features = modelEncoder(dlX,parametersEncoder); % Initialize state state = struct; state.gru.HiddenState = dlarray(zeros([numHiddenUnits 1],'like',dlX)); % Initialize candidates candidates = struct; candidates.State = state; candidates.Words = "<start>"; candidates.Score = 0; candidates.AttentionScores = dlarray(zeros([numFeatures maxNumWords],'like',dlX)); candidates.StopFlag = false; t = 0; % Loop over words while t < maxNumWords t = t + 1; candidatesNew = []; % Loop over candidates for i = 1:numel(candidates) % Stop generating when stop token is predicted if candidates(i).StopFlag continue end % Candidate details state = candidates(i).State; words = candidates(i).Words; score = candidates(i).Score; attentionScores = candidates(i).AttentionScores; % Predict next token decoderInput = word2ind(enc,words(end)); [dlYPred,state,attentionScores(:,t)] = modelDecoder(decoderInput,parametersDecoder,features,state); dlYPred = softmax(dlYPred,'DataFormat','CB'); [scoresTop,idxTop] = maxk(extractdata(dlYPred),beamIndex); idxTop = gather(idxTop); % Loop over top predictions for j = 1:beamIndex candidate = struct; candidateWord = ind2word(enc,idxTop(j)); candidateScore = scoresTop(j); if candidateWord == "<stop>" candidate.StopFlag = true; attentionScores(:,t+1:end) = []; else candidate.StopFlag = false; end candidate.State = state; candidate.Words = [words candidateWord]; candidate.Score = score + log(candidateScore); candidate.AttentionScores = attentionScores; candidatesNew = [candidatesNew candidate]; end end % Get top candidates [~,idx] = maxk([candidatesNew.Score],beamIndex); candidates = candidatesNew(idx); % Stop predicting when all candidates have stop token if all([candidates.StopFlag]) break end end % Get top candidate words = candidates(1).Words(2:end-1); attentionScores = candidates(1).AttentionScores; end
initializeGlorot генерирует массив весов в соответствии с инициализацией Glorot.
function weights = initializeGlorot(numOut, numIn) varWeights = sqrt( 6 / (numIn + numOut) ); weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1); end
adamupdate | crossentropy | dlarray | dlfeval | dlgradient | dlupdate | gru | lstm | softmax | doc2sequence (инструментарий для анализа текста) | tokenizedDocument(Панель инструментов для анализа текста) | word2ind(Панель инструментов для анализа текста) | wordEncoding(Панель инструментов для анализа текста)