В этом примере показано, как классифицировать текстовые данные с помощью сети долгой краткосрочной памяти (LSTM) глубокого обучения.
Текстовые данные естественно последовательны. Часть текста является последовательностью слов, которые могут иметь зависимости между ними. Чтобы изучить и использовать долгосрочные зависимости, чтобы классифицировать данные о последовательности, используйте нейронную сеть LSTM. Сеть LSTM является типом рекуррентной нейронной сети (RNN), которая может изучить долгосрочные зависимости между временными шагами данных о последовательности.
Чтобы ввести текст к сети LSTM, сначала преобразуйте текстовые данные в числовые последовательности. Можно достигнуть этого использования кодирования слова, которое сопоставляет документы последовательностям числовых индексов. Для лучших результатов также включайте слой встраивания слова в сеть. Вложения Word сопоставляют слова в словаре к числовым векторам, а не скалярным индексам. Эти вложения получают семантические детали слов, так, чтобы слова с подобными значениями имели подобные векторы. Они также отношения модели между словами через векторную арифметику. Например, отношение "Рим в Италию, как Париж во Францию", описан уравнением Италия – Рим + Париж = Франция.
Существует четыре шага в обучении и использовании сети LSTM в этом примере:
Импортируйте и предварительно обработайте данные.
Преобразуйте слова в числовые последовательности с помощью кодирования слова.
Создайте и обучите сеть LSTM со слоем встраивания слова.
Классифицируйте новые текстовые данные с помощью обученной сети LSTM.
Импортируйте данные об отчетах фабрики. Эти данные содержат помеченные текстовые описания событий фабрики. Чтобы импортировать текстовые данные как строки, задайте тип текста, чтобы быть 'string'
.
filename = "factoryReports.csv"; data = readtable(filename,'TextType','string'); head(data)
ans=8×5 table
Description Category Urgency Resolution Cost
_____________________________________________________________________ ____________________ ________ ____________________ _____
"Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45
"Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35
"There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200
"Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352
"Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55
"Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371
"A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441
"Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
Цель этого примера состоит в том, чтобы классифицировать события меткой в Category
столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальный.
data.Category = categorical(data.Category);
Просмотрите распределение классов в данных с помощью гистограммы.
figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")
Следующий шаг должен разделить его в наборы для обучения и валидации. Разделите данные в учебный раздел и протянутый раздел для валидации и тестирования. Задайте процент затяжки, чтобы быть 20%.
cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);
Извлеките текстовые данные и метки из разделенных таблиц.
textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; YTrain = dataTrain.Category; YValidation = dataValidation.Category;
Чтобы проверять, что вы импортировали данные правильно, визуализируйте учебные текстовые данные с помощью облака слова.
figure
wordcloud(textDataTrain);
title("Training Data")
Создайте функцию, которая маркирует и предварительно обрабатывает текстовые данные. Функциональный preprocessText
, перечисленный в конце примера, выполняет эти шаги:
Маркируйте текст с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Сотрите пунктуацию с помощью erasePunctuation
.
Предварительно обработайте обучающие данные и данные о валидации с помощью preprocessText
функция.
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
Просмотрите первые несколько предварительно обработанных учебных материалов.
documentsTrain(1:5)
ans = 5×1 tokenizedDocument: 9 tokens: items are occasionally getting stuck in the scanner spools 10 tokens: loud rattling and banging sounds are coming from assembler pistons 10 tokens: there are cuts to the power when starting the plant 5 tokens: fried capacitors in the assembler 4 tokens: mixer tripped the fuses
Чтобы ввести документы в сеть LSTM, используйте кодирование слова, чтобы преобразовать документы в последовательности числовых индексов.
Чтобы создать кодирование слова, используйте wordEncoding
функция.
enc = wordEncoding(documentsTrain);
Следующий шаг преобразования должен заполнить и обрезать документы, таким образом, они являются всеми одинаковыми длина. trainingOptions
функция предоставляет возможности заполнять и обрезать входные последовательности автоматически. Однако эти опции не хорошо подходят для последовательностей векторов слова. Вместо этого клавиатура и усеченный последовательности вручную. Если вы лево-заполняете и обрезаете последовательности векторов слова, то учебная сила улучшается.
Чтобы заполнить и обрезать документы, сначала выберите целевую длину, и затем обрежьте документы, которые более длинны, чем она и лево-заполняют документы, которые короче, чем она. Для лучших результатов целевая длина должна быть короткой, не отбрасывая большие объемы данных. Чтобы найти подходящую целевую длину, просмотрите гистограмму длин учебного материала.
documentLengths = doclength(documentsTrain); figure histogram(documentLengths) title("Document Lengths") xlabel("Length") ylabel("Number of Documents")
Большинство учебных материалов имеет меньше чем 10 лексем. Используйте это в качестве своей целевой длины для усечения и дополнения.
Преобразуйте документы последовательностям числовых индексов с помощью doc2sequence
. Чтобы обрезать или лево-заполнить последовательности, чтобы иметь длину 10, установите 'Length'
опция к 10.
sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
XTrain(1:5)
ans=5×1 cell array
{1×10 double}
{1×10 double}
{1×10 double}
{1×10 double}
{1×10 double}
Преобразуйте документы валидации последовательностям с помощью тех же опций.
XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);
Задайте архитектуру сети LSTM. Чтобы ввести данные о последовательности в сеть, включайте входной слой последовательности и установите входной размер на 1. Затем включайте слой встраивания слова размерности 50 и то же количество слов как кодирование слова. Затем включайте слой LSTM и определите номер скрытых модулей к 80. Чтобы использовать слой LSTM для проблемы классификации последовательностей к метке, установите режим вывода на 'last'
. Наконец, добавьте полносвязный слой с тем же размером как количество классов, softmax слоя и слоя классификации.
inputSize = 1; embeddingDimension = 50; numHiddenUnits = 80; numWords = enc.NumWords; numClasses = numel(categories(YTrain)); layers = [ ... sequenceInputLayer(inputSize) wordEmbeddingLayer(embeddingDimension,numWords) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]
layers = 6x1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' Word Embedding Layer Word embedding layer with 50 dimensions and 423 unique words 3 '' LSTM LSTM with 80 hidden units 4 '' Fully Connected 4 fully connected layer 5 '' Softmax softmax 6 '' Classification Output crossentropyex
Задайте опции обучения:
Обучите использование решателя Адама.
Задайте мини-пакетный размер 16.
Переставьте данные каждая эпоха.
Контролируйте процесс обучения путем установки 'Plots'
опция к 'training-progress'
.
Задайте данные о валидации с помощью 'ValidationData'
опция.
Подавите многословный выход путем установки 'Verbose'
опция к false
.
По умолчанию, trainNetwork
использует графический процессор, если вы доступны (требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше). В противном случае это использует центральный процессор. Чтобы задать среду выполнения вручную, используйте 'ExecutionEnvironment'
аргумент пары "имя-значение" trainingOptions
. Обучение на центральном процессоре может взять значительно дольше, чем обучение на графическом процессоре.
options = trainingOptions('adam', ... 'MiniBatchSize',16, ... 'GradientThreshold',2, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть LSTM с помощью trainNetwork
функция.
net = trainNetwork(XTrain,YTrain,layers,options);
Классифицируйте тип события трех новых отчетов. Создайте массив строк, содержащий новые отчеты.
reportsNew = [ ... "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
Предварительно обработайте текстовые данные с помощью шагов предварительной обработки в качестве учебных материалов.
documentsNew = preprocessText(reportsNew);
Преобразуйте текстовые данные в последовательности с помощью doc2sequence
с теми же опциями, создавая обучающие последовательности.
XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);
Классифицируйте новые последовательности с помощью обученной сети LSTM.
labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
Функциональный preprocessText
выполняет эти шаги:
Маркируйте текст с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Сотрите пунктуацию с помощью erasePunctuation
.
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
lstmLayer
| sequenceInputLayer
| trainingOptions
| trainNetwork
| doc2sequence
(Text Analytics Toolbox) | fastTextWordEmbedding
(Text Analytics Toolbox) | tokenizedDocument
(Text Analytics Toolbox) | wordcloud
(Text Analytics Toolbox) | wordEmbeddingLayer
(Text Analytics Toolbox)