В этом примере показано, как классифицировать текстовые описания прогнозов погоды с помощью сети долгой краткосрочной памяти (LSTM) глубокого обучения.
Текстовые данные естественно последовательны. Часть текста является последовательностью слов, которые могут иметь зависимости между ними. Чтобы изучить и использовать долгосрочные зависимости, чтобы классифицировать данные о последовательности, используйте нейронную сеть LSTM. Сеть LSTM является типом рекуррентной нейронной сети (RNN), которая может изучить долгосрочные зависимости между временными шагами данных о последовательности.
Чтобы ввести текст к сети LSTM, сначала преобразуйте текстовые данные в числовые последовательности. Можно достигнуть этого использования кодирования слова, которое сопоставляет документы последовательностям числовых индексов. Для лучших результатов также включайте слой встраивания слова в сеть. Вложения Word сопоставляют слова в словаре к числовым векторам, а не скалярным индексам. Эти вложения получают семантические детали слов, так, чтобы слова с подобными значениями имели подобные векторы. Они также отношения модели между словами через векторную арифметику. Например, отношение "король королеве, как человек женщине", описан королем уравнения – человек + женщина = королева.
Существует четыре шага в обучении и использовании сети LSTM в этом примере:
Импортируйте и предварительно обработайте данные.
Преобразуйте слова в числовые последовательности с помощью кодирования слова.
Создайте и обучите сеть LSTM со слоем встраивания слова.
Классифицируйте новые текстовые данные с помощью обученной сети LSTM.
Импортируйте данные о прогнозах погоды. Эти данные содержат помеченные текстовые описания погодных явлений. Чтобы импортировать текстовые данные как строки, задайте тип текста, чтобы быть 'string'
.
filename = "weatherReports.csv"; data = readtable(filename,'TextType','string'); head(data)
ans=8×16 table
Time event_id state event_type damage_property damage_crops begin_lat begin_lon end_lat end_lon event_narrative storm_duration begin_day end_day year end_timestamp
____________________ __________ ________________ ___________________ _______________ ____________ _________ _________ _______ _______ _________________________________________________________________________________________________________________________________________________________________________________________________ ______________ _________ _______ ____ ____________________
22-Jul-2016 16:10:00 6.4433e+05 "MISSISSIPPI" "Thunderstorm Wind" "" "0.00K" 34.14 -88.63 34.122 -88.626 "Large tree down between Plantersville and Nettleton." 00:05:00 22 22 2016 22-Jul-0016 16:15:00
15-Jul-2016 17:15:00 6.5182e+05 "SOUTH CAROLINA" "Heavy Rain" "2.00K" "0.00K" 34.94 -81.03 34.94 -81.03 "One to two feet of deep standing water developed on a street on the Winthrop University campus after more than an inch of rain fell in less than an hour. One vehicle was stalled in the water." 00:00:00 15 15 2016 15-Jul-0016 17:15:00
15-Jul-2016 17:25:00 6.5183e+05 "SOUTH CAROLINA" "Thunderstorm Wind" "0.00K" "0.00K" 35.01 -80.93 35.01 -80.93 "NWS Columbia relayed a report of trees blown down along Tom Hall St." 00:00:00 15 15 2016 15-Jul-0016 17:25:00
16-Jul-2016 12:46:00 6.5183e+05 "NORTH CAROLINA" "Thunderstorm Wind" "0.00K" "0.00K" 35.64 -82.14 35.64 -82.14 "Media reported two trees blown down along I-40 in the Old Fort area." 00:00:00 16 16 2016 16-Jul-0016 12:46:00
15-Jul-2016 14:28:00 6.4332e+05 "MISSOURI" "Hail" "" "" 36.45 -89.97 36.45 -89.97 "" 00:07:00 15 15 2016 15-Jul-0016 14:35:00
15-Jul-2016 16:31:00 6.4332e+05 "ARKANSAS" "Thunderstorm Wind" "" "0.00K" 35.85 -90.1 35.838 -90.087 "A few tree limbs greater than 6 inches down on HWY 18 in Roseland." 00:09:00 15 15 2016 15-Jul-0016 16:40:00
15-Jul-2016 16:03:00 6.4343e+05 "TENNESSEE" "Thunderstorm Wind" "20.00K" "0.00K" 35.056 -89.937 35.05 -89.904 "Awning blown off a building on Lamar Avenue. Multiple trees down near the intersection of Winchester and Perkins." 00:07:00 15 15 2016 15-Jul-0016 16:10:00
15-Jul-2016 17:27:00 6.4344e+05 "TENNESSEE" "Hail" "" "" 35.385 -89.78 35.385 -89.78 "Quarter size hail near Rosemark." 00:05:00 15 15 2016 15-Jul-0016 17:32:00
Удалите строки таблицы с пустыми отчетами.
idxEmpty = strlength(data.event_narrative) == 0; data(idxEmpty,:) = [];
Цель этого примера состоит в том, чтобы классифицировать события меткой в event_type
столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальный.
data.event_type = categorical(data.event_type);
Просмотрите распределение классов в данных с помощью гистограммы. Чтобы сделать метки легче читать, увеличьте ширину фигуры.
f = figure; f.Position(3) = 1.5*f.Position(3); h = histogram(data.event_type); xlabel("Class") ylabel("Frequency") title("Class Distribution")
Классы данных являются неустойчивыми со многими классами, содержащими немного наблюдений. Когда классы являются неустойчивыми таким образом, сетевая сила сходятся к менее точной модели. Чтобы предотвратить эту проблему, удалите любые классы, которые появляются меньше чем десять раз.
Получите подсчет частот классов и имен классов от гистограммы.
classCounts = h.BinCounts; classNames = h.Categories;
Найдите классы, содержащие меньше чем десять наблюдений.
idxLowCounts = classCounts < 10; infrequentClasses = classNames(idxLowCounts)
infrequentClasses = 1×8 cell array
{'Freezing Fog'} {'Hurricane'} {'Lakeshore Flood'} {'Marine Dense Fog'} {'Marine Strong Wind'} {'Marine Tropical Depression'} {'Seiche'} {'Sneakerwave'}
Удалите эти нечастые классы из данных. Используйте removecats
удалить неиспользованные категории из категориальных данных.
idxInfrequent = ismember(data.event_type,infrequentClasses); data(idxInfrequent,:) = []; data.event_type = removecats(data.event_type);
Теперь данные сортируются в классы разумного размера. Следующий шаг должен разделить его в наборы для обучения, валидации и тестирования. Разделите данные в учебный раздел и протянутый раздел для валидации и тестирования. Задайте процент затяжки, чтобы быть 30%.
cvp = cvpartition(data.event_type,'Holdout',0.3);
dataTrain = data(training(cvp),:);
dataHeldOut = data(test(cvp),:);
Разделите протянутый набор снова, чтобы установить валидацию. Задайте процент затяжки, чтобы быть 50%. Это приводит к разделению 70%-х учебных наблюдений, 15% наблюдений валидации и 15%-х тестовых наблюдений.
cvp = cvpartition(dataHeldOut.event_type,'HoldOut',0.5);
dataValidation = dataHeldOut(training(cvp),:);
dataTest = dataHeldOut(test(cvp),:);
Извлеките текстовые данные и метки из разделенных таблиц.
textDataTrain = dataTrain.event_narrative; textDataValidation = dataValidation.event_narrative; textDataTest = dataTest.event_narrative; YTrain = dataTrain.event_type; YValidation = dataValidation.event_type; YTest = dataTest.event_type;
Чтобы проверять, что вы импортировали данные правильно, визуализируйте учебные текстовые данные с помощью облака слова.
figure
wordcloud(textDataTrain);
title("Training Data")
Создайте функцию, которая маркирует и предварительно обрабатывает текстовые данные. Функциональный preprocessText
, перечисленный в конце примера, выполняет эти шаги:
Маркируйте текст с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Сотрите пунктуацию с помощью erasePunctuation
.
Предварительно обработайте обучающие данные и данные о валидации с помощью preprocessText
функция.
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
Просмотрите первые несколько предварительно обработанных учебных материалов.
documentsTrain(1:5)
ans = 5×1 tokenizedDocument: 7 tokens: large tree down between plantersville and nettleton 37 tokens: one to two feet of deep standing water developed on a street on the winthrop university campus after more than an inch of rain fell in less than an hour one vehicle was stalled in the water 13 tokens: nws columbia relayed a report of trees blown down along tom hall st 13 tokens: media reported two trees blown down along i40 in the old fort area 14 tokens: a few tree limbs greater than 6 inches down on hwy 18 in roseland
Чтобы ввести документы в сеть LSTM, используйте кодирование слова, чтобы преобразовать документы в последовательности числовых индексов.
Чтобы создать кодирование слова, используйте wordEncoding
функция.
enc = wordEncoding(documentsTrain);
Следующий шаг преобразования должен заполнить и обрезать документы, таким образом, они являются всеми одинаковыми длина. trainingOptions
функция предоставляет возможности заполнять и обрезать входные последовательности автоматически. Однако эти опции не хорошо подходят для последовательностей векторов слова. Вместо этого клавиатура и усеченный последовательности вручную. Если вы лево-заполняете и обрезаете последовательности векторов слова, то учебная сила улучшается.
Чтобы заполнить и обрезать документы, сначала выберите целевую длину, и затем обрежьте документы, которые более длинны, чем она и лево-заполняют документы, которые короче, чем она. Для лучших результатов целевая длина должна быть короткой, не отбрасывая большие объемы данных. Чтобы найти подходящую целевую длину, просмотрите гистограмму длин учебного материала.
documentLengths = doclength(documentsTrain); figure histogram(documentLengths) title("Document Lengths") xlabel("Length") ylabel("Number of Documents")
Большинство учебных материалов имеет меньше чем 75 лексем. Используйте это в качестве своей целевой длины для усечения и дополнения.
Преобразуйте документы последовательностям числовых индексов с помощью doc2sequence
. Чтобы обрезать или лево-заполнить последовательности, чтобы иметь длину 75, установите 'Length'
опция к 75.
XTrain = doc2sequence(enc,documentsTrain,'Length',75);
XTrain(1:5)
ans=5×1 cell
{1×75 double}
{1×75 double}
{1×75 double}
{1×75 double}
{1×75 double}
Преобразуйте документы валидации последовательностям с помощью тех же опций.
XValidation = doc2sequence(enc,documentsValidation,'Length',75);
Задайте архитектуру сети LSTM. Чтобы ввести данные о последовательности в сеть, включайте входной слой последовательности и установите входной размер на 1. Затем включайте слой встраивания слова размерности 100 и то же количество слов как кодирование слова. Затем включайте слой LSTM и определите номер скрытых модулей к 180. Чтобы использовать слой LSTM в проблеме классификации последовательностей к метке, установите режим вывода на 'last'
. Наконец, добавьте полносвязный слой с тем же размером как количество классов, softmax слоя и слоя классификации.
inputSize = 1; embeddingDimension = 100; numWords = enc.NumWords; numHiddenUnits = 180; numClasses = numel(categories(YTrain)); layers = [ ... sequenceInputLayer(inputSize) wordEmbeddingLayer(embeddingDimension,numWords) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]
layers = 6x1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' Word Embedding Layer Word embedding layer with 100 dimensions and 16954 unique words 3 '' LSTM LSTM with 180 hidden units 4 '' Fully Connected 39 fully connected layer 5 '' Softmax softmax 6 '' Classification Output crossentropyex
Задайте опции обучения. Установите решатель на 'adam'
, обучайтесь в течение 10 эпох и установите порог градиента к 1. Установите начальную букву, изучают уровень 0,01. Чтобы контролировать процесс обучения, установите 'Plots'
опция к 'training-progress'
. Задайте данные о валидации с помощью 'ValidationData'
опция. Чтобы подавить многословный выход, установите 'Verbose'
к false
.
По умолчанию, trainNetwork
использует графический процессор, если вы доступны (требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше). В противном случае это использует центральный процессор. Чтобы задать среду выполнения вручную, используйте 'ExecutionEnvironment'
аргумент пары "имя-значение" trainingOptions
. Обучение на центральном процессоре может взять значительно дольше, чем обучение на графическом процессоре.
options = trainingOptions('adam', ... 'MaxEpochs',10, ... 'GradientThreshold',1, ... 'InitialLearnRate',0.01, ... 'ValidationData',{XValidation,YValidation}, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть LSTM с помощью trainNetwork
функция.
net = trainNetwork(XTrain,YTrain,layers,options);
Чтобы протестировать сеть LSTM, сначала подготовьте тестовые данные таким же образом как обучающие данные. Затем сделайте прогнозы на предварительно обработанных тестовых данных с помощью обученной сети LSTM net
.
Предварительно обработайте тестовые данные с помощью тех же шагов в качестве учебных материалов.
textDataTest = lower(textDataTest); documentsTest = tokenizedDocument(textDataTest); documentsTest = erasePunctuation(documentsTest);
Преобразуйте тестовые документы последовательностям с помощью doc2sequence
с теми же опциями, создавая обучающие последовательности.
XTest = doc2sequence(enc,documentsTest,'Length',75);
XTest(1:5)
ans=5×1 cell
{1×75 double}
{1×75 double}
{1×75 double}
{1×75 double}
{1×75 double}
Классифицируйте тестовые документы с помощью обученной сети LSTM.
YPred = classify(net,XTest);
Вычислите точность классификации. Точность является пропорцией меток, которые сеть предсказывает правильно.
accuracy = sum(YPred == YTest)/numel(YPred)
accuracy = 0.8684
Классифицируйте тип события трех новых прогнозов погоды. Создайте массив строк, содержащий новые прогнозы погоды.
reportsNew = [ ... "Lots of water damage to computer equipment inside the office." "A large tree is downed and blocking traffic outside Apple Hill." "Damage to many car windshields in parking lot."];
Предварительно обработайте текстовые данные с помощью шагов предварительной обработки в качестве учебных материалов.
documentsNew = preprocessText(reportsNew);
Преобразуйте текстовые данные в последовательности с помощью doc2sequence
с теми же опциями, создавая обучающие последовательности.
XNew = doc2sequence(enc,documentsNew,'Length',75);
Классифицируйте новые последовательности с помощью обученной сети LSTM.
[labelsNew,score] = classify(net,XNew);
Покажите прогнозы погоды с их предсказанными метками.
[reportsNew string(labelsNew)]
ans = 3×2 string array
"Lots of water damage to computer equipment inside the office." "Flash Flood"
"A large tree is downed and blocking traffic outside Apple Hill." "Thunderstorm Wind"
"Damage to many car windshields in parking lot." "Hail"
Функциональный preprocessText
выполняет эти шаги:
Маркируйте текст с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Сотрите пунктуацию с помощью erasePunctuation
.
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
doc2sequence
| fastTextWordEmbedding
| lstmLayer
| sequenceInputLayer
| tokenizedDocument
| trainNetwork
| trainingOptions
| wordEmbeddingLayer
| wordcloud