Классифицируйте текстовые данные из памяти Используя пользовательский мини-пакетный Datastore

Этот пример показывает, как классифицировать текстовые данные из памяти с нейронной сетью для глубокого обучения с помощью пользовательского мини-пакетного datastore.

Мини-пакетный datastore является реализацией datastore с поддержкой чтения данных в пакетах. Можно использовать мини-пакетный datastore в качестве источника обучения, валидации, теста и наборов данных прогноза для применения глубокого обучения. Используйте мини-пакетные хранилища данных, чтобы считать данные, которые не помещаются в память, или выполнить определенные операции предварительной обработки при чтении пакетов данных.

При обучении сети программное обеспечение создает мини-пакеты последовательностей той же длины путем дополнения, обрезая или разделяя входные данные. Функция trainingOptions предоставляет возможности заполнять и обрезать входные последовательности, однако, эти опции не хорошо подходят для последовательностей векторов слова. Кроме того, эта функция не поддерживает дополнительные данные в пользовательском datastore. Вместо этого необходимо заполнить и обрезать последовательности вручную. Если вы лево-заполняете и обрезаете последовательности векторов слова, то учебная сила улучшается.

Классифицировать текстовые Данные Используя пример Глубокого обучения вручную обрезают и заполняют все документы той же длине. Этот процесс добавляет большое дополнение к очень коротким документам и отбрасывает много данных из очень длинных документов.

Также, чтобы предотвратить добавление слишком большого дополнения или отбрасывание слишком большого количества данных, создайте пользовательский мини-пакетный datastore, который вводит мини-пакеты в сеть. Пользовательский мини-пакетный datastore textDatastore.m преобразовывает мини-пакеты документов последовательностям или словарей и лево-заполняет каждый мини-пакет к длине самого длинного документа в мини-пакете. Для отсортированных данных этот datastore может помочь уменьшать объем дополнения добавленного к данным, поскольку документы не дополнены к фиксированной длине. Точно так же datastore не отбрасывает данных из документов.

Этот пример использует пользовательский мини-пакетный datastore textDatastore.m. Можно адаптировать этот datastore к данным путем настройки функций. Для примера, показывающего, как создать ваш собственный мини-пакетный datastore, смотрите, Разрабатывают Пользовательский Мини-пакетный Datastore (Deep Learning Toolbox).

Загрузите предварительно обученный Word Embedding

Datastore textDatastore требует, чтобы встраивание слова преобразовало документы последовательностям векторов. Загрузите предварительно обученное встраивание слова с помощью fastTextWordEmbedding. Эта функция требует Модели Text Analytics Toolbox™ для fastText английских 16 миллиардов Лексем пакет поддержки Word Embedding. Если этот пакет поддержки не установлен, то функция обеспечивает ссылку на загрузку.

emb = fastTextWordEmbedding;

Создайте мини-пакетный Datastore документов

Создайте datastore, который содержит данные для обучения. Пользовательский мини-пакетный datastore textDatastore читает предикторы и метки из файла CSV. Для предикторов datastore преобразовывает документы в последовательности словарей и для ответов, datastore возвращает категориальную метку для каждого документа.

Чтобы создать datastore, сначала сохраните пользовательский мини-пакетный datastore textDatastore.m в путь. Для получения дополнительной информации о создании пользовательских мини-пакетных хранилищ данных, смотрите, Разрабатывают Пользовательский Мини-пакетный Datastore (Deep Learning Toolbox).

Для данных тренировки задайте файл CSV "weatherReportsTrain.csv" и что текст и метки находятся в столбцах "event_narrative" и "event_type" соответственно.

filenameTrain = "weatherReportsTrain.csv";
textName = "event_narrative";
labelName = "event_type";
dsTrain = textDatastore(filenameTrain,textName,labelName,emb)
dsTrain = 
  textDatastore with properties:

            ClassNames: [1×39 string]
             Datastore: [1×1 matlab.io.datastore.TransformedDatastore]
    EmbeddingDimension: 300
             LabelName: "event_type"
         MiniBatchSize: 128
            NumClasses: 39
       NumObservations: 19683

Создайте datastore, содержащий данные о валидации из файла CSV "weatherReportsValidation.csv" с помощью тех же шагов.

filenameValidation = "weatherReportsValidation.csv";
dsValidation = textDatastore(filenameValidation,textName,labelName,emb)
dsValidation = 
  textDatastore with properties:

            ClassNames: [1×39 string]
             Datastore: [1×1 matlab.io.datastore.TransformedDatastore]
    EmbeddingDimension: 300
             LabelName: "event_type"
         MiniBatchSize: 128
            NumClasses: 39
       NumObservations: 4218

Создайте и обучите сеть LSTM

Задайте архитектуру сети LSTM. Чтобы ввести данные о последовательности в сеть, включайте входной слой последовательности и установите входной размер на размерность встраивания. Затем, включайте слой LSTM с 180 скрытыми модулями. Чтобы использовать слой LSTM для проблемы классификации последовательностей к метке, установите режим вывода на 'last'. Наконец, добавьте полносвязный слой с выходным размером, равным количеству классов, softmax слоя и слоя классификации.

numFeatures = dsTrain.EmbeddingDimension;
numHiddenUnits = 180;
numClasses = dsTrain.NumClasses;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Задайте опции обучения. Задайте решатель, чтобы быть 'adam' и порогом градиента, чтобы быть 2. Datastore textDatastore.m не поддерживает перестановку, так установил 'Shuffle' к 'never'. Для примера, показывающего, как реализовать datastore с поддержкой перестановки, смотрите, Разрабатывают Пользовательский Мини-пакетный Datastore (Deep Learning Toolbox). Подтвердите сеть однажды в эпоху. Чтобы контролировать учебный прогресс, установите опцию 'Plots' на 'training-progress'. Чтобы подавить многословный вывод, установите 'Verbose' на false.

По умолчанию trainNetwork использует графический процессор, если вы доступны (требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше). В противном случае это использует центральный процессор. Чтобы задать среду выполнения вручную, используйте аргумент пары "имя-значение" 'ExecutionEnvironment' trainingOptions. Обучение на центральном процессоре может взять значительно дольше, чем обучение на графическом процессоре.

miniBatchSize = 128;
numObservations = dsTrain.NumObservations;
numIterationsPerEpoch = floor(numObservations / miniBatchSize);

options = trainingOptions('adam', ...
    'MaxEpochs',15, ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',2, ...
    'Shuffle','never', ...
    'ValidationData',dsValidation, ...
    'ValidationFrequency',numIterationsPerEpoch, ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучите сеть LSTM с помощью функции trainNetwork.

net = trainNetwork(dsTrain,layers,options);

Протестируйте сеть LSTM

Создайте datastore, содержащий документы и метки

filenameTest = "weatherReportsTest.csv";
dsTest = textDatastore(filenameTest,textName,labelName,emb)
dsTest = 
  textDatastore with properties:

            ClassNames: [1×39 string]
             Datastore: [1×1 matlab.io.datastore.TransformedDatastore]
    EmbeddingDimension: 300
             LabelName: "event_type"
         MiniBatchSize: 128
            NumClasses: 39
       NumObservations: 4217

Считайте метки из datastore с помощью функции readLabels пользовательского datastore.

YTest = readLabels(dsTest);

Классифицируйте тестовые документы с помощью обученной сети LSTM.

YPred = classify(net,dsTest);

Вычислите точность классификации. Точность является пропорцией меток, которые сеть предсказывает правильно.

accuracy = mean(YPred == YTest)
accuracy = 0.8084

Смотрите также

| | | | | | | | | |

Похожие темы