В этом примере показано, как классифицировать текстовые данные с помощью сети глубокого обучения с длительной краткосрочной памятью (LSTM).
Текстовые данные естественно последовательны. Часть текста является последовательностью слов, которая может иметь зависимости между ними. Чтобы узнать и использовать долгосрочные зависимости для классификации данных последовательности, используйте нейронную сеть LSTM. Сеть LSTM является типом рекуррентной нейронной сети (RNN), который может изучать долгосрочные зависимости между временными шагами данных последовательности.
Чтобы ввести текст в сеть LSTM, сначала преобразуйте текстовые данные в числовые последовательности. Добиться этого можно с помощью кодировки слов, которая преобразует документы в последовательности числовых индексов. Для лучших результатов также включите слой встраивания слов в сеть. Вложения в Word сопоставляют слова в словаре с числовыми векторами, а не скалярными индексами. Эти вложения захватывают семантические детали слов, так что слова с подобными значениями имеют сходные векторы. Они также моделируют отношения между словами через векторную арифметику. Например, отношение «Рим - Италия, а Париж - Франция» описывается уравнением Италия - Рим + Париж = Франция.
В этом примере существует четыре шага обучения и использования сети LSTM:
Импортируйте и предварительно обработайте данные.
Преобразуйте слова в числовые последовательности с помощью кодировки слов.
Создайте и обучите сеть LSTM с слоем встраивания слов.
Классификация новых текстовых данных с помощью обученной сети LSTM.
Импортируйте данные заводских отчетов. Эти данные содержат маркированные текстовые описания заводских событий. Чтобы импортировать текстовые данные как строки, задайте тип текста, который будет 'string'
.
filename = "factoryReports.csv"; data = readtable(filename,'TextType','string'); head(data)
ans=8×5 table
Description Category Urgency Resolution Cost
_____________________________________________________________________ ____________________ ________ ____________________ _____
"Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45
"Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35
"There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200
"Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352
"Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55
"Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371
"A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441
"Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
Цель этого примера состоит в том, чтобы классифицировать события по меткам в Category
столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальные.
data.Category = categorical(data.Category);
Просмотр распределения классов в данных с помощью гистограммы.
figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")
Следующий шаг - разбить его на наборы для обучения и валидации. Разделите данные на обучающий раздел и удерживаемый раздел для валидации и проверки. Задайте процент удержания 20%.
cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);
Извлеките текстовые данные и метки из секционированных таблиц.
textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; YTrain = dataTrain.Category; YValidation = dataValidation.Category;
Чтобы проверить, правильно ли вы импортировали данные, визуализируйте обучающие текстовые данные с помощью облака слов.
figure
wordcloud(textDataTrain);
title("Training Data")
Создайте функцию, которая токенизирует и предварительно обрабатывает текстовые данные. Функция preprocessText
, перечисленный в конце примера, выполняет следующие шаги:
Токенизация текста с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Удалите пунктуацию с помощью erasePunctuation
.
Предварительно обработайте обучающие данные и данные валидации с помощью preprocessText
функция.
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
Просмотрите первые несколько предварительно обработанных обучающих документов.
documentsTrain(1:5)
ans = 5×1 tokenizedDocument: 9 tokens: items are occasionally getting stuck in the scanner spools 10 tokens: loud rattling and banging sounds are coming from assembler pistons 10 tokens: there are cuts to the power when starting the plant 5 tokens: fried capacitors in the assembler 4 tokens: mixer tripped the fuses
Для ввода документов в сеть LSTM используйте кодировку слов, чтобы преобразовать документы в последовательности числовых индексов.
Чтобы создать кодировку слов, используйте wordEncoding
функция.
enc = wordEncoding(documentsTrain);
Следующий шаг преобразования состоит в том, чтобы заполнить и обрезать документы, так что они имеют одинаковую длину. The trainingOptions
функция предоставляет опции для автоматического заполнения и усечения входных последовательностей. Однако эти опции плохо подходят для последовательностей векторов слов. Вместо этого дополните и обрезайте последовательности вручную. Если вы оставляете и обрезаете последовательности векторов слов, то обучение может улучшиться.
Чтобы дополнить и обрезать документы, сначала выберите целевую длину, а затем обрезайте документы, которые длиннее его, и документы на левой панели, которые короче его. Для наилучших результатов целевая длина должна быть короткой, не отбрасывая больших объемов данных. Чтобы найти подходящую целевую длину, просмотрите гистограмму длин обучающего документа.
documentLengths = doclength(documentsTrain); figure histogram(documentLengths) title("Document Lengths") xlabel("Length") ylabel("Number of Documents")
Большинство обучающих документов имеют менее 10 лексемы. Используйте это в качестве целевой длины для усечения и заполнения.
Преобразуйте документы в последовательности числовых индексов с помощью doc2sequence
. Чтобы обрезать или дополнить слева последовательности длиной 10, установите 'Length'
опция до 10.
sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
XTrain(1:5)
ans=5×1 cell array
{1×10 double}
{1×10 double}
{1×10 double}
{1×10 double}
{1×10 double}
Преобразуйте документы валидации в последовательности с помощью тех же опций.
XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);
Определите сетевую архитектуру LSTM. Чтобы ввести данные последовательности в сеть, включите входной слой последовательности и установите размер входа равным 1. Затем включают слой встраивания слов размерности 50 и такое же количество слов, как и в кодировании слов. Затем включите слой LSTM и установите количество скрытых модулей равным 80. Чтобы использовать слой LSTM для задачи классификации между последовательностями и метками, установите режим выхода равным 'last'
. Наконец, добавьте полносвязный слой с таким же размером, как и количество классов, слой softmax и слой классификации.
inputSize = 1; embeddingDimension = 50; numHiddenUnits = 80; numWords = enc.NumWords; numClasses = numel(categories(YTrain)); layers = [ ... sequenceInputLayer(inputSize) wordEmbeddingLayer(embeddingDimension,numWords) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]
layers = 6x1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' Word Embedding Layer Word embedding layer with 50 dimensions and 423 unique words 3 '' LSTM LSTM with 80 hidden units 4 '' Fully Connected 4 fully connected layer 5 '' Softmax softmax 6 '' Classification Output crossentropyex
Задайте опции обучения:
Обучите с помощью решателя Адама.
Задайте мини-пакет размером 16.
Перетасовывайте данные каждую эпоху.
Отслеживайте процесс обучения путем установки 'Plots'
опция для 'training-progress'
.
Укажите данные валидации с помощью 'ValidationData'
опция.
Подавить подробный выход путем установки 'Verbose'
опция для false
.
По умолчанию trainNetwork
использует графический процессор, если он доступен. В противном случае используется центральный процессор. Чтобы задать окружение выполнения вручную, используйте 'ExecutionEnvironment'
Аргумент пары "имя-значение" из trainingOptions
. Обучение на центральном процессоре может занять значительно больше времени, чем обучение на графическом процессоре. Для обучения с графическим процессором требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
options = trainingOptions('adam', ... 'MiniBatchSize',16, ... 'GradientThreshold',2, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть LSTM с помощью trainNetwork
функция.
net = trainNetwork(XTrain,YTrain,layers,options);
Классифицируйте тип события трех новых отчетов. Создайте строковые массивы, содержащий новые отчеты.
reportsNew = [ ... "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
Предварительно обработайте текстовые данные, используя шаги предварительной обработки в качестве обучающих документов.
documentsNew = preprocessText(reportsNew);
Преобразуйте текстовые данные в последовательности с помощью doc2sequence
с теми же опциями, что и при создании обучающих последовательностей.
XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);
Классификация новых последовательностей с помощью обученной сети LSTM.
labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
Функция preprocessText
выполняет следующие шаги:
Токенизация текста с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Удалите пунктуацию с помощью erasePunctuation
.
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
lstmLayer
| sequenceInputLayer
| trainingOptions
| trainNetwork
| doc2sequence
(Symbolic Math Toolbox) | fastTextWordEmbedding
(Symbolic Math Toolbox) | tokenizedDocument
(Symbolic Math Toolbox) | wordcloud
(Symbolic Math Toolbox) | wordEmbeddingLayer
(Symbolic Math Toolbox)