exponenta event banner

Гордость и предубеждение и MATLAB

В этом примере показано, как обучить сеть LSTM глубокому обучению создавать текст с использованием вставок символов.

Чтобы обучить сеть глубокого обучения для генерации текста, обучайте сеть LSTM последовательности для прогнозирования следующего символа в последовательности символов. Чтобы обучить сеть предсказывать следующий символ, укажите ответы, которые должны быть входными последовательностями, сдвинутыми на один шаг времени.

Чтобы использовать вложения символов, преобразуйте каждое обучающее наблюдение в последовательность целых чисел, где целые числа индексируются в словарь символов. Включите в сеть слой встраивания слов, который изучает встраивание символов и отображает целые числа в векторы.

Загрузка данных обучения

Прочитайте HTML код из проекта Gutenberg EBook гордости и предубеждения, Джейн Остин и разбор его с помощью webread и htmlTree.

url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);

Извлеките абзацы, найдя p элементы. Укажите, чтобы игнорировать элементы абзаца с классом "toc" с помощью селектора CSS ':not(.toc)'.

paragraphs = findElement(tree,'p:not(.toc)');

Извлечение текстовых данных из абзацев с помощью extractHTMLText. и удалите пустые строки.

textData = extractHTMLText(paragraphs);
textData(textData == "") = [];

Удаление строк длиной менее 20 символов.

idx = strlength(textData) < 20;
textData(idx) = [];

Визуализация текстовых данных в облаке слов.

figure
wordcloud(textData);
title("Pride and Prejudice")

Преобразование текстовых данных в последовательности

Преобразование текстовых данных в последовательности символьных индексов для предикторов и категориальных последовательностей для ответов.

Категориальная функция обрабатывает записи новой строки и пробела как неопределенные. Чтобы создать категориальные элементы для этих символов, замените их специальными символами ""(пилкроу, "\x00B6") и «·» (средняя точка, "\x00B7") соответственно. Для предотвращения неоднозначности необходимо выбрать специальные символы, которые не отображаются в тексте. Эти символы не отображаются в данных обучения, поэтому их можно использовать для этой цели.

newlineCharacter = compose("\x00B6");
whitespaceCharacter = compose("\x00B7");
textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);

Закольцевать текстовые данные и создать последовательность индексов символов, представляющих символы каждого наблюдения и категориальную последовательность символов для ответов. Для обозначения конца каждого наблюдения включить специальный символ «␃» (конец текста, "\x2403").

endOfTextCharacter = compose("\x2403");
numDocuments = numel(textData);
for i = 1:numDocuments
    characters = textData{i};
    X = double(characters);
    
    % Create vector of categorical responses with end of text character.
    charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter];
    Y = categorical(charactersShifted);
    
    XTrain{i} = X;
    YTrain{i} = Y;
end

Во время обучения программное обеспечение по умолчанию разбивает обучающие данные на мини-пакеты и размещает последовательности таким образом, чтобы они имели одинаковую длину. Слишком большое количество дополнений может оказать негативное влияние на производительность сети.

Чтобы предотвратить добавление слишком большого количества заполнений в процесс обучения, можно отсортировать данные обучения по длине последовательности и выбрать размер мини-партии таким образом, чтобы последовательности в мини-партии имели аналогичную длину.

Получение длин последовательностей для каждого наблюдения.

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

Сортировка данных по длине последовательности.

[~,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

Создание и обучение сети LSTM

Определите архитектуру LSTM. Укажите классификационную сеть последовательности к последовательности LSTM с 400 скрытыми единицами. Задайте размер ввода в качестве размера элемента данных обучения. Для последовательностей индексов символов размер элемента равен 1. Укажите слой встраивания слов с размерностью 200 и укажите количество слов (которые соответствуют символам), которое должно быть самым высоким значением символа во входных данных. Задайте размер вывода полностью подключенного слоя как количество категорий в ответах. Чтобы предотвратить переоборудование, включите уровень отсева после уровня LSTM.

Слой встраивания слов распознает встраивание символов и отображает каждый символ в 200-мерный вектор.

inputSize = size(XTrain{1},1);
numClasses = numel(categories([YTrain{:}]));
numCharacters = max([textData{:}]);

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(200,numCharacters)
    lstmLayer(400,'OutputMode','sequence')
    dropoutLayer(0.2);
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Укажите параметры обучения. Укажите обучение с размером мини-партии 32 и начальной скоростью обучения 0,01. Для предотвращения разузлования градиентов установите порог градиента равным 1. Для обеспечения сортировки данных установите 'Shuffle' кому 'never'. Для контроля за ходом обучения установите 'Plots' опция для 'training-progress'. Для подавления подробных выходных данных установите 'Verbose' кому false.

options = trainingOptions('adam', ...
    'MiniBatchSize',32,...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучение сети.

net = trainNetwork(XTrain,YTrain,layers,options);

Создать новый текст

Создайте первый символ текста путем выборки символа из распределения вероятностей в соответствии с первыми символами текста в данных обучения. Создайте оставшиеся символы с помощью обученной сети LSTM для прогнозирования следующей последовательности с использованием текущей последовательности сгенерированного текста. Продолжайте генерировать символы один за другим, пока сеть не прогнозирует символ «конец текста».

Образец первого символа в соответствии с распределением первых символов в данных обучения.

initialCharacters = extractBefore(textData,2);
firstCharacter = datasample(initialCharacters,1);
generatedText = firstCharacter;

Преобразование первого символа в числовой индекс.

X = double(char(firstCharacter));

Для остальных предсказаний выполните выборку следующего символа в соответствии с оценками предсказания сети. Оценки предсказания представляют распределение вероятности следующего символа. Выборка символов из словаря символов, заданных именами классов выходного уровня сети. Получите словарный запас из классификационного уровня сети.

vocabulary = string(net.Layers(end).ClassNames);

Выполнение предсказаний по символам с помощью predictAndUpdateState. Для каждого предсказания введите индекс предыдущего символа. Прекратите прогнозирование, когда сеть прогнозирует конец текстового символа или когда созданный текст имеет длину 500 символов. Для больших коллекций данных, длинных последовательностей или больших сетей прогнозы на GPU обычно вычисляются быстрее, чем прогнозы на CPU. В противном случае предсказания на CPU обычно вычисляются быстрее. Для прогнозирования одного шага времени используйте CPU. Чтобы использовать CPU для прогнозирования, установите 'ExecutionEnvironment' вариант predictAndUpdateState кому 'cpu'.

maxLength = 500;
while strlength(generatedText) < maxLength
    % Predict the next character scores.
    [net,characterScores] = predictAndUpdateState(net,X,'ExecutionEnvironment','cpu');
    
    % Sample the next character.
    newCharacter = datasample(vocabulary,1,'Weights',characterScores);
    
    % Stop predicting at the end of text.
    if newCharacter == endOfTextCharacter
        break
    end
    
    % Add the character to the generated text.
    generatedText = generatedText + newCharacter;
    
    % Get the numeric index of the character.
    X = double(char(newCharacter));
end

Реконструируйте созданный текст путем замены специальных символов соответствующими пробелами и новыми символами строки.

generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])
generatedText = 
"“I wish Mr. Darcy, upon latter of my sort sincerely fixed in the regard to relanth. We were to join on the Lucases. They are married with him way Sir Wickham, for the possibility which this two od since to know him one to do now thing, and the opportunity terms as they, and when I read; nor Lizzy, who thoughts of the scent; for a look for times, I never went to the advantage of the case; had forcibling himself. They pility and lively believe she was to treat off in situation because, I am exceal"

Чтобы создать несколько фрагментов текста, сбросьте состояние сети между поколениями с помощью resetState.

net = resetState(net);

См. также

| | | | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста)

Связанные темы