Создайте сеть классификации простых последовательностей с помощью Deep Network Designer

В этом примере показано, как создать простую сеть классификации долгой краткосрочной памяти (LSTM) с помощью Deep Network Designer.

Чтобы обучить глубокую нейронную сеть классифицировать данные последовательности, можно использовать сеть LSTM. Сеть LSTM является типом рекуррентной нейронной сети (RNN), который изучает долгосрочные зависимости между временными шагами данных последовательности.

Пример демонстрирует, как:

  • Загрузите данные последовательности.

  • Создайте сетевую архитектуру.

  • Задайте опции обучения.

  • Обучите сеть.

  • Спрогнозируйте метки новых данных и вычислите точность классификации.

Загрузка данных

Загрузите набор данных японских гласных, как описано в [1] и [2]. Предикторы являются массивами ячеек, содержащими последовательности различной длины с размерностью признаков 12. Метки являются категориальными векторами меток 1,2,..., 9.

[XTrain,YTrain] = japaneseVowelsTrainData;
[XValidation,YValidation] = japaneseVowelsTestData;

Просмотрите размеры первых нескольких обучающих последовательностей. Последовательности являются матрицами с 12 строками (по одной строке для каждой функции) и меняющимся количеством столбцов (по одному столбцу для каждого временного шага).

XTrain(1:5)
ans=5×1 cell array
    {12×20 double}
    {12×26 double}
    {12×22 double}
    {12×20 double}
    {12×21 double}

Определение сетевой архитектуры

Откройте Deep Network Designer.

deepNetworkDesigner

Пауза на Sequence-to-Label и нажмите Open. Это открывает предварительно построенную сеть, подходящую для задач классификации последовательностей.

Deep Network Designer отображает предварительно построенную сеть.

Вы можете легко адаптировать эту сеть последовательности для набора данных японских гласных.

Выберите sequenceInputLayer и проверьте, что значение InputSize установлено на 12, чтобы соответствовать размерности признаков.

Выберите lstmLayer и установите значение NumHiddenUnits равным 100.

Выберите fullyConnectedLayer и проверьте, что значение OutputSize установлено равным 9, количеству классов.

Проверяйте сетевую архитектуру

Чтобы проверить сеть и изучить более подробную информацию о слоях, нажмите Анализировать.

Экспорт сетевой архитектуры

Чтобы экспортировать сетевую архитектуру в рабочую область, на вкладке Designer, нажмите Экспорт. Deep Network Designer сохраняет сеть как переменную layers_1.

Можно также сгенерировать код для создания сетевой архитектуры, выбрав Экспорт > Сгенерировать код.

Обучите сеть

Укажите опции обучения и обучите сеть.

Поскольку мини-пакеты являются маленькими с короткими последовательностями, центральный процессор лучше подходит для обучения. Задайте 'ExecutionEnvironment' на 'cpu'. Для обучения на графическом процессоре, при наличии, установите 'ExecutionEnvironment' на 'auto' (значение по умолчанию).

miniBatchSize = 27;
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'ValidationData',{XValidation,YValidation}, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучите сеть.

net = trainNetwork(XTrain,YTrain,layers_1,options);

Тестирование сети

Классификация тестовых данных и вычисление точности классификации. Укажите тот же размер мини-пакета, что и для обучения.

YPred = classify(net,XValidation,'MiniBatchSize',miniBatchSize);
acc = mean(YPred == YValidation)
acc = 0.9432

Для следующих шагов можно попробовать улучшить точность при помощи двунаправленных слоев LSTM (BiLSTM) или путем создания более глубокой сети. Для получения дополнительной информации см. Раздел «Длинные краткосрочные сети памяти»

Для примера, показывающего, как использовать сверточные сети для классификации данных последовательности, смотрите Распознание речевых команд с использованием глубокого обучения.

Ссылки

[1] Кудо, Минейчи, Цзюнь Тояма и Масару Симбо. «Многомерная классификация кривых с использованием областей». Pattern Recognition Letters 20, No. 11-13 (November 1999): 1103-11. https://doi.org/10.1016/S0167-8655 (99) 00077-X.

[2] Кудо, Минейчи, Цзюнь Тояма и Масару Симбо. Японский набор данных гласных. Распространяется UCI Machine Learning Repository. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

См. также

Похожие темы