Перемещение последовательности в последовательность с использованием внимания

В этом примере показано, как преобразовать десятичные строки в римские числа с помощью рекуррентной модели кодер-декодер последовательности в последовательность с вниманием.

Модели рекуррентного энкодера-декодера оказались успешными в таких задачах, как абстрактное суммирование текста и нейронный машинный перевод. Модель состоит из энкодера, который обычно обрабатывает входные данные с помощью рекуррентного уровня, такого как LSTM, и декодера, который преобразует кодированный вход в требуемый выход, обычно со вторым рекуррентным слоем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях кодированного входа при генерации перевода.

Для модели энкодера в этом примере используется простая сеть, состоящая из встраивания, за которой следуют две операции LSTM. Встраивание является методом преобразования категориальных лексем в числовые векторы.

Для модели декодера в этом примере используется сеть, очень похожая на энкодер, который содержит два LSTM. Однако важным различием является то, что декодер содержит механизм внимания. Механизм внимания позволяет декодеру следить за конкретными частями выхода энкодера.

Загрузка обучающих данных

Загрузите десятичные-римские числительные пары из "romanNumerals.csv"

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    'TextType','string', ...
    'ReadVariableNames',false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

Разделите данные на обучающие и тестовые разделы, содержащие 50% данных каждый.

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

Просмотрите некоторые десятичные-римские числовые пары.

head(dataTrain)
ans=8×2 table
    Source       Target   
    ______    ____________

    "437"     "CDXXXVII"  
    "431"     "CDXXXI"    
    "102"     "CII"       
    "862"     "DCCCLXII"  
    "738"     "DCCXXXVIII"
    "527"     "DXXVII"    
    "401"     "CDI"       
    "184"     "CLXXXIV"   

Предварительная обработка данных

Предварительно обработайте текстовые данные с помощью transformText функции, перечисленной в конце примера. The transformText функция выполняет предварительную обработку и маркировку входного текста для перевода путем разделения текста на символы и добавления начальных и стоповых лексем. Чтобы перевести текст путем разделения текста на слова вместо символов, пропустите первый шаг.

startToken = "<start>";
stopToken = "<stop>";

strSource = dataTrain{:,1};
documentsSource = transformText(strSource,startToken,stopToken);

Создайте wordEncoding объект, который преобразует лексемы в числовой индекс и наоборот с помощью словаря.

encSource = wordEncoding(documentsSource);

Используя кодировку слова, преобразуйте исходные текстовые данные в числовые последовательности.

sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');

Преобразуйте целевые данные в последовательности с помощью тех же шагов.

strTarget = dataTrain{:,2};
documentsTarget = transformText(strTarget,startToken,stopToken);
encTarget = wordEncoding(documentsTarget);
sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');

Сортировка последовательностей по длине. Обучение с последовательностями, отсортированными путем увеличения длины последовательности, приводит к пакетам с последовательностями примерно той же длины последовательности и гарантирует, что меньшие пакеты последовательности используются для обновления модели перед более длинными пакетами последовательности.

sequenceLengths = cellfun(@(sequence) size(sequence,2),sequencesSource);
[~,idx] = sort(sequenceLengths);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

Создание arrayDatastore объекты, содержащие исходные и целевые данные, и объединяют их с помощью combine функция.

sequencesSourceDs = arrayDatastore(sequencesSource,'OutputType','same');
sequencesTargetDs = arrayDatastore(sequencesTarget,'OutputType','same');

sequencesDs = combine(sequencesSourceDs,sequencesTargetDs);

Инициализация параметров модели

Инициализируйте параметры модели. для энкодера и декодера задайте размерность встраивания 128, два слоя LSTM с 200 скрытыми модулями и слои отсева со случайным отсоединением с вероятностью 0,05.

embeddingDimension = 128;
numHiddenUnits = 200;
dropout = 0.05;

Инициализируйте параметры модели энкодера

Инициализируйте веса встраивания кодирования с помощью Гауссова с помощью initializeGaussian функция, которая присоединена к этому примеру как вспомогательный файл. Задайте среднее значение 0 и стандартное отклонение 0,01. Для получения дополнительной информации смотрите Гауссову инициализацию.

inputSize = encSource.NumWords + 1;
sz = [embeddingDimension inputSize];
mu = 0;
sigma = 0.01;
parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);

Инициализируйте настраиваемые параметры для операций LSTM энкодера:

  • Инициализируйте входные веса с помощью инициализатора Glorot с помощью initializeGlorot функция, которая присоединена к этому примеру как вспомогательный файл. Дополнительные сведения см. в разделе «Инициализация Glorot».

  • Инициализируйте повторяющиеся веса с помощью ортогонального инициализатора с помощью initializeOrthogonal функция, которая присоединена к этому примеру как вспомогательный файл. Дополнительные сведения см. в разделе Ортогональная инициализация.

  • Инициализируйте смещение с помощью модуля forget gate initializer с помощью initializeUnitForgetGate функция, которая присоединена к этому примеру как вспомогательный файл. Для получения дополнительной информации смотрите Unit Forget Gate Initialization.

Инициализируйте настраиваемые параметры для операции LSTM первого энкодера.

sz = [4*numHiddenUnits embeddingDimension];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension;

parameters.encoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для операции LSTM второго энкодера.

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.encoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте параметры модели декодера

Инициализируйте веса встраивания кодирования с помощью Гауссова с помощью initializeGaussian функция. Задайте среднее значение 0 и стандартное отклонение 0,01.

outputSize = encTarget.NumWords + 1;
sz = [embeddingDimension outputSize];
mu = 0;
sigma = 0.01;
parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);

Инициализируйте веса механизма внимания с помощью инициализатора Glorot с помощью initializeGlorot функция.

sz = [numHiddenUnits numHiddenUnits];
numOut = numHiddenUnits;
numIn = numHiddenUnits;
parameters.decoder.attn.Weights = initializeGlorot(sz,numOut,numIn);

Инициализируйте настраиваемые параметры для операций LSTM декодера:

  • Инициализируйте входные веса с помощью инициализатора Glorot с помощью initializeGlorot функция.

  • Инициализируйте повторяющиеся веса с помощью ортогонального инициализатора с помощью initializeOrthogonal функция.

  • Инициализируйте смещение с помощью модуля forget gate initializer с помощью initializeUnitForgetGate функция.

Инициализируйте настраиваемые параметры для операции LSTM первого декодера.

sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension + numHiddenUnits;

parameters.decoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для операции LSTM второго декодера.

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.decoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте настраиваемые параметры для полностью подключенной операции декодера:

  • Инициализируйте веса с помощью инициализатора Glorot.

  • Инициализируйте смещение с нулями, используя initializeZeros функция, которая присоединена к этому примеру как вспомогательный файл. Дополнительные сведения см. в разделе Инициализация нулей.

sz = [outputSize 2*numHiddenUnits];
numOut = outputSize;
numIn = 2*numHiddenUnits;

parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);

Задайте функции модели

Создайте функции modelEncoder и modelDecoder, перечисленных в конце примера, которые вычисляют выходы моделей энкодера и декодера, соответственно.

The modelEncoder функция, перечисленная в разделе Encoder Model Function примера, берёт входные данные, параметры модели, необязательную маску, которая используется для определения правильных выходов для обучения и возвращает выходные параметры модели и скрытое состояние LSTM.

The modelDecoder функция, перечисленная в разделе Decoder Model Function примера, берёт входные данные, параметры модели, вектор контекста, начальное скрытое состояние LSTM, выходы энкодера, а также вероятность отсева и выводит выходы декодера, обновленный вектор контекста, обновленное состояние LSTM и счетов внимания.

Задайте функцию градиентов модели

Создайте функцию modelGradients, приведенный в разделе Model Gradients Function примера, который принимает параметры модели энкодера и декодера, мини-пакет входных данных и маски заполнения, соответствующие входным данным, и вероятность отсева и возвращает градиенты потерь относительно настраиваемых параметров в моделях и соответствующих потерь.

Настройка опций обучения

Обучайте с мини-пакетом размером 32 на 75 эпох со скоростью обучения 0,002.

miniBatchSize = 32;
numEpochs = 75;
learnRate = 0.002;

Инициализируйте опции из Adam.

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучите модель

Обучите модель с помощью пользовательского цикла обучения. Использование minibatchqueue обрабатывать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch (определено в конце этого примера), чтобы найти длины всей последовательности в мини-пакете и дополнить последовательности той же длиной, что и самая длинная последовательность, для исходной и целевой последовательностей, соответственно.

  • Транспозиция вторых и третьих размерностей заполненных последовательностей.

  • Возвращает переменные мини-пакета неформатированные dlarray объекты с базовым типом данных single. Все другие выходы являются массивами типа данных single.

  • Обучите на графическом процессоре, если он доступен. Возвращает все мини-пакетные переменные на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах см. раздел Поддержка GPU по релизу.

The minibatchqueue объект возвращает четыре выходных аргументов для каждого мини-пакета: исходные последовательности, целевые последовательности, длины всех исходных последовательностей в мини-пакете и маска последовательности целевых последовательностей.

numMiniBatchOutputs = 4;

mbq = minibatchqueue(sequencesDs,numMiniBatchOutputs,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@(x,t) preprocessMiniBatch(x,t,inputSize,outputSize));

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])

xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте значения для adamupdate функция.

trailingAvg = [];
trailingAvgSq = [];

Обучите модель. Для каждого мини-пакета:

  • Считайте мини-пакет заполненных последовательностей.

  • Вычислите потери и градиенты.

  • Обновите параметры модели энкодера и декодера с помощью adamupdate функция.

  • Обновите график процесса обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    reset(mbq);
        
    % Loop over mini-batches.
    while hasdata(mbq)
    
        iteration = iteration + 1;
        
        [dlX,T,sequenceLengthsSource,maskSequenceTarget] = next(mbq);
        
        % Compute loss and gradients.
        [gradients,loss] = dlfeval(@modelGradients,parameters,dlX,T,sequenceLengthsSource,...
            maskSequenceTarget,dropout);
        
        % Update parameters using adamupdate.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients,trailingAvg,trailingAvgSq,...
            iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,double(gather(loss)))
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Сгенерируйте переводы

Чтобы сгенерировать переводы для новых данных с помощью обученной модели, преобразуйте текстовые данные в числовые последовательности с помощью тех же шагов, что и при обучении, и вводите последовательности в модель энкодер-декодер и преобразуйте получившиеся последовательности обратно в текст с помощью индексов маркера.

Предварительно обработайте текстовые данные с помощью тех же шагов, что и при обучении. Используйте transformText функция, перечисленная в конце примера, чтобы разделить текст на символы и добавить начальные и стоповые лексемы.

strSource = dataTest{:,1};
strTarget = dataTest{:,2};

Переведите текст с помощью modelPredictions функция.

maxSequenceLength = 10;
delimiter = "";

strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter);

Составьте таблицу, содержащую тестовый исходный текст, целевой текст и переводы.

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = strTranslated;

Просмотр случайного выбора переводов.

idx = randperm(size(dataTest,1),miniBatchSize);
tbl(idx,:)
ans=32×3 table
    Source      Target       Translated 
    ______    ___________    ___________

    "8"       "VIII"         "CCCXXVII" 
    "595"     "DXCV"         "DCCV"     
    "523"     "DXXIII"       "CDXXII"   
    "675"     "DCLXXV"       "DCCLXV"   
    "818"     "DCCCXVIII"    "DCCCXVIII"
    "265"     "CCLXV"        "CCLXV"    
    "770"     "DCCLXX"       "DCCCL"    
    "904"     "CMIV"         "CMVII"    
    "121"     "CXXI"         "CCXI"     
    "333"     "CCCXXXIII"    "CCCXXXIII"
    "817"     "DCCCXVII"     "DCCCXVII" 
    "37"      "XXXVII"       "CCCXXXIV" 
    "335"     "CCCXXXV"      "CCCXXXV"  
    "902"     "CMII"         "CMIII"    
    "995"     "CMXCV"        "CMXCV"    
    "334"     "CCCXXXIV"     "CCCXXXIV" 
      ⋮

Функция преобразования текста

The transformText функция выполняет предварительную обработку и маркировку входного текста для перевода путем разделения текста на символы и добавления начальных и стоповых лексем. Чтобы перевести текст путем разделения текста на слова вместо символов, пропустите первый шаг.

function documents = transformText(str,startToken,stopToken)

str = strip(replace(str,""," "));
str = startToken + str + stopToken;
documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]);

end

Функция мини-пакетной предварительной обработки

The preprocessMiniBatch функция, описанная в разделе Модель примера, предварительно обрабатывает данные для обучения. Функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Определите длины всех исходных и целевых последовательностей в мини-пакете

  2. Дополните последовательности такой же длиной, как и самую длинную последовательность в мини-пакете, используя padsequences функция.

  3. Транспозиция последних двух размерностей последовательностей

function [X,T,sequenceLengthsSource,maskTarget] = preprocessMiniBatch(sequencesSource,sequencesTarget,inputSize,outputSize)

sequenceLengthsSource = cellfun(@(x) size(x,2),sequencesSource);

X = padsequences(sequencesSource,2,"PaddingValue",inputSize);
X = permute(X,[1 3 2]);

[T,maskTarget] = padsequences(sequencesTarget,2,"PaddingValue",outputSize);
T = permute(T,[1 3 2]);
maskTarget = permute(maskTarget,[1 3 2]);

end

Функция градиентов модели

The modelGradients функция принимает параметры модели энкодера и декодера, мини-пакет входных данных и маски заполнения, соответствующие входным данным, и вероятность отсева и возвращает градиенты потерь относительно настраиваемых параметров в моделях и соответствующих потерь.

function [gradients,loss] = modelGradients(parameters,dlX,T,...
    sequenceLengthsSource,maskTarget,dropout)

% Forward through encoder.
[dlZ,hiddenState] = modelEncoder(parameters.encoder,dlX,sequenceLengthsSource);

% Decoder Output.
doTeacherForcing = rand < 0.5;
sequenceLength = size(T,3);
dlY = decoderPredictions(parameters.decoder,dlZ,T,hiddenState,dropout,...
     doTeacherForcing,sequenceLength);

% Masked loss.
dlY = dlY(:,:,1:end-1);
T = extractdata(gather(T(:,:,2:end)));
T = onehotencode(T,1,'ClassNames',1:size(dlY,1));

maskTarget = maskTarget(:,:,2:end); 
maskTarget = repmat(maskTarget,[size(dlY,1),1,1]);

loss = crossentropy(dlY,T,'Mask',maskTarget,'Dataformat','CBT');

% Update gradients.
gradients = dlgradient(loss,parameters);

% For plotting, return loss normalized by sequence length.
loss = extractdata(loss) ./ sequenceLength;

end

Функция модели энкодера

Функция modelEncoder принимает входные данные, параметры модели, необязательную маску, которая используется для определения правильных выходов для обучения, и возвращает выходы модели и скрытое состояние LSTM.

Если sequenceLengths пуст, тогда функция не маскирует выход. Задайте и пустое значение для sequenceLengths при использовании modelEncoder функция для предсказания.

function [dlZ, hiddenState] = modelEncoder(parametersEncoder, dlX, sequenceLengths)

% Embedding.
weights = parametersEncoder.emb.Weights;
dlZ = embed(dlX,weights,'DataFormat','CBT');

% LSTM 1.
inputWeights = parametersEncoder.lstm1.InputWeights;
recurrentWeights = parametersEncoder.lstm1.RecurrentWeights;
bias = parametersEncoder.lstm1.Bias;

numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

% LSTM 2.
inputWeights = parametersEncoder.lstm2.InputWeights;
recurrentWeights = parametersEncoder.lstm2.RecurrentWeights;
bias = parametersEncoder.lstm2.Bias;

[dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Masking for training.
if ~isempty(sequenceLengths)
    miniBatchSize = size(dlZ,2);
    for n = 1:miniBatchSize
        hiddenState(:,n) = dlZ(:,n,sequenceLengths(n));
    end
end

end

Функция модели декодера

Функция modelDecoder принимает входные данные, параметры модели, вектор контекста, начальное скрытое состояние LSTM, выходы энкодера и вероятность отсева и выводит выход декодера, обновленный вектор контекста, обновленное состояние LSTM и счетов внимания.

function [dlY, context, hiddenState, attentionScores] = modelDecoder(parametersDecoder, dlX, context, ...
    hiddenState, dlZ, dropout)

% Embedding.
weights = parametersDecoder.emb.Weights;
dlX = embed(dlX, weights,'DataFormat','CBT');

% RNN input.
sequenceLength = size(dlX,3);
dlY = cat(1, dlX, repmat(context, [1 1 sequenceLength]));

% LSTM 1.
inputWeights = parametersDecoder.lstm1.InputWeights;
recurrentWeights = parametersDecoder.lstm1.RecurrentWeights;
bias = parametersDecoder.lstm1.Bias;

initialCellState = dlarray(zeros(size(hiddenState)));

dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Dropout.
mask = ( rand(size(dlY), 'like', dlY) > dropout );
dlY = dlY.*mask;

% LSTM 2.
inputWeights = parametersDecoder.lstm2.InputWeights;
recurrentWeights = parametersDecoder.lstm2.RecurrentWeights;
bias = parametersDecoder.lstm2.Bias;

[dlY, hiddenState] = lstm(dlY, hiddenState, initialCellState,inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Attention.
weights = parametersDecoder.attn.Weights;
[attentionScores, context] = attention(hiddenState, dlZ, weights);

% Concatenate.
dlY = cat(1, dlY, repmat(context, [1 1 sequenceLength]));

% Fully connect.
weights = parametersDecoder.fc.Weights;
bias = parametersDecoder.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CBT');

% Softmax.
dlY = softmax(dlY,'DataFormat','CBT');

end

Функция внимания

The attention функция возвращает счета внимания согласно «общей» оценке Luong и обновленному вектору контекста. Энергия на каждом временном шаге является точечным продуктом скрытого состояния, и выученные веса внимания умножаются на выход энкодера.

function [attentionScores, context] = attention(hiddenState, encoderOutputs, weights)

% Initialize attention energies.
[miniBatchSize, sequenceLength] = size(encoderOutputs, 2:3);
attentionEnergies = zeros([sequenceLength miniBatchSize],'like',hiddenState);

% Attention energies.
hWX = hiddenState .* pagemtimes(weights,encoderOutputs);
for tt = 1:sequenceLength
    attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1);
end

% Attention scores.
attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB');

% Context.
encoderOutputs = permute(encoderOutputs, [1 3 2]);
attentionScores = permute(attentionScores,[1 3 2]);
context = pagemtimes(encoderOutputs,attentionScores);
context = squeeze(context);

end

Функция предсказаний модели декодера

The decoderModelPredictions функция возвращает предсказанную последовательность dlY учитывая входную последовательность, целевую последовательность, скрытое состояние, вероятность отсева, флаг, чтобы включить принудительное выполнение учителем, и длину последовательности.

function dlY = decoderPredictions(parametersDecoder,dlZ,T,hiddenState,dropout, ...
    doTeacherForcing,sequenceLength)

% Convert to dlarray.
dlT = dlarray(T);

% Initialize context.
miniBatchSize = size(dlT,2);
numHiddenUnits = size(dlZ,1);
context = zeros([numHiddenUnits miniBatchSize],'like',dlZ);

if doTeacherForcing
    % Forward through decoder.
    dlY = modelDecoder(parametersDecoder, dlT, context, hiddenState, dlZ, dropout);
else
    % Get first time step for decoder.
    decoderInput = dlT(:,:,1);
    
    % Initialize output.
    numClasses = numel(parametersDecoder.fc.Bias);
    dlY = zeros([numClasses miniBatchSize sequenceLength],'like',decoderInput);
    
    % Loop over time steps.
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY(:,:,t), context, hiddenState] = modelDecoder(parametersDecoder, decoderInput, context, ...
            hiddenState, dlZ, dropout);
        
        % Update decoder input.
        [~, decoderInput] = max(dlY(:,:,t),[],1);
    end
end

end

Функция перевода текста

The translateText Функция переводит массив текста путем итерации по мини-пакетам. Функция принимает за вход параметры модели, входные строковые массивы, максимальную длину последовательности, размер мини-пакета, объекты кодирования исходного и целевого слова, начальные и стоповые лексемы и разделитель для сборки выхода.

function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter)

% Transform text.
documentsSource = transformText(strSource,startToken,stopToken);
sequencesSource = doc2sequence(encSource,documentsSource, ...
    'PaddingDirection','right', ...
    'PaddingValue',encSource.NumWords + 1);

% Convert to dlarray.
X = cat(3,sequencesSource{:});
X = permute(X,[1 3 2]);
dlX = dlarray(X);

% Initialize output.
numObservations = numel(strSource);
strTranslated = strings(numObservations,1);

% Loop over mini-batches.
numIterations = ceil(numObservations / miniBatchSize);
for i = 1:numIterations
    idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    miniBatchSize = numel(idxMiniBatch);
    
    % Encode using model encoder.
    sequenceLengths = [];
    [dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX(:,idxMiniBatch,:), sequenceLengths);
        
    % Decoder predictions.
    doTeacherForcing = false;
    dropout = 0;
    decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]);
    decoderInput = dlarray(decoderInput);
    dlY = decoderPredictions(parameters.decoder,dlZ,decoderInput,hiddenState,dropout, ...
        doTeacherForcing,maxSequenceLength);
    [~, idxPred] = max(extractdata(dlY), [], 1);
    
    % Keep translating flag.
    idxStop = word2ind(encTarget,stopToken);
    keepTranslating = idxPred ~= idxStop;
     
    % Loop over time steps.
    t = 1;
    while t <= maxSequenceLength && any(keepTranslating(:,:,t))
    
        % Update output.
        newWords = ind2word(encTarget, idxPred(:,:,t))';
        idxUpdate = idxMiniBatch(keepTranslating(:,:,t));
        strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t));
        
        t = t + 1;
    end
end

end

См. также

| | | | | | | | (Symbolic Math Toolbox) | (Symbolic Math Toolbox) | (Symbolic Math Toolbox) | (Symbolic Math Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте