exponenta event banner

Классификация текстовых данных из памяти с помощью настраиваемого мини-хранилища пакетных данных

В этом примере показано, как классифицировать текстовые данные из памяти с помощью сети глубокого обучения с помощью настраиваемого мини-хранилища пакетных данных.

Мини-хранилище пакетных данных представляет собой реализацию хранилища данных с поддержкой считывания данных в пакетах. Хранилище данных мини-пакета можно использовать в качестве источника наборов данных обучения, проверки, тестирования и прогнозирования для приложений глубокого обучения. Использование хранилищ данных мини-пакетов для считывания данных из памяти или выполнения определенных операций предварительной обработки при считывании пакетов данных.

При обучении сети программное обеспечение создает мини-пакеты последовательностей одинаковой длины путем заполнения, усечения или разделения входных данных. trainingOptions функция предоставляет опции для вставки и усечения входных последовательностей, однако эти опции плохо подходят для последовательностей векторов слов. Кроме того, эта функция не поддерживает заполнение данных в пользовательском хранилище данных. Вместо этого необходимо вручную выполнить подстройку и усечение последовательностей. Если вы оставляете и усекаете последовательности векторов слов, то тренировка может улучшиться.

В примере «Классифицировать текстовые данные с помощью глубокого обучения» вручную выполняется усечение и вставка всех документов на одинаковую длину. Этот процесс добавляет много дополнений к очень коротким документам и отбрасывает много данных из очень длинных документов.

Кроме того, чтобы предотвратить добавление слишком большого количества дополнений или удаление слишком большого количества данных, создайте пользовательское хранилище данных мини-пакета, которое вводит мини-пакеты в сеть. Пользовательское хранилище данных мини-пакета textDatastore.m преобразует мини-пакеты документов в последовательности или индексы слов и левые панели каждого мини-пакета в длину самого длинного документа в мини-пакете. Для отсортированных данных это хранилище данных может помочь уменьшить количество дополнений, добавляемых к данным, поскольку документы не заполняются до фиксированной длины. Аналогично, хранилище данных не отбрасывает данные из документов.

В этом примере используется пользовательское хранилище данных мини-пакета textDatastore.m. Это хранилище данных можно адаптировать к данным путем настройки функций. Пример создания собственного хранилища данных мини-пакета см. в разделе Разработка хранилища данных мини-пакета (Deep Learning Toolbox).

Загрузка предварительно обученного встраивания слов

Хранилище данных textDatastore требуется вложение слова для преобразования документов в последовательности векторов. Загрузить предварительно подготовленное вложение слов с помощью fastTextWordEmbedding. Для выполнения этой функции требуется модель Text Analytics Toolbox™ для пакета поддержки внедрения Token Word на английском языке на 16 миллиардов. Если этот пакет поддержки не установлен, функция предоставляет ссылку для загрузки.

emb = fastTextWordEmbedding;

Создание хранилища данных мини-партии документов

Создайте хранилище данных, содержащее данные для обучения. Пользовательское хранилище данных мини-пакета textDatastore считывает предикторы и метки из CSV-файла. Для предикторов хранилище данных преобразует документы в последовательности индексов слов, а для ответов хранилище данных возвращает категориальную метку для каждого документа.

Чтобы создать хранилище данных, сначала сохраните пользовательское мини-пакетное хранилище данных. textDatastore.m к пути. Дополнительные сведения о создании пользовательских мини-пакетных хранилищ данных см. в разделе Разработка пользовательского мини-пакетного хранилища данных (инструментарий для глубокого обучения).

Для учебных данных укажите CSV-файл "factoryReports.csv" и что текст и метки находятся в столбцах "Description" и "Category" соответственно.

filenameTrain = "factoryReports.csv";
textName = "Description";
labelName = "Category";
dsTrain = textDatastore(filenameTrain,textName,labelName,emb)
dsTrain = 
  textDatastore with properties:

            ClassNames: ["Electronic Failure"    "Leak"    "Mechanical Failure"    "Software Failure"]
             Datastore: [1×1 matlab.io.datastore.TransformedDatastore]
    EmbeddingDimension: 300
             LabelName: "Category"
         MiniBatchSize: 128
            NumClasses: 4
       NumObservations: 480

Создание и обучение сети LSTM

Определите архитектуру сети LSTM. Чтобы ввести данные последовательности в сеть, включите входной уровень последовательности и установите размер ввода в размерность встраивания. Далее следует включить уровень LSTM со 180 скрытыми единицами измерения. Чтобы использовать уровень LSTM для проблемы классификации «последовательность-метка», установите режим вывода в значение 'last'. Наконец, добавьте полностью подключенный уровень с размером выхода, равным количеству классов, уровень softmax и уровень классификации.

numFeatures = dsTrain.EmbeddingDimension;
numHiddenUnits = 180;
numClasses = dsTrain.NumClasses;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Укажите параметры обучения. Укажите решатель для 'adam' и порог градиента должен быть равен 2. Хранилище данных textDatastore.m не поддерживает тасование, поэтому установите 'Shuffle'Кому 'never'. Пример внедрения хранилища данных с поддержкой тасования см. в разделе Разработка пользовательского мини-пакетного хранилища данных (Deep Learning Toolbox). Для контроля за ходом обучения установите 'Plots' опция для 'training-progress'. Для подавления подробных выходных данных установите 'Verbose' кому false.

По умолчанию trainNetwork использует графический процессор, если он доступен. Чтобы указать среду выполнения вручную, используйте 'ExecutionEnvironment' аргумент пары имя-значение trainingOptions. Обучение на CPU может занять значительно больше времени, чем обучение на GPU. Для обучения с использованием графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе.

miniBatchSize = 128;
numObservations = dsTrain.NumObservations;
numIterationsPerEpoch = floor(numObservations / miniBatchSize);

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',2, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучение сети LSTM с помощью trainNetwork функция.

net = trainNetwork(dsTrain,layers,options);

Прогнозирование с использованием новых данных

Классифицируйте тип события для трех новых отчетов. Создайте строковый массив, содержащий новые отчеты.

reportsNew = [ 
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

Предварительная обработка текстовых данных с использованием шагов предварительной обработки в качестве хранилища данных textDatastore.m.

documents = tokenizedDocument(reportsNew);
documents = lower(documents);
documents = erasePunctuation(documents);
predictors = doc2sequence(emb,documents);

Классифицируйте новые последовательности с помощью обученной сети LSTM.

labelsNew = classify(net,predictors)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

См. также

| | | | | | | (глубоко изучение комплекта инструментов) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения)

Связанные темы