В этом примере показано, как преобразовать десятичные строки в Римские цифры с помощью модели декодера энкодера повторяющейся последовательности к последовательности с вниманием.
Текущие модели декодера энкодера оказались успешными в задачах как абстрактное текстовое резюмирование и нейронный машинный перевод. Модели, сопоставимые из энкодера, который обычно входные данные процессов с текущим слоем, такие как LSTM и декодер, который сопоставляет закодированный вход в желаемый выход, обычно со вторым текущим слоем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях закодированного входа при генерации перевода.
Для модели энкодера этот пример использует простую сеть, состоящую из встраивания, сопровождаемого двумя операциями LSTM. Встраивание является методом преобразования категориальных лексем в числовые векторы.
Для модели декодера этот пример использует сеть, очень похожую на энкодер, который содержит два LSTMs. Однако важное различие - то, что декодер содержит механизм внимания. Механизм внимания позволяет декодеру проявлять внимание к определенным частям энкодера выход.
Загрузите пары десятичной Римской цифры с "romanNumerals.csv"
filename = fullfile("romanNumerals.csv"); options = detectImportOptions(filename, ... 'TextType','string', ... 'ReadVariableNames',false); options.VariableNames = ["Source" "Target"]; options.VariableTypes = ["string" "string"]; data = readtable(filename,options);
Разделите данные в обучение и протестируйте разделы, содержащие 50% данных каждый.
idx = randperm(size(data,1),500); dataTrain = data(idx,:); dataTest = data; dataTest(idx,:) = [];
Просмотрите некоторые пары десятичной римской цифры.
head(dataTrain)
ans=8×2 table
Source Target
______ ____________
"437" "CDXXXVII"
"431" "CDXXXI"
"102" "CII"
"862" "DCCCLXII"
"738" "DCCXXXVIII"
"527" "DXXVII"
"401" "CDI"
"184" "CLXXXIV"
Предварительно обработайте обучающие данные с помощью preprocessSourceTargetPairs
функция, перечисленная в конце примера. preprocessSourceTargetPairs
функция преобразует входные текстовые данные в числовые последовательности. Элементы последовательностей являются положительными целыми числами, которые индексируют в соответствующий wordEncoding
объект. wordEncoding
лексемы карт в числовой индекс и наоборот использование словаря. Чтобы подсветить начало и концы последовательностей, кодирование также инкапсулирует специальные лексемы "<start>"
и "<stop>"
.
startToken = "<start>"; stopToken = "<stop>"; [sequencesSource, sequencesTarget, encSource, encTarget] = preprocessSourceTargetPairs(dataTrain,startToken,stopToken);
Например, десятичная строка "441"
закодирован можно следующим образом:
strSource = "441";
Вставьте пробелы между символами.
strSource = strip(replace(strSource,""," "));
Добавьте специальный запуск и лексемы остановки.
strSource = startToken + strSource + stopToken
strSource = 1×1 string
"<start>4 4 1<stop>"
Маркируйте текст с помощью tokenizedDocument
функция и набор 'CustomTokens'
опция к специальным лексемам.
documentSource = tokenizedDocument(strSource,'CustomTokens',[startToken stopToken])
documentSource = tokenizedDocument: 5 tokens: <start> 4 4 1 <stop>
Преобразуйте документ последовательности маркерных индексов с помощью word2ind
функция с соответствующим wordEncoding
объект.
tokens = string(documentSource); sequenceSource = word2ind(encSource,tokens)
sequenceSource = 1×5
1 2 2 6 5
Данные о последовательности, такие как текст естественно имеют различные длины последовательности. Чтобы обучить модель с помощью последовательностей переменной длины, заполните мини-пакеты входных данных, чтобы иметь ту же длину. Чтобы гарантировать, что дополнительные значения не влияют на вычисления потерь, создайте маску, которая записывает, какие элементы последовательности действительны, и которые только дополняют.
Например, полагайте, что мини-пакет, содержащий десятичное число, представляет в виде строки "437", "431", и "102" с соответствующей Римской цифрой представляет в виде строки "CDXXXVII", "CDXXXI" и "CII". Для познаковых последовательностей входные последовательности имеют ту же длину и не должны быть дополнены. Соответствующая маска является массивом из единиц.
Выходные последовательности имеют различные длины, таким образом, они требуют дополнения. Соответствующая дополнительная маска содержит нули, где соответствующие временные шаги дополняют значения.
Инициализируйте параметры модели. и для энкодера и для декодера, задайте размерность встраивания 256, два слоя LSTM с 200 скрытыми модулями и слои уволенного со случайным уволенным с вероятностью 0.05.
embeddingDimension = 256; numHiddenUnits = 200; dropout = 0.05;
Инициализируйте параметры модели энкодера:
Задайте размерность встраивания 256 и размер словаря исходного словаря плюс 1, где дополнительное значение соответствует дополнительной лексеме.
Задайте две операции LSTM с 200 скрытыми модулями.
Инициализируйте веса встраивания путем выборки от случайного нормального распределения.
Инициализируйте веса LSTM и смещения путем выборки от равномерного распределения с помощью uniformNoise
функция, перечисленная в конце примера.
inputSize = encSource.NumWords + 1; parametersEncoder.emb.Weights = dlarray(randn([embeddingDimension inputSize])); parametersEncoder.lstm1.InputWeights = dlarray(uniformNoise([4*numHiddenUnits embeddingDimension],1/numHiddenUnits)); parametersEncoder.lstm1.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersEncoder.lstm1.Bias = dlarray(uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits)); parametersEncoder.lstm2.InputWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersEncoder.lstm2.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersEncoder.lstm2.Bias = dlarray(uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits));
Инициализируйте параметры модели декодера.
Задайте размерность встраивания 256 и размер словаря целевого словаря плюс 1, где дополнительное значение соответствует дополнительной лексеме.
Инициализируйте веса механизма внимания с помощью uniformNoise
функция.
Инициализируйте веса встраивания путем выборки от случайного нормального распределения.
Инициализируйте веса LSTM и смещения путем выборки от равномерного распределения с помощью uniformNoise
функция.
outputSize = encTarget.NumWords + 1; parametersDecoder.emb.Weights = dlarray(randn([embeddingDimension outputSize])); parametersDecoder.attn.Weights = dlarray(uniformNoise([numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm1.InputWeights = dlarray(uniformNoise([4*numHiddenUnits embeddingDimension+numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm1.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm1.Bias = dlarray( uniformNoise([4*numHiddenUnits 1],1/numHiddenUnits)); parametersDecoder.lstm2.InputWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm2.RecurrentWeights = dlarray(uniformNoise([4*numHiddenUnits numHiddenUnits],1/numHiddenUnits)); parametersDecoder.lstm2.Bias = dlarray(uniformNoise([4*numHiddenUnits 1], 1/numHiddenUnits)); parametersDecoder.fc.Weights = dlarray(uniformNoise([outputSize 2*numHiddenUnits],1/(2*numHiddenUnits))); parametersDecoder.fc.Bias = dlarray(uniformNoise([outputSize 1], 1/(2*numHiddenUnits)));
Создайте функции modelEncoder
и modelDecoder
, перечисленный в конце примера, которые вычисляют выходные параметры моделей энкодера и декодера, соответственно.
modelEncoder
функция, перечисленная в разделе Encoder Model Function примера, берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходные параметры модели и LSTM скрытое состояние.
modelDecoder
функция, перечисленная в разделе Decoder Model Function примера, берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.
Создайте функциональный modelGradients
, перечисленный в разделе Model Gradients Function примера, который берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно learnable параметров в моделях и соответствующей потери.
Обучайтесь с мини-пакетным размером 32 в течение 40 эпох. Задайте темп обучения 0,002 и отсеките градиенты с порогом 5.
miniBatchSize = 32; numEpochs = 40; learnRate = 0.002; gradientThreshold = 5;
Инициализируйте опции от Адама.
gradientDecayFactor = 0.9; squaredGradientDecayFactor = 0.999;
Задайте, чтобы построить процесс обучения. Чтобы отключить график процесса обучения, установите plots
значение к "none"
.
plots = "training-progress";
Обучите модель с помощью пользовательского учебного цикла.
В течение первой эпохи обучайтесь с последовательностями, отсортированными путем увеличения длины последовательности. Это приводит к пакетам с последовательностями приблизительно той же длины последовательности и гарантирует, что меньшие пакеты последовательности используются, чтобы обновить модель перед более длинными пакетами последовательности. В течение последующих эпох переставьте данные.
Для каждого мини-пакета:
Преобразуйте данные в dlarray
.
Вычислите потерю и градиенты.
Отсеките градиенты.
Обновите параметры модели энкодера и декодера с помощью adamupdate
функция.
Обновите график процесса обучения.
Сортировка последовательностей в течение первой эпохи.
sequenceLengthsEncoder = cellfun(@(sequence) size(sequence,2), sequencesSource); [~,idx] = sort(sequenceLengthsEncoder); sequencesSource = sequencesSource(idx); sequencesTarget = sequencesTarget(idx);
Инициализируйте график процесса обучения.
if plots == "training-progress" figure lineLossTrain = animatedline; xlabel("Iteration") ylabel("Loss") end
Инициализируйте значения для adamupdate
функция.
trailingAvgEncoder = []; trailingAvgSqEncoder = []; trailingAvgDecoder = []; trailingAvgSqDecoder = [];
Обучите модель.
numObservations = numel(sequencesSource); numIterationsPerEpoch = floor(numObservations/miniBatchSize); iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data idx = (i-1)*miniBatchSize+1:i*miniBatchSize; [XSource, XTarget, maskSource, maskTarget] = createBatch(sequencesSource(idx), ... sequencesTarget(idx), inputSize, outputSize); % Convert mini-batch of data to dlarray. dlXSource = dlarray(XSource); dlXTarget = dlarray(XTarget); % Compute loss and gradients. [gradientsEncoder, gradientsDecoder, loss] = dlfeval(@modelGradients, parametersEncoder, ... parametersDecoder, dlXSource, dlXTarget, maskSource, maskTarget, dropout); % Gradient clipping. gradientsEncoder = dlupdate(@(w) clipGradient(w,gradientThreshold), gradientsEncoder); gradientsDecoder = dlupdate(@(w) clipGradient(w,gradientThreshold), gradientsDecoder); % Update encoder using adamupdate. [parametersEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(parametersEncoder, ... gradientsEncoder, trailingAvgEncoder, trailingAvgSqEncoder, iteration, learnRate, ... gradientDecayFactor, squaredGradientDecayFactor); % Update decoder using adamupdate. [parametersDecoder, trailingAvgDecoder, trailingAvgSqDecoder] = adamupdate(parametersDecoder, ... gradientsDecoder, trailingAvgDecoder, trailingAvgSqDecoder, iteration, learnRate, ... gradientDecayFactor, squaredGradientDecayFactor); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(loss))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end % Shuffle data. idx = randperm(numObservations); sequencesSource = sequencesSource(idx); sequencesTarget = sequencesTarget(idx); end
Чтобы сгенерировать переводы для новых данных с помощью обученной модели, преобразуйте текстовые данные в числовые последовательности с помощью тех же шагов как тогда, когда обучение и ввело последовательности в модель декодера энкодера и преобразует получившиеся последовательности назад в текст с помощью маркерных индексов.
Выберите мини-пакет тестовых наблюдений.
numObservationsTest = 16; idx = randperm(size(dataTest,1),numObservationsTest); dataTest(idx,:)
ans=16×2 table
Source Target
______ ____________
"412" "CDXII"
"274" "CCLXXIV"
"231" "CCXXXI"
"558" "DLVIII"
"187" "CLXXXVII"
"828" "DCCCXXVIII"
"1" "I"
"217" "CCXVII"
"309" "CCCIX"
"489" "CDLXXXIX"
"406" "CDVI"
"840" "DCCCXL"
"757" "DCCLVII"
"268" "CCLXVIII"
"371" "CCCLXXI"
"988" "CMLXXXVIII"
Предварительно обработайте текстовые данные с помощью тех же шагов как тогда, когда обучение. Используйте transformText
функция, перечисленная в конце примера, чтобы разделить текст в символы и добавить запуск и лексемы остановки.
strSource = dataTest{idx,1}; strTarget = dataTest{idx,2}; documentsSource = transformText(strSource,startToken,stopToken);
Преобразуйте маркируемый текст в пакет заполненных последовательностей при помощи doc2sequence
функция. Чтобы автоматически заполнить последовательности, установите 'PaddingDirection'
опция к 'right'
и установленный дополнительное значение к входному размеру (маркерный индекс дополнительной лексемы).
sequencesSource = doc2sequence(encSource,documentsSource, ... 'PaddingDirection','right', ... 'PaddingValue',inputSize);
Конкатенация и переставляет данные о последовательности в необходимой форме для функции модели энкодера (1 N S, где N является количеством наблюдений, и S является длиной последовательности).
XSource = cat(3,sequencesSource{:}); XSource = permute(XSource,[1 3 2]);
Преобразуйте входные данные в dlarray
и вычислите модель энкодера выходные параметры.
dlXSource = dlarray(XSource); [dlZ, hiddenState] = modelEncoder(dlXSource, parametersEncoder);
Сгенерировать переводы для нового ввода данных последовательности в модель декодера энкодера и преобразовать получившиеся последовательности назад в текст с помощью маркерных индексов.
Чтобы инициализировать переводы, создайте вектор, содержащий только индексы, соответствующие лексеме запуска.
decoderInput = repmat(word2ind(encTarget,startToken),[1 numObservationsTest]); decoderInput = dlarray(decoderInput);
Инициализируйте вектор контекста и массивы ячеек, содержащие переведенные последовательности и музыку внимания к каждому наблюдению.
context = dlarray(zeros([size(dlZ, 1) numObservationsTest])); sequencesTranslated = cell(1,numObservationsTest); attentionScores = cell(1,numObservationsTest);
Цикл в зависимости от времени продвигается, и переведите последовательности. Сохраните цикличное выполнение по временным шагам, пока все последовательности не перевели. Для каждого наблюдения, когда перевод закончен (когда декодер предсказывает лексему остановки), устанавливает флаг прекращать переводить ту последовательность.
stopIdx = word2ind(encTarget,stopToken); stopTranslating = false(1, numObservationsTest); while ~all(stopTranslating) % Forward through decoder. [dlY, context, hiddenState, attn] = modelDecoder(decoderInput, parametersDecoder, context, ... hiddenState, dlZ); % Loop over observations. for i = 1:numObservationsTest % Skip already-translated sequences. if stopTranslating(i) continue end % Update attention scores. attentionScores{i} = [attentionScores{i} extractdata(attn(:,i))]; % Predict next time step. prob = softmax(dlY(:,i), 'DataFormat', 'CB'); [~, idx] = max(prob(1:end-1,:), [], 1); % Set stopTranslating flag when translation done. if idx == stopIdx stopTranslating(i) = true; else sequencesTranslated{i} = [sequencesTranslated{i} extractdata(idx)]; decoderInput(i) = idx; end end end
Просмотрите исходный текст, целевой текст и переводы в таблице.
tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = cellfun(@(sequence) join(ind2word(encTarget,sequence),""),sequencesTranslated)';
tbl
tbl=16×3 table
Source Target Translated
______ ____________ ____________
"412" "CDXII" "CDXII"
"274" "CCLXXIV" "CCLXXIV"
"231" "CCXXXI" "CCXXXI"
"558" "DLVIII" "DLVIII"
"187" "CLXXXVII" "CLXXXVII"
"828" "DCCCXXVIII" "DCCCXXVIII"
"1" "I" "CII"
"217" "CCXVII" "CCXVII"
"309" "CCCIX" "CCCIX"
"489" "CDLXXXIX" "CDLXXXIX"
"406" "CDVI" "CDVI"
"840" "DCCCXL" "DCCCXL"
"757" "DCCLVII" "DCCLVII"
"268" "CCLXVIII" "CCLXVIII"
"371" "CCCLXXI" "CCCLXXI"
"988" "CMLXXXVIII" "CMLXXXVIII"
Постройте множество внимания первой последовательности в карте тепла. Подсветка баллов внимания, какие области источника и переведенный упорядочивают модель, проявляет внимание при обработке перевода.
idx = 1; figure xlabs = [ind2word(encTarget,sequencesTranslated{idx}) stopToken]; ylabs = string(documentsSource(idx)); heatmap(attentionScores{idx}, ... 'CellLabelColor','none', ... 'XDisplayLabels',xlabs, ... 'YDisplayLabels',ylabs); xlabel("Translation") ylabel("Source") title("Attention Scores")
preprocessSourceTargetPairs
берет таблицу data
содержание целевых источником пар в двух столбцах и для каждого столбца возвращает последовательности маркерных индексов и соответствующего wordEncoding
возразите, что сопоставляет индексы со словами и наоборот.
function [sequencesSource, sequencesTarget, encSource, encTarget] = preprocessSourceTargetPairs(data,startToken,stopToken) % Extract text data. strSource = data{:,1}; strTarget = data{:,2}; % Create tokenized document arrays. documentsSource = transformText(strSource,startToken,stopToken); documentsTarget = transformText(strTarget,startToken,stopToken); % Create word encodings. encSource = wordEncoding(documentsSource); encTarget = wordEncoding(documentsTarget); % Convert documents to numeric sequences. sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none'); sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none'); end
transformText
функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в символы, и добавление запускают и останавливают лексемы. Чтобы перевести текст путем разделения текста в слова вместо символов, пропустите первый шаг.
function documents = transformText(str,startToken,stopToken) % Split text into characters. str = strip(replace(str,""," ")); % Add start and stop tokens. str = startToken + str + stopToken; % Create tokenized document array. documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]); end
createBatch
функционируйте берет мини-пакет входных и выходных последовательностей и возвращает дополненные последовательности с соответствующими дополнительными масками.
function [XSource, XTarget, maskSource, maskTarget] = createBatch(sequencesSource, sequencesTarget, ... paddingValueSource, paddingValueTarget) numObservations = size(sequencesSource,1); sequenceLengthSource = max(cellfun(@(x) size(x,2), sequencesSource)); sequenceLengthTarget = max(cellfun(@(x) size(x,2), sequencesTarget)); % Initialize masks. maskSource = false(numObservations, sequenceLengthSource); maskTarget = false(numObservations, sequenceLengthTarget); % Initialize mini-batch. XSource = zeros(1,numObservations,sequenceLengthSource); XTarget = zeros(1,numObservations,sequenceLengthTarget); % Pad sequences and create masks. for i = 1:numObservations % Source L = size(sequencesSource{i},2); paddingSize = sequenceLengthSource - L; padding = repmat(paddingValueSource, [1 paddingSize]); XSource(1,i,:) = [sequencesSource{i} padding]; maskSource(i,1:L) = true; % Target L = size(sequencesTarget{i},2); paddingSize = sequenceLengthTarget - L; padding = repmat(paddingValueTarget, [1 paddingSize]); XTarget(1,i,:) = [sequencesTarget{i} padding]; maskTarget(i,1:L) = true; end end
Функциональный modelEncoder
берет входные данные, параметры модели, дополнительная маска, которая используется, чтобы определить правильные выходные параметры для обучения и возвращает выходной параметр модели и LSTM скрытое состояние.
function [dlZ, hiddenState] = modelEncoder(dlX, parametersEncoder, maskSource) % Embedding weights = parametersEncoder.emb.Weights; dlZ = embedding(dlX,weights); % LSTM inputWeights = parametersEncoder.lstm1.InputWeights; recurrentWeights = parametersEncoder.lstm1.RecurrentWeights; bias = parametersEncoder.lstm1.Bias; numHiddenUnits = size(recurrentWeights, 2); initialHiddenState = dlarray(zeros([numHiddenUnits 1])); initialCellState = dlarray(zeros([numHiddenUnits 1])); dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ... recurrentWeights, bias, 'DataFormat', 'CBT'); % LSTM inputWeights = parametersEncoder.lstm2.InputWeights; recurrentWeights = parametersEncoder.lstm2.RecurrentWeights; bias = parametersEncoder.lstm2.Bias; [dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ... inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Mask output for training if nargin > 2 dlZ = dlZ.*permute(maskSource, [3 1 2]); sequenceLengths = sum(maskSource, 2); % Mask final hidden state for ii = 1:size(dlZ, 2) hiddenState(:, ii) = dlZ(:, ii, sequenceLengths(ii)); end end end
Функциональный modelDecoder
берет входные данные, параметры модели, вектор контекста, начальная буква LSTM скрытое состояние, выходные параметры энкодера и вероятность уволенного и выводит декодер выход, обновленный вектор контекста, обновленное состояние LSTM и баллы внимания.
function [dlY, context, hiddenState, attentionScores] = modelDecoder(dlX, parameters, context, ... hiddenState, encoderOutputs, dropout) % Embedding weights = parameters.emb.Weights; dlX = embedding(dlX, weights); % RNN input dlY = cat(1, dlX, context); % LSTM 1 initialCellState = dlarray(zeros(size(hiddenState))); inputWeights = parameters.lstm1.InputWeights; recurrentWeights = parameters.lstm1.RecurrentWeights; bias = parameters.lstm1.Bias; dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, ... recurrentWeights, bias, 'DataFormat', 'CBT'); if nargin > 5 % Dropout mask = ( rand(size(dlY), 'like', dlY) > dropout ); dlY = dlY.*mask; end % LSTM 2 inputWeights = parameters.lstm2.InputWeights; recurrentWeights = parameters.lstm2.RecurrentWeights; bias = parameters.lstm2.Bias; [~, hiddenState] = lstm(dlY, hiddenState, initialCellState, ... inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT'); % Attention weights = parameters.attn.Weights; [attentionScores, N] = attention(hiddenState, encoderOutputs, weights); % Context encoderOutputs = permute(encoderOutputs, [1 3 2]); for ii = 1:N context(:, ii) = encoderOutputs(:, :, ii)*attentionScores(:, ii); end % Fully connect weights = parameters.fc.Weights; bias = parameters.fc.Bias; dlY = weights*cat(1, hiddenState, context) + bias; end
embedding
функционируйте сопоставляет числовые индексы с соответствующим вектором, данным входными весами.
function Z = embedding(X, weights) % Reshape inputs into a vector [N, T] = size(X, 2:3); X = reshape(X, N*T, 1); % Index into embedding matrix Z = weights(:, X); % Reshape outputs by separating out batch and sequence dimensions Z = reshape(Z, [], N, T); end
attention
функция вычисляет баллы внимания accoring к Luong "общий" выигрыш.
function [attentionScores, N] = attention(hiddenState, encoderOutputs, weights) [N, S] = size(encoderOutputs, 2:3); attentionEnergies = dlarray(zeros( [S N] )); for tt = 1:S % The energy at each time step is the dot product of the hidden state % and the learnable attention weights times the encoder output attentionEnergies(tt, :) = sum(hiddenState.*(weights*encoderOutputs(:, :, tt)), 1); end % Compute softmax scores attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB'); end
modelGradients
функционируйте берет параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно learnable параметров в моделях и соответствующей потери.
function [gradientsEncoder, gradientsDecoder, maskedLoss] = modelGradients(parametersEncoder, ... parametersDecoder, dlXSource, dlXTarget, maskSource, maskTarget, dropout) % Forward through encoder. [dlZ, hiddenState] = modelEncoder(dlXSource, parametersEncoder, maskSource); % Get parameter sizes. [miniBatchSize, sequenceLength] = size(dlXTarget,2:3); sequenceLength = sequenceLength - 1; numHiddenUnits = size(dlZ,1); % Initialize context vector. context = dlarray(zeros([numHiddenUnits miniBatchSize])); % Initialize loss. loss = dlarray(zeros([miniBatchSize sequenceLength])); % Get first time step for decoder. decoderInput = dlXTarget(:,:,1); % Choose whether to use teacher forcing. doTeacherForcing = rand < 0.5; if doTeacherForcing for t = 1:sequenceLength % Forward through decoder. [dlY, context, hiddenState] = modelDecoder(decoderInput, parametersDecoder, context, ... hiddenState, dlZ, dropout); % Update loss. dlT = dlarray(oneHot(dlXTarget(:,:,t+1), size(dlY,1))); loss(:,t) = crossEntropyAndSoftmax(dlY, dlT); % Get next time step. decoderInput = dlXTarget(:,:,t+1); end else for t = 1:sequenceLength % Forward through decoder. [dlY, context, hiddenState] = modelDecoder(decoderInput, parametersDecoder, context, ... hiddenState, dlZ, dropout); % Update loss. dlT = dlarray(oneHot(dlXTarget(:,:,t+1), size(dlY,1))); loss(:,t) = crossEntropyAndSoftmax(dlY, dlT); % Greedily update next input time step. prob = softmax(dlY,'DataFormat','CB'); [~, decoderInput] = max(prob,[],1); end end % Determine masked loss. maskedLoss = sum(sum(loss.*maskTarget(:,2:end))) / miniBatchSize; % Update gradients. [gradientsEncoder, gradientsDecoder] = dlgradient(maskedLoss, parametersEncoder, parametersDecoder); % For plotting, return loss normalized by sequence length. maskedLoss = extractdata(maskedLoss) ./ sequenceLength; end
crossEntropyAndSoftmax
потеря вычисляет перекрестную энтропию и softmax потерю.
function loss = crossEntropyAndSoftmax(dlY, dlT) offset = max(dlY); logSoftmax = dlY - offset - log(sum(exp(dlY-offset))); loss = -sum(dlT.*logSoftmax); end
uniformNoise
функциональные демонстрационные веса от равномерного распределения.
function weights = uniformNoise(sz, k) weights = -sqrt(k) + 2*sqrt(k).*rand(sz); end
clipGradient
функционируйте отсекает градиенты модели.
function g = clipGradient(g, gradientThreshold) wnorm = norm(extractdata(g)); if wnorm > gradientThreshold g = (gradientThreshold/wnorm).*g; end end
oneHot
функция кодирует словари как одногорячие векторы.
function oh = oneHot(idx, numTokens) tokens = (1:numTokens)'; oh = (tokens == idx); end
adamupdate
| crossentropy
| dlarray
| dlfeval
| dlgradient
| dlupdate
| doc2sequence
| lstm
| softmax
| tokenizedDocument
| word2ind
| wordEncoding