Перевод от последовательности к последовательности Используя внимание

В этом примере показано, как преобразовать десятичные строки в Римские цифры с помощью модели декодера энкодера повторяющейся последовательности к последовательности с вниманием.

Текущие модели декодера энкодера оказались успешными в задачах как абстрактное текстовое резюмирование и нейронный машинный перевод. Модели, сопоставимые из энкодера, который обычно входные данные процессов с текущим слоем, такие как LSTM и декодер, который сопоставляет закодированный вход в желаемый выход, обычно со вторым текущим слоем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях закодированного входа при генерации перевода.

Для модели энкодера этот пример использует простую сеть, состоящую из встраивания, сопровождаемого двумя операциями LSTM. Встраивание является методом преобразования категориальных лексем в числовые векторы.

Для модели декодера этот пример использует сеть, очень похожую на энкодер, который содержит два LSTMs. Однако важное различие - то, что декодер содержит механизм внимания. Механизм внимания позволяет декодеру проявлять внимание к определенным частям энкодера выход.

Загрузите обучающие данные

Загрузите пары десятичной Римской цифры с "romanNumerals.csv"

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    'TextType','string', ...
    'ReadVariableNames',false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

Разделите данные в обучение и протестируйте разделы, содержащие 50% данных каждый.

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

Просмотрите некоторые пары десятичной римской цифры.

head(dataTrain)
ans=8×2 table
    Source     Target  
    ______    _________

    "168"     "CLXVIII"
    "154"     "CLIV"   
    "765"     "DCCLXV" 
    "714"     "DCCXIV" 
    "649"     "DCXLIX" 
    "346"     "CCCXLVI"
    "77"      "LXXVII" 
    "83"      "LXXXIII"

Предварительная Обработка Данных

Предварительно обработайте обучающие данные с помощью preprocessSourceTargetPairs функция, перечисленная в конце примера. preprocessSourceTargetPairs функция преобразует входные текстовые данные в числовые последовательности. Элементами последовательностей являются положительные целые числа, которые индексируют в соответствующий wordEncoding объект. wordEncoding лексемы карт в числовой индекс и наоборот использование словаря. Чтобы подсветить начало и концы последовательностей, кодирование также инкапсулирует специальные лексемы "<start>" и "<stop>".

startToken = "<start>";
stopToken = "<stop>";
[sequencesSource, sequencesTarget, encSource, encTarget] = preprocessSourceTargetPairs(dataTrain,startToken,stopToken);

Представление текста как числовые последовательности

Например, десятичная строка "441" закодирован можно следующим образом:

strSource = "441";

Вставьте пробелы между символами.

strSource = strip(replace(strSource,""," "));

Добавьте специальный запуск и лексемы остановки.

strSource = startToken + strSource + stopToken
strSource = 
"<start>4 4 1<stop>"

Маркируйте текст с помощью tokenizedDocument функция и набор 'CustomTokens' опция к специальным лексемам.

documentSource = tokenizedDocument(strSource,'CustomTokens',[startToken stopToken])
documentSource = 
  tokenizedDocument:

   5 tokens: <start> 4 4 1 <stop>

Преобразуйте документ последовательности маркерных индексов с помощью word2ind функция с соответствующим wordEncoding объект.

tokens = string(documentSource);
sequenceSource = word2ind(encSource,tokens)
sequenceSource = 1×5

     1     7     7     2     5

Дополнение и маскирование

Данные о последовательности, такие как текст естественно имеют различные длины последовательности. Чтобы обучить модель с помощью последовательностей переменной длины, заполните мини-пакеты входных данных, чтобы иметь ту же длину. Чтобы гарантировать, что дополнительные значения не влияют на вычисления потерь, создайте маску, которая записывает, какие элементы последовательности действительны, и которые только дополняют.

Например, полагайте, что мини-пакет, содержащий десятичное число, представляет в виде строки "437", "431", и "102" с соответствующей Римской цифрой представляет в виде строки "CDXXXVII", "CDXXXI" и "CII". Для познаковых последовательностей входные последовательности имеют ту же длину и не должны быть дополнены. Соответствующая маска является массивом из единиц.

Выходные последовательности имеют различные длины, таким образом, они требуют дополнения. Соответствующая дополнительная маска содержит нули, где соответствующие временные шаги дополняют значения.

Инициализируйте параметры модели

Инициализируйте параметры модели. и для энкодера и для декодера, задайте размерность встраивания 256, два слоя LSTM с 200 скрытыми модулями и слои уволенного со случайным уволенным с вероятностью 0.05.

embeddingDimension = 256;
numHiddenUnits = 200;
dropout = 0.05;

Инициализируйте параметры модели энкодера:

  • Задайте размерность встраивания 256 и размер словаря исходного словаря плюс 1, где дополнительное значение соответствует дополнительной лексеме.

  • Задайте две операции LSTM с 200 скрытыми модулями.

  • Инициализируйте веса встраивания путем выборки от случайного нормального распределения.

  • Инициализируйте веса LSTM и смещения путем выборки от равномерного распределения с помощью uniformNoise функция, перечисленная в конце примера.

inputSize = encSource.NumWords + 1;
parametersEncoder.emb.Weights = dlarray(randn([embeddingDimension inputSize]));

parametersEncoder.lstm1.InputWeights = dlarray(uniformNoise([4*numHiddenUnits embeddingDimension],1/numHiddenUnits));
parametersEncoder.lstm1.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersEncoder.lstm1.Bias = dlarray(uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits));

parametersEncoder.lstm2.InputWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersEncoder.lstm2.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersEncoder.lstm2.Bias = dlarray(uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits));

Инициализируйте параметры модели декодера.

  • Задайте размерность встраивания 256 и размер словаря целевого словаря плюс 1, где дополнительное значение соответствует дополнительной лексеме.

  • Инициализируйте веса механизма внимания с помощью uniformNoise функция.

  • Инициализируйте веса встраивания путем выборки от случайного нормального распределения.

  • Инициализируйте веса LSTM и смещения путем выборки от равномерного распределения с помощью uniformNoise функция.

outputSize = encTarget.NumWords + 1;
parametersDecoder.emb.Weights = dlarray(randn([embeddingDimension outputSize]));

parametersDecoder.attn.Weights = dlarray(uniformNoise([numHiddenUnits numHiddenUnits],1/numHiddenUnits));

parametersDecoder.lstm1.InputWeights = dlarray(uniformNoise([4*numHiddenUnits embeddingDimension+numHiddenUnits],1/numHiddenUnits));
parametersDecoder.lstm1.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersDecoder.lstm1.Bias = dlarray( uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits));

parametersDecoder.lstm2.InputWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersDecoder.lstm2.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits));
parametersDecoder.lstm2.Bias = dlarray(uniformNoise([4*numHiddenUnits 1], 1/numHiddenUnits));

parametersDecoder.fc.Weights = dlarray(uniformNoise([outputSize 2*numHiddenUnits],1/(2*numHiddenUnits)));
parametersDecoder.fc.Bias = dlarray(uniformNoise([outputSize 1], 1/(2*numHiddenUnits)));

Функции модели Define

Создайте функции modelEncoder и modelDecoder, перечисленный в конце примера, которые вычисляют выходные параметры моделей энкодера и декодера, соответственно.

modelEncoder функция, перечисленная в разделе Encoder Model Function примера, берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходные параметры модели и LSTM скрытое состояние.

modelDecoder функция, перечисленная в разделе Decoder Model Function примера, берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в разделе Model Gradients Function примера, который берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 32 в течение 40 эпох. Задайте скорость обучения 0,002 и отсеките градиенты с порогом 5.

miniBatchSize = 32;
numEpochs = 40;
learnRate = 0.002;
gradientThreshold = 5;

Инициализируйте опции от Адама.

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Задайте, чтобы построить процесс обучения. Чтобы отключить график процесса обучения, установите plots значение к "none".

plots = "training-progress";

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

В течение первой эпохи обучайтесь с последовательностями, отсортированными путем увеличения длины последовательности. Это приводит к пакетам с последовательностями приблизительно той же длины последовательности и гарантирует, что меньшие пакеты последовательности используются, чтобы обновить модель перед более длинными пакетами последовательности. В течение последующих эпох переставьте данные.

Для каждого мини-пакета:

  • Преобразуйте данные в dlarray.

  • Вычислите потерю и градиенты.

  • Отсеките градиенты.

  • Обновите параметры модели энкодера и декодера с помощью adamupdate функция.

  • Обновите график процесса обучения.

Сортировка последовательностей в течение первой эпохи.

sequenceLengthsEncoder = cellfun(@(sequence) size(sequence,2), sequencesSource);
[~,idx] = sort(sequenceLengthsEncoder);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализируйте значения для adamupdate функция.

trailingAvgEncoder = [];
trailingAvgSqEncoder = [];

trailingAvgDecoder = [];
trailingAvgSqDecoder = [];

Обучите модель.

numObservations = numel(sequencesSource);
numIterationsPerEpoch = floor(numObservations/miniBatchSize);

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        [XSource, XTarget, maskSource, maskTarget] = createBatch(sequencesSource(idx), ...
            sequencesTarget(idx), inputSize, outputSize);
        
        % Convert mini-batch of data to dlarray.
        dlXSource = dlarray(XSource);
        dlXTarget = dlarray(XTarget);
        
        % Compute loss and gradients.
        [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ...
            parametersDecoder, dlXSource, dlXTarget, maskSource, maskTarget, dropout);
        
        % Gradient clipping.
        gradientsEncoder = dlupdate(@(w) clipGradient(w,gradientThreshold), gradientsEncoder);
        gradientsDecoder = dlupdate(@(w) clipGradient(w,gradientThreshold), gradientsDecoder);
        
        % Update encoder using adamupdate.
        [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ...
            gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration, learnRate, ...
            gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update decoder using adamupdate.
        [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ...
            gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration, learnRate, ...
            gradientDecayFactor, squaredGradientDecayFactor);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(loss)))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
    
    % Shuffle data.
    idx = randperm(numObservations);
    sequencesSource = sequencesSource(idx);
    sequencesTarget = sequencesTarget(idx);
end

Сгенерируйте переводы

Чтобы сгенерировать переводы для новых данных с помощью обученной модели, преобразуйте текстовые данные в числовые последовательности с помощью тех же шагов как тогда, когда обучение и ввело последовательности в модель декодера энкодера и преобразует получившиеся последовательности назад в текст с помощью маркерных индексов.

Подготовка данных для перевода

Выберите мини-пакет тестовых наблюдений.

numObservationsTest = 16;
idx = randperm(size(dataTest,1),numObservationsTest);
dataTest(idx,:)
ans=16×2 table
    Source      Target   
    ______    ___________

    "857"     "DCCCLVII" 
    "991"     "CMXCI"    
    "143"     "CXLIII"   
    "924"     "CMXXIV"   
    "752"     "DCCLII"   
    "85"      "LXXXV"    
    "131"     "CXXXI"    
    "124"     "CXXIV"    
    "858"     "DCCCLVIII"
    "103"     "CIII"     
    "497"     "CDXCVII"  
    "76"      "LXXVI"    
    "815"     "DCCCXV"   
    "829"     "DCCCXXIX" 
    "940"     "CMXL"     
    "94"      "XCIV"     

Предварительно обработайте текстовые данные с помощью тех же шагов как тогда, когда обучение. Используйте transformText функция, перечисленная в конце примера, чтобы разделить текст в символы и добавить запуск и лексемы остановки.

strSource = dataTest{idx,1};
strTarget = dataTest{idx,2};

documentsSource = transformText(strSource,startToken,stopToken);

Преобразуйте маркируемый текст в пакет заполненных последовательностей при помощи doc2sequence функция. Чтобы автоматически заполнить последовательности, установите 'PaddingDirection' опция к 'right' и установленный дополнительное значение к входному размеру (маркерный индекс дополнительной лексемы).

sequencesSource = doc2sequence(encSource,documentsSource, ...
    'PaddingDirection','right', ...
    'PaddingValue',inputSize);

Конкатенация и переставляет данные о последовательности в необходимой форме для функции модели энкодера (1 N S, где N является количеством наблюдений, и S является длиной последовательности).

XSource = cat(3,sequencesSource{:});
XSource = permute(XSource,[1 3 2]);

Преобразуйте входные данные в dlarray и вычислите модель энкодера выходные параметры.

dlXSource = dlarray(XSource);
[dlZ, hiddenState] = modelEncoder(dlXSource, parametersEncoder);

Переведите исходный текст

Сгенерировать переводы для нового ввода данных последовательности в модель декодера энкодера и преобразовать получившиеся последовательности назад в текст с помощью маркерных индексов.

Чтобы инициализировать переводы, создайте вектор, содержащий только индексы, соответствующие лексеме запуска.

decoderInput = repmat(word2ind(encTarget,startToken),[1 numObservationsTest]);
decoderInput = dlarray(decoderInput);

Инициализируйте вектор контекста и массивы ячеек, содержащие переведенные последовательности и музыку внимания к каждому наблюдению.

context = dlarray(zeros([size(dlZ, 1) numObservationsTest]));
sequencesTranslated = cell(1,numObservationsTest);
attentionScores = cell(1,numObservationsTest);

Цикл в зависимости от времени продвигается, и переведите последовательности. Сохраните цикличное выполнение по временным шагам, пока все последовательности не перевели. Для каждого наблюдения, когда перевод закончен (когда декодер предсказывает лексему остановки), устанавливает флаг прекращать переводить ту последовательность.

stopIdx = word2ind(encTarget,stopToken);

stopTranslating = false(1, numObservationsTest);
maxSequenceLength = 10;

while ~all(stopTranslating)
    
    % Forward through decoder.
    [dlY, context, hiddenState, attn] = modelDecoder(decoderInput, parametersDecoder, context, ...
        hiddenState, dlZ);
    
    % Loop over observations.
    for i = 1:numObservationsTest
        % Skip already-translated sequences.
        if stopTranslating(i)
            continue
        end

        
        % Update attention scores.
        attentionScores{i} = [attentionScores{i} extractdata(attn(:,i))];
        
        % Predict next time step.
        prob = softmax(dlY(:,i), 'DataFormat', 'CB');
        [~, idx] = max(prob(1:end-1,:), [], 1);
        
        % Set stopTranslating flag when translation done.
        if idx == stopIdx || numel(sequencesTranslated{i}) == maxSequenceLength
            stopTranslating(i) = true;
        else
            sequencesTranslated{i} = [sequencesTranslated{i} extractdata(idx)];
            decoderInput(i) = idx;
        end
    end
end

Просмотрите исходный текст, целевой текст и переводы в таблице.

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = cellfun(@(sequence) join(ind2word(encTarget,sequence),""),sequencesTranslated)';
tbl
tbl=16×3 table
    Source      Target       Translated 
    ______    ___________    ___________

    "857"     "DCCCLVII"     "DCCCLVII" 
    "991"     "CMXCI"        "CMXCI"    
    "143"     "CXLIII"       "CXLIII"   
    "924"     "CMXXIV"       "CMXXIV"   
    "752"     "DCCLII"       "DCCLII"   
    "85"      "LXXXV"        "DCCCLVI"  
    "131"     "CXXXI"        "CXXXI"    
    "124"     "CXXIV"        "CXXIV"    
    "858"     "DCCCLVIII"    "DCCCLVIII"
    "103"     "CIII"         "CIII"     
    "497"     "CDXCVII"      "CDXCVII"  
    "76"      "LXXVI"        "DCCLVII"  
    "815"     "DCCCXV"       "DCCCXV"   
    "829"     "DCCCXXIX"     "DCCCXXIX" 
    "940"     "CMXL"         "CMXL"     
    "94"      "XCIV"         "CMXLVI"   

Постройте баллы внимания

Постройте множество внимания первой последовательности в карте тепла. Подсветка баллов внимания, какие области источника и переведенный упорядочивают модель, проявляет внимание при обработке перевода.

idx = 1;
figure
xlabs = [ind2word(encTarget,sequencesTranslated{idx}) stopToken];
ylabs = string(documentsSource(idx));

heatmap(attentionScores{idx}, ...
    'CellLabelColor','none', ...
    'XDisplayLabels',xlabs, ...
    'YDisplayLabels',ylabs);

xlabel("Translation")
ylabel("Source")
title("Attention Scores")

Предварительная обработка функции

preprocessSourceTargetPairs берет таблицу data содержание целевых источником пар в двух столбцах и для каждого столбца возвращает последовательности маркерных индексов и соответствующего wordEncoding возразите, что сопоставляет индексы со словами и наоборот.

function [sequencesSource, sequencesTarget, encSource, encTarget] = preprocessSourceTargetPairs(data,startToken,stopToken)

% Extract text data.
strSource = data{:,1};
strTarget = data{:,2};

% Create tokenized document arrays.
documentsSource = transformText(strSource,startToken,stopToken);
documentsTarget = transformText(strTarget,startToken,stopToken);

% Create word encodings.
encSource = wordEncoding(documentsSource);
encTarget = wordEncoding(documentsTarget);

% Convert documents to numeric sequences.
sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');
sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');

end

Текстовая функция преобразования

transformText функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в символы, и добавление запускают и останавливают лексемы. Чтобы перевести текст путем разделения текста в слова вместо символов, пропустите первый шаг.

function documents = transformText(str,startToken,stopToken)

% Split text into characters.
str = strip(replace(str,""," "));

% Add start and stop tokens.
str = startToken + str + stopToken;

% Create tokenized document array.
documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]);

end

Пакетная функция создания

createBatch функционируйте берет мини-пакет входных и выходных последовательностей и возвращает дополненные последовательности с соответствующими дополнительными масками.

function [XSource, XTarget, maskSource, maskTarget] = createBatch(sequencesSource, sequencesTarget, ...
    paddingValueSource, paddingValueTarget)

numObservations = size(sequencesSource,1);
sequenceLengthSource = max(cellfun(@(x) size(x,2), sequencesSource));
sequenceLengthTarget = max(cellfun(@(x) size(x,2), sequencesTarget));

% Initialize masks.
maskSource = false(numObservations, sequenceLengthSource);
maskTarget = false(numObservations, sequenceLengthTarget);

% Initialize mini-batch.
XSource = zeros(1,numObservations,sequenceLengthSource);
XTarget = zeros(1,numObservations,sequenceLengthTarget);

% Pad sequences and create masks.
for i = 1:numObservations
    
    % Source
    L = size(sequencesSource{i},2);
    paddingSize = sequenceLengthSource - L;
    padding = repmat(paddingValueSource, [1 paddingSize]);
    
    XSource(1,i,:) = [sequencesSource{i} padding];
    maskSource(i,1:L) = true;
    
    % Target
    L = size(sequencesTarget{i},2);
    paddingSize = sequenceLengthTarget - L;
    padding = repmat(paddingValueTarget, [1 paddingSize]);
    
    XTarget(1,i,:) = [sequencesTarget{i} padding];
    maskTarget(i,1:L) = true;
end

end

Функция модели энкодера

Функциональный modelEncoder берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходной параметр модели и LSTM скрытое состояние.

function [dlZ, hiddenState] = modelEncoder(dlX, parametersEncoder, maskSource)

% Embedding
weights = parametersEncoder.emb.Weights;
dlZ = embedding(dlX,weights);

% LSTM
inputWeights = parametersEncoder.lstm1.InputWeights;
recurrentWeights = parametersEncoder.lstm1.RecurrentWeights;
bias = parametersEncoder.lstm1.Bias;
numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

% LSTM
inputWeights = parametersEncoder.lstm2.InputWeights;
recurrentWeights = parametersEncoder.lstm2.RecurrentWeights;
bias = parametersEncoder.lstm2.Bias;

[dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Mask output for training
if nargin > 2
    dlZ = dlZ.*permute(maskSource, [3 1 2]);
    sequenceLengths = sum(maskSource, 2);
    
    % Mask final hidden state
    for ii = 1:size(dlZ, 2)
        hiddenState(:, ii) = dlZ(:, ii, sequenceLengths(ii));
    end
end

end

Функция модели декодера

Функциональный modelDecoder берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.

function [dlY, context, hiddenState, attentionScores] = modelDecoder(dlX, parameters, context, ...
    hiddenState, encoderOutputs, dropout)

% Embedding
weights = parameters.emb.Weights;
dlX = embedding(dlX, weights);

% RNN input
dlY = cat(1, dlX, context);

% LSTM 1
initialCellState = dlarray(zeros(size(hiddenState)));

inputWeights = parameters.lstm1.InputWeights;
recurrentWeights = parameters.lstm1.RecurrentWeights;
bias = parameters.lstm1.Bias;

dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

if nargin > 5
    % Dropout
    mask = ( rand(size(dlY), 'like', dlY) > dropout );
    dlY = dlY.*mask;
end

% LSTM 2
inputWeights = parameters.lstm2.InputWeights;
recurrentWeights = parameters.lstm2.RecurrentWeights;
bias = parameters.lstm2.Bias;
[~, hiddenState] = lstm(dlY, hiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Attention
weights = parameters.attn.Weights;
attentionScores = attention(hiddenState, encoderOutputs, weights);

% Context
encoderOutputs = permute(encoderOutputs, [1 3 2]);
attentionScores = permute(attentionScores,[1 3 2]);
context = dlmtimes(encoderOutputs,attentionScores);
context = squeeze(context);

% Fully connect
weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
dlY = weights*cat(1, hiddenState, context) + bias;

end

Встраивание функции

embedding функционируйте сопоставляет числовые индексы с соответствующим вектором, данным входными весами.

function Z = embedding(X, weights)
% Reshape inputs into a vector
[N, T] = size(X, 2:3);
X = reshape(X, N*T, 1);

% Index into embedding matrix
Z = weights(:, X);

% Reshape outputs by separating out batch and sequence dimensions
Z = reshape(Z, [], N, T);
end

Функция внимания

attention функция вычисляет баллы внимания по данным Luong "общий" выигрыш.

function attentionScores = attention(hiddenState, encoderOutputs, weights)

[N, S] = size(encoderOutputs, 2:3);
attentionEnergies = dlarray(zeros( [S N] ));

% The energy at each time step is the dot product of the hidden state
% and the learnable attention weights times the encoder output
hWX = hiddenState .* dlmtimes(weights,encoderOutputs);
for tt = 1:S
    attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1);
end

% Compute softmax scores
attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB');

end

Функция градиентов модели

modelGradients функционируйте берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.

function [gradientsEncoder, gradientsDecoder, maskedLoss] = modelGradients(parametersEncoder, ...
    parametersDecoder, dlXSource, dlXTarget, maskSource, maskTarget, dropout)

% Forward through encoder.
[dlZ, hiddenState] = modelEncoder(dlXSource, parametersEncoder, maskSource);

% Get parameter sizes.
[miniBatchSize, sequenceLength] = size(dlXTarget,2:3);
sequenceLength = sequenceLength - 1;
numHiddenUnits = size(dlZ,1);

% Initialize context vector.
context = dlarray(zeros([numHiddenUnits miniBatchSize]));

% Initialize loss.
loss = dlarray(zeros([miniBatchSize sequenceLength]));

% Get first time step for decoder.
decoderInput = dlXTarget(:,:,1);

% Choose whether to use teacher forcing.
doTeacherForcing = rand < 0.5;

if doTeacherForcing
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY, context, hiddenState] = modelDecoder(decoderInput, parametersDecoder, context, ...
            hiddenState, dlZ, dropout);
        
        % Update loss.
        dlT = dlarray(oneHot(dlXTarget(:,:,t+1), size(dlY,1)));
        loss(:,t) = crossEntropyAndSoftmax(dlY, dlT);
        
        % Get next time step.
        decoderInput = dlXTarget(:,:,t+1);
    end
else
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY, context, hiddenState] = modelDecoder(decoderInput, parametersDecoder, context, ...
            hiddenState, dlZ, dropout);
        
        % Update loss.
        dlT = dlarray(oneHot(dlXTarget(:,:,t+1), size(dlY,1)));
        loss(:,t) = crossEntropyAndSoftmax(dlY, dlT);
        
        % Greedily update next input time step.
        prob = softmax(dlY,'DataFormat','CB');
        [~, decoderInput] = max(prob,[],1);
    end
end

% Determine masked loss.
maskedLoss = sum(sum(loss.*maskTarget(:,2:end))) / miniBatchSize;

% Update gradients.
[gradientsEncoder, gradientsDecoder] = dlgradient(maskedLoss, parametersEncoder, parametersDecoder);

% For plotting, return loss normalized by sequence length.
maskedLoss = extractdata(maskedLoss) ./ sequenceLength;

end

Перекрестная энтропия и функция потерь Softmax

crossEntropyAndSoftmax потеря вычисляет перекрестную энтропию и softmax потерю.

function loss = crossEntropyAndSoftmax(dlY, dlT)

offset = max(dlY);
logSoftmax = dlY - offset - log(sum(exp(dlY-offset)));
loss = -sum(dlT.*logSoftmax);

end

Универсальная шумовая функция

uniformNoise функциональные демонстрационные веса от равномерного распределения.

function weights = uniformNoise(sz, k)

weights = -sqrt(k) + 2*sqrt(k).*rand(sz);

end

Функция усечения градиента

clipGradient функционируйте отсекает градиенты модели.

function g = clipGradient(g, gradientThreshold)

wnorm = norm(extractdata(g));
if wnorm > gradientThreshold
    g = (gradientThreshold/wnorm).*g;
end

end

Функция прямого кодирования

oneHot функция кодирует словари как одногорячие векторы.

function oh = oneHot(idx, numTokens)
tokens = (1:numTokens)';
oh = (tokens == idx);
end

Смотрите также

| | | | | | | | | | |

Похожие темы