В этом примере показано, как обучить модель глубокого обучения вводу субтитров изображений с помощью внимания.
Большинство предварительно обученных нейронных сетей для глубокого обучения сконфигурировано для классификации одно меток. Например, учитывая изображение типичного офисного стола, сетевая сила предсказывает единый класс "клавиатура" или "мышь". В отличие от этого модель ввода субтитров изображений комбинирует сверточные и текущие операции, чтобы произвести текстовое описание того, что находится в изображении, а не одной метке.
Эта модель, обученная в этом примере, использует архитектуру декодера энкодера. Энкодер является предварительно обученной сетью Inception-v3, используемой в качестве экстрактора функции. Декодер является рекуррентной нейронной сетью (RNN), которая берет извлеченные функции, как введено и генерирует заголовок. Декодер включает механизм внимания, который позволяет декодеру фокусироваться на частях закодированного входа при генерации заголовка.
Модель энкодера является предварительно обученной моделью Inception-v3, которая извлекает функции из "mixed10"
слой, сопровождаемый полностью связанным и операции ReLU.
Модель декодера состоит из встраивания слова, механизма внимания, закрытого текущего модуля (GRU) и двух полностью связанных операций.
Загрузите предварительно обученную сеть Inception-v3. Этот шаг требует Модели Deep Learning Toolbox™ для пакета Сетевой поддержки Inception-v3. Если вам не установили необходимый пакет поддержки, то программное обеспечение обеспечивает ссылку на загрузку.
net = inceptionv3; inputSizeNet = net.Layers(1).InputSize;
Преобразуйте сеть в dlnetwork
объект для извлечения признаков и удаляет последние четыре слоя, оставляя "mixed10"
слой как последний слой.
lgraph = layerGraph(net); lgraph = removeLayers(lgraph,["avg_pool" "predictions" "predictions_softmax" "ClassificationLayer_predictions"]);
Просмотрите входной слой сети. Симметричное использование сети Inception-v3 - перемасштабирует нормализацию с минимальным значением 0 и максимальным значением 255.
lgraph.Layers(1)
ans = ImageInputLayer with properties: Name: 'input_1' InputSize: [299 299 3] Hyperparameters DataAugmentation: 'none' Normalization: 'rescale-symmetric' NormalizationDimension: 'auto' Max: 255 Min: 0
Пользовательское обучение не поддерживает эту нормализацию, таким образом, необходимо отключить нормализацию в сети и выполнить нормализацию в пользовательском учебном цикле вместо этого. Сохраните минимальные и максимальные значения, как удваивается в переменных под названием inputMin
и inputMax
, соответственно, и замена входной слой с изображением ввела слой без нормализации.
inputMin = double(lgraph.Layers(1).Min); inputMax = double(lgraph.Layers(1).Max); layer = imageInputLayer(inputSizeNet,'Normalization','none','Name','input'); lgraph = replaceLayer(lgraph,'input_1',layer);
Определите выходной размер сети. Используйте analyzeNetwork
функция, чтобы видеть размеры активации последнего слоя. Чтобы анализировать сеть для пользовательских учебных рабочих процессов цикла, установите TargetUsage
опция к 'dlnetwork'
.
analyzeNetwork(lgraph,'TargetUsage','dlnetwork')
Создайте переменную под названием outputSizeNet
содержа сетевой выходной размер.
outputSizeNet = [8 8 2048];
Преобразуйте график слоев в dlnetwork
возразите и просмотрите выходной слой. Выходным слоем является "mixed10"
слой сети Inception-v3.
dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [311×1 nnet.cnn.layer.Layer] Connections: [345×2 table] Learnables: [376×3 table] State: [188×3 table] InputNames: {'input'} OutputNames: {'mixed10'}
Образы загрузки и аннотации от наборов данных "2014 Обучают изображения", и "2014 Обучают/val аннотации", соответственно, под эгидой https://cocodataset.org/#download. Извлеките изображения и аннотации в папку под названием "coco"
. Набор данных COCO 2014 был собран Кокосовым Консорциумом.
Извлеките заголовки из файла "captions_train2014.json"
использование jsondecode
функция.
dataFolder = fullfile(tempdir,"coco"); filename = fullfile(dataFolder,"annotations_trainval2014","annotations","captions_train2014.json"); str = fileread(filename); data = jsondecode(str)
data = struct with fields:
info: [1×1 struct]
images: [82783×1 struct]
licenses: [8×1 struct]
annotations: [414113×1 struct]
annotations
поле struct содержит данные, требуемые для ввода субтитров изображений.
data.annotations
ans=414113×1 struct array with fields:
image_id
id
caption
Набор данных содержит несколько заголовков для каждого изображения. Чтобы гарантировать те же изображения не появляются и в наборах обучения и в валидации, идентифицируют уникальные изображения в наборе данных с помощью unique
функция при помощи идентификаторов в image_id
поле поля аннотаций данных, затем просмотрите количество уникальных изображений.
numObservationsAll = numel(data.annotations)
numObservationsAll = 414113
imageIDs = [data.annotations.image_id]; imageIDsUnique = unique(imageIDs); numUniqueImages = numel(imageIDsUnique)
numUniqueImages = 82783
Каждое изображение имеет по крайней мере пять заголовков. Создайте struct annotationsAll
с этими полями:
ImageID
— Отображают ID
Filename
— Имя файла изображения
Captions
— Массив строк необработанных заголовков
CaptionIDs
— Вектор из индексов соответствующих заголовков в data.annotations
Чтобы сделать слияние легче, отсортируйте аннотации по идентификаторам изображений.
[~,idx] = sort([data.annotations.image_id]); data.annotations = data.annotations(idx);
Цикл по аннотациям и слиянию несколько аннотаций при необходимости.
i = 0; j = 0; imageIDPrev = 0; while i < numel(data.annotations) i = i + 1; imageID = data.annotations(i).image_id; caption = string(data.annotations(i).caption); if imageID ~= imageIDPrev % Create new entry j = j + 1; annotationsAll(j).ImageID = imageID; annotationsAll(j).Filename = fullfile(dataFolder,"train2014","COCO_train2014_" + pad(string(imageID),12,'left','0') + ".jpg"); annotationsAll(j).Captions = caption; annotationsAll(j).CaptionIDs = i; else % Append captions annotationsAll(j).Captions = [annotationsAll(j).Captions; caption]; annotationsAll(j).CaptionIDs = [annotationsAll(j).CaptionIDs; i]; end imageIDPrev = imageID; end
Разделите данные в наборы обучения и валидации. Протяните 5% наблюдений для тестирования.
cvp = cvpartition(numel(annotationsAll),'HoldOut',0.05);
idxTrain = training(cvp);
idxTest = test(cvp);
annotationsTrain = annotationsAll(idxTrain);
annotationsTest = annotationsAll(idxTest);
Struct содержит три поля:
id
— Уникальный идентификатор для заголовка
caption
— Отобразите заголовок в виде вектора символов
image_id
— Уникальный идентификатор изображения, соответствующего заголовку
Чтобы просмотреть изображение и соответствующий заголовок, найдите файл изображения с именем файла "train2014\COCO_train2014_XXXXXXXXXXXX.jpg"
, где "XXXXXXXXXXXX"
соответствует ID изображений, лево-дополненному нулями, чтобы иметь длину 12.
imageID = annotationsTrain(1).ImageID; captions = annotationsTrain(1).Captions; filename = annotationsTrain(1).Filename;
Чтобы просмотреть изображение, используйте imread
и imshow
функции.
img = imread(filename); figure imshow(img) title(captions)
Подготовьте заголовки к обучению и тестированию. Извлеките текст из Captions
поле struct, содержащего и обучение и тестовые данные (annotationsAll
), сотрите пунктуацию и преобразуйте текст в нижний регистр.
captionsAll = cat(1,annotationsAll.Captions); captionsAll = erasePunctuation(captionsAll); captionsAll = lower(captionsAll);
Для того, чтобы сгенерировать заголовки, декодер RNN требует, чтобы специальный запуск и лексемы остановки указали, когда запустить и прекратить генерировать текст, соответственно. Добавьте пользовательские лексемы "<start>"
и "<stop>"
к началу и концам заголовков, соответственно.
captionsAll = "<start>" + captionsAll + "<stop>";
Маркируйте заголовки с помощью tokenizedDocument
функционируйте и задайте запуск и лексемы остановки с помощью 'CustomTokens'
опция.
documentsAll = tokenizedDocument(captionsAll,'CustomTokens',["<start>" "<stop>"]);
Создайте wordEncoding
возразите что слова карт против числовых индексов и назад. Уменьшайте требования к памяти путем определения размера словаря 5 000 соответствий наиболее часто наблюдаемым словам в обучающих данных. Чтобы избежать смещения, используйте только документы, соответствующие набору обучающих данных.
enc = wordEncoding(documentsAll(idxTrain),'MaxNumWords',5000,'Order','frequency');
Создайте увеличенный datastore изображений, содержащий изображения, соответствующие заголовкам. Установите выходной размер совпадать с входным размером сверточной сети. Чтобы сохранить изображения синхронизируемыми с заголовками, задайте таблицу имен файлов для datastore путем восстановления имен файлов с помощью ID изображений. Чтобы возвратить полутоновые изображения как RGB, с 3 каналами отображает, установите 'ColorPreprocessing'
опция к 'gray2rgb'
.
tblFilenames = table(cat(1,annotationsTrain.Filename)); augimdsTrain = augmentedImageDatastore(inputSizeNet,tblFilenames,'ColorPreprocessing','gray2rgb')
augimdsTrain = augmentedImageDatastore with properties: NumObservations: 78644 MiniBatchSize: 1 DataAugmentation: 'none' ColorPreprocessing: 'gray2rgb' OutputSize: [299 299] OutputSizeMode: 'resize' DispatchInBackground: 0
Инициализируйте параметры модели. Задайте 512 скрытых модулей с размерностью встраивания слова 256.
embeddingDimension = 256; numHiddenUnits = 512;
Инициализируйте struct, содержащий параметры для модели энкодера.
Инициализируйте веса полностью связанных операций с помощью инициализатора Glorot, заданного initializeGlorot
функция, перечисленная в конце примера. Задайте выходной размер, чтобы совпадать с размерностью встраивания декодера (256) и входной размер, чтобы совпадать с количеством выходных каналов предварительно обученной сети. 'mixed10'
слой сети Inception-v3 выходные данные с 2 048 каналами.
numFeatures = outputSizeNet(1) * outputSizeNet(2); inputSizeEncoder = outputSizeNet(3); parametersEncoder = struct; % Fully connect parametersEncoder.fc.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeEncoder)); parametersEncoder.fc.Bias = dlarray(zeros([embeddingDimension 1],'single'));
Инициализируйте struct, содержащий параметры для модели декодера.
Инициализируйте веса встраивания слова размером, данным размерностью встраивания и размером словаря плюс один, где дополнительная запись соответствует дополнительному значению.
Инициализируйте веса и смещения для механизма внимания Bahdanau с размерами, соответствующими количеству скрытых модулей операции ГРУ.
Инициализируйте веса и смещение операции ГРУ.
Инициализируйте веса и смещения двух полностью связанных операций.
Для параметров декодера модели инициализируйте каждый то, чтобы взвешивать и смещения с инициализатором Glorot и нулями, соответственно.
inputSizeDecoder = enc.NumWords + 1; parametersDecoder = struct; % Word embedding parametersDecoder.emb.Weights = dlarray(initializeGlorot(embeddingDimension,inputSizeDecoder)); % Attention parametersDecoder.attention.Weights1 = dlarray(initializeGlorot(numHiddenUnits,embeddingDimension)); parametersDecoder.attention.Bias1 = dlarray(zeros([numHiddenUnits 1],'single')); parametersDecoder.attention.Weights2 = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits)); parametersDecoder.attention.Bias2 = dlarray(zeros([numHiddenUnits 1],'single')); parametersDecoder.attention.WeightsV = dlarray(initializeGlorot(1,numHiddenUnits)); parametersDecoder.attention.BiasV = dlarray(zeros(1,1,'single')); % GRU parametersDecoder.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,2*embeddingDimension)); parametersDecoder.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits)); parametersDecoder.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single')); % Fully connect parametersDecoder.fc1.Weights = dlarray(initializeGlorot(numHiddenUnits,numHiddenUnits)); parametersDecoder.fc1.Bias = dlarray(zeros([numHiddenUnits 1],'single')); % Fully connect parametersDecoder.fc2.Weights = dlarray(initializeGlorot(enc.NumWords+1,numHiddenUnits)); parametersDecoder.fc2.Bias = dlarray(zeros([enc.NumWords+1 1],'single'));
Создайте функции modelEncoder
и modelDecoder
, перечисленный в конце примера, которые вычисляют выходные параметры моделей энкодера и декодера, соответственно.
modelEncoder
функция, перечисленная в разделе Encoder Model Function примера, берет в качестве входа массив активаций dlX
от выхода предварительно обученной сети и передач это посредством полностью связанной операции и операции ReLU. Поскольку предварительно обученная сеть не должна быть прослежена для автоматического дифференцирования, извлечение функций вне функции модели энкодера более в вычислительном отношении эффективно.
modelDecoder
функция, перечисленная в разделе Decoder Model Function примера, занимает в качестве входа один входной такт, соответствуя входному слову, параметрам модели декодера, функциям от энкодера и сетевому состоянию, и возвращает предсказания для следующего временного шага, обновленного сетевого состояния и весов внимания.
Задайте опции для обучения. Обучайтесь в течение 30 эпох с мини-пакетным размером 128 и отобразите прогресс обучения в графике.
miniBatchSize = 128;
numEpochs = 30;
plots = "training-progress";
Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).
executionEnvironment = "auto";
Обучите сеть с помощью пользовательского учебного цикла.
В начале каждой эпохи переставьте входные данные. Чтобы сохранить изображения в увеличенном datastore изображений и заголовках синхронизируемыми, создайте массив переставленных индексов, который индексирует в оба набора данных.
Для каждого мини-пакета:
Перемасштабируйте изображения к размеру, который ожидает предварительно обученная сеть.
Для каждого изображения выберите случайный заголовок.
Преобразуйте заголовки в последовательности словарей. Задайте дополнение права последовательностей с дополнительным значением, соответствующим индексу дополнительной лексемы.
Преобразуйте данные в dlarray
объекты. Для изображений укажите, что размерность маркирует 'SSCB'
(пространственный, пространственный, канал, пакет).
Для обучения графического процессора преобразуйте данные в gpuArray
объекты.
Извлеките функции изображений с помощью предварительно обученной сети и измените их к размеру, который ожидает энкодер.
Оцените градиенты модели и потерю с помощью dlfeval
и modelGradients
функции.
Обновите параметры модели энкодера и декодера с помощью adamupdate
функция.
Отобразите прогресс обучения в графике.
Инициализируйте параметры для оптимизатора Адама.
trailingAvgEncoder = []; trailingAvgSqEncoder = []; trailingAvgDecoder = []; trailingAvgSqDecoder = [];
Инициализируйте график процесса обучения. Создайте анимированную линию, которая строит потерю против соответствующей итерации.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); xlabel("Iteration") ylabel("Loss") ylim([0 inf]) grid on end
Обучите модель.
iteration = 0; numObservationsTrain = numel(annotationsTrain); numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize); start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. idxShuffle = randperm(numObservationsTrain); % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Determine mini-batch indices. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; idxMiniBatch = idxShuffle(idx); % Read mini-batch of data. tbl = readByIndex(augimdsTrain,idxMiniBatch); X = cat(4,tbl.input{:}); annotations = annotationsTrain(idxMiniBatch); % For each image, select random caption. idx = cellfun(@(captionIDs) randsample(captionIDs,1),{annotations.CaptionIDs}); documents = documentsAll(idx); % Create batch of data. [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment); % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ... parametersDecoder, dlX, dlT); % Update encoder using adamupdate. [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ... gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration); % Update decoder using adamupdate. [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ... gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Процесс генерации заголовка отличается от процесса для обучения. Во время обучения, на каждом временном шаге, декодер использует истинное значение предыдущего временного шага, как введено. Это известно как "учителя, обеспечивающего". При создании предсказаний на новых данных декодер использует предыдущие ожидаемые значения вместо истинных значений.
Предсказание наиболее вероятного слова для каждого шага в последовательности может привести к субоптимальным результатам. Например, если декодер предсказывает, что первое слово заголовка является "a", когда дали изображение слона, то вероятность предсказания "слона" для следующего слова становится намного более маловероятной из-за чрезвычайно низкой вероятности фразы "слон", появляющийся в английском тексте.
Чтобы решить эту проблему, можно использовать алгоритм поиска луча: вместо того, чтобы брать наиболее вероятное предсказание для каждого шага в последовательности, возьмите верхнюю часть k предсказания (индекс луча) и для каждого следующего шага, сохраните верхнюю часть k предсказанными последовательностями до сих пор согласно общей оценке.
Сгенерируйте заголовок нового изображения путем извлечения функций изображений, введения их в энкодер, и затем использования beamSearch
функция, перечисленная в разделе Beam Search Function примера.
img = imread("laika_sitting.jpg");
dlX = extractImageFeatures(dlnet,img,inputMin,inputMax,executionEnvironment);
beamIndex = 3;
maxNumWords = 20;
[words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords);
caption = join(words)
caption = "a dog is standing on a tile floor"
Отобразите изображение с заголовком.
figure imshow(img) title(caption)
Предсказать заголовки для набора изображений, цикла по мини-пакетам данных в datastore и извлечь функции из изображений с помощью extractImageFeatures
функция. Затем цикл по изображениям в мини-пакете и генерирует заголовки с помощью beamSearch
функция.
Создайте увеличенный datastore изображений и установите выходной размер совпадать с входным размером сверточной сети. Чтобы вывести полутоновые изображения как RGB, с 3 каналами отображает, установите 'ColorPreprocessing'
опция к 'gray2rgb'
.
tblFilenamesTest = table(cat(1,annotationsTest.Filename)); augimdsTest = augmentedImageDatastore(inputSizeNet,tblFilenamesTest,'ColorPreprocessing','gray2rgb')
augimdsTest = augmentedImageDatastore with properties: NumObservations: 4139 MiniBatchSize: 1 DataAugmentation: 'none' ColorPreprocessing: 'gray2rgb' OutputSize: [299 299] OutputSizeMode: 'resize' DispatchInBackground: 0
Сгенерируйте заголовки для тестовых данных. Предсказание заголовков на большом наборе данных может занять время. Если у вас есть Parallel Computing Toolbox™, то можно сделать предсказания параллельно путем генерации заголовков в parfor
посмотреть. Если у вас нет Parallel Computing Toolbox. затем parfor
цикл запускается в сериале.
beamIndex = 2; maxNumWords = 20; numObservationsTest = numel(annotationsTest); numIterationsTest = ceil(numObservationsTest/miniBatchSize); captionsTestPred = strings(1,numObservationsTest); documentsTestPred = tokenizedDocument(strings(1,numObservationsTest)); for i = 1:numIterationsTest % Mini-batch indices. idxStart = (i-1)*miniBatchSize+1; idxEnd = min(i*miniBatchSize,numObservationsTest); idx = idxStart:idxEnd; sz = numel(idx); % Read images. tbl = readByIndex(augimdsTest,idx); % Extract image features. X = cat(4,tbl.input{:}); dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment); % Generate captions. captionsPredMiniBatch = strings(1,sz); documentsPredMiniBatch = tokenizedDocument(strings(1,sz)); parfor j = 1:sz words = beamSearch(dlX(:,:,j),beamIndex,parametersEncoder,parametersDecoder,enc,maxNumWords); captionsPredMiniBatch(j) = join(words); documentsPredMiniBatch(j) = tokenizedDocument(words,'TokenizeMethod','none'); end captionsTestPred(idx) = captionsPredMiniBatch; documentsTestPred(idx) = documentsPredMiniBatch; end
Analyzing and transferring files to the workers ...done.
Чтобы просмотреть тестовое изображение с соответствующим заголовком, используйте imshow
функция и набор заголовок на предсказанный заголовок.
idx = 1; tbl = readByIndex(augimdsTest,idx); img = tbl.input{1}; figure imshow(img) title(captionsTestPred(idx))
Чтобы оценить точность заголовков с помощью BLEU score, вычислите BLEU score для каждого заголовка (кандидат) против соответствующих заголовков в наборе тестов (ссылки) использование bleuEvaluationScore
функция. Используя bleuEvaluationScore
функция, можно сравнить один документ кандидата нескольким справочным документам.
bleuEvaluationScore
функция, по умолчанию, подобие баллов с помощью N-грамм длины один - четыре. Когда заголовки коротки, это поведение может привести к неинформативным результатам, как большинство баллов близко к нулю. Установите длину n-граммы на один - два путем установки 'NgramWeights'
опция к двухэлементному вектору с равными весами.
ngramWeights = [0.5 0.5]; for i = 1:numObservationsTest annotation = annotationsTest(i); captionIDs = annotation.CaptionIDs; candidate = documentsTestPred(i); references = documentsAll(captionIDs); score = bleuEvaluationScore(candidate,references,'NgramWeights',ngramWeights); scores(i) = score; end
Просмотрите средний BLEU score.
scoreMean = mean(scores)
scoreMean = 0.4224
Визуализируйте баллы в гистограмме.
figure histogram(scores) xlabel("BLEU Score") ylabel("Frequency")
attention
функция вычисляет вектор контекста и использование весов внимания внимание Bahdanau.
function [contextVector, attentionWeights] = attention(hidden,features,weights1, ... bias1,weights2,bias2,weightsV,biasV) % Model dimensions. [embeddingDimension,numFeatures,miniBatchSize] = size(features); numHiddenUnits = size(weights1,1); % Fully connect. dlY1 = reshape(features,embeddingDimension, numFeatures*miniBatchSize); dlY1 = fullyconnect(dlY1,weights1,bias1,'DataFormat','CB'); dlY1 = reshape(dlY1,numHiddenUnits,numFeatures,miniBatchSize); % Fully connect. dlY2 = fullyconnect(hidden,weights2,bias2,'DataFormat','CB'); dlY2 = reshape(dlY2,numHiddenUnits,1,miniBatchSize); % Addition, tanh. scores = tanh(dlY1 + dlY2); scores = reshape(scores, numHiddenUnits, numFeatures*miniBatchSize); % Fully connect, softmax. attentionWeights = fullyconnect(scores,weightsV,biasV,'DataFormat','CB'); attentionWeights = reshape(attentionWeights,1,numFeatures,miniBatchSize); attentionWeights = softmax(attentionWeights,'DataFormat','SCB'); % Context. contextVector = attentionWeights .* features; contextVector = squeeze(sum(contextVector,2)); end
embedding
функционируйте сопоставляет массив индексов к последовательности встраивания векторов.
function Z = embedding(X, weights) % Reshape inputs into a vector [N, T] = size(X, 1:2); X = reshape(X, N*T, 1); % Index into embedding matrix Z = weights(:, X); % Reshape outputs by separating out batch and sequence dimensions Z = reshape(Z, [], N, T); end
extractImageFeatures
функционируйте берет в качестве входа обученный dlnetwork
объект, входное изображение, статистика для перемасштабирующего изображения, и среда выполнения, и возвращают dlarray
содержание функций извлечено из предварительно обученной сети.
function dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment) % Resize and rescale. inputSize = dlnet.Layers(1).InputSize(1:2); X = imresize(X,inputSize); X = rescale(X,-1,1,'InputMin',inputMin,'InputMax',inputMax); % Convert to dlarray. dlX = dlarray(X,'SSCB'); % Convert to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Extract features and reshape. dlX = predict(dlnet,dlX); sz = size(dlX); numFeatures = sz(1) * sz(2); inputSizeEncoder = sz(3); miniBatchSize = sz(4); dlX = reshape(dlX,[numFeatures inputSizeEncoder miniBatchSize]); end
createBatch
функционируйте берет в качестве входа мини-пакет данных, маркируемых заголовков, предварительно обученной сети, статистики для перемасштабирующего изображения, кодирование слова и среда выполнения, и возвращает мини-пакет данных, соответствующих извлеченным функциям изображений и заголовкам для обучения.
function [dlX, dlT] = createBatch(X,documents,dlnet,inputMin,inputMax,enc,executionEnvironment) dlX = extractImageFeatures(dlnet,X,inputMin,inputMax,executionEnvironment); % Convert documents to sequences of word indices. T = doc2sequence(enc,documents,'PaddingDirection','right','PaddingValue',enc.NumWords+1); T = cat(1,T{:}); % Convert mini-batch of data to dlarray. dlT = dlarray(T); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlT = gpuArray(dlT); end end
modelEncoder
функционируйте берет в качестве входа массив активаций dlX
и передачи это посредством полностью связанной операции и операции ReLU. Для полностью связанной операции работайте с размерностью канала только. Чтобы применить полностью связанную операцию через размерность канала только, сгладьте другие каналы в одну размерность и задайте эту размерность как пакетную размерность с помощью 'DataFormat'
опция fullyconnect
функция.
function dlY = modelEncoder(dlX,parametersEncoder) [numFeatures,inputSizeEncoder,miniBatchSize] = size(dlX); % Fully connect weights = parametersEncoder.fc.Weights; bias = parametersEncoder.fc.Bias; embeddingDimension = size(weights,1); dlX = permute(dlX,[2 1 3]); dlX = reshape(dlX,inputSizeEncoder,numFeatures*miniBatchSize); dlY = fullyconnect(dlX,weights,bias,'DataFormat','CB'); dlY = reshape(dlY,embeddingDimension,numFeatures,miniBatchSize); % ReLU dlY = relu(dlY); end
modelDecoder
функционируйте занимает в качестве входа один такт dlX
, параметры модели декодера, функции от энкодера и сетевое состояние, и возвращают предсказания для следующего временного шага, обновленного сетевого состояния и весов внимания.
function [dlY,state,attentionWeights] = modelDecoder(dlX,parametersDecoder,features,state) hiddenState = state.gru.HiddenState; % Attention weights1 = parametersDecoder.attention.Weights1; bias1 = parametersDecoder.attention.Bias1; weights2 = parametersDecoder.attention.Weights2; bias2 = parametersDecoder.attention.Bias2; weightsV = parametersDecoder.attention.WeightsV; biasV = parametersDecoder.attention.BiasV; [contextVector, attentionWeights] = attention(hiddenState,features,weights1,bias1,weights2,bias2,weightsV,biasV); % Embedding weights = parametersDecoder.emb.Weights; dlX = embedding(dlX,weights); % Concatenate dlY = cat(1,contextVector,dlX); % GRU inputWeights = parametersDecoder.gru.InputWeights; recurrentWeights = parametersDecoder.gru.RecurrentWeights; bias = parametersDecoder.gru.Bias; [dlY, hiddenState] = gru(dlY, hiddenState, inputWeights, recurrentWeights, bias, 'DataFormat','CBT'); % Update state state.gru.HiddenState = hiddenState; % Fully connect weights = parametersDecoder.fc1.Weights; bias = parametersDecoder.fc1.Bias; dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB'); % Fully connect weights = parametersDecoder.fc2.Weights; bias = parametersDecoder.fc2.Bias; dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB'); end
modelGradients
функционируйте берет в качестве входа параметры энкодера и декодера, энкодер показывает dlX
, и целевой заголовок dlT
, и возвращает градиенты параметров энкодера и декодера относительно потери, потери и предсказаний.
function [gradientsEncoder,gradientsDecoder,loss,dlYPred] = ... modelGradients(parametersEncoder,parametersDecoder,dlX,dlT) miniBatchSize = size(dlX,3); sequenceLength = size(dlT,2) - 1; vocabSize = size(parametersDecoder.emb.Weights,2); % Model encoder features = modelEncoder(dlX,parametersEncoder); % Initialize state numHiddenUnits = size(parametersDecoder.attention.Weights1,1); state = struct; state.gru.HiddenState = dlarray(zeros([numHiddenUnits miniBatchSize],'single')); dlYPred = dlarray(zeros([vocabSize miniBatchSize sequenceLength],'like',dlX)); loss = dlarray(single(0)); padToken = vocabSize; for t = 1:sequenceLength decoderInput = dlT(:,t); dlYReal = dlT(:,t+1); [dlYPred(:,:,t),state] = modelDecoder(decoderInput,parametersDecoder,features,state); mask = dlYReal ~= padToken; loss = loss + sparseCrossEntropyAndSoftmax(dlYPred(:,:,t),dlYReal,mask); end % Calculate gradients [gradientsEncoder,gradientsDecoder] = dlgradient(loss, parametersEncoder,parametersDecoder); end
sparseCrossEntropyAndSoftmax
берет в качестве входа предсказания dlY
, соответствующие цели dlT
, и дополнительная маска последовательности, и применяет softmax
функции и возвращают потерю перекрестной энтропии.
function loss = sparseCrossEntropyAndSoftmax(dlY, dlT, mask) miniBatchSize = size(dlY, 2); % Softmax. dlY = softmax(dlY,'DataFormat','CB'); % Find rows corresponding to the target words. idx = sub2ind(size(dlY), dlT', 1:miniBatchSize); dlY = dlY(idx); % Bound away from zero. dlY = max(dlY, single(1e-8)); % Masked loss. loss = log(dlY) .* mask'; loss = -sum(loss,'all') ./ miniBatchSize; end
beamSearch
функционируйте берет в качестве входа, изображение показывает dlX
, индекс луча, параметры для сетей энкодера и декодера, кодирования слова и максимальной длины последовательности, и возвращают слова заголовка для изображения с помощью алгоритма поиска луча.
function [words,attentionScores] = beamSearch(dlX,beamIndex,parametersEncoder,parametersDecoder, ... enc,maxNumWords) % Model dimensions numFeatures = size(dlX,1); numHiddenUnits = size(parametersDecoder.attention.Weights1,1); % Extract features features = modelEncoder(dlX,parametersEncoder); % Initialize state state = struct; state.gru.HiddenState = dlarray(zeros([numHiddenUnits 1],'like',dlX)); % Initialize candidates candidates = struct; candidates.State = state; candidates.Words = "<start>"; candidates.Score = 0; candidates.AttentionScores = dlarray(zeros([numFeatures maxNumWords],'like',dlX)); candidates.StopFlag = false; t = 0; % Loop over words while t < maxNumWords t = t + 1; candidatesNew = []; % Loop over candidates for i = 1:numel(candidates) % Stop generating when stop token is predicted if candidates(i).StopFlag continue end % Candidate details state = candidates(i).State; words = candidates(i).Words; score = candidates(i).Score; attentionScores = candidates(i).AttentionScores; % Predict next token decoderInput = word2ind(enc,words(end)); [dlYPred,state,attentionScores(:,t)] = modelDecoder(decoderInput,parametersDecoder,features,state); dlYPred = softmax(dlYPred,'DataFormat','CB'); [scoresTop,idxTop] = maxk(extractdata(dlYPred),beamIndex); idxTop = gather(idxTop); % Loop over top predictions for j = 1:beamIndex candidate = struct; candidateWord = ind2word(enc,idxTop(j)); candidateScore = scoresTop(j); if candidateWord == "<stop>" candidate.StopFlag = true; attentionScores(:,t+1:end) = []; else candidate.StopFlag = false; end candidate.State = state; candidate.Words = [words candidateWord]; candidate.Score = score + log(candidateScore); candidate.AttentionScores = attentionScores; candidatesNew = [candidatesNew candidate]; end end % Get top candidates [~,idx] = maxk([candidatesNew.Score],beamIndex); candidates = candidatesNew(idx); % Stop predicting when all candidates have stop token if all([candidates.StopFlag]) break end end % Get top candidate words = candidates(1).Words(2:end-1); attentionScores = candidates(1).AttentionScores; end
initializeGlorot
функция генерирует массив весов согласно инициализации Glorot.
function weights = initializeGlorot(numOut, numIn) varWeights = sqrt( 6 / (numIn + numOut) ); weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1); end
word2ind
(Text Analytics Toolbox) | tokenizedDocument
(Text Analytics Toolbox) | wordEncoding
(Text Analytics Toolbox) | dlarray
| adamupdate
| dlupdate
| dlfeval
| dlgradient
| crossentropy
| softmax
| lstm
| doc2sequence
(Text Analytics Toolbox) | gru