exponenta event banner

Классификация текстовых данных с помощью глубокого обучения

В этом примере показано, как классифицировать текстовые данные с помощью сети LSTM.

Текстовые данные, естественно, являются последовательными. Фрагмент текста - это последовательность слов, которая может иметь зависимости между ними. Чтобы узнать и использовать долгосрочные зависимости для классификации данных последовательности, используйте нейронную сеть LSTM. Сеть LSTM - это тип рекуррентной нейронной сети (RNN), которая может изучать долгосрочные зависимости между временными шагами данных последовательности.

Для ввода текста в сеть LSTM сначала преобразуйте текстовые данные в числовые последовательности. Для этого можно использовать кодировку слов, которая сопоставляет документы с последовательностями числовых индексов. Для получения лучших результатов также включите в сеть слой внедрения слов. Слово встраивает слова в словарь в числовые векторы, а не в скалярные индексы. Эти вложения захватывают семантические детали слов, так что слова со сходными значениями имеют схожие векторы. Они также моделируют отношения между словами через векторную арифметику. Например, отношение «Рим - к Италии, как Париж - к Франции» описывается уравнением Италия - Рим + Париж = Франция.

В этом примере существует четыре этапа обучения и использования сети LSTM:

  • Импорт и предварительная обработка данных.

  • Преобразование слов в числовые последовательности с использованием кодирования слов.

  • Создание и обучение сети LSTM с уровнем внедрения слов.

  • Классифицируйте новые текстовые данные с использованием обученной сети LSTM.

Импорт данных

Импортируйте данные производственных отчетов. Эти данные содержат помеченные текстовые описания заводских событий. Чтобы импортировать текстовые данные в виде строк, укажите тип текста 'string'.

filename = "factoryReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

Целью этого примера является классификация событий по метке в Category столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальные.

data.Category = categorical(data.Category);

Просмотрите распределение классов в данных с помощью гистограммы.

figure
histogram(data.Category);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

Следующим шагом является разделение его на наборы для обучения и проверки. Разбиение данных на раздел обучения и раздел удержания для проверки и тестирования. Укажите процент удержания равным 20%.

cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

Извлеките текстовые данные и метки из секционированных таблиц.

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
YTrain = dataTrain.Category;
YValidation = dataValidation.Category;

Чтобы убедиться, что данные импортированы правильно, визуализируйте текстовые данные обучения с помощью облака слов.

figure
wordcloud(textDataTrain);
title("Training Data")

Предварительная обработка текстовых данных

Создайте функцию, которая маркирует и предварительно обрабатывает текстовые данные. Функция preprocessText, перечисленное в конце примера, выполняет следующие шаги:

  1. Маркировка текста с помощью tokenizedDocument.

  2. Преобразование текста в нижний регистр с помощью lower.

  3. Стереть пунктуацию с помощью erasePunctuation.

Предварительная обработка данных обучения и данных проверки с помощью preprocessText функция.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

Просмотрите первые несколько предварительно обработанных учебных документов.

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: loud rattling and banging sounds are coming from assembler pistons
    10 tokens: there are cuts to the power when starting the plant
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses

Преобразовать документ в последовательности

Для ввода документов в сеть LSTM используйте кодировку слов для преобразования документов в последовательности числовых индексов.

Чтобы создать кодировку слова, используйте wordEncoding функция.

enc = wordEncoding(documentsTrain);

Следующий этап преобразования состоит в заполнении и усечении документов таким образом, что они имеют одинаковую длину. trainingOptions функция предоставляет опции для автоматической вставки и усечения входных последовательностей. Однако эти варианты плохо подходят для последовательностей векторов слов. Вместо этого установите и сократите последовательности вручную. Если вы оставляете и усекаете последовательности векторов слов, то тренировка может улучшиться.

Чтобы добавить и обрезать документы, сначала выберите целевую длину, а затем усечение документов, которые длиннее, чем она, и документов, которые короче, чем она. Для достижения наилучших результатов целевая длина должна быть короткой без отбрасывания больших объемов данных. Чтобы найти подходящую целевую длину, просмотрите гистограмму длины учебного документа.

documentLengths = doclength(documentsTrain);
figure
histogram(documentLengths)
title("Document Lengths")
xlabel("Length")
ylabel("Number of Documents")

Большинство учебных документов имеют менее 10 маркеров. Используйте его в качестве целевой длины для усечения и заполнения.

Преобразование документов в последовательности числовых индексов с помощью doc2sequence. Чтобы усечь или оставить последовательности длиной 10, установите 'Length' опция до 10.

sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
XTrain(1:5)
ans=5×1 cell array
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}

Преобразование документов проверки в последовательности с использованием тех же параметров.

XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);

Создание и обучение сети LSTM

Определите архитектуру сети LSTM. Чтобы ввести данные последовательности в сеть, включите входной уровень последовательности и установите размер ввода равным 1. Затем включают в себя слой внедрения слова размерности 50 и то же количество слов, что и кодирование слова. Затем включите уровень LSTM и установите количество скрытых блоков равным 80. Чтобы использовать уровень LSTM для проблемы классификации «последовательность-метка», установите режим вывода в значение 'last'. Наконец, добавьте полностью подключенный слой того же размера, что и количество классов, слой softmax и классификационный слой.

inputSize = 1;
embeddingDimension = 50;
numHiddenUnits = 80;

numWords = enc.NumWords;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numWords)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 50 dimensions and 423 unique words
     3   ''   LSTM                    LSTM with 80 hidden units
     4   ''   Fully Connected         4 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

Укажите параметры обучения

Укажите параметры обучения:

  • Выполните обучение с помощью решателя Adam.

  • Укажите размер мини-пакета 16.

  • Тасуйте данные каждую эпоху.

  • Контролировать ход обучения, установив 'Plots' опция для 'training-progress'.

  • Укажите данные проверки с помощью 'ValidationData' вариант.

  • Подавление подробных выходных данных путем установки параметра 'Verbose' опция для false.

По умолчанию trainNetwork использует графический процессор, если он доступен. В противном случае используется ЦП. Чтобы указать среду выполнения вручную, используйте 'ExecutionEnvironment' аргумент пары имя-значение trainingOptions. Обучение на CPU может занять значительно больше времени, чем обучение на GPU. Для обучения с помощью графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

options = trainingOptions('adam', ...
    'MiniBatchSize',16, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучение сети LSTM с помощью trainNetwork функция.

net = trainNetwork(XTrain,YTrain,layers,options);

Прогнозирование с использованием новых данных

Классифицируйте тип события для трех новых отчетов. Создайте строковый массив, содержащий новые отчеты.

reportsNew = [ ...
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

Предварительная обработка текстовых данных с использованием шагов предварительной обработки в качестве учебных документов.

documentsNew = preprocessText(reportsNew);

Преобразование текстовых данных в последовательности с помощью doc2sequence с теми же опциями, что и при создании обучающих последовательностей.

XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);

Классифицируйте новые последовательности с помощью обученной сети LSTM.

labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

Функция предварительной обработки

Функция preprocessText выполняет следующие шаги:

  1. Маркировка текста с помощью tokenizedDocument.

  2. Преобразование текста в нижний регистр с помощью lower.

  3. Стереть пунктуацию с помощью erasePunctuation.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

См. также

| | | | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста) | (Панель инструментов для анализа текста)

Связанные темы