exponenta event banner

Сеть поездов, использующая индивидуальный цикл обучения

В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с помощью пользовательского графика обучения.

Если trainingOptions не предоставляет необходимых опций (например, пользовательский график обучения), то можно определить собственный пользовательский цикл обучения с помощью автоматического дифференцирования.

Этот пример обучает сеть классифицировать рукописные цифры по расписанию скорости обучения затуханию, основанному на времени: для каждой итерации решатель использует скорость обучения, заданную, в частности, в отношении, где t - число итераций, α0 - начальная скорость обучения, а k - затухание.

Загрузка данных обучения

Загрузите данные цифр как хранилище данных изображения с помощью imageDatastore и укажите папку, содержащую данные изображения.

dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(dataFolder, ...
    'IncludeSubfolders',true, ....
    'LabelSource','foldernames');

Разбейте данные на обучающие и проверочные наборы. Отложите 10% данных для проверки с помощью splitEachLabel функция.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,'randomize');

Сеть, используемая в этом примере, требует входных изображений размера 28 на 28 на 1. Для автоматического изменения размеров учебных изображений используйте хранилище данных дополненного изображения. Укажите дополнительные операции увеличения, выполняемые на обучающих изображениях: случайное перемещение изображений до 5 пикселей в горизонтальной и вертикальной осях. Увеличение объема данных помогает предотвратить переоборудование сети и запоминание точных деталей обучающих изображений.

inputSize = [28 28 1];
pixelRange = [-5 5];
imageAugmenter = imageDataAugmenter( ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);

Чтобы автоматически изменять размер изображений проверки без дальнейшего увеличения данных, используйте хранилище данных дополненного изображения без указания дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Определите количество занятий в данных обучения.

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

Определение сети

Определите сеть для классификации изображений.

layers = [
    imageInputLayer(inputSize,'Normalization','none','Name','input')
    convolution2dLayer(5,20,'Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создать dlnetwork объект из графа слоев.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}

Определение функции градиентов модели

Создание функции modelGradients, перечисленных в конце примера, который принимает dlnetwork объект, мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно обучаемых параметров в сети и соответствующих потерь.

Укажите параметры обучения

Поезд на десять эпох с размером мини-партии 128.

numEpochs = 10;
miniBatchSize = 128;

Укажите параметры оптимизации SGDM. Укажите начальную скорость обучения 0,01 с распадом 0,01, а импульс 0,9.

initialLearnRate = 0.01;
decay = 0.01;
momentum = 0.9;

Модель поезда

Создать minibatchqueue объект, обрабатывающий и управляющий мини-партиями изображений во время обучения. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определяется в конце этого примера) для преобразования меток в однокодированные переменные.

  • Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат к меткам класса.

  • Обучение на GPU, если он доступен. По умолчанию minibatchqueue объект преобразует каждый вывод в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

mbq = minibatchqueue(augimdsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Инициализируйте график хода обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте параметр скорости для решателя SGDM.

velocity = [];

Обучение сети с использованием пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Для каждой мини-партии:

  • Оцените градиенты модели, состояние и потери с помощью dlfeval и modelGradients и обновить состояние сети.

  • Определите скорость обучения для графика скорости обучения с учетом времени.

  • Обновление параметров сети с помощью sgdmupdate функция.

  • Просмотрите ход обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX, dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        dlnet.State = state;
        
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Тестовая модель

Проверьте точность классификации модели путем сравнения прогнозов на наборе проверки с истинными метками.

После обучения составление прогнозов по новым данным не требует меток. Создать minibatchqueue объект, содержащий только предикторы тестовых данных:

  • Чтобы игнорировать метки для тестирования, установите число выходов мини-пакетной очереди равным 1.

  • Укажите размер мини-партии, используемый для обучения.

  • Предварительная обработка предикторов с помощью preprocessMiniBatchPredictors функция, перечисленная в конце примера.

  • Для одиночного вывода хранилища данных укажите формат мини-пакета. 'SSCB' (пространственный, пространственный, канальный, пакетный).

numOutputs = 1;
mbqTest = minibatchqueue(augimdsValidation,numOutputs, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatchPredictors, ...
    'MiniBatchFormat','SSCB');

Закольцовывать мини-пакеты и классифицировать изображения с помощью modelPredictions функция, перечисленная в конце примера.

predictions = modelPredictions(dlnet,mbqTest,classes);

Оцените точность классификации.

YTest = imdsValidation.Labels;
accuracy = mean(predictions == YTest)
accuracy = 0.9530

Функция градиентов модели

modelGradients функция принимает dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствующими метками Y и возвращает градиенты потерь относительно обучаемых параметров в dlnet, состояние сети и потери. Для автоматического вычисления градиентов используйте dlgradient функция.

function [gradients,state,loss] = modelGradients(dlnet,dlX,Y)

[dlYPred,state] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));

end

Функция прогнозирования модели

modelPredictions функция принимает dlnetwork объект dlnet, a minibatchqueue входных данных mbqи сетевые классы, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функция для поиска прогнозируемого класса с наивысшим баллом.

function predictions = modelPredictions(dlnet,mbq,classes)

predictions = [];

while hasdata(mbq)
    
    dlXTest = next(mbq);
    dlYPred = predict(dlnet,dlXTest);
    
    YPred = onehotdecode(dlYPred,classes,1)';
    
    predictions = [predictions; YPred];
end

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция предварительно обрабатывает мини-пакет предикторов и меток, используя следующие шаги:

  1. Предварительная обработка изображений с помощью preprocessMiniBatchPredictors функция.

  2. Извлеките данные метки из входящего массива ячеек и объедините их в категориальный массив вдоль второго измерения.

  3. Одноконтурное кодирование категориальных меток в числовые массивы. Кодирование в первом измерении создает закодированный массив, который соответствует форме сетевого вывода.

function [X,Y] = preprocessMiniBatch(XCell,YCell)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate.
Y = cat(2,YCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,1);

end

Функция предварительной обработки мини-пакетных предикторов

preprocessMiniBatchPredictors функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из массива входных ячеек и конкатенации в числовой массив. Для ввода в оттенках серого при конкатенации над четвертым размером к каждому изображению добавляется третий размер для использования в качестве размера одиночного канала.

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});

end

См. также

| | | | | | | | |

Связанные темы