Гордитесь и нанесите ущерб и MATLAB

В этом примере показано, как обучить сеть LSTM глубокого обучения, чтобы сгенерировать текст с помощью символьных вложений.

Чтобы обучить нейронную сеть для глубокого обучения текстовой генерации, обучите сеть LSTM от последовательности к последовательности, чтобы предсказать следующий символ в последовательности символов. Чтобы обучить сеть, чтобы предсказать следующий символ, задайте ответы, чтобы быть входными последовательностями, переключенными одним временным шагом.

Чтобы использовать символьные вложения, преобразуйте каждое учебное наблюдение в последовательность целых чисел, где целые числа индексируют в словарь символов. Включайте слой встраивания слова в сеть, которая изучает встраивание символов и сопоставляет целые числа с векторами.

Загрузите обучающие данные

Считайте код HTML из Проекта Электронная книга Гутенберга Гордости и Предубеждения, Джейн Остин и проанализируйте его с помощью webread и htmlTree.

url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);

Извлеките абзацы путем нахождения p элементы. Задайте, чтобы проигнорировать элементы абзаца с классом "toc" использование селектора CSS ':not(.toc)'.

paragraphs = findElement(tree,'p:not(.toc)');

Извлеките текстовые данные из абзацев с помощью extractHTMLText. и удалите пустые строки.

textData = extractHTMLText(paragraphs);
textData(textData == "") = [];

Удалите строки короче, чем 20 символов.

idx = strlength(textData) < 20;
textData(idx) = [];

Визуализируйте текстовые данные, одним словом, облако.

figure
wordcloud(textData);
title("Pride and Prejudice")

Преобразуйте текстовые данные в последовательности

Преобразуйте текстовые данные в последовательности индексов символа для предикторов и категориальные последовательности для ответов.

Категориальная функция обрабатывает новую строку и пробельные записи как неопределенные. Чтобы создать категориальные элементы для этих символов, замените их на специальные символы ""(pilcrow, "\x00B6") и "·" (средняя точка, "\x00B7") соответственно. Чтобы предотвратить неоднозначность, необходимо выбрать специальные символы, которые не появляются в тексте. Эти символы не появляются в обучающих данных, так может использоваться с этой целью.

newlineCharacter = compose("\x00B6");
whitespaceCharacter = compose("\x00B7");
textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);

Цикл по текстовым данным и создает последовательность индексов символа, представляющих символы каждого наблюдения и категориальной последовательности символов для ответов. Чтобы обозначить конец каждого наблюдения, включайте специальный символ "␃" (конец текста, "\x2403").

endOfTextCharacter = compose("\x2403");
numDocuments = numel(textData);
for i = 1:numDocuments
    characters = textData{i};
    X = double(characters);
    
    % Create vector of categorical responses with end of text character.
    charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter];
    Y = categorical(charactersShifted);
    
    XTrain{i} = X;
    YTrain{i} = Y;
end

Во время обучения, по умолчанию, программное обеспечение разделяет обучающие данные в мини-пакеты и заполняет последовательности так, чтобы у них была та же длина. Слишком много дополнения может оказать негативное влияние на производительность сети.

Чтобы препятствовать тому, чтобы учебный процесс добавил слишком много дополнения, можно отсортировать обучающие данные по длине последовательности и выбрать мини-пакетный размер так, чтобы последовательности в мини-пакете имели подобную длину.

Получите длины последовательности для каждого наблюдения.

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

Сортировка данных длиной последовательности.

[~,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

Создайте и обучите сеть LSTM

Задайте архитектуру LSTM. Задайте от последовательности к последовательности сеть классификации LSTM с 400 скрытыми модулями. Установите входной размер быть размерностью признаков обучающих данных. Для последовательностей индексов символа размерность признаков равняется 1. Задайте слой встраивания слова с размерностью 200 и задайте количество слов (которые соответствуют символам) быть самым высоким символьным значением во входных данных. Установите выходной размер полносвязного слоя быть количеством категорий в ответах. Чтобы помочь предотвратить сверхподбор кривой, включайте слой уволенного после слоя LSTM.

Слой встраивания слова изучает встраивание символов и сопоставляет каждый символ с вектором с 200 размерностями.

inputSize = size(XTrain{1},1);
numClasses = numel(categories([YTrain{:}]));
numCharacters = max([textData{:}]);

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(200,numCharacters)
    lstmLayer(400,'OutputMode','sequence')
    dropoutLayer(0.2);
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Задайте опции обучения. Задайте, чтобы обучаться с мини-пакетным размером 32, и начальная буква изучают уровень 0.01. Чтобы препятствовать тому, чтобы градиенты взорвались, установите порог градиента к 1. Гарантировать данные остается отсортированным, набор 'Shuffle' к 'never'. Чтобы контролировать процесс обучения, установите 'Plots' опция к 'training-progress'. Чтобы подавить многословный выход, установите 'Verbose' к false.

options = trainingOptions('adam', ...
    'MiniBatchSize',32,...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучите сеть.

net = trainNetwork(XTrain,YTrain,layers,options);

Сгенерируйте новый текст

Сгенерируйте первый символ текста путем выборки символа от вероятностного распределения согласно первым символам текста в обучающих данных. Сгенерируйте оставшиеся символы при помощи обученной сети LSTM, чтобы предсказать следующую последовательность с помощью текущей последовательности сгенерированного текста. Продолжите генерировать символы один за другим, пока сеть не предскажет "конец текста" символ.

Произведите первый символ согласно распределению первых символов в обучающих данных.

initialCharacters = extractBefore(textData,2);
firstCharacter = datasample(initialCharacters,1);
generatedText = firstCharacter;

Преобразуйте первый символ в числовой индекс.

X = double(char(firstCharacter));

Для остающихся предсказаний произведите следующий символ согласно множеству предсказания сети. Баллы предсказания представляют вероятностное распределение следующего символа. Произведите символы из словаря символов, данных именами классов выходного слоя сети. Получите словарь от слоя классификации сети.

vocabulary = string(net.Layers(end).ClassNames);

Сделайте символ предсказаний символом с помощью predictAndUpdateState. Для каждого предсказания, вход индекс предыдущего символа. Прекратите предсказывать, когда сеть предсказывает конец текстового символа или когда сгенерированный текст является 500 символами долго. Для большого количества данных, длинных последовательностей или больших сетей, предсказания на графическом процессоре обычно быстрее, чтобы вычислить, чем предсказания на центральном процессоре. В противном случае предсказания на центральном процессоре обычно быстрее, чтобы вычислить. Для одного предсказаний временного шага используйте центральный процессор. Чтобы использовать центральный процессор для предсказания, установите 'ExecutionEnvironment' опция predictAndUpdateState к 'cpu'.

maxLength = 500;
while strlength(generatedText) < maxLength
    % Predict the next character scores.
    [net,characterScores] = predictAndUpdateState(net,X,'ExecutionEnvironment','cpu');
    
    % Sample the next character.
    newCharacter = datasample(vocabulary,1,'Weights',characterScores);
    
    % Stop predicting at the end of text.
    if newCharacter == endOfTextCharacter
        break
    end
    
    % Add the character to the generated text.
    generatedText = generatedText + newCharacter;
    
    % Get the numeric index of the character.
    X = double(char(newCharacter));
end

Восстановите сгенерированный текст, заменив специальные символы на их соответствующий пробел и символы новой строки.

generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])
generatedText = 
"“I wish Mr. Darcy, upon latter of my sort sincerely fixed in the regard to relanth. We were to join on the Lucases. They are married with him way Sir Wickham, for the possibility which this two od since to know him one to do now thing, and the opportunity terms as they, and when I read; nor Lizzy, who thoughts of the scent; for a look for times, I never went to the advantage of the case; had forcibling himself. They pility and lively believe she was to treat off in situation because, I am exceal"

Чтобы сгенерировать несколько частей текста, сбросьте сетевое состояние между поколениями с помощью resetState.

net = resetState(net);

Смотрите также

| | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте