exponenta event banner

Обучение сети с использованием пользовательского мини-хранилища пакетных данных для данных последовательности

В этом примере показано, как обучить сеть глубокого обучения данным последовательности без памяти с помощью пользовательского мини-хранилища пакетных данных.

Мини-хранилище пакетных данных представляет собой реализацию хранилища данных с поддержкой считывания данных в пакетах. Использование хранилищ данных мини-пакетов для считывания данных из памяти или выполнения определенных операций предварительной обработки при считывании пакетов данных. Хранилище данных мини-пакета можно использовать в качестве источника наборов данных обучения, проверки, тестирования и прогнозирования для приложений глубокого обучения.

В этом примере используется пользовательское хранилище данных мини-пакета sequenceDatastore.m. Это хранилище данных можно адаптировать к данным путем настройки функций хранилища данных. Пример создания собственного хранилища данных мини-пакета см. в разделе Разработка хранилища данных мини-пакета.

Загрузка данных обучения

Загрузите набор данных японских гласных, как описано в [1] и [2]. Zip-файл japaneseVowels.zip содержит последовательности различной длины. Последовательности разделены на две папки, Train и Test, которые содержат обучающие последовательности и тестовые последовательности соответственно. В каждой из этих папок последовательности делятся на подпапки, которые нумеруются из 1 кому 9. Имена этих подпапок являются именами меток. MAT-файл представляет каждую последовательность. Каждая последовательность является матрицей с 12 строками, с одной строкой для каждого элемента и различным количеством столбцов, с одним столбцом для каждого временного шага. Число строк - это измерение последовательности, а число столбцов - длина последовательности.

Распакуйте данные последовательности.

filename = "japaneseVowels.zip";
outputFolder = fullfile(tempdir,"japaneseVowels");
unzip(filename,outputFolder);

Создание пользовательского мини-хранилища пакетных данных

Создайте пользовательское хранилище данных мини-пакета. Хранилище данных мини-пакета sequenceDatastore считывает данные из папки и получает метки из имен подпапок. Чтобы использовать это хранилище данных, сначала сохраните файл sequenceDatastore.m к пути.

Создание хранилища данных, содержащего данные последовательности, с помощью sequenceDatastore.

folderTrain = fullfile(outputFolder,"Train");
dsTrain = sequenceDatastore(folderTrain)
dsTrain = 
  sequenceDatastore with properties:

            Datastore: [1×1 matlab.io.datastore.FileDatastore]
               Labels: [270×1 categorical]
           NumClasses: 9
    SequenceDimension: 12
        MiniBatchSize: 128
      NumObservations: 270

Определение сетевой архитектуры LSTM

Определите архитектуру сети LSTM. Укажите размер последовательности входных данных в качестве размера входных данных. Укажите уровень LSTM со 100 скрытыми единицами измерения и для вывода последнего элемента последовательности. Наконец, укажите полностью подключенный уровень с размером выхода, равным числу классов, за которым следуют уровень softmax и уровень классификации.

inputSize = dsTrain.SequenceDimension;
numClasses = dsTrain.NumClasses;
numHiddenUnits = 100;
layers = [
    sequenceInputLayer(inputSize)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Укажите параметры обучения. Определить 'adam' в качестве решателя и 'GradientThreshold' как 1. Установите размер мини-партии равным 27 и максимальное количество периодов равным 75. Чтобы убедиться, что хранилище данных создает мини-пакеты такого размера, как trainNetwork ожидается функция, также установите размер мини-пакета хранилища данных на то же значение.

Поскольку мини-пакеты малы с короткими последовательностями, ЦП лучше подходит для обучения. Набор 'ExecutionEnvironment' кому 'cpu'. Обучение на GPU, если доступно, установить 'ExecutionEnvironment' кому 'auto' (значение по умолчанию).

miniBatchSize = 27;
options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',75, ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',1, ...
    'Verbose',0, ...
    'Plots','training-progress');
dsTrain.MiniBatchSize = miniBatchSize;

Обучение сети LSTM с указанными вариантами обучения.

net = trainNetwork(dsTrain,layers,options);

Тестирование сети

Создайте хранилище данных последовательности из тестовых данных.

folderTest = fullfile(outputFolder,"Test");
dsTest = sequenceDatastore(folderTest);

Классифицируйте данные теста. Укажите тот же размер мини-партии, что и для данных обучения. Чтобы убедиться, что хранилище данных создает мини-пакеты такого размера, как classify ожидается функция, также установите размер мини-пакета хранилища данных на то же значение.

dsTest.MiniBatchSize = miniBatchSize;
YPred = classify(net,dsTest,'MiniBatchSize',miniBatchSize);

Вычислите точность классификации прогнозов.

YTest = dsTest.Labels;
acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9432

Ссылки

[1] Кудо, М., Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием сквозных областей». Буквы распознавания образов. т. 20, № 11-13, стр. 1103-1111.

[2] Кудо, М., Дж. Тояма и М. Симбо. Набор данных гласных на японском языке. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

См. также

| | |

Связанные темы