В этом примере показано, как классифицировать текстовые данные из памяти с нейронной сетью для глубокого обучения с помощью преобразованного datastore.
Преобразованный datastore преобразовывает или данные о процессах, считанные из базового datastore, можно использовать преобразованный datastore в качестве источника обучения, валидации, теста и наборов данных прогноза для применения глубокого обучения. Используйте преобразованные хранилища данных, чтобы считать данные, которые не помещаются в память, или выполнить определенные операции предварительной обработки при чтении пакетов данных.
При обучении сети программное обеспечение создает мини-пакеты последовательностей той же длины путем дополнения, обрезая или разделяя входные данные. trainingOptions
функция предоставляет возможности заполнять и обрезать входные последовательности, однако, эти опции не хорошо подходят для последовательностей векторов слова. Кроме того, эта функция не поддерживает дополнительные данные в пользовательском datastore. Вместо этого необходимо заполнить и обрезать последовательности вручную. Если вы лево-заполняете и обрезаете последовательности векторов слова, то учебная сила улучшается.
Классифицировать текстовые Данные Используя Глубокое обучение (Text Analytics Toolbox) пример вручную обрезает и заполняет все документы той же длине. Этот процесс добавляет большое дополнение к очень коротким документам и отбрасывает много данных из очень длинных документов.
В качестве альтернативы, чтобы предотвратить добавление слишком большого дополнения или отбрасывание слишком большого количества данных, создайте преобразованный datastore, который вводит мини-пакеты в сеть. Datastore, созданный в этом примере, преобразует мини-пакеты документов последовательностям или словарей и лево-заполняет каждый мини-пакет к длине самого длинного документа в мини-пакете.
Datastore требует, чтобы встраивание слова преобразовало документы последовательностям векторов. Загрузите предварительно обученное встраивание слова с помощью fastTextWordEmbedding
. Эта функция требует Модели Text Analytics Toolbox™ для fastText английских 16 миллиардов Лексем пакет поддержки Word Embedding. Если этот пакет поддержки не установлен, то функция обеспечивает ссылку на загрузку.
emb = fastTextWordEmbedding;
Создайте табличный текстовый datastore из данных в weatherReportsTrain.csv
. Задайте, чтобы считать данные из "event_narrative"
и "event_type"
столбцы только.
filenameTrain = "weatherReportsTrain.csv"; textName = "event_narrative"; labelName = "event_type"; ttdsTrain = tabularTextDatastore(filenameTrain,'SelectedVariableNames',[textName labelName]);
Просмотрите предварительный просмотр datastore.
preview(ttdsTrain)
ans=8×2 table
event_narrative event_type
_________________________________________________________________________________________________________________________________________________________________________________________________ ___________________
'Large tree down between Plantersville and Nettleton.' 'Thunderstorm Wind'
'One to two feet of deep standing water developed on a street on the Winthrop University campus after more than an inch of rain fell in less than an hour. One vehicle was stalled in the water.' 'Heavy Rain'
'NWS Columbia relayed a report of trees blown down along Tom Hall St.' 'Thunderstorm Wind'
'Media reported two trees blown down along I-40 in the Old Fort area.' 'Thunderstorm Wind'
'A few tree limbs greater than 6 inches down on HWY 18 in Roseland.' 'Thunderstorm Wind'
'Awning blown off a building on Lamar Avenue. Multiple trees down near the intersection of Winchester and Perkins.' 'Thunderstorm Wind'
'Tin roof ripped off house on Old Memphis Road near Billings Drive. Several large trees down in the area.' 'Thunderstorm Wind'
'Powerlines down at Walnut Grove and Cherry Lane roads.' 'Thunderstorm Wind'
Создайте пользовательское, преобразовывают функцию, которая преобразует данные, считанные от datastore до таблицы, содержащей предикторы и ответы. transformTextData
функционируйте берет данные, считанные из tabularTextDatastore
возразите и возвращает таблицу предикторов и ответов. Предикторы являются C-by-S массивами векторов слова, данных словом, встраивающим emb
, где C является размерностью встраивания, и S является длиной последовательности. Ответы являются категориальными метками по классам.
Чтобы получить имена классов, считайте метки из обучающих данных с помощью readLabels
функция, перечисленная и конец примера, и, находит уникальные имена классов.
labels = readLabels(ttdsTrain,labelName); classNames = unique(labels); numObservations = numel(labels);
Поскольку tablular текстовые хранилища данных могут считать несколько строк данных в одном чтении, можно обработать полный мини-пакет данных в функции преобразования. Чтобы гарантировать, что функция преобразования обрабатывает полный мини-пакет данных, устанавливает размер чтения табличного текстового datastore к мини-пакетному размеру, который будет использоваться в обучении.
miniBatchSize = 128; ttdsTrain.ReadSize = miniBatchSize;
Чтобы преобразовать выход табличных текстовых данных к последовательностям для обучения, преобразуйте datastore с помощью transform
функция.
tdsTrain = transform(ttdsTrain, @(data) transformTextData(data,emb,classNames))
tdsTrain = TransformedDatastore with properties: UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore] Transforms: {@(data)transformTextData(data,emb,classNames)} IncludeInfo: 0
Предварительный просмотр преобразованного datastore. Предикторы являются C-by-S массивами, где S является длиной последовательности, и C является количеством функций (размерность встраивания). Ответы являются категориальными метками.
preview(tdsTrain)
ans=8×2 table
predictors responses
________________ _________________
[300×164 single] Thunderstorm Wind
[300×164 single] Heavy Rain
[300×164 single] Thunderstorm Wind
[300×164 single] Thunderstorm Wind
[300×164 single] Thunderstorm Wind
[300×164 single] Thunderstorm Wind
[300×164 single] Thunderstorm Wind
[300×164 single] Thunderstorm Wind
Создайте преобразованный datastore, содержащий данные о валидации в weatherReportsValidation.csv
использование тех же шагов.
filenameValidation = "weatherReportsValidation.csv"; ttdsValidation = tabularTextDatastore(filenameValidation,'SelectedVariableNames',[textName labelName]); ttdsValidation.ReadSize = miniBatchSize; tdsValidation = transform(ttdsValidation, @(data) transformTextData(data,emb,classNames))
tdsValidation = TransformedDatastore with properties: UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore] Transforms: {@(data)transformTextData(data,emb,classNames)} IncludeInfo: 0
Задайте архитектуру сети LSTM. Чтобы ввести данные о последовательности в сеть, включайте входной слой последовательности и установите входной размер на размерность встраивания. Затем включайте слой LSTM с 180 скрытыми модулями. Чтобы использовать слой LSTM в проблеме классификации последовательностей к метке, установите режим вывода на 'last'
. Наконец, добавьте полносвязный слой с выходным размером, равным количеству классов, softmax слоя и слоя классификации.
numFeatures = emb.Dimension; numHiddenUnits = 180; numClasses = numel(classNames); layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Задайте опции обучения. Задайте решатель, чтобы быть 'adam'
и порог градиента, чтобы быть 2. Datastore не поддерживает перестановку, таким образом, устанавливает 'Shuffle'
, к 'never'
. Подтвердите сеть однажды в эпоху. Чтобы контролировать процесс обучения, установите 'Plots'
опция к 'training-progress'
. Чтобы подавить многословный выход, установите 'Verbose'
к false
.
По умолчанию, trainNetwork
использует графический процессор, если вы доступны (требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше). В противном случае это использует центральный процессор. Чтобы задать среду выполнения вручную, используйте 'ExecutionEnvironment'
аргумент пары "имя-значение" trainingOptions
. Обучение на центральном процессоре может взять значительно дольше, чем обучение на графическом процессоре.
numIterationsPerEpoch = floor(numObservations / miniBatchSize); options = trainingOptions('adam', ... 'MaxEpochs',15, ... 'MiniBatchSize',miniBatchSize, ... 'GradientThreshold',2, ... 'Shuffle','never', ... 'ValidationData',tdsValidation, ... 'ValidationFrequency',numIterationsPerEpoch, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть LSTM с помощью trainNetwork
функция.
net = trainNetwork(tdsTrain,layers,options);
Создайте преобразованный datastore, содержащий протянутые тестовые данные в weatherReportsTest.csv
.
filenameTest = "weatherReportsTest.csv"; ttdsTest = tabularTextDatastore(filenameTest,'SelectedVariableNames',[textName labelName]); ttdsTest.ReadSize = miniBatchSize; tdsTest = transform(ttdsTest, @(data) transformTextData(data,emb,classNames))
tdsTest = TransformedDatastore with properties: UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore] Transforms: {@(data)transformTextData(data,emb,classNames)} IncludeInfo: 0
Считайте метки из tabularTextDatastore
.
labelsTest = readLabels(ttdsTest,labelName); YTest = categorical(labelsTest,classNames);
Сделайте прогнозы на тестовых данных с помощью обучившего сеть.
YPred = classify(net,tdsTest,'MiniBatchSize',miniBatchSize);
Вычислите точность классификации на тестовые данные.
accuracy = mean(YPred == YTest)
accuracy = 0.8293
readLabels
функция создает копию tabularTextDatastore
объект ttds
и читает метки из labelName
столбец.
function labels = readLabels(ttds,labelName) ttdsNew = copy(ttds); ttdsNew.SelectedVariableNames = labelName; tbl = readall(ttdsNew); labels = tbl.(labelName); end
transformTextData
функционируйте берет данные, считанные из tabularTextDatastore
возразите и возвращает таблицу предикторов и ответов. Предикторы являются C-by-S массивами векторов слова, данных словом, встраивающим emb
, где C является размерностью встраивания, и S является длиной последовательности. Ответы являются категориальными метками по классам в classNames
.
function dataTransformed = transformTextData(data,emb,classNames) % Preprocess documents. textData = data{:,1}; textData = lower(textData); documents = tokenizedDocument(textData); % Convert to sequences. predictors = doc2sequence(emb,documents); % Read labels. labels = data{:,2}; responses = categorical(labels,classNames); % Convert data to table. dataTransformed = table(predictors,responses); end
doc2sequence
| fastTextWordEmbedding
| lstmLayer
| sequenceInputLayer
| tokenizedDocument
| trainNetwork
| trainingOptions
| transform
| wordEmbeddingLayer