exponenta event banner

Преобразование последовательности в последовательность с использованием внимания

В этом примере показано, как преобразовывать десятичные строки в римские цифры с помощью повторяющейся модели кодера-декодера последовательности.

Повторяющиеся модели кодера-декодера оказались успешными при выполнении таких задач, как абстрактное суммирование текста и нейронный машинный перевод. Модель состоит из кодера, который обычно обрабатывает входные данные с повторяющимся уровнем, таким как LSTM, и декодера, который отображает кодированный вход в желаемый выход, обычно со вторым повторяющимся уровнем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях кодированного ввода при генерации трансляции.

Для модели кодера в этом примере используется простая сеть, состоящая из встраивания, за которым следуют две операции LSTM. Встраивание - метод преобразования категориальных токенов в числовые векторы.

Для модели декодера в этом примере используется сеть, очень похожая на кодер, который содержит два LSTM. Однако важным отличием является то, что декодер содержит механизм внимания. Механизм внимания позволяет декодеру следить за конкретными частями выходного сигнала кодера.

Загрузка данных обучения

Загрузить десятичные-римские числовые пары из "romanNumerals.csv"

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    'TextType','string', ...
    'ReadVariableNames',false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

Разбейте данные на разделы обучения и тестирования, содержащие по 50% данных.

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

Просмотрите некоторые десятичные-римские числовые пары.

head(dataTrain)
ans=8×2 table
    Source       Target   
    ______    ____________

    "437"     "CDXXXVII"  
    "431"     "CDXXXI"    
    "102"     "CII"       
    "862"     "DCCCLXII"  
    "738"     "DCCXXXVIII"
    "527"     "DXXVII"    
    "401"     "CDI"       
    "184"     "CLXXXIV"   

Данные предварительной обработки

Предварительная обработка текстовых данных с помощью transformText функция, перечисленная в конце примера. transformText функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста на символы и добавления маркеров запуска и остановки. Чтобы перевести текст путем разделения текста на слова вместо символов, пропустите первый шаг.

startToken = "<start>";
stopToken = "<stop>";

strSource = dataTrain{:,1};
documentsSource = transformText(strSource,startToken,stopToken);

Создать wordEncoding объект, который сопоставляет маркеры числовому индексу и наоборот с использованием словаря.

encSource = wordEncoding(documentsSource);

С помощью кодировки слов преобразуйте данные исходного текста в числовые последовательности.

sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');

Преобразуйте целевые данные в последовательности, используя те же шаги.

strTarget = dataTrain{:,2};
documentsTarget = transformText(strTarget,startToken,stopToken);
encTarget = wordEncoding(documentsTarget);
sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');

Сортировка последовательностей по длине. Обучение с последовательностями, отсортированными путем увеличения длины последовательности, приводит к пакетам с последовательностями приблизительно одинаковой длины и гарантирует использование меньших партий последовательности для обновления модели перед более длинными пакетами последовательности.

sequenceLengths = cellfun(@(sequence) size(sequence,2),sequencesSource);
[~,idx] = sort(sequenceLengths);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

Создать arrayDatastore объекты, содержащие исходные и целевые данные и объединяющие их с помощью combine функция.

sequencesSourceDs = arrayDatastore(sequencesSource,'OutputType','same');
sequencesTargetDs = arrayDatastore(sequencesTarget,'OutputType','same');

sequencesDs = combine(sequencesSourceDs,sequencesTargetDs);

Инициализация параметров модели

Инициализируйте параметры модели. как для кодера, так и для декодера укажите размер встраивания 128, два уровня LSTM с 200 скрытыми блоками и уровни отсева с случайным отсевом с вероятностью 0,05.

embeddingDimension = 128;
numHiddenUnits = 200;
dropout = 0.05;

Инициализация параметров модели кодировщика

Инициализируйте веса внедрения кодирования с помощью гауссова initializeGaussian функция, которая присоединена к этому примеру в качестве вспомогательного файла. Укажите среднее значение 0 и стандартное отклонение 0,01. Дополнительные сведения см. в разделе Гауссовская инициализация (Deep Learning Toolbox).

inputSize = encSource.NumWords + 1;
sz = [embeddingDimension inputSize];
mu = 0;
sigma = 0.01;
parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);

Инициализируйте обучаемые параметры для операций кодера LSTM:

  • Инициализируйте входные веса с помощью инициализатора Glorot с помощью initializeGlorot функция, которая присоединена к этому примеру в качестве вспомогательного файла. Дополнительные сведения см. в разделе Инициализация Glorot (Deep Learning Toolbox).

  • Инициализируйте повторяющиеся веса ортогональным инициализатором с помощью initializeOrthogonal функция, которая присоединена к этому примеру в качестве вспомогательного файла. Дополнительные сведения см. в разделе Ортогональная инициализация (инструментарий глубокого обучения).

  • Инициализируйте смещение с помощью инициализатора литника unit forget с помощью initializeUnitForgetGate функция, которая присоединена к этому примеру в качестве вспомогательного файла. Дополнительные сведения см. в разделе Инициализация Gate Unit Forget (инструментарий глубокого обучения).

Инициализируйте обучаемые параметры для первой операции кодера LSTM.

sz = [4*numHiddenUnits embeddingDimension];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension;

parameters.encoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте обучаемые параметры для второй операции кодера LSTM.

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.encoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализация параметров модели декодера

Инициализируйте веса внедрения кодирования с помощью гауссова initializeGaussian функция. Укажите среднее значение 0 и стандартное отклонение 0,01.

outputSize = encTarget.NumWords + 1;
sz = [embeddingDimension outputSize];
mu = 0;
sigma = 0.01;
parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);

Инициализируйте веса механизма внимания с помощью инициализатора Glorot с помощью initializeGlorot функция.

sz = [numHiddenUnits numHiddenUnits];
numOut = numHiddenUnits;
numIn = numHiddenUnits;
parameters.decoder.attn.Weights = initializeGlorot(sz,numOut,numIn);

Инициализируйте обучаемые параметры для операций декодера LSTM:

  • Инициализируйте входные веса с помощью инициализатора Glorot с помощью initializeGlorot функция.

  • Инициализируйте повторяющиеся веса ортогональным инициализатором с помощью initializeOrthogonal функция.

  • Инициализируйте смещение с помощью инициализатора литника unit forget с помощью initializeUnitForgetGate функция.

Инициализируйте обучаемые параметры для первой операции декодера LSTM.

sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension + numHiddenUnits;

parameters.decoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте обучаемые параметры для операции LSTM второго декодера.

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.decoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

Инициализируйте обучаемые параметры для полностью подключенной операции декодера:

  • Инициализируйте веса с помощью инициализатора Glorot.

  • Инициализируйте смещение нулями с помощью initializeZeros функция, которая присоединена к этому примеру в качестве вспомогательного файла. Дополнительные сведения см. в разделе Инициализация нулей (Deep Learning Toolbox).

sz = [outputSize 2*numHiddenUnits];
numOut = outputSize;
numIn = 2*numHiddenUnits;

parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);

Определение функций модели

Создание функций modelEncoder и modelDecoder, перечисленных в конце примера, которые вычисляют выходные сигналы моделей кодера и декодера соответственно.

modelEncoder функция, перечисленная в разделе Encoder Model Function примера, берет входные данные, параметры модели, необязательную маску, которая используется для определения правильных выходных сигналов для обучения и возвращает выходные данные модели и скрытое состояние LSTM.

modelDecoder функция, перечисленная в разделе «Функция модели декодера» примера, принимает входные данные, параметры модели, вектор контекста, начальное скрытое состояние LSTM, выходы кодера и вероятность отсева и выводит выходной сигнал декодера, обновленный вектор контекста, обновленное состояние LSTM и оценки внимания.

Определение функции градиентов модели

Создание функции modelGradients, перечисленных в разделе «Функция градиентов модели» примера, который принимает параметры модели кодера и декодера, мини-пакет входных данных и маски заполнения, соответствующие входным данным, и вероятность отсева и возвращает градиенты потерь относительно обучаемых параметров в моделях и соответствующих потерь.

Укажите параметры обучения

Поезд с размером мини-партии 32 на 75 эпох с показателем обучения 0,002.

miniBatchSize = 32;
numEpochs = 75;
learnRate = 0.002;

Инициализируйте параметры из Adam.

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Модель поезда

Обучение модели с помощью пользовательского цикла обучения. Использовать minibatchqueue обрабатывать и управлять мини-партиями изображений во время обучения. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определенный в конце этого примера), чтобы найти длины всех последовательностей в мини-партии и поместить последовательности на ту же длину, что и самая длинная последовательность, для исходной и целевой последовательностей соответственно.

  • Переставьте второе и третье измерения дополненных последовательностей.

  • Возврат неформатированных переменных мини-пакета dlarray объекты с базовым типом данных single. Все остальные выходы являются массивами типа данных single.

  • Обучение на GPU, если он доступен. Верните все переменные мини-пакета на GPU, если они доступны. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по версии.

minibatchqueue объект возвращает четыре выходных аргумента для каждого мини-пакета: исходные последовательности, целевые последовательности, длины всех исходных последовательностей в мини-пакете и маску последовательности целевых последовательностей.

numMiniBatchOutputs = 4;

mbq = minibatchqueue(sequencesDs,numMiniBatchOutputs,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@(x,t) preprocessMiniBatch(x,t,inputSize,outputSize));

Инициализируйте график хода обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])

xlabel("Iteration")
ylabel("Loss")
grid on

Инициализация значений для adamupdate функция.

trailingAvg = [];
trailingAvgSq = [];

Тренируйте модель. Для каждой мини-партии:

  • Прочитайте мини-пакет дополненных последовательностей.

  • Вычислить потери и градиенты.

  • Обновление параметров модели кодера и декодера с помощью adamupdate функция.

  • Обновите график хода обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    reset(mbq);
        
    % Loop over mini-batches.
    while hasdata(mbq)
    
        iteration = iteration + 1;
        
        [dlX,T,sequenceLengthsSource,maskSequenceTarget] = next(mbq);
        
        % Compute loss and gradients.
        [gradients,loss] = dlfeval(@modelGradients,parameters,dlX,T,sequenceLengthsSource,...
            maskSequenceTarget,dropout);
        
        % Update parameters using adamupdate.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients,trailingAvg,trailingAvgSq,...
            iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,double(gather(loss)))
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Создание переводов

Чтобы генерировать трансляции для новых данных с использованием обученной модели, преобразуйте текстовые данные в числовые последовательности, используя те же шаги, что и при обучении и вводе последовательностей в модель кодера-декодера, и преобразуйте результирующие последовательности обратно в текст с использованием индексов маркеров.

Предварительная обработка текстовых данных выполняется с использованием тех же шагов, что и при обучении. Используйте transformText функция, перечисленная в конце примера, для разделения текста на символы и добавления маркеров начала и окончания.

strSource = dataTest{:,1};
strTarget = dataTest{:,2};

Перевести текст с помощью modelPredictions функция.

maxSequenceLength = 10;
delimiter = "";

strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter);

Создайте таблицу, содержащую тестовый исходный текст, целевой текст и переводы.

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = strTranslated;

Просмотр случайного выбора переводов.

idx = randperm(size(dataTest,1),miniBatchSize);
tbl(idx,:)
ans=32×3 table
    Source      Target       Translated 
    ______    ___________    ___________

    "8"       "VIII"         "CCCXXVII" 
    "595"     "DXCV"         "DCCV"     
    "523"     "DXXIII"       "CDXXII"   
    "675"     "DCLXXV"       "DCCLXV"   
    "818"     "DCCCXVIII"    "DCCCXVIII"
    "265"     "CCLXV"        "CCLXV"    
    "770"     "DCCLXX"       "DCCCL"    
    "904"     "CMIV"         "CMVII"    
    "121"     "CXXI"         "CCXI"     
    "333"     "CCCXXXIII"    "CCCXXXIII"
    "817"     "DCCCXVII"     "DCCCXVII" 
    "37"      "XXXVII"       "CCCXXXIV" 
    "335"     "CCCXXXV"      "CCCXXXV"  
    "902"     "CMII"         "CMIII"    
    "995"     "CMXCV"        "CMXCV"    
    "334"     "CCCXXXIV"     "CCCXXXIV" 
      ⋮

Функция преобразования текста

transformText функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста на символы и добавления маркеров запуска и остановки. Чтобы перевести текст путем разделения текста на слова вместо символов, пропустите первый шаг.

function documents = transformText(str,startToken,stopToken)

str = strip(replace(str,""," "));
str = startToken + str + stopToken;
documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]);

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция, описанная в разделе «Модель поезда» примера, предварительно обрабатывает данные для обучения. Функция выполняет предварительную обработку данных с помощью следующих шагов:

  1. Определение длины всех исходных и целевых последовательностей в мини-пакете

  2. Поместите последовательности на ту же длину, что и самая длинная последовательность в мини-партии, используя padsequences функция.

  3. Перестановка двух последних измерений последовательностей

function [X,T,sequenceLengthsSource,maskTarget] = preprocessMiniBatch(sequencesSource,sequencesTarget,inputSize,outputSize)

sequenceLengthsSource = cellfun(@(x) size(x,2),sequencesSource);

X = padsequences(sequencesSource,2,"PaddingValue",inputSize);
X = permute(X,[1 3 2]);

[T,maskTarget] = padsequences(sequencesTarget,2,"PaddingValue",outputSize);
T = permute(T,[1 3 2]);
maskTarget = permute(maskTarget,[1 3 2]);

end

Функция градиентов модели

modelGradients функция принимает параметры модели кодера и декодера, мини-пакет входных данных и маски заполнения, соответствующие входным данным, и вероятность отсева и возвращает градиенты потерь относительно обучаемых параметров в моделях и соответствующих потерь.

function [gradients,loss] = modelGradients(parameters,dlX,T,...
    sequenceLengthsSource,maskTarget,dropout)

% Forward through encoder.
[dlZ,hiddenState] = modelEncoder(parameters.encoder,dlX,sequenceLengthsSource);

% Decoder Output.
doTeacherForcing = rand < 0.5;
sequenceLength = size(T,3);
dlY = decoderPredictions(parameters.decoder,dlZ,T,hiddenState,dropout,...
     doTeacherForcing,sequenceLength);

% Masked loss.
dlY = dlY(:,:,1:end-1);
T = extractdata(gather(T(:,:,2:end)));
T = onehotencode(T,1,'ClassNames',1:size(dlY,1));

maskTarget = maskTarget(:,:,2:end); 
maskTarget = repmat(maskTarget,[size(dlY,1),1,1]);

loss = crossentropy(dlY,T,'Mask',maskTarget,'Dataformat','CBT');

% Update gradients.
gradients = dlgradient(loss,parameters);

% For plotting, return loss normalized by sequence length.
loss = extractdata(loss) ./ sequenceLength;

end

Функция модели кодировщика

Функция modelEncoder принимает входные данные, параметры модели, опциональную маску, которая используется для определения правильных выходных данных для обучения и возвращает выходные данные модели и скрытое состояние LSTM.

Если sequenceLengths пуст, то функция не маскирует вывод. Укажите и пустое значение для sequenceLengths при использовании modelEncoder функция прогнозирования.

function [dlZ, hiddenState] = modelEncoder(parametersEncoder, dlX, sequenceLengths)

% Embedding.
weights = parametersEncoder.emb.Weights;
dlZ = embed(dlX,weights,'DataFormat','CBT');

% LSTM 1.
inputWeights = parametersEncoder.lstm1.InputWeights;
recurrentWeights = parametersEncoder.lstm1.RecurrentWeights;
bias = parametersEncoder.lstm1.Bias;

numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

% LSTM 2.
inputWeights = parametersEncoder.lstm2.InputWeights;
recurrentWeights = parametersEncoder.lstm2.RecurrentWeights;
bias = parametersEncoder.lstm2.Bias;

[dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Masking for training.
if ~isempty(sequenceLengths)
    miniBatchSize = size(dlZ,2);
    for n = 1:miniBatchSize
        hiddenState(:,n) = dlZ(:,n,sequenceLengths(n));
    end
end

end

Функция модели декодера

Функция modelDecoder принимает входные данные, параметры модели, вектор контекста, начальное скрытое состояние LSTM, выходы кодера и вероятность отсева и выводит выходной сигнал декодера, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.

function [dlY, context, hiddenState, attentionScores] = modelDecoder(parametersDecoder, dlX, context, ...
    hiddenState, dlZ, dropout)

% Embedding.
weights = parametersDecoder.emb.Weights;
dlX = embed(dlX, weights,'DataFormat','CBT');

% RNN input.
sequenceLength = size(dlX,3);
dlY = cat(1, dlX, repmat(context, [1 1 sequenceLength]));

% LSTM 1.
inputWeights = parametersDecoder.lstm1.InputWeights;
recurrentWeights = parametersDecoder.lstm1.RecurrentWeights;
bias = parametersDecoder.lstm1.Bias;

initialCellState = dlarray(zeros(size(hiddenState)));

dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Dropout.
mask = ( rand(size(dlY), 'like', dlY) > dropout );
dlY = dlY.*mask;

% LSTM 2.
inputWeights = parametersDecoder.lstm2.InputWeights;
recurrentWeights = parametersDecoder.lstm2.RecurrentWeights;
bias = parametersDecoder.lstm2.Bias;

[dlY, hiddenState] = lstm(dlY, hiddenState, initialCellState,inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Attention.
weights = parametersDecoder.attn.Weights;
[attentionScores, context] = attention(hiddenState, dlZ, weights);

% Concatenate.
dlY = cat(1, dlY, repmat(context, [1 1 sequenceLength]));

% Fully connect.
weights = parametersDecoder.fc.Weights;
bias = parametersDecoder.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CBT');

% Softmax.
dlY = softmax(dlY,'DataFormat','CBT');

end

Функция внимания

attention функция возвращает оценки внимания согласно «общей» оценке Луонга и обновленному вектору контекста. Энергия на каждом временном шаге является скалярным произведением скрытого состояния и распознаваемых весов внимания, умноженных на выходной сигнал кодера.

function [attentionScores, context] = attention(hiddenState, encoderOutputs, weights)

% Initialize attention energies.
[miniBatchSize, sequenceLength] = size(encoderOutputs, 2:3);
attentionEnergies = zeros([sequenceLength miniBatchSize],'like',hiddenState);

% Attention energies.
hWX = hiddenState .* pagemtimes(weights,encoderOutputs);
for tt = 1:sequenceLength
    attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1);
end

% Attention scores.
attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB');

% Context.
encoderOutputs = permute(encoderOutputs, [1 3 2]);
attentionScores = permute(attentionScores,[1 3 2]);
context = pagemtimes(encoderOutputs,attentionScores);
context = squeeze(context);

end

Функция прогнозирования модели декодера

decoderModelPredictions функция возвращает предсказанную последовательность dlY с учетом входной последовательности, целевой последовательности, скрытого состояния, вероятности отсева, флага для активизации форсирования учителя и длины последовательности.

function dlY = decoderPredictions(parametersDecoder,dlZ,T,hiddenState,dropout, ...
    doTeacherForcing,sequenceLength)

% Convert to dlarray.
dlT = dlarray(T);

% Initialize context.
miniBatchSize = size(dlT,2);
numHiddenUnits = size(dlZ,1);
context = zeros([numHiddenUnits miniBatchSize],'like',dlZ);

if doTeacherForcing
    % Forward through decoder.
    dlY = modelDecoder(parametersDecoder, dlT, context, hiddenState, dlZ, dropout);
else
    % Get first time step for decoder.
    decoderInput = dlT(:,:,1);
    
    % Initialize output.
    numClasses = numel(parametersDecoder.fc.Bias);
    dlY = zeros([numClasses miniBatchSize sequenceLength],'like',decoderInput);
    
    % Loop over time steps.
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY(:,:,t), context, hiddenState] = modelDecoder(parametersDecoder, decoderInput, context, ...
            hiddenState, dlZ, dropout);
        
        % Update decoder input.
        [~, decoderInput] = max(dlY(:,:,t),[],1);
    end
end

end

Функция перевода текста

translateText функция переводит массив текста путем итерации по мини-пакетам. Функция принимает за входные параметры модели, массив входных строк, максимальную длину последовательности, размер мини-пакета, объекты кодирования исходного и целевого слов, маркеры начала и остановки и разделитель для сборки выходных данных.

function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter)

% Transform text.
documentsSource = transformText(strSource,startToken,stopToken);
sequencesSource = doc2sequence(encSource,documentsSource, ...
    'PaddingDirection','right', ...
    'PaddingValue',encSource.NumWords + 1);

% Convert to dlarray.
X = cat(3,sequencesSource{:});
X = permute(X,[1 3 2]);
dlX = dlarray(X);

% Initialize output.
numObservations = numel(strSource);
strTranslated = strings(numObservations,1);

% Loop over mini-batches.
numIterations = ceil(numObservations / miniBatchSize);
for i = 1:numIterations
    idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    miniBatchSize = numel(idxMiniBatch);
    
    % Encode using model encoder.
    sequenceLengths = [];
    [dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX(:,idxMiniBatch,:), sequenceLengths);
        
    % Decoder predictions.
    doTeacherForcing = false;
    dropout = 0;
    decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]);
    decoderInput = dlarray(decoderInput);
    dlY = decoderPredictions(parameters.decoder,dlZ,decoderInput,hiddenState,dropout, ...
        doTeacherForcing,maxSequenceLength);
    [~, idxPred] = max(extractdata(dlY), [], 1);
    
    % Keep translating flag.
    idxStop = word2ind(encTarget,stopToken);
    keepTranslating = idxPred ~= idxStop;
     
    % Loop over time steps.
    t = 1;
    while t <= maxSequenceLength && any(keepTranslating(:,:,t))
    
        % Update output.
        newWords = ind2word(encTarget, idxPred(:,:,t))';
        idxUpdate = idxMiniBatch(keepTranslating(:,:,t));
        strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t));
        
        t = t + 1;
    end
end

end

См. также

| | | | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения)

Связанные темы