Перевод от последовательности к последовательности Используя внимание

В этом примере показано, как преобразовать десятичные строки в Римские цифры с помощью модели декодера энкодера повторяющейся последовательности к последовательности с вниманием.

Текущие модели декодера энкодера оказались успешными в задачах как абстрактное текстовое резюмирование и нейронный машинный перевод. Модели, сопоставимые из энкодера, который обычно входные данные процессов с текущим слоем, такие как LSTM и декодер, который сопоставляет закодированный вход в желаемый выход, обычно со вторым текущим слоем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях закодированного входа при генерации перевода.

Для модели энкодера этот пример использует простую сеть, состоящую из встраивания, сопровождаемого двумя операциями LSTM. Встраивание является методом преобразования категориальных лексем в числовые векторы.

Для модели декодера этот пример использует сеть, очень похожую на энкодер, который содержит два LSTMs. Однако важное различие - то, что декодер содержит механизм внимания. Механизм внимания позволяет декодеру проявлять внимание к определенным частям энкодера выход.

Загрузите обучающие данные

Загрузите пары десятичной Римской цифры с "romanNumerals.csv"

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    'TextType','string', ...
    'ReadVariableNames',false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

Разделите данные в обучение и протестируйте разделы, содержащие 50% данных каждый.

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

Просмотрите некоторые пары десятичной Римской цифры.

head(dataTrain)
ans=8×2 table
    Source      Target  
    ______    __________

    "228"     "CCXXVIII"
    "267"     "CCLXVII" 
    "294"     "CCXCIV"  
    "179"     "CLXXIX"  
    "396"     "CCCXCVI" 
    "2"       "II"      
    "4"       "IV"      
    "270"     "CCLXX"   

Предварительная Обработка Данных

Предварительно обработайте текстовые данные с помощью transformText функция, перечисленная в конце примера. transformText функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в символы, и добавление запускают и останавливают лексемы. Чтобы перевести текст путем разделения текста в слова вместо символов, пропустите первый шаг.

startToken = "<start>";
stopToken = "<stop>";

strSource = dataTrain{:,1};
documentsSource = transformText(strSource,startToken,stopToken);

Создайте wordEncoding возразите что лексемы карт против числового индекса и наоборот использования словаря.

encSource = wordEncoding(documentsSource);

Используя кодирование слова, преобразуйте данные об исходном тексте в числовые последовательности.

sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');

Преобразуйте целевые данные в последовательности с помощью тех же шагов.

strTarget = dataTrain{:,2};
documentsTarget = transformText(strTarget,startToken,stopToken);
encTarget = wordEncoding(documentsTarget);
sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');

Инициализируйте параметры модели

Инициализируйте параметры модели. и для энкодера и для декодера, задайте размерность встраивания 128, два слоя LSTM с 200 скрытыми модулями и слои уволенного со случайным уволенным с вероятностью 0.05.

embeddingDimension = 128;
numHiddenUnits = 200;
dropout = 0.05;

Инициализируйте параметры модели энкодера

Инициализируйте веса встраивания кодирования с помощью Гауссова использования initializeGaussian функция, которая присоединена к этому примеру как к вспомогательному файлу. Задайте среднее значение 0 и стандартное отклонение 0,01. Чтобы узнать больше, смотрите Гауссову Инициализацию.

inputSize = encSource.NumWords + 1;
sz = [embeddingDimension inputSize];
mu = 0;
sigma = 0.01;
parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);

Инициализируйте настраиваемые параметры для операций LSTM энкодера:

  • Инициализируйте входные веса инициализатором Glorot с помощью initializeGlorot функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, см. Инициализацию Glorot.

  • Инициализируйте текущие веса ортогональным инициализатором с помощью initializeOrthogonal функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите Ортогональную Инициализацию.

  • Инициализируйте смещение модулем, забывают инициализатор логического элемента с помощью initializeUnitForgetGate функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите, что Модуль Забывает Инициализацию Логического элемента.

Инициализируйте настраиваемые параметры для первой операции LSTM энкодера.

sz = [4*numHiddenUnits embeddingDimension];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension;

parameters.encoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для второй операции LSTM энкодера.

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.encoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте параметры модели декодера

Инициализируйте веса встраивания кодирования с помощью Гауссова использования initializeGaussian функция. Задайте среднее значение 0 и стандартное отклонение 0,01.

outputSize = encTarget.NumWords + 1;
sz = [embeddingDimension outputSize];
mu = 0;
sigma = 0.01;
parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);

Инициализируйте веса механизма внимания с помощью инициализатора Glorot с помощью initializeGlorot функция.

sz = [numHiddenUnits numHiddenUnits];
numOut = numHiddenUnits;
numIn = numHiddenUnits;
parameters.decoder.attn.Weights = initializeGlorot(sz,numOut,numIn);

Инициализируйте настраиваемые параметры для операций LSTM декодера:

  • Инициализируйте входные веса инициализатором Glorot с помощью initializeGlorot функция.

  • Инициализируйте текущие веса ортогональным инициализатором с помощью initializeOrthogonal функция.

  • Инициализируйте смещение модулем, забывают инициализатор логического элемента с помощью initializeUnitForgetGate функция.

Инициализируйте настраиваемые параметры для первой операции LSTM декодера.

sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension + numHiddenUnits;

parameters.decoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для второй операции LSTM декодера.

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.decoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для декодера, полностью соединил операцию:

  • Инициализируйте веса инициализатором Glorot.

  • Инициализируйте смещение нулями с помощью initializeZeros функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите Нулевую Инициализацию.

sz = [outputSize 2*numHiddenUnits];
numOut = outputSize;
numIn = 2*numHiddenUnits;

parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);

Функции модели Define

Создайте функции modelEncoder и modelDecoder, перечисленный в конце примера, которые вычисляют выходные параметры моделей энкодера и декодера, соответственно.

modelEncoder функция, перечисленная в разделе Encoder Model Function примера, берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходные параметры модели и LSTM скрытое состояние.

modelDecoder функция, перечисленная в разделе Decoder Model Function примера, берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в разделе Model Gradients Function примера, который берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 32 в течение 75 эпох со скоростью обучения 0,002.

miniBatchSize = 32;
numEpochs = 75;
learnRate = 0.002;

Инициализируйте опции от Адама.

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

Обучайтесь с последовательностями, отсортированными путем увеличения длины последовательности. Это приводит к пакетам с последовательностями приблизительно той же длины последовательности и гарантирует, что меньшие пакеты последовательности используются, чтобы обновить модель перед более длинными пакетами последовательности.

Сортировка последовательностей длиной.

sequenceLengths = cellfun(@(sequence) size(sequence,2), sequencesSource);
[~,idx] = sort(sequenceLengths);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])

xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте значения для adamupdate функция.

trailingAvg = [];
trailingAvgSq = [];

Обучите модель. Для каждого мини-пакета:

  • Считайте мини-пакет последовательностей и добавьте дополнение.

  • Преобразуйте данные в dlarray.

  • Вычислите потерю и градиенты.

  • Обновите параметры модели энкодера и декодера с помощью adamupdate функция.

  • Обновите график процесса обучения.

numObservations = numel(sequencesSource);
numIterationsPerEpoch = floor(numObservations/miniBatchSize);

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
        
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and pad.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        [X, sequenceLengthsSource] = padSequences(sequencesSource(idx), inputSize);
        [T, sequenceLengthsTarget] = padSequences(sequencesTarget(idx), outputSize);

        % Convert mini-batch of data to dlarray.
        dlX = dlarray(X);
        
        % Compute loss and gradients.
        [gradients, loss] = dlfeval(@modelGradients, parameters, dlX, T, ...
            sequenceLengthsSource, sequenceLengthsTarget, dropout);
        
        % Update parameters using adamupdate.
        [parameters, trailingAvg, trailingAvgSq] = adamupdate(parameters, gradients, trailingAvg, trailingAvgSq, ...
            iteration, learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,double(gather(loss)))
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Сгенерируйте переводы

Чтобы сгенерировать переводы для новых данных с помощью обученной модели, преобразуйте текстовые данные в числовые последовательности с помощью тех же шагов как тогда, когда обучение и ввело последовательности в модель декодера энкодера и преобразует получившиеся последовательности назад в текст с помощью маркерных индексов.

Предварительно обработайте текстовые данные с помощью тех же шагов как тогда, когда обучение. Используйте transformText функция, перечисленная в конце примера, чтобы разделить текст в символы и добавить запуск и лексемы остановки.

strSource = dataTest{:,1};
strTarget = dataTest{:,2};

Переведите текст с помощью modelPredictions функция.

maxSequenceLength = 10;
delimiter = "";

strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter);

Составьте таблицу, содержащую тестовый исходный текст, целевой текст и переводы.

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = strTranslated;

Просмотрите случайный выбор переводов.

idx = randperm(size(dataTest,1),miniBatchSize);
tbl(idx,:)
ans=32×3 table
    Source      Target      Translated
    ______    __________    __________

    "936"     "CMXXXVI"     "CMXXXVI" 
    "423"     "CDXXIII"     "CDXXIII" 
    "981"     "CMLXXXI"     "CMLXXXIX"
    "200"     "CC"          "CC"      
    "224"     "CCXXIV"      "CCXXIV"  
    "56"      "LVI"         "DLVI"    
    "330"     "CCCXXX"      "CCCXXX"  
    "336"     "CCCXXXVI"    "CCCXXXVI"
    "524"     "DXXIV"       "DXXIV"   
    "860"     "DCCCLX"      "DCCCLX"  
    "318"     "CCCXVIII"    "CCCXVIII"
    "902"     "CMII"        "CMII"    
    "681"     "DCLXXXI"     "DCLXXXI" 
    "299"     "CCXCIX"      "CCXCIX"  
    "931"     "CMXXXI"      "CMXXXIX" 
    "859"     "DCCCLIX"     "DCCCLIX" 
      ⋮

Текстовая функция преобразования

transformText функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в символы, и добавление запускают и останавливают лексемы. Чтобы перевести текст путем разделения текста в слова вместо символов, пропустите первый шаг.

function documents = transformText(str,startToken,stopToken)

str = strip(replace(str,""," "));
str = startToken + str + stopToken;
documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]);

end

Дополнительная функция последовательности

padSequences функционируйте берет мини-пакет последовательностей и дополнительного значения и возвращает дополненные последовательности с соответствующими дополнительными масками.

function [X, sequenceLengths] = padSequences(sequences, paddingValue)

% Initialize mini-batch with padding.
numObservations = size(sequences,1);
sequenceLengths = cellfun(@(x) size(x,2), sequences);
maxLength = max(sequenceLengths);
X = repmat(paddingValue, [1 numObservations maxLength]);

% Insert sequences.
for n = 1:numObservations
    X(1,n,1:sequenceLengths(n)) = sequences{n};
end

end

Функция градиентов модели

modelGradients функционируйте берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.

function [gradients, loss] = modelGradients(parameters, dlX, T, ...
    sequenceLengthsSource, sequenceLengthsTarget, dropout)

% Forward through encoder.
[dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX, sequenceLengthsSource);

% Decoder Output.
doTeacherForcing = rand < 0.5;
sequenceLength = size(T,3);
dlY = decoderPredictions(parameters.decoder,dlZ,T,hiddenState,dropout,...
     doTeacherForcing,sequenceLength);

% Masked loss.
dlY = dlY(:,:,1:end-1);
T = T(:,:,2:end);
T = onehotencode(T,1,'ClassNames',1:size(dlY,1));
loss = maskedCrossEntropy(dlY,T,sequenceLengthsTarget-1);

% Update gradients.
gradients = dlgradient(loss, parameters);

% For plotting, return loss normalized by sequence length.
loss = extractdata(loss) ./ sequenceLength;

end

Функция модели энкодера

Функциональный modelEncoder берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходной параметр модели и LSTM скрытое состояние.

Если sequenceLengths пусто, затем функция не маскирует выход. Задайте и пустое значение для sequenceLengths при использовании modelEncoder функция для предсказания.

function [dlZ, hiddenState] = modelEncoder(parametersEncoder, dlX, sequenceLengths)

% Embedding.
weights = parametersEncoder.emb.Weights;
dlZ = embed(dlX,weights,'DataFormat','CBT');

% LSTM 1.
inputWeights = parametersEncoder.lstm1.InputWeights;
recurrentWeights = parametersEncoder.lstm1.RecurrentWeights;
bias = parametersEncoder.lstm1.Bias;

numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

% LSTM 2.
inputWeights = parametersEncoder.lstm2.InputWeights;
recurrentWeights = parametersEncoder.lstm2.RecurrentWeights;
bias = parametersEncoder.lstm2.Bias;

[dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Masking for training.
if ~isempty(sequenceLengths)
    miniBatchSize = size(dlZ,2);
    for n = 1:miniBatchSize
        hiddenState(:,n) = dlZ(:,n,sequenceLengths(n));
    end
end

end

Функция модели декодера

Функциональный modelDecoder берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.

function [dlY, context, hiddenState, attentionScores] = modelDecoder(parametersDecoder, dlX, context, ...
    hiddenState, dlZ, dropout)

% Embedding.
weights = parametersDecoder.emb.Weights;
dlX = embed(dlX, weights,'DataFormat','CBT');

% RNN input.
sequenceLength = size(dlX,3);
dlY = cat(1, dlX, repmat(context, [1 1 sequenceLength]));

% LSTM 1.
inputWeights = parametersDecoder.lstm1.InputWeights;
recurrentWeights = parametersDecoder.lstm1.RecurrentWeights;
bias = parametersDecoder.lstm1.Bias;

initialCellState = dlarray(zeros(size(hiddenState)));

dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Dropout.
mask = ( rand(size(dlY), 'like', dlY) > dropout );
dlY = dlY.*mask;

% LSTM 2.
inputWeights = parametersDecoder.lstm2.InputWeights;
recurrentWeights = parametersDecoder.lstm2.RecurrentWeights;
bias = parametersDecoder.lstm2.Bias;

[dlY, hiddenState] = lstm(dlY, hiddenState, initialCellState,inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Attention.
weights = parametersDecoder.attn.Weights;
[attentionScores, context] = attention(hiddenState, dlZ, weights);

% Concatenate.
dlY = cat(1, dlY, repmat(context, [1 1 sequenceLength]));

% Fully connect.
weights = parametersDecoder.fc.Weights;
bias = parametersDecoder.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CBT');

% Softmax.
dlY = softmax(dlY,'DataFormat','CBT');

end

Функция внимания

attention функция возвращает баллы внимания по данным Luong "общий" выигрыш и обновленный вектор контекста. Энергия на каждом временном шаге является скалярным произведением скрытого состояния и learnable времена весов внимания энкодера выход.

function [attentionScores, context] = attention(hiddenState, encoderOutputs, weights)

% Initialize attention energies.
[miniBatchSize, sequenceLength] = size(encoderOutputs, 2:3);
attentionEnergies = zeros([sequenceLength miniBatchSize],'like',hiddenState);

% Attention energies.
hWX = hiddenState .* pagemtimes(weights,encoderOutputs);
for tt = 1:sequenceLength
    attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1);
end

% Attention scores.
attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB');

% Context.
encoderOutputs = permute(encoderOutputs, [1 3 2]);
attentionScores = permute(attentionScores,[1 3 2]);
context = pagemtimes(encoderOutputs,attentionScores);
context = squeeze(context);

end

Потеря перекрестной энтропии маскированная

maskedCrossEntropy функция вычисляет потерю между заданными входными последовательностями и целевыми последовательностями, игнорирующими любые временные шаги, содержащие дополняющий использование заданного вектора из длин последовательности.

function loss = maskedCrossEntropy(dlY,T,sequenceLengths)

% Initialize loss.
loss = 0;

% Loop over mini-batch.
miniBatchSize = size(dlY,2);
for n = 1:miniBatchSize
    idx = 1:sequenceLengths(n);
    loss = loss + crossentropy(dlY(:,n,idx), T(:,n,idx),'DataFormat','CBT');
end

% Normalize.
loss = loss / miniBatchSize;

end

Модель декодера функция Predicitons

decoderModelPredictions функция возвращает предсказанную последовательность dlY учитывая входную последовательность, предназначайтесь для последовательности, скрытого состояния, вероятности уволенного, флаг, чтобы включить учителю, обеспечивающему и длине последовательности.

function dlY = decoderPredictions(parametersDecoder,dlZ,T,hiddenState,dropout, ...
    doTeacherForcing,sequenceLength)

% Convert to dlarray.
dlT = dlarray(T);

% Initialize context.
miniBatchSize = size(dlT,2);
numHiddenUnits = size(dlZ,1);
context = zeros([numHiddenUnits miniBatchSize],'like',dlZ);

if doTeacherForcing
    % Forward through decoder.
    dlY = modelDecoder(parametersDecoder, dlT, context, hiddenState, dlZ, dropout);
else
    % Get first time step for decoder.
    decoderInput = dlT(:,:,1);
    
    % Initialize output.
    numClasses = numel(parametersDecoder.fc.Bias);
    dlY = zeros([numClasses miniBatchSize sequenceLength],'like',decoderInput);
    
    % Loop over time steps.
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY(:,:,t), context, hiddenState] = modelDecoder(parametersDecoder, decoderInput, context, ...
            hiddenState, dlZ, dropout);
        
        % Update decoder input.
        [~, decoderInput] = max(dlY(:,:,t),[],1);
    end
end

end

Функция перевода текста

translateText функция переводит массив текста путем итерации по мини-пакетам. Функция берет в качестве входа параметры модели, массив входной строки, максимальную длину последовательности, мини-пакетный размер, входные и выходные объекты кодирования слова, запуск и лексемы остановки и разделитель для сборки выхода.

function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter)

% Transform text.
documentsSource = transformText(strSource,startToken,stopToken);
sequencesSource = doc2sequence(encSource,documentsSource, ...
    'PaddingDirection','right', ...
    'PaddingValue',encSource.NumWords + 1);

% Convert to dlarray.
X = cat(3,sequencesSource{:});
X = permute(X,[1 3 2]);
dlX = dlarray(X);

% Initialize output.
numObservations = numel(strSource);
strTranslated = strings(numObservations,1);

% Loop over mini-batches.
numIterations = ceil(numObservations / miniBatchSize);
for i = 1:numIterations
    idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    miniBatchSize = numel(idxMiniBatch);
    
    % Encode using model encoder.
    sequenceLengths = [];
    [dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX(:,idxMiniBatch,:), sequenceLengths);
        
    % Decoder predictions.
    doTeacherForcing = false;
    dropout = 0;
    decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]);
    decoderInput = dlarray(decoderInput);
    dlY = decoderPredictions(parameters.decoder,dlZ,decoderInput,hiddenState,dropout, ...
        doTeacherForcing,maxSequenceLength);
    [~, idxPred] = max(extractdata(dlY), [], 1);
    
    % Keep translating flag.
    idxStop = word2ind(encTarget,stopToken);
    keepTranslating = idxPred ~= idxStop;
     
    % Loop over time steps.
    t = 1;
    while t <= maxSequenceLength && any(keepTranslating(:,:,t))
    
        % Update output.
        newWords = ind2word(encTarget, idxPred(:,:,t))';
        idxUpdate = idxMiniBatch(keepTranslating(:,:,t));
        strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t));
        
        t = t + 1;
    end
end

end

Смотрите также

| | | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

Похожие темы