Гордость и предубеждения и MATLAB

В этом примере показано, как обучить сеть LSTM глубокого обучения для генерации текста с помощью вложений символов.

Чтобы обучить нейронную сеть для глубокого обучения для генерации текста, обучите сеть LSTM с последовательностью в последовательности для предсказания следующего символа в последовательности символов. Чтобы обучить сеть предсказывать следующий символ, задайте отклики, которые будут входными последовательностями, сдвинутыми на один временной шаг.

Чтобы использовать вложения символов, преобразуйте каждое обучающее наблюдение в последовательность целых чисел, где целые числа индексируются в словарь символов. Включите слой встраивания слов в сеть, который изучает встраивание символов и сопоставляет целые числа с векторами.

Загрузка обучающих данных

Прочитайте HTML кода из The Project Gutenberg EBook of Pride and Prejudice, Джейн Остин и проанализируйте его с помощью webread и htmlTree.

url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);

Извлеките абзацы, найдя p элементы. Задайте, чтобы игнорировать элементы абзаца с "toc" классов использование селектора CSS ':not(.toc)'.

paragraphs = findElement(tree,'p:not(.toc)');

Извлеките текстовые данные из абзацев с помощью extractHTMLText. и удалите пустые строки.

textData = extractHTMLText(paragraphs);
textData(textData == "") = [];

Удалите строки короче 20 символов.

idx = strlength(textData) < 20;
textData(idx) = [];

Визуализируйте текстовые данные в облаке слов.

figure
wordcloud(textData);
title("Pride and Prejudice")

Преобразование текстовых данных в последовательности

Преобразуйте текстовые данные в последовательности индексов символа для предикторов и категориальные последовательности для ответов.

Категориальная функция обрабатывает записи newline и whitespace как неопределенные. Чтобы создать категориальные элементы для этих символов, замените их специальными символами ""(pilcrow, "\x00B6") и «·» (средняя точка, "\x00B7") соответственно. Чтобы предотвратить неоднозначность, необходимо выбрать специальные символы, которые не появляются в тексте. Эти символы не появляются в обучающих данных, поэтому их можно использовать для этой цели.

newlineCharacter = compose("\x00B6");
whitespaceCharacter = compose("\x00B7");
textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);

Закольцовывайте текстовые данные и создайте последовательность индексов символов, представляющих символы каждого наблюдения и категориальную последовательность символов для ответов. Чтобы обозначить конец каждого наблюдения, включите специальный символ «␃» (конец текста, "\x2403").

endOfTextCharacter = compose("\x2403");
numDocuments = numel(textData);
for i = 1:numDocuments
    characters = textData{i};
    X = double(characters);
    
    % Create vector of categorical responses with end of text character.
    charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter];
    Y = categorical(charactersShifted);
    
    XTrain{i} = X;
    YTrain{i} = Y;
end

Во время обучения по умолчанию программное обеспечение разделяет обучающие данные на мини-пакеты и заполняет последовательности так, чтобы они имели одинаковую длину. Слишком большое количество заполнения может негативно влияние на эффективность сети.

Чтобы предотвратить добавление слишком большого количества заполнений в процесс обучения, можно отсортировать обучающие данные по длине последовательности и выбрать размер мини-пакета, чтобы последовательности в мини-пакете имели одинаковую длину.

Получите длины последовательности для каждого наблюдения.

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

Отсортируйте данные по длине последовательности.

[~,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

Создание и обучение сети LSTM

Определите архитектуру LSTM. Задайте сеть классификации LSTM в последовательности с 400 скрытыми модулями. Установите размер входного сигнала в качестве размерности признаков обучающих данных. Для последовательностей индексов символов размерность признаков равна 1. Задайте слой встраивания слов с размерностью 200 и укажите количество слов (которые соответствуют символам), которое будет самым высоким значением символов во входных данных. Установите размер выхода полносвязного слоя равным количеству категорий в ответах. Чтобы предотвратить сверхподбор кривой, включите слой отсева после уровня LSTM.

Слой встраивания слов учит встраивание символов и сопоставляет каждый символ с 200-размерным вектором.

inputSize = size(XTrain{1},1);
numClasses = numel(categories([YTrain{:}]));
numCharacters = max([textData{:}]);

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(200,numCharacters)
    lstmLayer(400,'OutputMode','sequence')
    dropoutLayer(0.2);
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Задайте опции обучения. Укажите, чтобы обучаться с мини-пакетом размером 32 и начальной скоростью обучения 0.01. Чтобы предотвратить разнесение градиентов, установите порог градиента равным 1. Чтобы убедиться, что данные остаются отсортированными, установите 'Shuffle' на 'never'. Чтобы контролировать процесс обучения, установите 'Plots' опция для 'training-progress'. Чтобы подавить подробный выход, установите 'Verbose' на false.

options = trainingOptions('adam', ...
    'MiniBatchSize',32,...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучите сеть.

net = trainNetwork(XTrain,YTrain,layers,options);

Сгенерируйте новый текст

Сгенерируйте первый символ текста путем выборки символа из распределения вероятностей согласно первым символам текста в обучающих данных. Сгенерируйте оставшиеся символы при помощи обученной сети LSTM, чтобы предсказать следующую последовательность с помощью текущей последовательности сгенерированного текста. Продолжайте генерировать символы один за одним, пока сеть не предсказает символ «конец текста».

Выборка первого символа в соответствии с распределением первых символов в обучающих данных.

initialCharacters = extractBefore(textData,2);
firstCharacter = datasample(initialCharacters,1);
generatedText = firstCharacter;

Преобразуйте первый символ в числовой индекс.

X = double(char(firstCharacter));

Для остальных предсказаний выберете следующий символ в соответствии с счетами прогноза сети. Счета предсказания представляют распределение вероятностей следующего символа. Выборка символов из словаря символов, заданных именами классов выходного слоя сети. Получите словарь из классификационного слоя сети.

vocabulary = string(net.Layers(end).ClassNames);

Делайте предсказания по символам, используя predictAndUpdateState. Для каждого предсказания введите индекс предыдущего символа. Остановите предсказание, когда сеть предсказывает конец текстового символа или когда сгенерированный текст имеет длину 500 символов. Для больших наборов данных, длинных последовательностей или больших сетей предсказания на графическом процессоре обычно вычисляются быстрее, чем предсказания на центральном процессоре. В противном случае предсказания на центральном процессоре обычно вычисляются быстрее. Для одного временного шага предсказаний используйте центральный процессор. Чтобы использовать центральный процессор для предсказания, установите 'ExecutionEnvironment' опция predictAndUpdateState на 'cpu'.

maxLength = 500;
while strlength(generatedText) < maxLength
    % Predict the next character scores.
    [net,characterScores] = predictAndUpdateState(net,X,'ExecutionEnvironment','cpu');
    
    % Sample the next character.
    newCharacter = datasample(vocabulary,1,'Weights',characterScores);
    
    % Stop predicting at the end of text.
    if newCharacter == endOfTextCharacter
        break
    end
    
    % Add the character to the generated text.
    generatedText = generatedText + newCharacter;
    
    % Get the numeric index of the character.
    X = double(char(newCharacter));
end

Восстановите сгенерированный текст путем замены специальных символов соответствующим пробелом и новыми символами линии.

generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])
generatedText = 
"“I wish Mr. Darcy, upon latter of my sort sincerely fixed in the regard to relanth. We were to join on the Lucases. They are married with him way Sir Wickham, for the possibility which this two od since to know him one to do now thing, and the opportunity terms as they, and when I read; nor Lizzy, who thoughts of the scent; for a look for times, I never went to the advantage of the case; had forcibling himself. They pility and lively believe she was to treat off in situation because, I am exceal"

Чтобы сгенерировать несколько фрагментов текста, сбросьте состояние сети между поколениями, используя resetState.

net = resetState(net);

См. также

| | | | | | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Похожие темы