exponenta event banner

matlab.io.datastore.MiniBatchable класс

Пакет: matlab.io.datastore

Добавление поддержки мини-пакетов в хранилище данных

Описание

matlab.io.datastore.MiniBatchable является абстрактным смешанным классом, который добавляет поддержку мини-пакетов в пользовательское хранилище данных для использования с Deep Learning Toolbox™. Мини-хранилище данных содержит обучающие и тестовые наборы данных для использования в обучении, прогнозировании и классификации Deep Learning Toolbox.

Чтобы использовать этот класс mixin, необходимо наследовать от matlab.io.datastore.MiniBatchable в дополнение к наследованию от matlab.io.Datastore базовый класс. Введите следующий синтаксис в качестве первой строки файла определения класса:

classdef MyDatastore < matlab.io.Datastore & ...
                       matlab.io.datastore.MiniBatchable
    ...
end

Чтобы добавить поддержку мини-пакетов в хранилище данных:

  • Наследовать от дополнительного класса matlab.io.datastore.MiniBatchable

  • Определите два дополнительных свойства: MiniBatchSize и NumObservations.

Дополнительные сведения и шаги по созданию собственного мини-хранилища данных для оптимизации производительности во время обучения, прогнозирования и классификации см. в разделе Разработка пользовательского мини-хранилища данных.

Свойства

развернуть все

Количество наблюдений, возвращаемых в каждом пакете, или вызов read функция. Для обучения, прогнозирования и классификации, MiniBatchSize свойство имеет размер мини-пакета, определенный в trainingOptions.

Атрибуты:

Abstracttrue
AccessPublic

Общее число наблюдений, содержащихся в хранилище данных. Такое количество наблюдений составляет продолжительность одной тренировочной эпохи.

Атрибуты:

Abstracttrue
SetAccessProtected
ReadAccessPublic

Признаки

Abstracttrue
Sealedfalse

Сведения об атрибутах класса см. в разделе Атрибуты класса.

Копирование семантики

Ручка. Сведения о том, как классы обработки влияют на операции копирования, см. в разделе Копирование объектов.

Примеры

свернуть все

В этом примере показано, как обучить сеть глубокого обучения данным последовательности без памяти путем преобразования и объединения хранилищ данных.

Преобразованное хранилище данных преобразует или обрабатывает данные, считанные из базового хранилища данных. Преобразованное хранилище данных можно использовать в качестве источника наборов данных обучения, проверки, тестирования и прогнозирования для приложений глубокого обучения. Использование преобразованных хранилищ данных для считывания данных из памяти или выполнения определенных операций предварительной обработки при считывании пакетов данных. При наличии отдельных хранилищ данных, содержащих предикторы и метки, их можно объединить для ввода данных в сеть глубокого обучения.

При обучении сети программное обеспечение создает мини-пакеты последовательностей одинаковой длины путем заполнения, усечения или разделения входных данных. Для данных в памяти, trainingOptions функция предоставляет опции для вставки и усечения входных последовательностей, однако для данных, не находящихся в памяти, необходимо вставлять и усекать последовательности вручную.

Загрузка данных обучения

Загрузите набор данных японских гласных, как описано в [1] и [2]. Zip-файл japaneseVowels.zip содержит последовательности различной длины. Последовательности разделены на две папки, Train и Test, которые содержат обучающие последовательности и тестовые последовательности соответственно. В каждой из этих папок последовательности делятся на подпапки, которые нумеруются из 1 кому 9. Имена этих подпапок являются именами меток. MAT-файл представляет каждую последовательность. Каждая последовательность является матрицей с 12 строками, с одной строкой для каждого элемента и различным количеством столбцов, с одним столбцом для каждого временного шага. Число строк - это измерение последовательности, а число столбцов - длина последовательности.

Распакуйте данные последовательности.

filename = "japaneseVowels.zip";
outputFolder = fullfile(tempdir,"japaneseVowels");
unzip(filename,outputFolder);

Для обучающих предикторов создайте хранилище данных файла и укажите функцию чтения, которая будет load функция. load загружает данные из MAT-файла в структурный массив. Для чтения файлов из подпапок в учебной папке установите 'IncludeSubfolders' опция для true.

folderTrain = fullfile(outputFolder,"Train");
fdsPredictorTrain = fileDatastore(folderTrain, ...
    'ReadFcn',@load, ...
    'IncludeSubfolders',true);

Предварительный просмотр хранилища данных. Возвращенная структура содержит одну последовательность из первого файла.

preview(fdsPredictorTrain)
ans = struct with fields:
    X: [12×20 double]

Для меток создайте хранилище данных файла и укажите функцию чтения, которая будет readLabel , определенной в конце примера. readLabel извлекает метку из имени подпапки.

classNames = string(1:9);
fdsLabelTrain = fileDatastore(folderTrain, ...
    'ReadFcn',@(filename) readLabel(filename,classNames), ...
    'IncludeSubfolders',true);

Предварительный просмотр хранилища данных. Выходные данные соответствуют метке первого файла.

preview(fdsLabelTrain)
ans = categorical
     1 

Преобразование и объединение хранилищ данных

Чтобы ввести данные последовательности из хранилища данных предикторов в сеть глубокого обучения, мини-пакеты последовательностей должны иметь одинаковую длину. Преобразование хранилища данных с помощью padSequence функция, определенная в конце хранилища данных, которая помещает или усекает последовательности, имеющие длину 20.

sequenceLength = 20;
tdsTrain = transform(fdsPredictorTrain,@(data) padSequence(data,sequenceLength));

Предварительный просмотр преобразованного хранилища данных. Выходные данные соответствуют дополненной последовательности из первого файла.

X = preview(tdsTrain)
X = 1×1 cell array
    {12×20 double}

Чтобы ввести как предикторы, так и метки из обоих хранилищ данных в сеть глубокого обучения, объедините их с помощью combine функция.

cdsTrain = combine(tdsTrain,fdsLabelTrain);

Предварительный просмотр объединенного хранилища данных. Хранилище данных возвращает массив ячеек 1 на 2. Первый элемент соответствует предикторам. Второй элемент соответствует метке.

preview(cdsTrain)
ans = 1×2 cell array
    {12×20 double}    {[1]}

Определение сетевой архитектуры LSTM

Определите архитектуру сети LSTM. Укажите количество элементов входных данных в качестве размера входных данных. Укажите уровень LSTM со 100 скрытыми единицами измерения и для вывода последнего элемента последовательности. Наконец, укажите полностью подключенный уровень с размером выхода, равным числу классов, за которым следуют уровень softmax и уровень классификации.

numFeatures = 12;
numClasses = numel(classNames);
numHiddenUnits = 100;

layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Укажите параметры обучения. Задайте для решателя значение 'adam' и 'GradientThreshold' на 2. Установите размер мини-партии равным 27 и максимальное количество периодов равным 75. Хранилища данных не поддерживают тасование, поэтому установите 'Shuffle' кому 'never'.

Поскольку мини-пакеты малы с короткими последовательностями, ЦП лучше подходит для обучения. Набор 'ExecutionEnvironment' кому 'cpu'. Обучение на GPU, если доступно, установить 'ExecutionEnvironment' кому 'auto' (значение по умолчанию).

miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',75, ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',2, ...
    'Shuffle','never',...
    'Verbose',0, ...
    'Plots','training-progress');

Обучение сети LSTM с указанными вариантами обучения.

net = trainNetwork(cdsTrain,layers,options);

Тестирование сети

Создайте преобразованное хранилище данных, содержащее задержанные тестовые данные, используя те же шаги, что и для учебных данных.

folderTest = fullfile(outputFolder,"Test");

fdsPredictorTest = fileDatastore(folderTest, ...
    'ReadFcn',@load, ...
    'IncludeSubfolders',true);
tdsTest = transform(fdsPredictorTest,@(data) padSequence(data,sequenceLength));

Составление прогнозов по данным тестирования с использованием обученной сети.

YPred = classify(net,tdsTest,'MiniBatchSize',miniBatchSize);

Рассчитайте точность классификации по данным теста. Чтобы получить метки тестового набора, создайте хранилище данных файла с функцией чтения readLabel и укажите, чтобы включить вложенные папки. Укажите, что выходы являются вертикально сцепляемыми, задав значение 'UniformRead' опция для true.

fdsLabelTest = fileDatastore(folderTest, ...
    'ReadFcn',@(filename) readLabel(filename,classNames), ...
    'IncludeSubfolders',true, ...
    'UniformRead',true);
YTest = readall(fdsLabelTest);
accuracy = mean(YPred == YTest)
accuracy = 0.9351

Функции

readLabel извлекает метку из указанного имени файла по категориям в classNames.

function label = readLabel(filename,classNames)

filepath = fileparts(filename);
[~,label] = fileparts(filepath);

label = categorical(string(label),classNames);

end

padSequence панель функций или усечение последовательности в data.X чтобы иметь заданную длину последовательности и возвращает результат в ячейке 1 на 1.

function sequence = padSequence(data,sequenceLength)

sequence = data.X;
[C,S] = size(sequence);

if S < sequenceLength
    padding = zeros(C,sequenceLength-S);
    sequence = [sequence padding];
else
    sequence = sequence(:,1:sequenceLength);
end

sequence = {sequence};

end

Вопросы совместимости

развернуть все

Не рекомендуется начинать с R2019a

Ссылки

[1] Кудо, М., Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием сквозных областей». Буквы распознавания образов. т. 20, № 11-13, стр. 1103-1111.

[2] Кудо, М., Дж. Тояма и М. Симбо. Набор данных гласных на японском языке. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Представлен в R2018a