Создайте простую сеть классификации последовательностей

Этот пример показывает, как создать простую сеть классификации долгой краткосрочной памяти (LSTM).

Чтобы обучить глубокую нейронную сеть классифицировать данные о последовательности, можно использовать сеть LSTM. Сеть LSTM является типом рекуррентной нейронной сети (RNN), которая изучает долгосрочные зависимости между временными шагами данных о последовательности.

Пример демонстрирует как:

  • Загрузите данные о последовательности.

  • Определить сетевую архитектуру.

  • Задайте опции обучения.

  • Обучите сеть.

  • Предскажите метки новых данных и вычислите точность классификации.

Загрузка данных

Загрузите японский набор данных Гласных, как описано в [1] и [2]. Предикторы являются массивами ячеек, содержащими последовательности переменной длины с размерностью признаков 12. Метки являются категориальными векторами меток 1,2..., 9.

[XTrain,YTrain] = japaneseVowelsTrainData;
[XValidation,YValidation] = japaneseVowelsTestData;

Просмотрите размеры первых нескольких обучающих последовательностей. Последовательности являются матрицами с 12 строками (одна строка для каждого признака) и переменным количеством столбцов (один столбец для каждого временного шага).

XTrain(1:5)
ans = 5x1 cell array
    {12x20 double}
    {12x26 double}
    {12x22 double}
    {12x20 double}
    {12x21 double}

Архитектура сети Define

Задайте архитектуру сети LSTM. Задайте количество признаков во входном уровне и количество классов в полносвязном слое.

numFeatures = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Обучение сети

Задайте опции обучения и обучите сеть.

Поскольку мини-пакеты являются маленькими с короткими последовательностями, центральный процессор лучше подходит для обучения. Установите 'ExecutionEnvironment' на 'cpu'. Чтобы обучаться на графическом процессоре, при наличии, устанавливает 'ExecutionEnvironment' на 'auto' (значение по умолчанию).

miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'ValidationData',{XValidation,YValidation}, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'Verbose',false, ...
    'Plots','training-progress');

net = trainNetwork(XTrain,YTrain,layers,options);

Для получения дополнительной информации об определении опций обучения, смотрите Настроенные Параметры и Обучите Сверточную Нейронную сеть.

Тестирование сети

Классифицируйте тестовые данные и вычислите точность классификации. Задайте тот же мини-пакетный размер, используемый для обучения.

YPred = classify(net,XValidation,'MiniBatchSize',miniBatchSize);
acc = mean(YPred == YValidation)
acc = 0.9541

Для следующих шагов можно попытаться улучшить точность при помощи двунаправленных слоев LSTM (BiLSTM) или путем создания более глубокой сети. Для получения дополнительной информации смотрите Длинные Краткосрочные Сети Памяти.

Для примера, показывающего, как использовать сверточные сети, чтобы классифицировать данные о последовательности, смотрите Распознание речевых команд с использованием глубокого обучения.

Ссылки

  1. М. Кудо, J. Тояма, и М. Шимбо. "Многомерная Классификация Кривых Используя Прохождение через области". Буквы Распознавания образов. Издание 20, № 11-13, страницы 1103-1111.

  2. Репозиторий машинного обучения UCI: японский набор данных гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Смотрите также

| |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте