Перевод от последовательности к последовательности Используя внимание

В этом примере показано, как преобразовать десятичные строки в Римские цифры с помощью модели декодера энкодера повторяющейся последовательности к последовательности с вниманием.

Текущие модели декодера энкодера оказались успешными в задачах как абстрактное текстовое резюмирование и нейронный машинный перевод. Модель состоит из энкодера, который обычно входные данные процессов с текущим слоем, такие как LSTM и декодер, который сопоставляет закодированный вход в желаемый выход, обычно со вторым текущим слоем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях закодированного входа при генерации перевода.

Для модели энкодера этот пример использует простую сеть, состоящую из встраивания, сопровождаемого двумя операциями LSTM. Встраивание является методом преобразования категориальных лексем в числовые векторы.

Для модели декодера этот пример использует сеть, очень похожую на энкодер, который содержит два LSTMs. Однако важное различие - то, что декодер содержит механизм внимания. Механизм внимания позволяет декодеру проявлять внимание к определенным частям энкодера выход.

Загрузите обучающие данные

Загрузите пары десятичной Римской цифры с "romanNumerals.csv"

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    'TextType','string', ...
    'ReadVariableNames',false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

Разделите данные в обучение и протестируйте разделы, содержащие 50% данных каждый.

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

Просмотрите некоторые пары десятичной Римской цифры.

head(dataTrain)
ans=8×2 table
    Source       Target   
    ______    ____________

    "437"     "CDXXXVII"  
    "431"     "CDXXXI"    
    "102"     "CII"       
    "862"     "DCCCLXII"  
    "738"     "DCCXXXVIII"
    "527"     "DXXVII"    
    "401"     "CDI"       
    "184"     "CLXXXIV"   

Предварительная Обработка Данных

Предварительно обработайте текстовые данные с помощью transformText функция, перечисленная в конце примера. transformText функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в символы, и добавление запускают и останавливают лексемы. Чтобы перевести текст путем разделения текста в слова вместо символов, пропустите первый шаг.

startToken = "<start>";
stopToken = "<stop>";

strSource = dataTrain{:,1};
documentsSource = transformText(strSource,startToken,stopToken);

Создайте wordEncoding возразите что лексемы карт против числового индекса и наоборот использования словаря.

encSource = wordEncoding(documentsSource);

Используя кодирование слова, преобразуйте данные об исходном тексте в числовые последовательности.

sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');

Преобразуйте целевые данные в последовательности с помощью тех же шагов.

strTarget = dataTrain{:,2};
documentsTarget = transformText(strTarget,startToken,stopToken);
encTarget = wordEncoding(documentsTarget);
sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');

Сортировка последовательностей длиной. Обучение с последовательностями, отсортированными путем увеличения длины последовательности, приводит к пакетам с последовательностями приблизительно той же длины последовательности и гарантирует, что меньшие пакеты последовательности используются, чтобы обновить модель перед более длинными пакетами последовательности.

sequenceLengths = cellfun(@(sequence) size(sequence,2),sequencesSource);
[~,idx] = sort(sequenceLengths);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

Создайте arrayDatastore объекты, содержащие входные и выходные данные и, комбинируют их использующий combine функция.

sequencesSourceDs = arrayDatastore(sequencesSource,'OutputType','same');
sequencesTargetDs = arrayDatastore(sequencesTarget,'OutputType','same');

sequencesDs = combine(sequencesSourceDs,sequencesTargetDs);

Инициализируйте параметры модели

Инициализируйте параметры модели. и для энкодера и для декодера, задайте размерность встраивания 128, два слоя LSTM с 200 скрытыми модулями и слои уволенного со случайным уволенным с вероятностью 0.05.

embeddingDimension = 128;
numHiddenUnits = 200;
dropout = 0.05;

Инициализируйте параметры модели энкодера

Инициализируйте веса встраивания кодирования с помощью Гауссова использования initializeGaussian функция, которая присоединена к этому примеру как к вспомогательному файлу. Задайте среднее значение 0 и стандартное отклонение 0,01. Чтобы узнать больше, смотрите Гауссову Инициализацию.

inputSize = encSource.NumWords + 1;
sz = [embeddingDimension inputSize];
mu = 0;
sigma = 0.01;
parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);

Инициализируйте настраиваемые параметры для операций LSTM энкодера:

  • Инициализируйте входные веса инициализатором Glorot с помощью initializeGlorot функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, см. Инициализацию Glorot.

  • Инициализируйте текущие веса ортогональным инициализатором с помощью initializeOrthogonal функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите Ортогональную Инициализацию.

  • Инициализируйте смещение модулем, забывают инициализатор логического элемента с помощью initializeUnitForgetGate функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите, что Модуль Забывает Инициализацию Логического элемента.

Инициализируйте настраиваемые параметры для первой операции LSTM энкодера.

sz = [4*numHiddenUnits embeddingDimension];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension;

parameters.encoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для второй операции LSTM энкодера.

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.encoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте параметры модели декодера

Инициализируйте веса встраивания кодирования с помощью Гауссова использования initializeGaussian функция. Задайте среднее значение 0 и стандартное отклонение 0,01.

outputSize = encTarget.NumWords + 1;
sz = [embeddingDimension outputSize];
mu = 0;
sigma = 0.01;
parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);

Инициализируйте веса механизма внимания с помощью инициализатора Glorot с помощью initializeGlorot функция.

sz = [numHiddenUnits numHiddenUnits];
numOut = numHiddenUnits;
numIn = numHiddenUnits;
parameters.decoder.attn.Weights = initializeGlorot(sz,numOut,numIn);

Инициализируйте настраиваемые параметры для операций LSTM декодера:

  • Инициализируйте входные веса инициализатором Glorot с помощью initializeGlorot функция.

  • Инициализируйте текущие веса ортогональным инициализатором с помощью initializeOrthogonal функция.

  • Инициализируйте смещение модулем, забывают инициализатор логического элемента с помощью initializeUnitForgetGate функция.

Инициализируйте настраиваемые параметры для первой операции LSTM декодера.

sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension + numHiddenUnits;

parameters.decoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для второй операции LSTM декодера.

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.decoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для декодера, полностью соединил операцию:

  • Инициализируйте веса инициализатором Glorot.

  • Инициализируйте смещение нулями с помощью initializeZeros функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы узнать больше, смотрите Нулевую Инициализацию.

sz = [outputSize 2*numHiddenUnits];
numOut = outputSize;
numIn = 2*numHiddenUnits;

parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);

Функции модели Define

Создайте функции modelEncoder и modelDecoder, перечисленный в конце примера, которые вычисляют выходные параметры моделей энкодера и декодера, соответственно.

modelEncoder функция, перечисленная в разделе Encoder Model Function примера, берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходные параметры модели и LSTM скрытое состояние.

modelDecoder функция, перечисленная в разделе Decoder Model Function примера, берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в разделе Model Gradients Function примера, который берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 32 в течение 75 эпох со скоростью обучения 0,002.

miniBatchSize = 32;
numEpochs = 75;
learnRate = 0.002;

Инициализируйте опции от Адама.

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучите модель

Обучите модель с помощью пользовательского учебного цикла. Используйте minibatchqueue обработать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера), чтобы найти длины всей последовательности в мини-пакете и заполнить последовательности к той же длине как самая длинная последовательность, для входных и выходных последовательностей, соответственно.

  • Переставьте вторые и третьи размерности заполненных последовательностей.

  • Возвратитесь мини-пакетные переменные восстановили после форматирования dlarray объекты с базовым типом данных single. Все другие выходные параметры являются массивами типа данных single.

  • Обучайтесь на графическом процессоре, если вы доступны. Возвратите все мини-пакетные переменные на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом.

minibatchqueue объект возвращает четыре выходных аргумента в пользу каждого мини-пакета: исходные последовательности, целевые последовательности, длины всех исходных последовательностей в мини-пакете и маска последовательности целевых последовательностей.

numMiniBatchOutputs = 4;

mbq = minibatchqueue(sequencesDs,numMiniBatchOutputs,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@(x,t) preprocessMiniBatch(x,t,inputSize,outputSize));

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])

xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте значения для adamupdate функция.

trailingAvg = [];
trailingAvgSq = [];

Обучите модель. Для каждого мини-пакета:

  • Считайте мини-пакет заполненных последовательностей.

  • Вычислите потерю и градиенты.

  • Обновите параметры модели энкодера и декодера с помощью adamupdate функция.

  • Обновите график процесса обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    reset(mbq);
        
    % Loop over mini-batches.
    while hasdata(mbq)
    
        iteration = iteration + 1;
        
        [dlX,T,sequenceLengthsSource,maskSequenceTarget] = next(mbq);
        
        % Compute loss and gradients.
        [gradients,loss] = dlfeval(@modelGradients,parameters,dlX,T,sequenceLengthsSource,...
            maskSequenceTarget,dropout);
        
        % Update parameters using adamupdate.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients,trailingAvg,trailingAvgSq,...
            iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,double(gather(loss)))
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Сгенерируйте переводы

Чтобы сгенерировать переводы для новых данных с помощью обученной модели, преобразуйте текстовые данные в числовые последовательности с помощью тех же шагов как тогда, когда обучение и ввело последовательности в модель декодера энкодера и преобразует получившиеся последовательности назад в текст с помощью маркерных индексов.

Предварительно обработайте текстовые данные с помощью тех же шагов как тогда, когда обучение. Используйте transformText функция, перечисленная в конце примера, чтобы разделить текст в символы и добавить запуск и лексемы остановки.

strSource = dataTest{:,1};
strTarget = dataTest{:,2};

Переведите текст с помощью modelPredictions функция.

maxSequenceLength = 10;
delimiter = "";

strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter);

Составьте таблицу, содержащую тестовый исходный текст, целевой текст и переводы.

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = strTranslated;

Просмотрите случайный выбор переводов.

idx = randperm(size(dataTest,1),miniBatchSize);
tbl(idx,:)
ans=32×3 table
    Source      Target       Translated 
    ______    ___________    ___________

    "8"       "VIII"         "CCCXXVII" 
    "595"     "DXCV"         "DCCV"     
    "523"     "DXXIII"       "CDXXII"   
    "675"     "DCLXXV"       "DCCLXV"   
    "818"     "DCCCXVIII"    "DCCCXVIII"
    "265"     "CCLXV"        "CCLXV"    
    "770"     "DCCLXX"       "DCCCL"    
    "904"     "CMIV"         "CMVII"    
    "121"     "CXXI"         "CCXI"     
    "333"     "CCCXXXIII"    "CCCXXXIII"
    "817"     "DCCCXVII"     "DCCCXVII" 
    "37"      "XXXVII"       "CCCXXXIV" 
    "335"     "CCCXXXV"      "CCCXXXV"  
    "902"     "CMII"         "CMIII"    
    "995"     "CMXCV"        "CMXCV"    
    "334"     "CCCXXXIV"     "CCCXXXIV" 
      ⋮

Текстовая функция преобразования

transformText функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в символы, и добавление запускают и останавливают лексемы. Чтобы перевести текст путем разделения текста в слова вместо символов, пропустите первый шаг.

function documents = transformText(str,startToken,stopToken)

str = strip(replace(str,""," "));
str = startToken + str + stopToken;
documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]);

end

Функция предварительной обработки мини-пакета

preprocessMiniBatch функция, описанная в разделе Train Model примера, предварительно обрабатывает данные для обучения. Функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Определите длины всех входных и выходных последовательностей в мини-пакете

  2. Заполните последовательности к той же длине как самая длинная последовательность в мини-пакете с помощью padsequences функция.

  3. Переставьте последние две размерности последовательностей

function [X,T,sequenceLengthsSource,maskTarget] = preprocessMiniBatch(sequencesSource,sequencesTarget,inputSize,outputSize)

sequenceLengthsSource = cellfun(@(x) size(x,2),sequencesSource);

X = padsequences(sequencesSource,2,"PaddingValue",inputSize);
X = permute(X,[1 3 2]);

[T,maskTarget] = padsequences(sequencesTarget,2,"PaddingValue",outputSize);
T = permute(T,[1 3 2]);
maskTarget = permute(maskTarget,[1 3 2]);

end

Функция градиентов модели

modelGradients функционируйте берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.

function [gradients,loss] = modelGradients(parameters,dlX,T,...
    sequenceLengthsSource,maskTarget,dropout)

% Forward through encoder.
[dlZ,hiddenState] = modelEncoder(parameters.encoder,dlX,sequenceLengthsSource);

% Decoder Output.
doTeacherForcing = rand < 0.5;
sequenceLength = size(T,3);
dlY = decoderPredictions(parameters.decoder,dlZ,T,hiddenState,dropout,...
     doTeacherForcing,sequenceLength);

% Masked loss.
dlY = dlY(:,:,1:end-1);
T = extractdata(gather(T(:,:,2:end)));
T = onehotencode(T,1,'ClassNames',1:size(dlY,1));

maskTarget = maskTarget(:,:,2:end); 
maskTarget = repmat(maskTarget,[size(dlY,1),1,1]);

loss = crossentropy(dlY,T,'Mask',maskTarget,'Dataformat','CBT');

% Update gradients.
gradients = dlgradient(loss,parameters);

% For plotting, return loss normalized by sequence length.
loss = extractdata(loss) ./ sequenceLength;

end

Функция модели энкодера

Функциональный modelEncoder берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходной параметр модели и LSTM скрытое состояние.

Если sequenceLengths пусто, затем функция не маскирует выход. Задайте и пустое значение для sequenceLengths при использовании modelEncoder функция для предсказания.

function [dlZ, hiddenState] = modelEncoder(parametersEncoder, dlX, sequenceLengths)

% Embedding.
weights = parametersEncoder.emb.Weights;
dlZ = embed(dlX,weights,'DataFormat','CBT');

% LSTM 1.
inputWeights = parametersEncoder.lstm1.InputWeights;
recurrentWeights = parametersEncoder.lstm1.RecurrentWeights;
bias = parametersEncoder.lstm1.Bias;

numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

% LSTM 2.
inputWeights = parametersEncoder.lstm2.InputWeights;
recurrentWeights = parametersEncoder.lstm2.RecurrentWeights;
bias = parametersEncoder.lstm2.Bias;

[dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Masking for training.
if ~isempty(sequenceLengths)
    miniBatchSize = size(dlZ,2);
    for n = 1:miniBatchSize
        hiddenState(:,n) = dlZ(:,n,sequenceLengths(n));
    end
end

end

Функция модели декодера

Функциональный modelDecoder берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.

function [dlY, context, hiddenState, attentionScores] = modelDecoder(parametersDecoder, dlX, context, ...
    hiddenState, dlZ, dropout)

% Embedding.
weights = parametersDecoder.emb.Weights;
dlX = embed(dlX, weights,'DataFormat','CBT');

% RNN input.
sequenceLength = size(dlX,3);
dlY = cat(1, dlX, repmat(context, [1 1 sequenceLength]));

% LSTM 1.
inputWeights = parametersDecoder.lstm1.InputWeights;
recurrentWeights = parametersDecoder.lstm1.RecurrentWeights;
bias = parametersDecoder.lstm1.Bias;

initialCellState = dlarray(zeros(size(hiddenState)));

dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Dropout.
mask = ( rand(size(dlY), 'like', dlY) > dropout );
dlY = dlY.*mask;

% LSTM 2.
inputWeights = parametersDecoder.lstm2.InputWeights;
recurrentWeights = parametersDecoder.lstm2.RecurrentWeights;
bias = parametersDecoder.lstm2.Bias;

[dlY, hiddenState] = lstm(dlY, hiddenState, initialCellState,inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Attention.
weights = parametersDecoder.attn.Weights;
[attentionScores, context] = attention(hiddenState, dlZ, weights);

% Concatenate.
dlY = cat(1, dlY, repmat(context, [1 1 sequenceLength]));

% Fully connect.
weights = parametersDecoder.fc.Weights;
bias = parametersDecoder.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CBT');

% Softmax.
dlY = softmax(dlY,'DataFormat','CBT');

end

Функция внимания

attention функция возвращает баллы внимания по данным Luong "общий" выигрыш и обновленный вектор контекста. Энергия на каждом временном шаге является скалярным произведением скрытого состояния и learnable времена весов внимания энкодера выход.

function [attentionScores, context] = attention(hiddenState, encoderOutputs, weights)

% Initialize attention energies.
[miniBatchSize, sequenceLength] = size(encoderOutputs, 2:3);
attentionEnergies = zeros([sequenceLength miniBatchSize],'like',hiddenState);

% Attention energies.
hWX = hiddenState .* pagemtimes(weights,encoderOutputs);
for tt = 1:sequenceLength
    attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1);
end

% Attention scores.
attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB');

% Context.
encoderOutputs = permute(encoderOutputs, [1 3 2]);
attentionScores = permute(attentionScores,[1 3 2]);
context = pagemtimes(encoderOutputs,attentionScores);
context = squeeze(context);

end

Функция предсказаний модели декодера

decoderModelPredictions функция возвращает предсказанную последовательность dlY учитывая входную последовательность, предназначайтесь для последовательности, скрытого состояния, вероятности уволенного, флаг, чтобы включить учителю, обеспечивающему и длине последовательности.

function dlY = decoderPredictions(parametersDecoder,dlZ,T,hiddenState,dropout, ...
    doTeacherForcing,sequenceLength)

% Convert to dlarray.
dlT = dlarray(T);

% Initialize context.
miniBatchSize = size(dlT,2);
numHiddenUnits = size(dlZ,1);
context = zeros([numHiddenUnits miniBatchSize],'like',dlZ);

if doTeacherForcing
    % Forward through decoder.
    dlY = modelDecoder(parametersDecoder, dlT, context, hiddenState, dlZ, dropout);
else
    % Get first time step for decoder.
    decoderInput = dlT(:,:,1);
    
    % Initialize output.
    numClasses = numel(parametersDecoder.fc.Bias);
    dlY = zeros([numClasses miniBatchSize sequenceLength],'like',decoderInput);
    
    % Loop over time steps.
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY(:,:,t), context, hiddenState] = modelDecoder(parametersDecoder, decoderInput, context, ...
            hiddenState, dlZ, dropout);
        
        % Update decoder input.
        [~, decoderInput] = max(dlY(:,:,t),[],1);
    end
end

end

Функция перевода текста

translateText функция переводит массив текста путем итерации по мини-пакетам. Функция берет в качестве входа параметры модели, массив входной строки, максимальную длину последовательности, мини-пакетный размер, входные и выходные объекты кодирования слова, запуск и лексемы остановки и разделитель для сборки выхода.

function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter)

% Transform text.
documentsSource = transformText(strSource,startToken,stopToken);
sequencesSource = doc2sequence(encSource,documentsSource, ...
    'PaddingDirection','right', ...
    'PaddingValue',encSource.NumWords + 1);

% Convert to dlarray.
X = cat(3,sequencesSource{:});
X = permute(X,[1 3 2]);
dlX = dlarray(X);

% Initialize output.
numObservations = numel(strSource);
strTranslated = strings(numObservations,1);

% Loop over mini-batches.
numIterations = ceil(numObservations / miniBatchSize);
for i = 1:numIterations
    idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    miniBatchSize = numel(idxMiniBatch);
    
    % Encode using model encoder.
    sequenceLengths = [];
    [dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX(:,idxMiniBatch,:), sequenceLengths);
        
    % Decoder predictions.
    doTeacherForcing = false;
    dropout = 0;
    decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]);
    decoderInput = dlarray(decoderInput);
    dlY = decoderPredictions(parameters.decoder,dlZ,decoderInput,hiddenState,dropout, ...
        doTeacherForcing,maxSequenceLength);
    [~, idxPred] = max(extractdata(dlY), [], 1);
    
    % Keep translating flag.
    idxStop = word2ind(encTarget,stopToken);
    keepTranslating = idxPred ~= idxStop;
     
    % Loop over time steps.
    t = 1;
    while t <= maxSequenceLength && any(keepTranslating(:,:,t))
    
        % Update output.
        newWords = ind2word(encTarget, idxPred(:,:,t))';
        idxUpdate = idxMiniBatch(keepTranslating(:,:,t));
        strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t));
        
        t = t + 1;
    end
end

end

Смотрите также

| | | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте