В этом примере показано, как преобразовать десятичные строки в римские числа с помощью рекуррентной модели кодер-декодер последовательности в последовательность с вниманием.
Модели рекуррентного энкодера-декодера оказались успешными в таких задачах, как абстрактное суммирование текста и нейронный машинный перевод. Модель состоит из энкодера, который обычно обрабатывает входные данные с помощью рекуррентного уровня, такого как LSTM, и декодера, который преобразует кодированный вход в требуемый выход, обычно со вторым рекуррентным слоем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях кодированного входа при генерации перевода.
Для модели энкодера в этом примере используется простая сеть, состоящая из встраивания, за которой следуют две операции LSTM. Встраивание является методом преобразования категориальных лексем в числовые векторы.
Для модели декодера в этом примере используется сеть, очень похожая на энкодер, который содержит два LSTM. Однако важным различием является то, что декодер содержит механизм внимания. Механизм внимания позволяет декодеру следить за конкретными частями выхода энкодера.
Загрузите десятичные-римские числительные пары из "romanNumerals.csv"
filename = fullfile("romanNumerals.csv"); options = detectImportOptions(filename, ... 'TextType','string', ... 'ReadVariableNames',false); options.VariableNames = ["Source" "Target"]; options.VariableTypes = ["string" "string"]; data = readtable(filename,options);
Разделите данные на обучающие и тестовые разделы, содержащие 50% данных каждый.
idx = randperm(size(data,1),500); dataTrain = data(idx,:); dataTest = data; dataTest(idx,:) = [];
Просмотрите некоторые десятичные-римские числовые пары.
head(dataTrain)
ans=8×2 table
Source Target
______ ____________
"437" "CDXXXVII"
"431" "CDXXXI"
"102" "CII"
"862" "DCCCLXII"
"738" "DCCXXXVIII"
"527" "DXXVII"
"401" "CDI"
"184" "CLXXXIV"
Предварительно обработайте текстовые данные с помощью transformText
функции, перечисленной в конце примера. The transformText
функция выполняет предварительную обработку и маркировку входного текста для перевода путем разделения текста на символы и добавления начальных и стоповых лексем. Чтобы перевести текст путем разделения текста на слова вместо символов, пропустите первый шаг.
startToken = "<start>"; stopToken = "<stop>"; strSource = dataTrain{:,1}; documentsSource = transformText(strSource,startToken,stopToken);
Создайте wordEncoding
объект, который преобразует лексемы в числовой индекс и наоборот с помощью словаря.
encSource = wordEncoding(documentsSource);
Используя кодировку слова, преобразуйте исходные текстовые данные в числовые последовательности.
sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');
Преобразуйте целевые данные в последовательности с помощью тех же шагов.
strTarget = dataTrain{:,2}; documentsTarget = transformText(strTarget,startToken,stopToken); encTarget = wordEncoding(documentsTarget); sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');
Сортировка последовательностей по длине. Обучение с последовательностями, отсортированными путем увеличения длины последовательности, приводит к пакетам с последовательностями примерно той же длины последовательности и гарантирует, что меньшие пакеты последовательности используются для обновления модели перед более длинными пакетами последовательности.
sequenceLengths = cellfun(@(sequence) size(sequence,2),sequencesSource); [~,idx] = sort(sequenceLengths); sequencesSource = sequencesSource(idx); sequencesTarget = sequencesTarget(idx);
Создание arrayDatastore
объекты, содержащие исходные и целевые данные, и объединяют их с помощью combine
функция.
sequencesSourceDs = arrayDatastore(sequencesSource,'OutputType','same'); sequencesTargetDs = arrayDatastore(sequencesTarget,'OutputType','same'); sequencesDs = combine(sequencesSourceDs,sequencesTargetDs);
Инициализируйте параметры модели. для энкодера и декодера задайте размерность встраивания 128, два слоя LSTM с 200 скрытыми модулями и слои отсева со случайным отсоединением с вероятностью 0,05.
embeddingDimension = 128; numHiddenUnits = 200; dropout = 0.05;
Инициализируйте веса встраивания кодирования с помощью Гауссова с помощью initializeGaussian
функция, которая присоединена к этому примеру как вспомогательный файл. Задайте среднее значение 0 и стандартное отклонение 0,01. Для получения дополнительной информации смотрите Гауссову инициализацию (Deep Learning Toolbox).
inputSize = encSource.NumWords + 1; sz = [embeddingDimension inputSize]; mu = 0; sigma = 0.01; parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);
Инициализируйте настраиваемые параметры для операций LSTM энкодера:
Инициализируйте входные веса с помощью инициализатора Glorot с помощью initializeGlorot
функция, которая присоединена к этому примеру как вспомогательный файл. Для получения дополнительной информации смотрите Glorot Initialization (Deep Learning Toolbox).
Инициализируйте повторяющиеся веса с помощью ортогонального инициализатора с помощью initializeOrthogonal
функция, которая присоединена к этому примеру как вспомогательный файл. Дополнительные сведения см. в разделе Ортогональная инициализация (Deep Learning Toolbox).
Инициализируйте смещение с помощью модуля forget gate initializer с помощью initializeUnitForgetGate
функция, которая присоединена к этому примеру как вспомогательный файл. Для получения дополнительной информации смотрите Unit Forget Gate Initialization (Deep Learning Toolbox).
Инициализируйте настраиваемые параметры для операции LSTM первого энкодера.
sz = [4*numHiddenUnits embeddingDimension]; numOut = 4*numHiddenUnits; numIn = embeddingDimension; parameters.encoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn); parameters.encoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]); parameters.encoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);
Инициализируйте настраиваемые параметры для операции LSTM второго энкодера.
sz = [4*numHiddenUnits numHiddenUnits]; numOut = 4*numHiddenUnits; numIn = numHiddenUnits; parameters.encoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn); parameters.encoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]); parameters.encoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);
Инициализируйте веса встраивания кодирования с помощью Гауссова с помощью initializeGaussian
функция. Задайте среднее значение 0 и стандартное отклонение 0,01.
outputSize = encTarget.NumWords + 1; sz = [embeddingDimension outputSize]; mu = 0; sigma = 0.01; parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);
Инициализируйте веса механизма внимания с помощью инициализатора Glorot с помощью initializeGlorot
функция.
sz = [numHiddenUnits numHiddenUnits]; numOut = numHiddenUnits; numIn = numHiddenUnits; parameters.decoder.attn.Weights = initializeGlorot(sz,numOut,numIn);
Инициализируйте настраиваемые параметры для операций LSTM декодера:
Инициализируйте входные веса с помощью инициализатора Glorot с помощью initializeGlorot
функция.
Инициализируйте повторяющиеся веса с помощью ортогонального инициализатора с помощью initializeOrthogonal
функция.
Инициализируйте смещение с помощью модуля forget gate initializer с помощью initializeUnitForgetGate
функция.
Инициализируйте настраиваемые параметры для операции LSTM первого декодера.
sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits]; numOut = 4*numHiddenUnits; numIn = embeddingDimension + numHiddenUnits; parameters.decoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn); parameters.decoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]); parameters.decoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);
Инициализируйте настраиваемые параметры для операции LSTM второго декодера.
sz = [4*numHiddenUnits numHiddenUnits]; numOut = 4*numHiddenUnits; numIn = numHiddenUnits; parameters.decoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn); parameters.decoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]); parameters.decoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);
Инициализируйте настраиваемые параметры для полностью подключенной операции декодера:
Инициализируйте веса с помощью инициализатора Glorot.
Инициализируйте смещение с нулями, используя initializeZeros
функция, которая присоединена к этому примеру как вспомогательный файл. Дополнительные сведения см. в разделе Инициализация нулей (Deep Learning Toolbox).
sz = [outputSize 2*numHiddenUnits]; numOut = outputSize; numIn = 2*numHiddenUnits; parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn); parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);
Создайте функции modelEncoder
и modelDecoder
, перечисленных в конце примера, которые вычисляют выходы моделей энкодера и декодера, соответственно.
The modelEncoder
функция, перечисленная в разделе Encoder Model Function примера, берёт входные данные, параметры модели, необязательную маску, которая используется для определения правильных выходов для обучения и возвращает выходные параметры модели и скрытое состояние LSTM.
The modelDecoder
функция, перечисленная в разделе Decoder Model Function примера, берёт входные данные, параметры модели, вектор контекста, начальное скрытое состояние LSTM, выходы энкодера, а также вероятность отсева и выводит выходы декодера, обновленный вектор контекста, обновленное состояние LSTM и счетов внимания.
Создайте функцию modelGradients
, приведенный в разделе Model Gradients Function примера, который принимает параметры модели энкодера и декодера, мини-пакет входных данных и маски заполнения, соответствующие входным данным, и вероятность отсева и возвращает градиенты потерь относительно настраиваемых параметров в моделях и соответствующих потерь.
Обучайте с мини-пакетом размером 32 на 75 эпох со скоростью обучения 0,002.
miniBatchSize = 32; numEpochs = 75; learnRate = 0.002;
Инициализируйте опции из Adam.
gradientDecayFactor = 0.9; squaredGradientDecayFactor = 0.999;
Обучите модель с помощью пользовательского цикла обучения. Использование minibatchqueue
обрабатывать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch
(определено в конце этого примера), чтобы найти длины всей последовательности в мини-пакете и дополнить последовательности той же длиной, что и самая длинная последовательность, для исходной и целевой последовательностей, соответственно.
Транспозиция вторых и третьих размерностей заполненных последовательностей.
Возвращает переменные мини-пакета неформатированные dlarray
объекты с базовым типом данных single
. Все другие выходы являются массивами типа данных single
.
Обучите на графическом процессоре, если он доступен. Возвращает все мини-пакетные переменные на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах см. раздел Поддержка GPU по релизу.
The minibatchqueue
объект возвращает четыре выходных аргументов для каждого мини-пакета: исходные последовательности, целевые последовательности, длины всех исходных последовательностей в мини-пакете и маска последовательности целевых последовательностей.
numMiniBatchOutputs = 4; mbq = minibatchqueue(sequencesDs,numMiniBatchOutputs,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',@(x,t) preprocessMiniBatch(x,t,inputSize,outputSize));
Инициализируйте график процесса обучения.
figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Инициализируйте значения для adamupdate
функция.
trailingAvg = []; trailingAvgSq = [];
Обучите модель. Для каждого мини-пакета:
Считайте мини-пакет заполненных последовательностей.
Вычислите потери и градиенты.
Обновите параметры модели энкодера и декодера с помощью adamupdate
функция.
Обновите график процесса обучения.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs reset(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; [dlX,T,sequenceLengthsSource,maskSequenceTarget] = next(mbq); % Compute loss and gradients. [gradients,loss] = dlfeval(@modelGradients,parameters,dlX,T,sequenceLengthsSource,... maskSequenceTarget,dropout); % Update parameters using adamupdate. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients,trailingAvg,trailingAvgSq,... iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(loss))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end
Чтобы сгенерировать переводы для новых данных с помощью обученной модели, преобразуйте текстовые данные в числовые последовательности с помощью тех же шагов, что и при обучении, и вводите последовательности в модель энкодер-декодер и преобразуйте получившиеся последовательности обратно в текст с помощью индексов маркера.
Предварительно обработайте текстовые данные с помощью тех же шагов, что и при обучении. Используйте transformText
функция, перечисленная в конце примера, чтобы разделить текст на символы и добавить начальные и стоповые лексемы.
strSource = dataTest{:,1}; strTarget = dataTest{:,2};
Переведите текст с помощью modelPredictions
функция.
maxSequenceLength = 10; delimiter = ""; strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ... encSource,encTarget,startToken,stopToken,delimiter);
Составьте таблицу, содержащую тестовый исходный текст, целевой текст и переводы.
tbl = table; tbl.Source = strSource; tbl.Target = strTarget; tbl.Translated = strTranslated;
Просмотр случайного выбора переводов.
idx = randperm(size(dataTest,1),miniBatchSize); tbl(idx,:)
ans=32×3 table
Source Target Translated
______ ___________ ___________
"8" "VIII" "CCCXXVII"
"595" "DXCV" "DCCV"
"523" "DXXIII" "CDXXII"
"675" "DCLXXV" "DCCLXV"
"818" "DCCCXVIII" "DCCCXVIII"
"265" "CCLXV" "CCLXV"
"770" "DCCLXX" "DCCCL"
"904" "CMIV" "CMVII"
"121" "CXXI" "CCXI"
"333" "CCCXXXIII" "CCCXXXIII"
"817" "DCCCXVII" "DCCCXVII"
"37" "XXXVII" "CCCXXXIV"
"335" "CCCXXXV" "CCCXXXV"
"902" "CMII" "CMIII"
"995" "CMXCV" "CMXCV"
"334" "CCCXXXIV" "CCCXXXIV"
⋮
The transformText
функция выполняет предварительную обработку и маркировку входного текста для перевода путем разделения текста на символы и добавления начальных и стоповых лексем. Чтобы перевести текст путем разделения текста на слова вместо символов, пропустите первый шаг.
function documents = transformText(str,startToken,stopToken) str = strip(replace(str,""," ")); str = startToken + str + stopToken; documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]); end
The preprocessMiniBatch
функция, описанная в разделе Модель примера, предварительно обрабатывает данные для обучения. Функция предварительно обрабатывает данные с помощью следующих шагов:
Определите длины всех исходных и целевых последовательностей в мини-пакете
Дополните последовательности такой же длиной, как и самую длинную последовательность в мини-пакете, используя padsequences
функция.
Транспозиция последних двух размерностей последовательностей
function [X,T,sequenceLengthsSource,maskTarget] = preprocessMiniBatch(sequencesSource,sequencesTarget,inputSize,outputSize) sequenceLengthsSource = cellfun(@(x) size(x,2),sequencesSource); X = padsequences(sequencesSource,2,"PaddingValue",inputSize); X = permute(X,[1 3 2]); [T,maskTarget] = padsequences(sequencesTarget,2,"PaddingValue",outputSize); T = permute(T,[1 3 2]); maskTarget = permute(maskTarget,[1 3 2]); end
The modelGradients
функция принимает параметры модели энкодера и декодера, мини-пакет входных данных и маски заполнения, соответствующие входным данным, и вероятность отсева и возвращает градиенты потерь относительно настраиваемых параметров в моделях и соответствующих потерь.
function [gradients,loss] = modelGradients(parameters,dlX,T,... sequenceLengthsSource,maskTarget,dropout) % Forward through encoder. [dlZ,hiddenState] = modelEncoder(parameters.encoder,dlX,sequenceLengthsSource); % Decoder Output. doTeacherForcing = rand < 0.5; sequenceLength = size(T,3); dlY = decoderPredictions(parameters.decoder,dlZ,T,hiddenState,dropout,... doTeacherForcing,sequenceLength); % Masked loss. dlY = dlY(:,:,1:end-1); T = extractdata(gather(T(:,:,2:end))); T = onehotencode(T,1,'ClassNames',1:size(dlY,1)); maskTarget = maskTarget(:,:,2:end); maskTarget = repmat(maskTarget,[size(dlY,1),1,1]); loss = crossentropy(dlY,T,'Mask',maskTarget,'Dataformat','CBT'); % Update gradients. gradients = dlgradient(loss,parameters); % For plotting, return loss normalized by sequence length. loss = extractdata(loss) ./ sequenceLength; end
Функция modelEncoder
принимает входные данные, параметры модели, необязательную маску, которая используется для определения правильных выходов для обучения, и возвращает выходы модели и скрытое состояние LSTM.
Если sequenceLengths
пуст, тогда функция не маскирует выход. Задайте и пустое значение для sequenceLengths
при использовании modelEncoder
функция для предсказания.
function [dlZ, hiddenState] = modelEncoder(parametersEncoder, dlX, sequenceLengths) % Embedding. weights = parametersEncoder.emb.Weights; dlZ = embed(dlX,weights,'DataFormat','CBT'); % LSTM 1. inputWeights = parametersEncoder.lstm1.InputWeights; recurrentWeights = parametersEncoder.lstm1.RecurrentWeights; bias = parametersEncoder.lstm1.Bias; numHiddenUnits = size(recurrentWeights, 2); initialHiddenState = dlarray(zeros([numHiddenUnits 1])); initialCellState = dlarray(zeros([numHiddenUnits 1])); dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ... recurrentWeights, bias, 'DataFormat', 'CBT'); % LSTM 2. inputWeights = parametersEncoder.lstm2.InputWeights; recurrentWeights = parametersEncoder.lstm2.RecurrentWeights; bias = parametersEncoder.lstm2.Bias; [dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ... inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Masking for training. if ~isempty(sequenceLengths) miniBatchSize = size(dlZ,2); for n = 1:miniBatchSize hiddenState(:,n) = dlZ(:,n,sequenceLengths(n)); end end end
Функция modelDecoder
принимает входные данные, параметры модели, вектор контекста, начальное скрытое состояние LSTM, выходы энкодера и вероятность отсева и выводит выход декодера, обновленный вектор контекста, обновленное состояние LSTM и счетов внимания.
function [dlY, context, hiddenState, attentionScores] = modelDecoder(parametersDecoder, dlX, context, ... hiddenState, dlZ, dropout) % Embedding. weights = parametersDecoder.emb.Weights; dlX = embed(dlX, weights,'DataFormat','CBT'); % RNN input. sequenceLength = size(dlX,3); dlY = cat(1, dlX, repmat(context, [1 1 sequenceLength])); % LSTM 1. inputWeights = parametersDecoder.lstm1.InputWeights; recurrentWeights = parametersDecoder.lstm1.RecurrentWeights; bias = parametersDecoder.lstm1.Bias; initialCellState = dlarray(zeros(size(hiddenState))); dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Dropout. mask = ( rand(size(dlY), 'like', dlY) > dropout ); dlY = dlY.*mask; % LSTM 2. inputWeights = parametersDecoder.lstm2.InputWeights; recurrentWeights = parametersDecoder.lstm2.RecurrentWeights; bias = parametersDecoder.lstm2.Bias; [dlY, hiddenState] = lstm(dlY, hiddenState, initialCellState,inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Attention. weights = parametersDecoder.attn.Weights; [attentionScores, context] = attention(hiddenState, dlZ, weights); % Concatenate. dlY = cat(1, dlY, repmat(context, [1 1 sequenceLength])); % Fully connect. weights = parametersDecoder.fc.Weights; bias = parametersDecoder.fc.Bias; dlY = fullyconnect(dlY,weights,bias,'DataFormat','CBT'); % Softmax. dlY = softmax(dlY,'DataFormat','CBT'); end
The attention
функция возвращает счета внимания согласно «общей» оценке Luong и обновленному вектору контекста. Энергия на каждом временном шаге является точечным продуктом скрытого состояния, и выученные веса внимания умножаются на выход энкодера.
function [attentionScores, context] = attention(hiddenState, encoderOutputs, weights) % Initialize attention energies. [miniBatchSize, sequenceLength] = size(encoderOutputs, 2:3); attentionEnergies = zeros([sequenceLength miniBatchSize],'like',hiddenState); % Attention energies. hWX = hiddenState .* pagemtimes(weights,encoderOutputs); for tt = 1:sequenceLength attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1); end % Attention scores. attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB'); % Context. encoderOutputs = permute(encoderOutputs, [1 3 2]); attentionScores = permute(attentionScores,[1 3 2]); context = pagemtimes(encoderOutputs,attentionScores); context = squeeze(context); end
The decoderModelPredictions
функция возвращает предсказанную последовательность dlY
учитывая входную последовательность, целевую последовательность, скрытое состояние, вероятность отсева, флаг, чтобы включить принудительное выполнение учителем, и длину последовательности.
function dlY = decoderPredictions(parametersDecoder,dlZ,T,hiddenState,dropout, ... doTeacherForcing,sequenceLength) % Convert to dlarray. dlT = dlarray(T); % Initialize context. miniBatchSize = size(dlT,2); numHiddenUnits = size(dlZ,1); context = zeros([numHiddenUnits miniBatchSize],'like',dlZ); if doTeacherForcing % Forward through decoder. dlY = modelDecoder(parametersDecoder, dlT, context, hiddenState, dlZ, dropout); else % Get first time step for decoder. decoderInput = dlT(:,:,1); % Initialize output. numClasses = numel(parametersDecoder.fc.Bias); dlY = zeros([numClasses miniBatchSize sequenceLength],'like',decoderInput); % Loop over time steps. for t = 1:sequenceLength % Forward through decoder. [dlY(:,:,t), context, hiddenState] = modelDecoder(parametersDecoder, decoderInput, context, ... hiddenState, dlZ, dropout); % Update decoder input. [~, decoderInput] = max(dlY(:,:,t),[],1); end end end
The translateText
Функция переводит массив текста путем итерации по мини-пакетам. Функция принимает за вход параметры модели, входные строковые массивы, максимальную длину последовательности, размер мини-пакета, объекты кодирования исходного и целевого слова, начальные и стоповые лексемы и разделитель для сборки выхода.
function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ... encSource,encTarget,startToken,stopToken,delimiter) % Transform text. documentsSource = transformText(strSource,startToken,stopToken); sequencesSource = doc2sequence(encSource,documentsSource, ... 'PaddingDirection','right', ... 'PaddingValue',encSource.NumWords + 1); % Convert to dlarray. X = cat(3,sequencesSource{:}); X = permute(X,[1 3 2]); dlX = dlarray(X); % Initialize output. numObservations = numel(strSource); strTranslated = strings(numObservations,1); % Loop over mini-batches. numIterations = ceil(numObservations / miniBatchSize); for i = 1:numIterations idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); miniBatchSize = numel(idxMiniBatch); % Encode using model encoder. sequenceLengths = []; [dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX(:,idxMiniBatch,:), sequenceLengths); % Decoder predictions. doTeacherForcing = false; dropout = 0; decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]); decoderInput = dlarray(decoderInput); dlY = decoderPredictions(parameters.decoder,dlZ,decoderInput,hiddenState,dropout, ... doTeacherForcing,maxSequenceLength); [~, idxPred] = max(extractdata(dlY), [], 1); % Keep translating flag. idxStop = word2ind(encTarget,stopToken); keepTranslating = idxPred ~= idxStop; % Loop over time steps. t = 1; while t <= maxSequenceLength && any(keepTranslating(:,:,t)) % Update output. newWords = ind2word(encTarget, idxPred(:,:,t))'; idxUpdate = idxMiniBatch(keepTranslating(:,:,t)); strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t)); t = t + 1; end end end
doc2sequence
| tokenizedDocument
| word2ind
| wordEncoding
| adamupdate
(Deep Learning Toolbox) | crossentropy
(Deep Learning Toolbox) | dlarray
(Deep Learning Toolbox) | dlfeval
(Deep Learning Toolbox) | dlgradient
(Deep Learning Toolbox) | dlupdate
(Deep Learning Toolbox) | lstm
(Deep Learning Toolbox) | softmax
(Deep Learning Toolbox)