Этот пример показывает, как классифицировать текстовые данные из памяти с нейронной сетью для глубокого обучения с помощью пользовательского мини-пакетного datastore.
Мини-пакетный datastore является реализацией datastore с поддержкой чтения данных в пакетах. Можно использовать мини-пакетный datastore в качестве источника обучения, валидации, теста и наборов данных прогноза для применения глубокого обучения. Используйте мини-пакетные хранилища данных, чтобы считать данные, которые не помещаются в память, или выполнить определенные операции предварительной обработки при чтении пакетов данных.
При обучении сети программное обеспечение создает мини-пакеты последовательностей той же длины путем дополнения, обрезая или разделяя входные данные. Функция trainingOptions
предоставляет возможности заполнять и обрезать входные последовательности, однако, эти опции не хорошо подходят для последовательностей векторов слова. Кроме того, эта функция не поддерживает дополнительные данные в пользовательском datastore. Вместо этого необходимо заполнить и обрезать последовательности вручную. Если вы лево-заполняете и обрезаете последовательности векторов слова, то учебная сила улучшается.
Классифицировать текстовые Данные Используя пример Глубокого обучения вручную обрезают и заполняют все документы той же длине. Этот процесс добавляет большое дополнение к очень коротким документам и отбрасывает много данных из очень длинных документов.
Также, чтобы предотвратить добавление слишком большого дополнения или отбрасывание слишком большого количества данных, создайте пользовательский мини-пакетный datastore, который вводит мини-пакеты в сеть. Пользовательский мини-пакетный datastore textDatastore.m
преобразовывает мини-пакеты документов последовательностям или словарей и лево-заполняет каждый мини-пакет к длине самого длинного документа в мини-пакете. Для отсортированных данных этот datastore может помочь уменьшать объем дополнения добавленного к данным, поскольку документы не дополнены к фиксированной длине. Точно так же datastore не отбрасывает данных из документов.
Этот пример использует пользовательский мини-пакетный datastore textDatastore.m
. Можно адаптировать этот datastore к данным путем настройки функций. Для примера, показывающего, как создать ваш собственный мини-пакетный datastore, смотрите, Разрабатывают Пользовательский Мини-пакетный Datastore (Deep Learning Toolbox).
Datastore textDatastore
требует, чтобы встраивание слова преобразовало документы последовательностям векторов. Загрузите предварительно обученное встраивание слова с помощью fastTextWordEmbedding
. Эта функция требует Модели Text Analytics Toolbox™ для fastText английских 16 миллиардов Лексем пакет поддержки Word Embedding. Если этот пакет поддержки не установлен, то функция обеспечивает ссылку на загрузку.
emb = fastTextWordEmbedding;
Создайте datastore, который содержит данные для обучения. Пользовательский мини-пакетный datastore textDatastore
читает предикторы и метки из файла CSV. Для предикторов datastore преобразовывает документы в последовательности словарей и для ответов, datastore возвращает категориальную метку для каждого документа.
Чтобы создать datastore, сначала сохраните пользовательский мини-пакетный datastore textDatastore.m
в путь. Для получения дополнительной информации о создании пользовательских мини-пакетных хранилищ данных, смотрите, Разрабатывают Пользовательский Мини-пакетный Datastore (Deep Learning Toolbox).
Для данных тренировки задайте файл CSV "weatherReportsTrain.csv"
и что текст и метки находятся в столбцах "event_narrative"
и "event_type"
соответственно.
filenameTrain = "weatherReportsTrain.csv"; textName = "event_narrative"; labelName = "event_type"; dsTrain = textDatastore(filenameTrain,textName,labelName,emb)
dsTrain = textDatastore with properties: ClassNames: [1×39 string] Datastore: [1×1 matlab.io.datastore.TransformedDatastore] EmbeddingDimension: 300 LabelName: "event_type" MiniBatchSize: 128 NumClasses: 39 NumObservations: 19683
Создайте datastore, содержащий данные о валидации из файла CSV "weatherReportsValidation.csv"
с помощью тех же шагов.
filenameValidation = "weatherReportsValidation.csv";
dsValidation = textDatastore(filenameValidation,textName,labelName,emb)
dsValidation = textDatastore with properties: ClassNames: [1×39 string] Datastore: [1×1 matlab.io.datastore.TransformedDatastore] EmbeddingDimension: 300 LabelName: "event_type" MiniBatchSize: 128 NumClasses: 39 NumObservations: 4218
Задайте архитектуру сети LSTM. Чтобы ввести данные о последовательности в сеть, включайте входной слой последовательности и установите входной размер на размерность встраивания. Затем, включайте слой LSTM с 180 скрытыми модулями. Чтобы использовать слой LSTM для проблемы классификации последовательностей к метке, установите режим вывода на 'last'
. Наконец, добавьте полносвязный слой с выходным размером, равным количеству классов, softmax слоя и слоя классификации.
numFeatures = dsTrain.EmbeddingDimension; numHiddenUnits = 180; numClasses = dsTrain.NumClasses; layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Задайте опции обучения. Задайте решатель, чтобы быть 'adam'
и порогом градиента, чтобы быть 2. Datastore textDatastore.m
не поддерживает перестановку, так установил 'Shuffle'
к 'never'
. Для примера, показывающего, как реализовать datastore с поддержкой перестановки, смотрите, Разрабатывают Пользовательский Мини-пакетный Datastore (Deep Learning Toolbox). Подтвердите сеть однажды в эпоху. Чтобы контролировать учебный прогресс, установите опцию 'Plots'
на 'training-progress'
. Чтобы подавить многословный вывод, установите 'Verbose'
на false
.
По умолчанию trainNetwork
использует графический процессор, если вы доступны (требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше). В противном случае это использует центральный процессор. Чтобы задать среду выполнения вручную, используйте аргумент пары "имя-значение" 'ExecutionEnvironment'
trainingOptions
. Обучение на центральном процессоре может взять значительно дольше, чем обучение на графическом процессоре.
miniBatchSize = 128; numObservations = dsTrain.NumObservations; numIterationsPerEpoch = floor(numObservations / miniBatchSize); options = trainingOptions('adam', ... 'MaxEpochs',15, ... 'MiniBatchSize',miniBatchSize, ... 'GradientThreshold',2, ... 'Shuffle','never', ... 'ValidationData',dsValidation, ... 'ValidationFrequency',numIterationsPerEpoch, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть LSTM с помощью функции trainNetwork
.
net = trainNetwork(dsTrain,layers,options);
Создайте datastore, содержащий документы и метки
filenameTest = "weatherReportsTest.csv";
dsTest = textDatastore(filenameTest,textName,labelName,emb)
dsTest = textDatastore with properties: ClassNames: [1×39 string] Datastore: [1×1 matlab.io.datastore.TransformedDatastore] EmbeddingDimension: 300 LabelName: "event_type" MiniBatchSize: 128 NumClasses: 39 NumObservations: 4217
Считайте метки из datastore с помощью функции readLabels
пользовательского datastore.
YTest = readLabels(dsTest);
Классифицируйте тестовые документы с помощью обученной сети LSTM.
YPred = classify(net,dsTest);
Вычислите точность классификации. Точность является пропорцией меток, которые сеть предсказывает правильно.
accuracy = mean(YPred == YTest)
accuracy = 0.8084
doc2sequence
| extractHTMLText
| findElement
| htmlTree
| lstmLayer
| sequenceInputLayer
| tokenizedDocument
| trainNetwork
| trainingOptions
| wordEmbeddingLayer
| wordcloud