В этом примере показано, как создать простую сеть классификации долгой краткосрочной памяти (LSTM) использование Deep Network Designer.
Чтобы обучить глубокую нейронную сеть классифицировать данные о последовательности, можно использовать сеть LSTM. Сеть LSTM является типом рекуррентной нейронной сети (RNN), которая изучает долгосрочные зависимости между временными шагами данных о последовательности.
Пример демонстрирует как:
Загрузите данные о последовательности.
Создайте сетевую архитектуру.
Задайте опции обучения.
Обучите сеть.
Предскажите метки новых данных и вычислите точность классификации.
Загрузите японский набор данных Гласных, как описано в [1] и [2]. Предикторы являются массивами ячеек, содержащими последовательности различной длины с размерностью признаков 12. Метки являются категориальными векторами из меток 1,2..., 9.
[XTrain,YTrain] = japaneseVowelsTrainData; [XValidation,YValidation] = japaneseVowelsTestData;
Просмотрите размеры первых нескольких обучающих последовательностей. Последовательности являются матрицами с 12 строками (одна строка для каждого признака) и различным количеством столбцов (один столбец для каждого временного шага).
XTrain(1:5)
ans=5×1 cell array
{12×20 double}
{12×26 double}
{12×22 double}
{12×20 double}
{12×21 double}
Открытый Deep Network Designer.
deepNetworkDesigner
Сделайте паузу на Последовательности к метке и нажмите Open. Это открывает предварительно созданную сеть, подходящую для проблем классификации последовательностей.
Deep Network Designer отображает предварительно созданную сеть.
Можно легко адаптировать эту сеть последовательности к японскому набору данных Гласных.
Выберите sequenceInputLayer и проверяйте, что InputSize собирается в 12 совпадать с размерностью признаков.
Выберите lstmLayer и установите NumHiddenUnits на 100.
Выберите fullyConnectedLayer и проверяйте, что OutputSize установлен в 9, количество классов.
Чтобы проверять сеть и исследовать больше деталей слоев, нажмите Analyze.
Чтобы экспортировать сетевую архитектуру в рабочую область, на вкладке Designer, нажимают Export. Deep Network Designer сохраняет сеть как переменную layers_1
.
Можно также сгенерировать код, чтобы создать сетевую архитектуру путем выбора Export> Generate Code.
Задайте опции обучения и обучите сеть.
Поскольку мини-пакеты малы с короткими последовательностями, центральный процессор лучше подходит для обучения. Установите 'ExecutionEnvironment'
к 'cpu'
. Чтобы обучаться на графическом процессоре, при наличии, устанавливает 'ExecutionEnvironment'
к 'auto'
(значение по умолчанию).
miniBatchSize = 27; options = trainingOptions('adam', ... 'ExecutionEnvironment','cpu', ... 'MaxEpochs',100, ... 'MiniBatchSize',miniBatchSize, ... 'ValidationData',{XValidation,YValidation}, ... 'GradientThreshold',2, ... 'Shuffle','every-epoch', ... 'Verbose',false, ... 'Plots','training-progress');
Обучите сеть.
net = trainNetwork(XTrain,YTrain,layers_1,options);
Можно также обучить эту сеть с помощью объектов datastore и Deep Network Designer. Для примера, показывающего, как обучить сеть регрессии от последовательности к последовательности в Deep Network Designer, смотрите, Обучат сеть для Прогнозирования Временных рядов Используя Deep Network Designer.
Классифицируйте тестовые данные и вычислите точность классификации. Задайте тот же мини-пакетный размер что касается обучения.
YPred = classify(net,XValidation,'MiniBatchSize',miniBatchSize);
acc = mean(YPred == YValidation)
acc = 0.9405
Для следующих шагов можно попытаться улучшить точность при помощи двунаправленных слоев LSTM (BiLSTM) или путем создания более глубокой сети. Для получения дополнительной информации смотрите Длинные Краткосрочные Сети Памяти.
Для примера, показывающего, как использовать сверточные сети, чтобы классифицировать данные о последовательности, смотрите Распознание речевых команд с использованием глубокого обучения.
[1] Kudo, Mineichi, Юн Тояма и Масару Шимбо. “Многомерная Классификация Кривых Используя Прохождение через области”. Буквы Распознавания образов 20, № 11-13 (ноябрь 1999): 1103–11. https://doi.org/10.1016/S0167-8655 (99) 00077-X.
[2] Kudo, Mineichi, Юн Тояма и Масару Шимбо. Японский Набор данных Гласных. Распределенный Репозиторием Машинного обучения UCI. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
trainingOptions
| trainNetwork
| lstmLayer