Перевод языка Используя глубокое обучение

В этом примере показано, как обучить немца к английскому переводчику языка с помощью модели декодера энкодера повторяющейся последовательности к последовательности с вниманием.

Текущие модели декодера энкодера оказались успешными в задачах, таких как абстрактное текстовое резюмирование и нейронный машинный перевод. Эти модели состоят из энкодера, который обычно входные данные процессов с текущим слоем, такие как слой LSTM и декодер, который сопоставляет закодированный вход в желаемый выход, обычно также с текущим слоем. Модели, которые включают механизмы внимания в модели, позволяют декодеру фокусироваться на частях закодированного входа при генерации перевода один временной шаг за один раз. Этот пример реализует внимание Bahdanau [1] использование пользовательского слоя attentionLayer, присоединенный к этому примеру как вспомогательный файл. Чтобы получить доступ к этому слою, откройте этот пример как live скрипт.

Эта схема показывает структуру модели перевода языка. Входной текст в виде последовательности слов, передается через энкодер, который выводит закодированную версию входной последовательности, и скрытое состояние раньше инициализировало состояние декодера. Декодер делает предсказания одним словом во время с помощью предыдущего предсказания в качестве входа и также выходного состояния обновления и значений контекста.

Для получения дополнительной информации и детали о сетях энкодера и декодера, используемых в этом примере, смотрите раздел Define Encoder и Decoder Networks примера.

Предсказание наиболее вероятного слова для каждого шага в последовательности может привести к субоптимальным результатам. Любые неправильные предсказания могут вызвать еще больше неправильных предсказаний в более поздних временных шагах. Например, для целевого текста "An eagle flew by.", если декодер предсказывает первое слово перевода как "A", затем вероятность предсказания "орла" для следующего слова становится намного более маловероятной из-за низкой вероятности фразы "орел", появляющийся в английском тексте. Процесс генерации перевода отличается для обучения и предсказания. Этот пример использует разные подходы, чтобы стабилизировать обучение и предсказания:

  • Чтобы стабилизировать обучение, можно случайным образом использовать целевые значения в качестве входных параметров к декодеру. В частности, можно настроить, вероятность раньше вводила целевые значения, в то время как обучение прогрессирует. Например, можно обучить использование целевых значений на намного более высоком уровне в начале обучения, затем затухнуть вероятность, таким образом, что к концу обучения модель использует только предыдущие предсказания. Этот метод известен, как запланировано выборка [2]. Для получения дополнительной информации смотрите раздел Decoder Predictions Function примера.

  • Чтобы улучшить предсказания во время перевода, для каждого временного шага, можно рассмотреть верхнюю часть K предсказания для некоторого положительного целого числа K и исследуйте различные последовательности предсказаний, чтобы идентифицировать лучшую комбинацию. Этот метод известен как поиск луча. Для получения дополнительной информации смотрите раздел Beam Search Function примера.

В этом примере показано, как загрузить и предварительно обработать текстовые данные, чтобы обучить немца к английскому переводчику языка, задать сети энкодера и декодера, обучить модель с помощью пользовательского учебного цикла и сгенерировать переводы с помощью поиска луча.

Примечание: перевод Языка является в вычислительном отношении интенсивной задачей. Обучение на полном наборе данных, используемом в этом примере, может занять много часов, чтобы запуститься. Чтобы сделать пример запущенным более быстрый, можно уменьшать учебное время за счет точности предсказаний с ранее невидимыми данными путем отбрасывания фрагмента обучающих данных. Удаление наблюдений может ускорить обучение, потому что это уменьшает объем данных до процесса в эпоху и уменьшает размер словаря обучающих данных.

Чтобы сократить время, это исполняется пример, отбрасывание 70% данных. Обратите внимание на то, что отбрасывание больших объемов данных негативно влияет на точность изученной модели. Для более точных результатов уменьшайте сумму отброшенных данных. Чтобы ускорить пример, увеличьте сумму отброшенных данных.

discardProp = 0.70;

Загрузите обучающие данные

Загрузите и извлеките англо-немецкий разграниченный Вкладкой Двуязычный набор данных Пар Предложения. Данные прибывают из http://www.manythings.org/anki и https://tatoeba.org, и обеспечиваются в соответствии с Условиями использования Tatoeba и CC лицензию-BY.

downloadFolder = tempdir;
url = "http://www.manythings.org/anki/deu-eng.zip";
filename = fullfile(downloadFolder,"deu-eng.zip");
dataFolder = fullfile(downloadFolder,"deu-eng");

if ~exist(dataFolder,"dir")
    fprintf("Downloading English-German Tab-delimited Bilingual Sentence Pairs data set (7.6 MB)... ")
    websave(filename,url);
    unzip(filename,dataFolder);
    fprintf("Done.\n")
end

Составьте таблицу, которая содержит пары предложения, заданные как строки. Считайте разграниченные вкладкой пары предложений с помощью readtable. Задайте немецкий текст как источник и английский текст как цель.

filename = fullfile(dataFolder,"deu.txt");

opts = delimitedTextImportOptions(...
    Delimiter="\t", ...
    VariableNames=["Target" "Source" "License"], ...
    SelectedVariableNames=["Source" "Target"], ...
    VariableTypes=["string" "string" "string"], ...
    Encoding="UTF-8");

Просмотрите первые несколько пар предложения в данных.

data = readtable(filename, opts);
head(data)
ans=8×2 table
        Source         Target 
    _______________    _______

    "Geh."             "Go."  
    "Hallo!"           "Hi."  
    "Grüß Gott!"       "Hi."  
    "Lauf!"            "Run!" 
    "Lauf!"            "Run." 
    "Potzdonner!"      "Wow!" 
    "Donnerwetter!"    "Wow!" 
    "Feuer!"           "Fire!"

Обучение на полном наборе данных может занять много времени, чтобы запуститься. Чтобы уменьшать учебное время за счет точности, можно отбросить фрагмент обучающих данных. Удаление наблюдений может ускорить обучение, потому что это уменьшает объем данных до процесса в эпоху, а также сокращения размера словаря обучающих данных.

Отбросьте фрагмент данных согласно discardProp переменная задана в начале примера. Обратите внимание на то, что отбрасывание больших объемов данных негативно влияет на точность изученной модели. Для более точных результатов уменьшайте сумму отброшенных данных установкой discardProp к нижнему значению.

idx = size(data,1) - floor(discardProp*size(data,1)) + 1;
data(idx:end,:) = [];

Просмотрите количество остающихся наблюдений.

size(data,1)
ans = 68124

Разделите данные в обучение и протестируйте разделы, содержащие 90% и 10% данных, соответственно.

trainingProp = 0.9;
idx = randperm(size(data,1),floor(trainingProp*size(data,1)));
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

Просмотрите первые несколько строк обучающих данных.

head(dataTrain)
ans=8×2 table
                  Source                            Target          
    ___________________________________    _________________________

    "Tom erschoss Mary."                   "Tom shot Mary."         
    "Ruf mich bitte an."                   "Call me, please."       
    "Kann das einer nachprüfen?"           "Can someone check this?"
    "Das lasse ich mir nicht gefallen!"    "I won't stand for it."  
    "Ich mag Englisch nicht."              "I don't like English."  
    "Er ist auf dem Laufenden."            "He is up to date."      
    "Sie sieht glücklich aus."             "She seems happy."       
    "Wo wurden sie geboren?"               "Where were they born?"  

Просмотрите количество учебных наблюдений.

numObservationsTrain = size(dataTrain,1)
numObservationsTrain = 61311

Предварительная Обработка Данных

Предварительно обработайте текстовые данные с помощью preprocessText функция, перечисленная в конце примера. preprocessText функция предварительно обрабатывает и маркирует входной текст для перевода путем разделения текста в слова, и добавление запускают и останавливают лексемы.

documentsGerman = preprocessText(dataTrain.Source);

Создайте wordEncoding возразите что лексемы карт против числового индекса и наоборот использования словаря.

encGerman = wordEncoding(documentsGerman);

Преобразуйте целевые данные в последовательности с помощью тех же шагов.

documentsEnglish = preprocessText(dataTrain.Target);
encEnglish = wordEncoding(documentsEnglish);

Просмотрите размеры словаря входной и выходной кодировки.

numWordsGerman = encGerman.NumWords
numWordsGerman = 12117
numWordsEnglish = encEnglish.NumWords
numWordsEnglish = 7226

Задайте сети энкодера и декодера

Эта схема показывает структуру модели перевода языка. Входной текст в виде последовательности слов, передается через энкодер, который выводит закодированную версию входной последовательности, и скрытое состояние раньше инициализировало состояние декодера. Декодер делает предсказания одним словом во время с помощью предыдущего предсказание, как введено и также выходные параметры обновленное состояние и значения контекста.

Создайте сети энкодера и декодера с помощью languageTranslationLayers функция, присоединенная к этому примеру как вспомогательный файл. Чтобы получить доступ к этой функции, откройте пример как live скрипт.

Для сети энкодера, languageTranslationLayers функция задает простую сеть, состоящую из слоя встраивания, сопровождаемого слоем LSTM. Операция встраивания преобразует категориальные лексемы в числовые векторы, где числовые векторы изучены сетью.

Для сети декодера, languageTranslationLayers функция задает сеть, которая передает входные данные, конкатенированные с входным контекстом через слой LSTM, и берет обновленное скрытое состояние и энкодер выход и передает ее через механизм внимания, чтобы определить вектор контекста. LSTM выход и вектор контекста затем конкатенирован и проходится полностью связанный и softmax слой для классификации.

Создайте сети энкодера и декодера с помощью languageTranslationLayers функция, присоединенная к этому примеру как вспомогательный файл. Чтобы получить доступ к этой функции, откройте пример как live скрипт. Задайте размерность встраивания 128, и 128 скрытых модулей в слоях LSTM.

embeddingDimension = 128;
numHiddenUnits = 128;

[lgraphEncoder,lgraphDecoder] = languageTranslationLayers(embeddingDimension,numHiddenUnits,numWordsGerman,numWordsEnglish);

Чтобы обучить сеть в пользовательском учебном цикле, преобразуйте сети энкодера и декодера в dlnetwork объекты.

dlnetEncoder = dlnetwork(lgraphEncoder);
dlnetDecoder = dlnetwork(lgraphDecoder);

Декодер имеет несколько выходных параметров включая контекст выход слоя внимания, который также передается другому слою. Задайте сетевые выходные параметры с помощью OutputNames свойство декодера dlnetwork объект.

dlnetDecoder.OutputNames = ["softmax" "context" "lstm2/hidden" "lstm2/cell"];

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в разделе Model Gradients Function примера, который берет в качестве входа параметры модели энкодера и декодера, мини-пакет входных данных и дополнительных масок, соответствующих входным данным и вероятности уволенного, и возвращает градиенты потери относительно настраиваемых параметров в моделях и соответствующей потери.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 64 в течение 15 эпох и скорости обучения 0,005.

miniBatchSize = 64;
numEpochs = 15;
learnRate = 0.005;

Инициализируйте опции для оптимизации Адама.

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучайте использование постепенно затухающий значения ϵ для запланированной выборки. Начните со значения ϵ=0.5 и линейно затухните, чтобы закончиться значением ϵ=0. Для получения дополнительной информации о запланированной выборке, смотрите раздел Decoder Predictions Function примера.

epsilonStart = 0.5;
epsilonEnd = 0;

Обучите SortaGrad [3] использования, который является стратегией улучшить обучение неровных последовательностей по образованию в течение одной эпохи с последовательностями, отсортированными по последовательности, затем переставляющей однажды в эпоху после этого.

Сортировка обучающих последовательностей длиной последовательности.

sequenceLengths = doclength(documentsGerman);
[~,idx] = sort(sequenceLengths);
documentsGerman = documentsGerman(idx);
documentsEnglish = documentsEnglish(idx);

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

Создайте хранилища данных массивов для входных и выходных данных с помощью arrayDatastore функция. Объедините хранилища данных с помощью combine функция.

adsSource = arrayDatastore(documentsGerman);
adsTarget = arrayDatastore(documentsEnglish);
cds = combine(adsSource,adsTarget);

Создайте мини-пакетную очередь, чтобы автоматически подготовить мини-пакеты к обучению.

  • Предварительно обработайте обучающие данные с помощью preprocessMiniBatch функция, которая возвращает мини-пакет исходных последовательностей, целевых последовательностей, соответствующей маски и начальной лексемы запуска.

  • Выведите dlarray объекты с форматом "CTB" (канал, время, пакет).

  • Отбросьте любые частичные мини-пакеты.

mbq = minibatchqueue(cds,4, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@(X,Y) preprocessMiniBatch(X,Y,encGerman,encEnglish), ...
    MiniBatchFormat=["CTB" "CTB" "CTB" "CTB"], ...
    PartialMiniBatch="discard");

Инициализируйте график процесса обучения.

figure
C = colororder;
lineLossTrain = animatedline(Color=C(2,:));

xlabel("Iteration")
ylabel("Loss")
ylim([0 inf])
grid on

Для сетей энкодера и декодера инициализируйте значения для оптимизации Адама.

trailingAvgEncoder = [];
trailingAvgSqEncoder = [];
trailingAvgDecder = [];
trailingAvgSqDecoder = [];

Создайте массив ϵ значения для запланированной выборки.

numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
numIterations = numIterationsPerEpoch * numEpochs;
epsilon = linspace(epsilonStart,epsilonEnd,numIterations);

Обучите модель. Для каждой итерации:

  • Считайте мини-пакет данных из мини-пакетной очереди.

  • Вычислите градиенты модели и потерю.

  • Обновите сети энкодера и декодера с помощью adamupdate функция.

  • Обновите график процесса обучения и отобразите перевод в качестве примера с помощью ind2str функция, присоединенная к этому примеру как вспомогательный файл. Чтобы получить доступ к этой функции, откройте этот пример как live скрипт.

  • Если итерация дает к самой низкой учебной потере, то сохраните сеть.

В конце каждой эпохи переставьте мини-пакетную очередь.

Для больших наборов данных обучение может занять много часов, чтобы запуститься.

iteration = 0;
start = tic;
lossMin = inf;
reset(mbq)

% Loop over epochs.
for epoch = 1:numEpochs

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        [dlX,dlT,maskT,decoderInput] = next(mbq);
        
        % Compute loss and gradients.
        [gradientsEncoder,gradientsDecoder,loss,dlYPred] = dlfeval(@modelGradients,dlnetEncoder,dlnetDecoder,dlX,dlT,maskT,decoderInput,epsilon(iteration));
        
        % Update network learnable parameters using adamupdate.
        [dlnetEncoder, trailingAvgEncoder, trailingAvgSqEncoder] = adamupdate(dlnetEncoder,gradientsEncoder,trailingAvgEncoder,trailingAvgSqEncoder, ...
            iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);

        [dlnetDecoder, trailingAvgDecder, trailingAvgSqDecoder] = adamupdate(dlnetDecoder,gradientsDecoder,trailingAvgDecder,trailingAvgSqDecoder, ...
            iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);

        % Generate translation for plot.
        if iteration == 1 || mod(iteration,10) == 0
            strGerman = ind2str(dlX(:,1,:),encGerman);
            strEnglish = ind2str(dlT(:,1,:),encEnglish,Mask=maskT);
            strTranslated = ind2str(dlYPred(:,1,:),encEnglish);
        end

        % Display training progress.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        loss = double(gather(extractdata(loss)));
        addpoints(lineLossTrain,iteration,loss)
        title( ...
            "Epoch: " + epoch + ", Elapsed: " + string(D) + newline + ...
            "Source: " + strGerman + newline + ...
            "Target: " + strEnglish + newline + ...
            "Training Translation: " + strTranslated)

        drawnow
        
        % Save best network.
        if loss < lossMin
            lossMin = loss;
            netBest.dlnetEncoder = dlnetEncoder;
            netBest.dlnetDecoder = dlnetDecoder;
            netBest.loss = loss;
            netBest.iteration = iteration;
            netBest.D = D;
        end
    end

    % Shuffle.
    shuffle(mbq);
end

График показывает два перевода исходного текста. Цель является целевым переводом, обеспеченным обучающими данными, которые сеть пытается воспроизвести. Учебный перевод является предсказанным переводом, который использует информацию из целевого текста с помощью запланированного механизма выборки.

Добавьте кодировку слова в netBest структура и сохраняет структуру в файле MAT.

netBest.encGerman = encGerman;
netBest.encEnglish = encEnglish;

D = datestr(now,'yyyy_mm_dd__HH_MM_SS');
filename = "dlnet_best__" + D + ".mat";
save(filename,"netBest");

Извлеките лучшую сеть из netBest.

dlnetEncoder = netBest.dlnetEncoder;
dlnetDecoder = netBest.dlnetDecoder;

Тестовая модель

Чтобы оценить качество переводов, используйте Дублера Оценки BiLingual (BLEU) алгоритм выигрыша [4].

Переведите тестовые данные с помощью translateText функция перечислена в конце примера.

strTranslatedTest = translateText(dlnetEncoder,dlnetDecoder,encGerman,encEnglish,dataTest.Source);

Просмотрите случайный выбор тестового исходного текста, целевого текста и предсказанных переводов в таблице.

numObservationsTest = size(dataTest,1);
idx = randperm(numObservationsTest,8);
tbl = table;
tbl.Source = dataTest.Source(idx);
tbl.Target = dataTest.Target(idx);
tbl.Translated = strTranslatedTest(idx)
tbl=8×3 table
                   Source                              Target                         Translated            
    _____________________________________    __________________________    _________________________________

    "Er sieht krank aus."                    "He seems ill."               "he looks sick ."                
    "Ich werde das Buch holen."              "I'll get the book."          "i'll get the book . . it . ."   
    "Ruhst du dich jemals aus?"              "Do you ever rest?"           "do you look out of ? ? ? ?"     
    "Was willst du?"                         "What are you after?"         "what do you want want ? ? ? ?"  
    "Du hast keinen Beweis."                 "You have no proof."          "you have no proof . . . . ."    
    "Macht es, wann immer ihr wollt."        "Do it whenever you want."    "do it you like it . . it ."     
    "Tom ist gerade nach Hause gekommen."    "Tom has just come home."     "tom just came home home . . . ."
    "Er lügt nie."                           "He never tells a lie."       "he never lie lies . . . . ."    

Чтобы оценить качество переводов с помощью счета подобия BLEU, сначала предварительно обработайте текстовые данные с помощью тех же шагов что касается обучения. Задайте пустой запуск и лексемы остановки, когда они не используются в переводе.

candidates = preprocessText(strTranslatedTest,StartToken="",StopToken="");
references = preprocessText(dataTest.Target,StartToken="",StopToken="");

bleuEvaluationScore функция, по умолчанию, оценивает баллы подобия путем сравнения N-грамм длины один - четыре (фразы многословные с четырьмя или меньшим количеством слов или отдельных слов). Если у кандидата или справочных документов есть меньше чем четыре лексемы, то получившийся счет оценки BLEU является нулем. Гарантировать тот bleuEvaluationScore возвращает ненулевую музыку к этим коротким документам кандидата, установите веса n-граммы на вектор с меньшим количеством элементов, чем количество слов в candidate.

Определите длину самого короткого документа кандидата.

minLength = min([doclength(candidates); doclength(references)])
minLength = 2

Если самый короткий документ имеет меньше чем четыре лексемы, то установленный веса n-граммы в вектор с длиной, совпадающей с самым коротким документом равным весам та сумма одной. В противном случае задайте веса n-граммы [0.25 0.25 0.25 0.25]. Обратите внимание на то, что, если minLength 1 (и следовательно весами n-граммы является также 1), затем bleuEvaluationScore функция может возвратить меньше значимых результатов, когда она только сравнивает отдельные слова (униграммы) и не сравнивает N-грамм (фразы многословные).

if minLength < 4
    ngramWeights = ones(1,minLength) / minLength;
else
    ngramWeights = [0.25 0.25 0.25 0.25];
end

Вычислите баллы оценки BLEU путем итерации по переводам и использования bleuEvaluationScore функция.

for i = 1:numObservationsTest
    score(i) = bleuEvaluationScore(candidates(i),references(i),NgramWeights=ngramWeights);
end

Визуализируйте баллы оценки BLEU в гистограмме.

figure
histogram(score);
title("BLEU Evaluation Scores")
xlabel("Score")
ylabel("Frequency")

Просмотрите таблицу некоторых лучших переводов.

[~,idxSorted] = sort(score,"descend");
idx = idxSorted(1:8);
tbl = table;
tbl.Source = dataTest.Source(idx);
tbl.Target = dataTest.Target(idx);
tbl.Translated = strTranslatedTest(idx)
tbl=8×3 table
             Source                Target        Translated  
    ________________________    ____________    _____________

    "Legen Sie sich hin!"       "Lie low."      "lie low ."  
    "Ich gähnte."               "I yawned."     "i yawned ." 
    "Küsse Tom!"                "Kiss Tom."     "kiss tom ." 
    "Küssen Sie Tom!"           "Kiss Tom."     "kiss tom ." 
    "Nimm Tom."                 "Take Tom."     "take tom ." 
    "Komm bald."                "Come soon."    "come soon ."
    "Ich habe es geschafft."    "I made it."    "i made it ."
    "Ich sehe Tom."             "I see Tom."    "i see tom ."

Просмотрите таблицу некоторых худших переводов.

idx = idxSorted(end-7:end);
tbl = table;
tbl.Source = dataTest.Source(idx);
tbl.Target = dataTest.Target(idx);
tbl.Translated = strTranslatedTest(idx)
tbl=8×3 table
                                Source                                           Target                       Translated         
    _______________________________________________________________    __________________________    ____________________________

    "Diese Schnecken kann man essen."                                  "These snails are edible."    "this can be eat ."         
    "Sie stehen noch zu Verfügung."                                    "They're still available."    "it's still at . . . . . ." 
    "Diese Schraube passt zu dieser Mutter."                           "This bolt fits this nut."    "this life is too . . . . ."
    "Diese Puppe gehört mir."                                          "This doll belongs to me."    "this one is mine ."        
    "Das ist eine japanische Puppe."                                   "This is a Japanese doll."    "that's a old trick ."      
    "Das ist eine Kreuzung, an der alle Fahrzeuge anhalten müssen."    "This is a four-way stop."    "that's a to to to . . . ." 
    "Diese Sendung ist eine Wiederholung."                             "This program is a rerun."    "this is is quiet ."        
    "Die heutige Folge ist eine Wiederholung."                         "Today's show is a rerun."    "uranus is care ."          

Сгенерируйте переводы

Сгенерируйте переводы для новых данных с помощью translateText функция.

strGermanNew = [
    "Wie geht es Dir heute?"
    "Wie heißen Sie?"
    "Das Wetter ist heute gut."];

Переведите текст с помощью translateText, функция перечислена в конце примера.

strTranslatedNew = translateText(dlnetEncoder,dlnetDecoder,encGerman,encEnglish,strGermanNew)
strTranslatedNew = 3×1 string
    "how do you feel today ?"
    "what's your your name ? ? ? ? ?"
    "the is is today . . today . ."

Функции предсказания

Излучите поисковую функцию

Поиск луча является методом для исследования различных комбинаций предсказаний такта, чтобы помочь найти лучшую последовательность предсказания. Предпосылка поиска луча для каждого предсказания такта, идентифицируйте верхнюю часть K предсказания для некоторого положительного целого числа K (также известный как индекс луча или ширину луча), и обеспечивают верхнюю часть K предсказанные последовательности до сих пор на каждом временном шаге.

Эта схема показывает структуру поиска луча в качестве примера с индексом луча K=3. Для каждого предсказания обеспечены лучшие три последовательности.

beamSearch функционируйте берет в качестве входа входные данные dlX, сети энкодера и декодера и целевое кодирование слова, и возвращают предсказанные переведенные слова с помощью алгоритма поиска луча с индексом луча 3 и максимальной длиной последовательности 10. Можно также задать дополнительные аргументы с помощью аргументов name-value:

  • BeamIndex — Излучите индекс. Значение по умолчанию равняется 3.

  • MaxNumWords — Максимальная длина последовательности. Значение по умолчанию равняется 10.

function str = beamSearch(dlX,dlnetEncoder,dlnetDecoder,encEnglish,args)

% Parse input arguments.
arguments
    dlX
    dlnetEncoder
    dlnetDecoder
    encEnglish
    
    args.BeamIndex = 3;
    args.MaxNumWords = 10;
end

beamIndex = args.BeamIndex;
maxNumWords = args.MaxNumWords;
startToken = "<start>";
stopToken = "<stop>";

% Encoder predictions.
[dlZ, hiddenState, cellState] = predict(dlnetEncoder,dlX);

% Initialize context.
miniBatchSize = size(dlX,2);
numHiddenUnits = size(dlZ,1);
context = zeros([numHiddenUnits miniBatchSize],"like",dlZ);
context = dlarray(context,"CB");

% Initialize candidates.
candidates = struct;
candidates.Words = startToken;
candidates.Score = 0;
candidates.StopFlag = false;
candidates.HiddenState = hiddenState;
candidates.CellState = cellState;

% Loop over words.
t = 0;
while t < maxNumWords
    t = t + 1;

    candidatesNew = [];

    % Loop over candidates.
    for i = 1:numel(candidates)

        % Stop generating when stop token is predicted.
        if candidates(i).StopFlag
            continue
        end

        % Candidate details.
        words = candidates(i).Words;
        score = candidates(i).Score;
        hiddenState = candidates(i).HiddenState;
        cellState = candidates(i).CellState;

        % Predict next token.
        decoderInput = word2ind(encEnglish,words(end));
        decoderInput = dlarray(decoderInput,"CBT");

        [dlYPred,context,hiddenState,cellState] = predict(dlnetDecoder,decoderInput,hiddenState,cellState,context,dlZ, ...
            Outputs=["softmax" "context" "lstm2/hidden" "lstm2/cell"]);

        % Find top predictions.
        [scoresTop,idxTop] = maxk(extractdata(dlYPred),beamIndex);
        idxTop = gather(idxTop);

        % Loop over top predictions.
        for j = 1:beamIndex
            candidate = struct;

            % Determine candidate word and score.
            candidateWord = ind2word(encEnglish,idxTop(j));
            candidateScore = scoresTop(j);

            % Set stop translating flag.
            if candidateWord == stopToken
                candidate.StopFlag = true;
            else
                candidate.StopFlag = false;
            end

            % Update candidate details.
            candidate.Words = [words candidateWord];
            candidate.Score = score + log(candidateScore);
            candidate.HiddenState = hiddenState;
            candidate.CellState = cellState;

            % Add to new candidates.
            candidatesNew = [candidatesNew candidate];
        end
    end

    % Get top candidates.
    [~,idx] = maxk([candidatesNew.Score],beamIndex);
    candidates = candidatesNew(idx);

    % Stop predicting when all candidates have stop token.
    if all([candidates.StopFlag])
        break
    end
end

% Get top candidate.
words = candidates(1).Words;

% Convert to string scalar.
words(ismember(words,[startToken stopToken])) = [];
str = join(words);

end

Переведите текстовую функцию

translateText функционируйте берет в качестве входа сети энкодера и декодера, входную строку и входную и выходную кодировку слова и возвращает переведенный текст.

function strTranslated = translateText(dlnetEncoder,dlnetDecoder,encGerman,encEnglish,strGerman,args)

% Parse input arguments.
arguments
    dlnetEncoder
    dlnetDecoder
    encGerman
    encEnglish
    strGerman
    
    args.BeamIndex = 3;
end

beamIndex = args.BeamIndex;

% Preprocess text.
documentsGerman = preprocessText(strGerman);
dlX = preprocessPredictors(documentsGerman,encGerman);
dlX = dlarray(dlX,"CTB");

% Loop over observations.
numObservations = numel(strGerman);
strTranslated = strings(numObservations,1);             
for n = 1:numObservations
    
    % Translate text.
    strTranslated(n) = beamSearch(dlX(:,n,:),dlnetEncoder,dlnetDecoder,encEnglish,BeamIndex=beamIndex);
end

end

Функции модели

Функция градиентов модели

modelGradients функционируйте берет в качестве входа сеть энкодера, сеть декодера, мини-пакеты предикторов dlX, цели dlT, дополнение маски, соответствующей целям maskT, и ϵ значение для запланированной выборки. Функция возвращает градиенты потери относительно настраиваемых параметров в сетях gradientsE и gradientsD, соответствующая потеря и предсказания декодера dlYPred закодированный как последовательности одногорячих векторов.

function [gradientsE,gradientsD,loss,dlYPred] = modelGradients(dlnetEncoder,dlnetDecoder,dlX,dlT,maskT,decoderInput,epsilon)

% Forward through encoder.
[dlZ, hiddenState, cellState] = forward(dlnetEncoder,dlX);

% Decoder output.
dlY = decoderPredictions(dlnetDecoder,dlZ,dlT,hiddenState,cellState,decoderInput,epsilon);

% Sparse cross-entropy loss.
loss = sparseCrossEntropy(dlY,dlT,maskT);

% Update gradients.
[gradientsE,gradientsD] = dlgradient(loss,dlnetEncoder.Learnables,dlnetDecoder.Learnables);

% For plotting, return loss normalized by sequence length.
sequenceLength = size(dlT,3);
loss = loss ./ sequenceLength;

% For plotting example translations, return the decoder output.
dlYPred = onehotdecode(dlY,1:size(dlY,1),1,"single");

end

Функция предсказаний декодера

decoderPredictions функционируйте берет в качестве входа, сети декодера, энкодер выход dlZ, цели dlT, декодер ввел скрытый и значения состояния ячейки, и ϵ значение для запланированной выборки.

Чтобы стабилизировать обучение, можно случайным образом использовать целевые значения в качестве входных параметров к декодеру. В частности, можно настроить, вероятность раньше вводила целевые значения, в то время как обучение прогрессирует. Например, можно обучить использование целевых значений на намного более высоком уровне в начале обучения, затем затухнуть вероятность, таким образом, что к концу обучения модель использует только предыдущие предсказания. Этот метод известен, как запланировано выборка [2]. Эта схема показывает механизм выборки, включенный в один временной шаг предсказания декодера.

Декодер делает предсказания одним временным шагом за один раз. На каждом временном шаге вход случайным образом выбран согласно ϵ значение для запланированной выборки. В частности, функция использует целевое значение в качестве входа с вероятностью ϵ и использует предыдущее предсказание в противном случае.

function dlY = decoderPredictions(dlnetDecoder,dlZ,dlT,hiddenState,cellState,decoderInput,epsilon)

% Initialize context.
numHiddenUnits = size(dlZ,1);
miniBatchSize = size(dlZ,2);
context = zeros([numHiddenUnits miniBatchSize],"like",dlZ);
context = dlarray(context,"CB");

% Initialize output.
idx = (dlnetDecoder.Learnables.Layer == "fc" & dlnetDecoder.Learnables.Parameter=="Bias");
numClasses = numel(dlnetDecoder.Learnables.Value{idx});
sequenceLength = size(dlT,3);
dlY = zeros([numClasses miniBatchSize sequenceLength],"like",dlZ);
dlY = dlarray(dlY,"CBT");

% Forward start token through decoder.
[dlY(:,:,1),context,hiddenState,cellState] = forward(dlnetDecoder,decoderInput,hiddenState,cellState,context,dlZ);

% Loop over remaining time steps.
for t = 2:sequenceLength

    % Scheduled sampling. Randomly select previous target or previous
    % prediction.
    if rand < epsilon
        % Use target value.
        decoderInput = dlT(:,:,t-1);
    else
        % Use previous prediction.
        [~,dlYhat] = max(dlY(:,:,t-1),[],1);
        decoderInput = dlYhat;
    end

    % Forward through decoder.
    [dlY(:,:,t),context,hiddenState,cellState] = forward(dlnetDecoder,decoderInput,hiddenState,cellState,context,dlZ);
end

end

Разреженная потеря перекрестной энтропии

sparseCrossEntropy функция вычисляет потерю перекрестной энтропии между предсказаниями dlY и цели dlT с целевой маской maskT, где dlY массив вероятностей и dlT закодирован как последовательность целочисленных значений.

function loss = sparseCrossEntropy(dlY,dlT,maskT)

% Initialize loss.
[~,miniBatchSize,sequenceLength] = size(dlY);
loss = zeros([miniBatchSize sequenceLength],"like",dlY);

% To prevent calculating log of 0, bound away from zero.
precision = underlyingType(dlY);
dlY(dlY < eps(precision)) = eps(precision);

% Loop over time steps.
for n = 1:miniBatchSize
    for t = 1:sequenceLength
        idx = dlT(1,n,t);
        loss(n,t) = -log(dlY(idx,n,t));
    end
end

% Apply masking.
maskT = squeeze(maskT);
loss = loss .* maskT;

% Calculate sum and normalize.
loss = sum(loss,"all");
loss = loss / miniBatchSize;

end

Предварительная обработка функций

Текст, предварительно обрабатывающий функцию

preprocessText функция предварительно обрабатывает входной текст для перевода путем преобразования текста в нижний регистр, добавление запускают и останавливают лексемы и маркирование.

function documents = preprocessText(str,args)

arguments
    str
    args.StartToken = "<start>";
    args.StopToken = "<stop>";
end

startToken = args.StartToken;
stopToken = args.StopToken;

str = lower(str);
str = startToken + str + stopToken;
documents = tokenizedDocument(str,CustomTokens=[startToken stopToken]);

end

Функция предварительной обработки мини-пакета

preprocessMiniBatch функция предварительно обрабатывает маркируемые документы для обучения. Функция кодирует мини-пакеты документов как последовательности числовых индексов и заполняет последовательности, чтобы иметь ту же длину.

function [XSource,XTarget,mask,decoderInput] = preprocessMiniBatch(dataSource,dataTarget,encGerman,encEnglish)

documentsGerman = cat(1,dataSource{:});
XSource = preprocessPredictors(documentsGerman,encGerman);

documentsEngligh = cat(1,dataTarget{:});
sequencesTarget = doc2sequence(encEnglish,documentsEngligh,PaddingDirection="none");

[XTarget,mask] = padsequences(sequencesTarget,2,PaddingValue=1);

decoderInput = XTarget(:,1,:);
XTarget(:,1,:) = [];
mask(:,1,:) = [];

end

Предикторы, предварительно обрабатывающие функцию

preprocessPredictors функция предварительно обрабатывает исходные документы для обучения или предсказание. Функция кодирует массив маркируемых документов как последовательности числовых индексов.

function XSource = preprocessPredictors(documentsGerman,encGerman)

sequencesSource = doc2sequence(encGerman,documentsGerman,PaddingDirection="none");
XSource = padsequences(sequencesSource,2);

end

Библиография

  1. Чоровский, январь, Dzmitry Bahdanau, Дмитрий Сердюк, Кюнгюн Чо и Иосуа Бенхио. “Основанные на внимании Модели для Распознавания речи”. Предварительно распечатайте, siubmitted 24 июня 2015. https://arxiv.org/abs/1506.07503.

  2. Bengio, Samy, Oriol Vinyals, Нэвдип Джэйтли и Ноам Шазир. “Запланированная Выборка для Предсказания Последовательности с Рекуррентными нейронными сетями”. Предварительно распечатайте, представленный 23 сентября 2015. https://arxiv.org/abs/1506.03099.

  3. Amodei, Дарио, Sundaram Ananthanarayanan, Rishita Anubhai, Цзинлян Бай, Эрик Баттенберг, Карл Кэз, Джаред Каспер и др. "Глубокая Речь 2: сквозное Распознавание речи на английском и Мандарине". В <цитируют> Продолжения Исследования Машинного обучения </, цитируют> 48 (2016): 173–182.

  4. Papineni, Кишор, Салим Рукос, Тодд Уорд и Вэй-Цзин Чжу. "BLEU: метод для автоматической оценки машинного перевода". В продолжениях 40-го годового собрания на ассоциации для компьютерной лингвистики (2002): 311–318.

Смотрите также

| | | |

Похожие темы