Классификация последовательностей Используя глубокое обучение

В этом примере показано, как классифицировать данные о последовательности с помощью сети долгой краткосрочной памяти (LSTM).

Чтобы обучить глубокую нейронную сеть классифицировать данные о последовательности, можно использовать сеть LSTM. Сеть LSTM позволяет вам ввести данные о последовательности в сеть и сделать предсказания на основе отдельных временных шагов данных о последовательности.

Этот пример использует японский набор данных Гласных как описано в [1] и [2]. Этот пример обучает сеть LSTM, чтобы распознать динамик, данный данные временных рядов, представляющие два японских гласные, на которых говорят по очереди. Обучающие данные содержат данные временных рядов для девяти динамиков. Каждая последовательность имеет 12 функций и варьируется по длине. Набор данных содержит 270 учебных наблюдений и 370 тестовых наблюдений.

Загрузите данные о последовательности

Загрузите японские обучающие данные Гласных. XTrain массив ячеек, содержащий 270 последовательностей размерности 12 из различной длины. Y категориальный вектор из меток "1", "2"..., "9", которые соответствуют этим девяти динамикам. Записи в XTrain матрицы с 12 строками (одна строка для каждого признака) и различным количеством столбцов (один столбец для каждого временного шага).

[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)
ans=5×1 cell array
    {12x20 double}
    {12x26 double}
    {12x22 double}
    {12x20 double}
    {12x21 double}

Визуализируйте первые временные ряды в графике. Каждая линия соответствует функции.

figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
numFeatures = size(XTrain{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Figure contains an axes object. The axes object with title Training Observation 1 contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

Подготовка данных для дополнения

Во время обучения, по умолчанию, программное обеспечение разделяет обучающие данные в мини-пакеты и заполняет последовательности так, чтобы у них была та же длина. Слишком много дополнения может оказать негативное влияние на производительность сети.

Чтобы препятствовать тому, чтобы учебный процесс добавил слишком много дополнения, можно отсортировать обучающие данные по длине последовательности и выбрать мини-пакетный размер так, чтобы последовательности в мини-пакете имели подобную длину. Следующий рисунок показывает эффект дополнения последовательностей до и после сортировки данных.

Получите длины последовательности для каждого наблюдения.

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

Сортировка данных длиной последовательности.

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

Просмотрите отсортированные длины последовательности в столбчатой диаграмме.

figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

Figure contains an axes object. The axes object with title Sorted Data contains an object of type bar.

Выберите мини-пакетный размер 27, чтобы разделить обучающие данные равномерно и уменьшать объем дополнения в мини-пакетах. Следующая фигура иллюстрирует дополнение, добавленное к последовательностям.

miniBatchSize = 27;

Задайте сетевую архитектуру LSTM

Задайте архитектуру сети LSTM. Задайте входной размер, чтобы быть последовательностями размера 12 (размерность входных данных). Задайте двунаправленный слой LSTM с 100 скрытыми модулями и выведите последний элемент последовательности. Наконец, задайте девять классов включением полносвязного слоя размера 9, сопровождаемый softmax слоем и слоем классификации.

Если у вас есть доступ к полным последовательностям во время предсказания, то можно использовать двунаправленный слой LSTM в сети. Двунаправленный слой LSTM извлекает уроки из полной последовательности на каждом временном шаге. Если у вас нет доступа к полной последовательности во время предсказания, например, если вы предсказываете значения или предсказываете один временной шаг за один раз, то используйте слой LSTM вместо этого.

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    bilstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 12 dimensions
     2   ''   BiLSTM                  BiLSTM with 100 hidden units
     3   ''   Fully Connected         9 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

Теперь задайте опции обучения. Задайте решатель, чтобы быть 'adam', порог градиента, чтобы быть 1, и максимальное количество эпох, чтобы быть 100. Чтобы уменьшать объем дополнения в мини-пакетах, выберите мини-пакетный размер 27. Чтобы заполнить данные, чтобы иметь ту же длину как самые длинные последовательности, задайте длину последовательности, чтобы быть 'longest'. Чтобы гарантировать, что данные остаются отсортированными по длине последовательности, задайте, чтобы никогда не переставить данные.

Поскольку мини-пакеты малы с короткими последовательностями, обучение лучше подходит для центрального процессора. Задайте 'ExecutionEnvironment' быть 'cpu'. Чтобы обучаться на графическом процессоре, при наличии, устанавливает 'ExecutionEnvironment' к 'auto' (это - значение по умолчанию).

maxEpochs = 100;
miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

Обучите сеть LSTM

Обучите сеть LSTM с заданными опциями обучения при помощи trainNetwork.

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (25-Aug-2021 07:28:01) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 14 objects of type patch, text, line. Axes object 2 contains 14 objects of type patch, text, line.

Протестируйте сеть LSTM

Загрузите набор тестов и классифицируйте последовательности в динамики.

Загрузите японские тестовые данные Гласных. XTest массив ячеек, содержащий 370 последовательностей размерности 12 из различной длины. YTest категориальный вектор из меток "1", "2"... "9", которые соответствуют этим девяти динамикам.

[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)
ans=3×1 cell array
    {12x19 double}
    {12x17 double}
    {12x19 double}

Сеть LSTM net был обучен с помощью мини-пакетов последовательностей подобной длины. Убедитесь, что тестовые данные организованы таким же образом. Сортировка тестовых данных длиной последовательности.

numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

Классифицируйте тестовые данные. Чтобы уменьшать объем дополнения введенного процессом классификации, установите мини-пакетный размер на 27. Чтобы применить то же дополнение как обучающие данные, задайте длину последовательности, чтобы быть 'longest'.

miniBatchSize = 27;
YPred = classify(net,XTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest');

Вычислите точность классификации предсказаний.

acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9730

Ссылки

[1] М. Кудо, J. Тояма, и М. Шимбо. "Многомерная Классификация Кривых Используя Прохождение через области". Буквы Распознавания образов. Издание 20, № 11-13, страницы 1103-1111.

[2] Репозиторий Машинного обучения UCI: японский Набор данных Гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Смотрите также

| | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте