Гордитесь и нанесите ущерб и MATLAB

Этот пример показывает, как обучить сеть LSTM глубокого обучения, чтобы сгенерировать текст с помощью символьных вложений.

Чтобы обучить нейронную сеть для глубокого обучения текстовой генерации, обучите сеть LSTM от последовательности к последовательности, чтобы предсказать следующий символ в последовательности символов. Чтобы обучить сеть, чтобы предсказать следующий символ, задайте ответы, чтобы быть входными последовательностями, переключенными одним временным шагом.

Чтобы использовать символьные вложения, преобразуйте каждое учебное наблюдение в последовательность целых чисел, где целые числа индексируют в словарь символов. Включайте слой встраивания слова в сеть, которая изучает встраивание символов и сопоставляет целые числа с векторами.

Загрузите данные тренировки

Считайте код HTML из Проекта Электронная книга Гутенберга Гордости и Предубеждения, Джейн Остин и проанализируйте его с помощью webread и htmlTree.

url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);

Извлеките абзацы путем нахождения элементов p. Задайте, чтобы проигнорировать элементы абзаца с классом "toc" с помощью селектора CSS ':not(.toc)'.

paragraphs = findElement(tree,'p:not(.toc)');

Извлеките текстовые данные из абзацев с помощью extractHTMLText. и удалите пустые строки.

textData = extractHTMLText(paragraphs);
textData(textData == "") = [];

Удалите строки короче, чем 20 символов.

idx = strlength(textData) < 20;
textData(idx) = [];

Визуализируйте текстовые данные, одним словом, облако.

figure
wordcloud(textData);
title("Pride and Prejudice")

Преобразуйте текстовые данные в последовательности

Преобразуйте текстовые данные в последовательности индексов символа для предикторов и категориальные последовательности для ответов.

Категориальная функция обрабатывает новую строку и пробельные записи как неопределенные. Чтобы создать категориальные элементы для этих символов, замените их на специальные символы "" (pilcrow, "\x00B6") и "·" (средняя точка, "\x00B7") соответственно. Чтобы предотвратить неоднозначность, необходимо выбрать специальные символы, которые не появляются в тексте. Эти символы не появляются в данных тренировки, так может использоваться с этой целью.

newlineCharacter = compose("\x00B6");
whitespaceCharacter = compose("\x00B7");
textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);

Цикл по текстовым данным и создает последовательность индексов символа, представляющих символы каждого наблюдения и категориальной последовательности символов для ответов. Чтобы обозначить конец каждого наблюдения, включайте специальный символ "␃" (конец текста, "\x2403").

endOfTextCharacter = compose("\x2403");
numDocuments = numel(textData);
for i = 1:numDocuments
    characters = textData{i};
    X = double(characters);
    
    % Create vector of categorical responses with end of text character.
    charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter];
    Y = categorical(charactersShifted);
    
    XTrain{i} = X;
    YTrain{i} = Y;
end

Во время обучения, по умолчанию, программное обеспечение разделяет данные тренировки в мини-пакеты и заполняет последовательности так, чтобы у них была та же длина. Слишком много дополнения может оказать негативное влияние на производительность сети.

Чтобы препятствовать тому, чтобы учебный процесс добавил слишком много дополнения, можно отсортировать данные тренировки по длине последовательности и выбрать мини-пакетный размер так, чтобы последовательности в мини-пакете имели подобную длину.

Получите длины последовательности для каждого наблюдения.

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

Сортировка данных длиной последовательности.

[~,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

Создайте и обучите сеть LSTM

Задайте архитектуру LSTM. Задайте от последовательности к последовательности сеть классификации LSTM с 400 скрытыми модулями. Установите входной размер быть размерностью признаков данных тренировки. Для последовательностей индексов символа размерность признаков равняется 1. Задайте слой встраивания слова с размерностью 200 и задайте количество слов (которые соответствуют символам) быть самым высоким символьным значением во входных данных. Установите выходной размер полносвязного слоя быть количеством категорий в ответах. Чтобы помочь предотвратить сверхподбор кривой, включайте слой уволенного после слоя LSTM.

Слой встраивания слова изучает встраивание символов и сопоставляет каждый символ с вектором с 200 размерностями.

inputSize = size(XTrain{1},1);
numClasses = numel(categories([YTrain{:}]));
numCharacters = max([textData{:}]);

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(200,numCharacters)
    lstmLayer(400,'OutputMode','sequence')
    dropoutLayer(0.2);
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Задайте опции обучения. Задайте, чтобы обучаться с мини-пакетным размером 32, и начальная буква изучают уровень 0.01. Чтобы препятствовать тому, чтобы градиенты взорвались, установите порог градиента к 1. Гарантировать данные остается отсортированным, набор 'Shuffle' к 'never'. Чтобы контролировать учебный прогресс, установите опцию 'Plots' на 'training-progress'. Чтобы подавить многословный вывод, установите 'Verbose' на false.

options = trainingOptions('adam', ...
    'MiniBatchSize',32,...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучите сеть.

net = trainNetwork(XTrain,YTrain,layers,options);

Сгенерируйте новый текст

Сгенерируйте первый символ текста путем выборки символа от распределения вероятностей согласно первым символам текста в данных тренировки. Сгенерируйте оставшиеся символы при помощи обученной сети LSTM, чтобы предсказать следующую последовательность с помощью текущей последовательности сгенерированного текста. Продолжите генерировать символы один за другим, пока сеть не предскажет "конец текста" символ.

Выберите первый символ согласно распределению первых символов в данных тренировки.

initialCharacters = extractBefore(textData,2);
firstCharacter = datasample(initialCharacters,1);
generatedText = firstCharacter;

Преобразуйте первый символ в числовой индекс.

X = double(char(firstCharacter));

Для остающихся прогнозов выберите следующий символ согласно множеству прогноза сети. Очки прогноза представляют распределение вероятностей следующего символа. Выберите символы из словаря символов, данных именами классов выходного слоя сети. Получите словарь от слоя классификации сети.

vocabulary = string(net.Layers(end).ClassNames);

Сделайте символ прогнозов символом с помощью predictAndUpdateState. Для каждого прогноза, вход индекс предыдущего символа. Прекратите предсказывать, когда сеть предсказывает конец текстового символа или когда сгенерированный текст является 500 символами долго. Для большого количества данных, длинных последовательностей или больших сетей, прогнозы на графическом процессоре обычно быстрее, чтобы вычислить, чем прогнозы на центральном процессоре. В противном случае прогнозы на центральном процессоре обычно быстрее, чтобы вычислить. Для одного прогнозов временного шага используйте центральный процессор. Чтобы использовать центральный процессор для прогноза, установите опцию 'ExecutionEnvironment' predictAndUpdateState к 'cpu'.

maxLength = 500;
while strlength(generatedText) < maxLength
    % Predict the next character scores.
    [net,characterScores] = predictAndUpdateState(net,X,'ExecutionEnvironment','cpu');
    
    % Sample the next character.
    newCharacter = datasample(vocabulary,1,'Weights',characterScores);
    
    % Stop predicting at the end of text.
    if newCharacter == endOfTextCharacter
        break
    end
    
    % Add the character to the generated text.
    generatedText = generatedText + newCharacter;
    
    % Get the numeric index of the character.
    X = double(char(newCharacter));
end

Восстановите сгенерированный текст, заменив специальные символы на их соответствующий пробел и символы новой строки.

generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])
generatedText = 
"“I wish Mr. Darcy, upon latter of my sort sincerely fixed in the regard to relanth. We were to join on the Lucases. They are married with him way Sir Wickham, for the possibility which this two od since to know him one to do now thing, and the opportunity terms as they, and when I read; nor Lizzy, who thoughts of the scent; for a look for times, I never went to the advantage of the case; had forcibling himself. They pility and lively believe she was to treat off in situation because, I am exceal"

Чтобы сгенерировать несколько частей текста, сбросьте сетевое состояние между поколениями с помощью resetState.

net = resetState(net);

Смотрите также

| | | | | | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте