В этом примере показано, как преобразовать десятичные строки в Римские цифры с помощью модели декодера энкодера повторяющейся последовательности к последовательности с вниманием.
Текущие модели декодера энкодера оказались успешными в задачах как абстрактное текстовое резюмирование и нейронный машинный перевод. Модель состоит из энкодера, который обычно входные данные процессов с текущим слоем, такие как LSTM и декодер, который сопоставляет закодированный вход в желаемый выход, обычно со вторым текущим слоем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях закодированного входа при генерации перевода.
Для модели энкодера этот пример использует простую сеть, состоящую из встраивания, сопровождаемого двумя операциями LSTM. Встраивание является методом преобразования категориальных лексем в числовые векторы.
Для модели декодера этот пример использует сеть, очень похожую на энкодер, который содержит два LSTMs. Однако важное различие - то, что декодер содержит механизм внимания. Механизм внимания позволяет декодеру проявлять внимание к определенным частям энкодера выход.
Загрузите пары десятичной Римской цифры с "romanNumerals.csv"
filename = fullfile("romanNumerals.csv"); options = detectImportOptions(filename, ... 'TextType','string', ... 'ReadVariableNames',false); options.VariableNames = ["Source" "Target"]; options.VariableTypes = ["string" "string"]; data = readtable(filename,options);
Разделите данные в обучение и протестируйте разделы, содержащие 50% данных каждый.
idx = randperm(size(data,1),500); dataTrain = data(idx,:); dataTest = data; dataTest(idx,:) = [];
Просмотрите некоторые пары десятичной Римской цифры.
head(dataTrain)
ans=8×2 table
Source Target
______ ____________
"437" "CDXXXVII"
"431" "CDXXXI"
"102" "CII"
"862" "DCCCLXII"
"738" "DCCXXXVIII"
"527" "DXXVII"
"401" "CDI"
"184" "CLXXXIV"
Предварительно обработайте текстовые данные с помощью transformText
функция, перечисленная в конце примера. transformText
функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в символы, и добавление запускают и останавливают лексемы. Чтобы перевести текст путем разделения текста в слова вместо символов, пропустите первый шаг.
startToken = "<start>"; stopToken = "<stop>"; strSource = dataTrain{:,1}; documentsSource = transformText(strSource,startToken,stopToken);
Создайте wordEncoding
возразите что лексемы карт против числового индекса и наоборот использования словаря.
encSource = wordEncoding(documentsSource);
Используя кодирование слова, преобразуйте данные об исходном тексте в числовые последовательности.
sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');
Преобразуйте целевые данные в последовательности с помощью тех же шагов.
strTarget = dataTrain{:,2}; documentsTarget = transformText(strTarget,startToken,stopToken); encTarget = wordEncoding(documentsTarget); sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');
Сортировка последовательностей длиной. Обучение с последовательностями, отсортированными путем увеличения длины последовательности, приводит к пакетам с последовательностями приблизительно той же длины последовательности и гарантирует, что меньшие пакеты последовательности используются, чтобы обновить модель перед более длинными пакетами последовательности.
sequenceLengths = cellfun(@(sequence) size(sequence,2),sequencesSource); [~,idx] = sort(sequenceLengths); sequencesSource = sequencesSource(idx); sequencesTarget = sequencesTarget(idx);
Создайте arrayDatastore
объекты, содержащие входные и выходные данные и, комбинируют их использующий combine
функция.
sequencesSourceDs = arrayDatastore(sequencesSource,'OutputType','same'); sequencesTargetDs = arrayDatastore(sequencesTarget,'OutputType','same'); sequencesDs = combine(sequencesSourceDs,sequencesTargetDs);
Инициализируйте параметры модели. и для энкодера и для декодера, задайте размерность встраивания 128, два слоя LSTM с 200 скрытыми модулями и слои уволенного со случайным уволенным с вероятностью 0.05.
embeddingDimension = 128; numHiddenUnits = 200; dropout = 0.05;
Инициализируйте веса встраивания кодирования с помощью Гауссова использования initializeGaussian
функция, которая присоединена к этому примеру как к вспомогательному файлу. Задайте среднее значение 0 и стандартное отклонение 0,01. Чтобы узнать больше, смотрите Гауссову Инициализацию (Deep Learning Toolbox).
inputSize = encSource.NumWords + 1; sz = [embeddingDimension inputSize]; mu = 0; sigma = 0.01; parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);
Инициализируйте настраиваемые параметры для операций LSTM энкодера:
Инициализируйте входные веса инициализатором Glorot с помощью initializeGlorot
функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, см. Инициализацию Glorot (Deep Learning Toolbox).
Инициализируйте текущие веса ортогональным инициализатором с помощью initializeOrthogonal
функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите Ортогональную Инициализацию (Deep Learning Toolbox).
Инициализируйте смещение модулем, забывают инициализатор логического элемента с помощью initializeUnitForgetGate
функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите, что Модуль Забывает Инициализацию Логического элемента (Deep Learning Toolbox).
Инициализируйте настраиваемые параметры для первой операции LSTM энкодера.
sz = [4*numHiddenUnits embeddingDimension]; numOut = 4*numHiddenUnits; numIn = embeddingDimension; parameters.encoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn); parameters.encoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]); parameters.encoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);
Инициализируйте настраиваемые параметры для второй операции LSTM энкодера.
sz = [4*numHiddenUnits numHiddenUnits]; numOut = 4*numHiddenUnits; numIn = numHiddenUnits; parameters.encoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn); parameters.encoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]); parameters.encoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);
Инициализируйте веса встраивания кодирования с помощью Гауссова использования initializeGaussian
функция. Задайте среднее значение 0 и стандартное отклонение 0,01.
outputSize = encTarget.NumWords + 1; sz = [embeddingDimension outputSize]; mu = 0; sigma = 0.01; parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);
Инициализируйте веса механизма внимания с помощью инициализатора Glorot с помощью initializeGlorot
функция.
sz = [numHiddenUnits numHiddenUnits]; numOut = numHiddenUnits; numIn = numHiddenUnits; parameters.decoder.attn.Weights = initializeGlorot(sz,numOut,numIn);
Инициализируйте настраиваемые параметры для операций LSTM декодера:
Инициализируйте входные веса инициализатором Glorot с помощью initializeGlorot
функция.
Инициализируйте текущие веса ортогональным инициализатором с помощью initializeOrthogonal
функция.
Инициализируйте смещение модулем, забывают инициализатор логического элемента с помощью initializeUnitForgetGate
функция.
Инициализируйте настраиваемые параметры для первой операции LSTM декодера.
sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits]; numOut = 4*numHiddenUnits; numIn = embeddingDimension + numHiddenUnits; parameters.decoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn); parameters.decoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]); parameters.decoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);
Инициализируйте настраиваемые параметры для второй операции LSTM декодера.
sz = [4*numHiddenUnits numHiddenUnits]; numOut = 4*numHiddenUnits; numIn = numHiddenUnits; parameters.decoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn); parameters.decoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]); parameters.decoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);
Инициализируйте настраиваемые параметры для декодера, полностью соединил операцию:
Инициализируйте веса инициализатором Glorot.
Инициализируйте смещение нулями с помощью initializeZeros
функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите Нулевую Инициализацию (Deep Learning Toolbox).
sz = [outputSize 2*numHiddenUnits]; numOut = outputSize; numIn = 2*numHiddenUnits; parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn); parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);
Создайте функции modelEncoder
и modelDecoder
, перечисленный в конце примера, которые вычисляют выходные параметры моделей энкодера и декодера, соответственно.
modelEncoder
функция, перечисленная в разделе Encoder Model Function примера, берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходные параметры модели и LSTM скрытое состояние.
modelDecoder
функция, перечисленная в разделе Decoder Model Function примера, берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.
Создайте функциональный modelGradients
, перечисленный в разделе Model Gradients Function примера, который берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.
Обучайтесь с мини-пакетным размером 32 в течение 75 эпох со скоростью обучения 0,002.
miniBatchSize = 32; numEpochs = 75; learnRate = 0.002;
Инициализируйте опции от Адама.
gradientDecayFactor = 0.9; squaredGradientDecayFactor = 0.999;
Обучите модель с помощью пользовательского учебного цикла. Используйте minibatchqueue
обработать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch
(заданный в конце этого примера), чтобы найти длины всей последовательности в мини-пакете и заполнить последовательности к той же длине как самая длинная последовательность, для входных и выходных последовательностей, соответственно.
Переставьте вторые и третьи размерности заполненных последовательностей.
Возвратитесь мини-пакетные переменные восстановили после форматирования dlarray
объекты с базовым типом данных single
. Все другие выходные параметры являются массивами типа данных single
.
Обучайтесь на графическом процессоре, если вы доступны. Возвратите все мини-пакетные переменные на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом.
minibatchqueue
объект возвращает четыре выходных аргумента в пользу каждого мини-пакета: исходные последовательности, целевые последовательности, длины всех исходных последовательностей в мини-пакете и маска последовательности целевых последовательностей.
numMiniBatchOutputs = 4; mbq = minibatchqueue(sequencesDs,numMiniBatchOutputs,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',@(x,t) preprocessMiniBatch(x,t,inputSize,outputSize));
Инициализируйте график процесса обучения.
figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Инициализируйте значения для adamupdate
функция.
trailingAvg = []; trailingAvgSq = [];
Обучите модель. Для каждого мини-пакета:
Считайте мини-пакет заполненных последовательностей.
Вычислите потерю и градиенты.
Обновите параметры модели энкодера и декодера с помощью adamupdate
функция.
Обновите график процесса обучения.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs reset(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; [dlX,T,sequenceLengthsSource,maskSequenceTarget] = next(mbq); % Compute loss and gradients. [gradients,loss] = dlfeval(@modelGradients,parameters,dlX,T,sequenceLengthsSource,... maskSequenceTarget,dropout); % Update parameters using adamupdate. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients,trailingAvg,trailingAvgSq,... iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(loss))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end
Чтобы сгенерировать переводы для новых данных с помощью обученной модели, преобразуйте текстовые данные в числовые последовательности с помощью тех же шагов как тогда, когда обучение и ввело последовательности в модель декодера энкодера и преобразует получившиеся последовательности назад в текст с помощью маркерных индексов.
Предварительно обработайте текстовые данные с помощью тех же шагов как тогда, когда обучение. Используйте transformText
функция, перечисленная в конце примера, чтобы разделить текст в символы и добавить запуск и лексемы остановки.
strSource = dataTest{:,1}; strTarget = dataTest{:,2};
Переведите текст с помощью modelPredictions
функция.
maxSequenceLength = 10; delimiter = ""; strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ... encSource,encTarget,startToken,stopToken,delimiter);
Составьте таблицу, содержащую тестовый исходный текст, целевой текст и переводы.
tbl = table; tbl.Source = strSource; tbl.Target = strTarget; tbl.Translated = strTranslated;
Просмотрите случайный выбор переводов.
idx = randperm(size(dataTest,1),miniBatchSize); tbl(idx,:)
ans=32×3 table
Source Target Translated
______ ___________ ___________
"8" "VIII" "CCCXXVII"
"595" "DXCV" "DCCV"
"523" "DXXIII" "CDXXII"
"675" "DCLXXV" "DCCLXV"
"818" "DCCCXVIII" "DCCCXVIII"
"265" "CCLXV" "CCLXV"
"770" "DCCLXX" "DCCCL"
"904" "CMIV" "CMVII"
"121" "CXXI" "CCXI"
"333" "CCCXXXIII" "CCCXXXIII"
"817" "DCCCXVII" "DCCCXVII"
"37" "XXXVII" "CCCXXXIV"
"335" "CCCXXXV" "CCCXXXV"
"902" "CMII" "CMIII"
"995" "CMXCV" "CMXCV"
"334" "CCCXXXIV" "CCCXXXIV"
⋮
transformText
функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в символы, и добавление запускают и останавливают лексемы. Чтобы перевести текст путем разделения текста в слова вместо символов, пропустите первый шаг.
function documents = transformText(str,startToken,stopToken) str = strip(replace(str,""," ")); str = startToken + str + stopToken; documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]); end
preprocessMiniBatch
функция, описанная в разделе Train Model примера, предварительно обрабатывает данные для обучения. Функция предварительно обрабатывает данные с помощью следующих шагов:
Определите длины всех входных и выходных последовательностей в мини-пакете
Заполните последовательности к той же длине как самая длинная последовательность в мини-пакете с помощью padsequences
функция.
Переставьте последние две размерности последовательностей
function [X,T,sequenceLengthsSource,maskTarget] = preprocessMiniBatch(sequencesSource,sequencesTarget,inputSize,outputSize) sequenceLengthsSource = cellfun(@(x) size(x,2),sequencesSource); X = padsequences(sequencesSource,2,"PaddingValue",inputSize); X = permute(X,[1 3 2]); [T,maskTarget] = padsequences(sequencesTarget,2,"PaddingValue",outputSize); T = permute(T,[1 3 2]); maskTarget = permute(maskTarget,[1 3 2]); end
modelGradients
функционируйте берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.
function [gradients,loss] = modelGradients(parameters,dlX,T,... sequenceLengthsSource,maskTarget,dropout) % Forward through encoder. [dlZ,hiddenState] = modelEncoder(parameters.encoder,dlX,sequenceLengthsSource); % Decoder Output. doTeacherForcing = rand < 0.5; sequenceLength = size(T,3); dlY = decoderPredictions(parameters.decoder,dlZ,T,hiddenState,dropout,... doTeacherForcing,sequenceLength); % Masked loss. dlY = dlY(:,:,1:end-1); T = extractdata(gather(T(:,:,2:end))); T = onehotencode(T,1,'ClassNames',1:size(dlY,1)); maskTarget = maskTarget(:,:,2:end); maskTarget = repmat(maskTarget,[size(dlY,1),1,1]); loss = crossentropy(dlY,T,'Mask',maskTarget,'Dataformat','CBT'); % Update gradients. gradients = dlgradient(loss,parameters); % For plotting, return loss normalized by sequence length. loss = extractdata(loss) ./ sequenceLength; end
Функциональный modelEncoder
берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходной параметр модели и LSTM скрытое состояние.
Если sequenceLengths
пусто, затем функция не маскирует выход. Задайте и пустое значение для sequenceLengths
при использовании modelEncoder
функция для предсказания.
function [dlZ, hiddenState] = modelEncoder(parametersEncoder, dlX, sequenceLengths) % Embedding. weights = parametersEncoder.emb.Weights; dlZ = embed(dlX,weights,'DataFormat','CBT'); % LSTM 1. inputWeights = parametersEncoder.lstm1.InputWeights; recurrentWeights = parametersEncoder.lstm1.RecurrentWeights; bias = parametersEncoder.lstm1.Bias; numHiddenUnits = size(recurrentWeights, 2); initialHiddenState = dlarray(zeros([numHiddenUnits 1])); initialCellState = dlarray(zeros([numHiddenUnits 1])); dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ... recurrentWeights, bias, 'DataFormat', 'CBT'); % LSTM 2. inputWeights = parametersEncoder.lstm2.InputWeights; recurrentWeights = parametersEncoder.lstm2.RecurrentWeights; bias = parametersEncoder.lstm2.Bias; [dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ... inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Masking for training. if ~isempty(sequenceLengths) miniBatchSize = size(dlZ,2); for n = 1:miniBatchSize hiddenState(:,n) = dlZ(:,n,sequenceLengths(n)); end end end
Функциональный modelDecoder
берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.
function [dlY, context, hiddenState, attentionScores] = modelDecoder(parametersDecoder, dlX, context, ... hiddenState, dlZ, dropout) % Embedding. weights = parametersDecoder.emb.Weights; dlX = embed(dlX, weights,'DataFormat','CBT'); % RNN input. sequenceLength = size(dlX,3); dlY = cat(1, dlX, repmat(context, [1 1 sequenceLength])); % LSTM 1. inputWeights = parametersDecoder.lstm1.InputWeights; recurrentWeights = parametersDecoder.lstm1.RecurrentWeights; bias = parametersDecoder.lstm1.Bias; initialCellState = dlarray(zeros(size(hiddenState))); dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Dropout. mask = ( rand(size(dlY), 'like', dlY) > dropout ); dlY = dlY.*mask; % LSTM 2. inputWeights = parametersDecoder.lstm2.InputWeights; recurrentWeights = parametersDecoder.lstm2.RecurrentWeights; bias = parametersDecoder.lstm2.Bias; [dlY, hiddenState] = lstm(dlY, hiddenState, initialCellState,inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Attention. weights = parametersDecoder.attn.Weights; [attentionScores, context] = attention(hiddenState, dlZ, weights); % Concatenate. dlY = cat(1, dlY, repmat(context, [1 1 sequenceLength])); % Fully connect. weights = parametersDecoder.fc.Weights; bias = parametersDecoder.fc.Bias; dlY = fullyconnect(dlY,weights,bias,'DataFormat','CBT'); % Softmax. dlY = softmax(dlY,'DataFormat','CBT'); end
attention
функция возвращает баллы внимания по данным Luong "общий" выигрыш и обновленный вектор контекста. Энергия на каждом временном шаге является скалярным произведением скрытого состояния и learnable времена весов внимания энкодера выход.
function [attentionScores, context] = attention(hiddenState, encoderOutputs, weights) % Initialize attention energies. [miniBatchSize, sequenceLength] = size(encoderOutputs, 2:3); attentionEnergies = zeros([sequenceLength miniBatchSize],'like',hiddenState); % Attention energies. hWX = hiddenState .* pagemtimes(weights,encoderOutputs); for tt = 1:sequenceLength attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1); end % Attention scores. attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB'); % Context. encoderOutputs = permute(encoderOutputs, [1 3 2]); attentionScores = permute(attentionScores,[1 3 2]); context = pagemtimes(encoderOutputs,attentionScores); context = squeeze(context); end
decoderModelPredictions
функция возвращает предсказанную последовательность dlY
учитывая входную последовательность, предназначайтесь для последовательности, скрытого состояния, вероятности уволенного, флаг, чтобы включить учителю, обеспечивающему и длине последовательности.
function dlY = decoderPredictions(parametersDecoder,dlZ,T,hiddenState,dropout, ... doTeacherForcing,sequenceLength) % Convert to dlarray. dlT = dlarray(T); % Initialize context. miniBatchSize = size(dlT,2); numHiddenUnits = size(dlZ,1); context = zeros([numHiddenUnits miniBatchSize],'like',dlZ); if doTeacherForcing % Forward through decoder. dlY = modelDecoder(parametersDecoder, dlT, context, hiddenState, dlZ, dropout); else % Get first time step for decoder. decoderInput = dlT(:,:,1); % Initialize output. numClasses = numel(parametersDecoder.fc.Bias); dlY = zeros([numClasses miniBatchSize sequenceLength],'like',decoderInput); % Loop over time steps. for t = 1:sequenceLength % Forward through decoder. [dlY(:,:,t), context, hiddenState] = modelDecoder(parametersDecoder, decoderInput, context, ... hiddenState, dlZ, dropout); % Update decoder input. [~, decoderInput] = max(dlY(:,:,t),[],1); end end end
translateText
функция переводит массив текста путем итерации по мини-пакетам. Функция берет в качестве входа параметры модели, массив входной строки, максимальную длину последовательности, мини-пакетный размер, входные и выходные объекты кодирования слова, запуск и лексемы остановки и разделитель для сборки выхода.
function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ... encSource,encTarget,startToken,stopToken,delimiter) % Transform text. documentsSource = transformText(strSource,startToken,stopToken); sequencesSource = doc2sequence(encSource,documentsSource, ... 'PaddingDirection','right', ... 'PaddingValue',encSource.NumWords + 1); % Convert to dlarray. X = cat(3,sequencesSource{:}); X = permute(X,[1 3 2]); dlX = dlarray(X); % Initialize output. numObservations = numel(strSource); strTranslated = strings(numObservations,1); % Loop over mini-batches. numIterations = ceil(numObservations / miniBatchSize); for i = 1:numIterations idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); miniBatchSize = numel(idxMiniBatch); % Encode using model encoder. sequenceLengths = []; [dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX(:,idxMiniBatch,:), sequenceLengths); % Decoder predictions. doTeacherForcing = false; dropout = 0; decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]); decoderInput = dlarray(decoderInput); dlY = decoderPredictions(parameters.decoder,dlZ,decoderInput,hiddenState,dropout, ... doTeacherForcing,maxSequenceLength); [~, idxPred] = max(extractdata(dlY), [], 1); % Keep translating flag. idxStop = word2ind(encTarget,stopToken); keepTranslating = idxPred ~= idxStop; % Loop over time steps. t = 1; while t <= maxSequenceLength && any(keepTranslating(:,:,t)) % Update output. newWords = ind2word(encTarget, idxPred(:,:,t))'; idxUpdate = idxMiniBatch(keepTranslating(:,:,t)); strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t)); t = t + 1; end end end
doc2sequence
| tokenizedDocument
| word2ind
| wordEncoding
| adamupdate
(Deep Learning Toolbox) | crossentropy
(Deep Learning Toolbox) | dlarray
(Deep Learning Toolbox) | dlfeval
(Deep Learning Toolbox) | dlgradient
(Deep Learning Toolbox) | dlupdate
(Deep Learning Toolbox) | lstm
(Deep Learning Toolbox) | softmax
(Deep Learning Toolbox)