Обучите сеть с помощью пользовательского цикла обучения

В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с пользовательским расписанием скорости обучения.

Если trainingOptions не предоставляет необходимые опции (для примера, пользовательское расписание скорости обучения), тогда можно задать свой собственный пользовательский цикл обучения с помощью автоматической дифференциации.

Этот пример обучает сеть классифицировать рукописные цифры с основанным на времени расписанием скорости обучения с распадом: для каждой итерации решатель использует скорость обучения, заданную как ρt=ρ01+kt, где t - число итерации, ρ0 является начальной скоростью обучения, и k является распадом.

Загрузка обучающих данных

Загрузите данные цифр в виде datastore изображений с помощью imageDatastore и укажите папку, содержащую данные изображения.

dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(dataFolder, ...
    'IncludeSubfolders',true, ....
    'LabelSource','foldernames');

Разделите данные на наборы для обучения и валидации. Отложите 10% данных для валидации с помощью splitEachLabel функция.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,'randomize');

Сеть, используемая в этом примере, требует изображений входа размера 28 на 28 на 1. Чтобы автоматически изменить размер обучающих изображений, используйте дополненный image datastore. Задайте дополнительные операции увеличения для выполнения на обучающих изображениях: случайным образом переведите изображения до 5 пикселей в горизонтальной и вертикальной осях. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.

inputSize = [28 28 1];
pixelRange = [-5 5];
imageAugmenter = imageDataAugmenter( ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшего увеличения данных, используйте хранилище datastore с дополненными изображениями, не задавая никаких дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Определите количество классов в обучающих данных.

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

Определение сети

Определите сеть для классификации изображений.

layers = [
    imageInputLayer(inputSize,'Normalization','none','Name','input')
    convolution2dLayer(5,20,'Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект из графика слоев.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}

Задайте функцию градиентов модели

Создайте функцию modelGradients, перечисленный в конце примера, который принимает dlnetwork объект, мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно настраиваемых параметров в сети и соответствующих потерь.

Настройка опций обучения

Обучайте на десять эпох с мини-партией размером 128.

numEpochs = 10;
miniBatchSize = 128;

Задайте опции для оптимизации SGDM. Задайте начальную скорость обучения 0,01 с распадом 0,01 и импульсом 0,9.

initialLearnRate = 0.01;
decay = 0.01;
momentum = 0.9;

Обучите модель

Создайте minibatchqueue объект, который обрабатывает и управляет мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch (определено в конце этого примера), чтобы преобразовать метки в переменные с кодировкой с одним горячим контактом.

  • Форматируйте данные изображения с помощью меток размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию в minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат к меткам классов.

  • Обучите на графическом процессоре, если он доступен. По умолчанию в minibatchqueue объект преобразует каждый выход в gpuArray при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

mbq = minibatchqueue(augimdsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте параметр скорости для решателя SGDM.

velocity = [];

Обучите сеть с помощью пользовательского цикла обучения. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. Для каждого мини-пакета:

  • Оцените градиенты модели, состояние и потери с помощью dlfeval и modelGradients функционирует и обновляет состояние сети.

  • Определите скорость обучения для основанного на времени расписания скорости обучения с распадом.

  • Обновляйте параметры сети с помощью sgdmupdate функция.

  • Отображение процесса обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX, dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        dlnet.State = state;
        
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Экспериментальная модель

Протестируйте классификационную точность модели путем сравнения предсказаний на наборе валидации с истинными метками.

После обучения создание предсказаний по новым данным не требует меток. Создание minibatchqueue объект, содержащий только предикторы тестовых данных:

  • Чтобы игнорировать метки для проверки, установите количество выходов мини-очереди пакетов равным 1.

  • Укажите тот же размер мини-пакета, что и для обучения.

  • Предварительно обработайте предикторы, используя preprocessMiniBatchPredictors функции, перечисленной в конце примера.

  • Для одинарного выхода datastore задайте формат пакета 'SSCB' (пространственный, пространственный, канальный, пакетный).

numOutputs = 1;
mbqTest = minibatchqueue(augimdsValidation,numOutputs, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatchPredictors, ...
    'MiniBatchFormat','SSCB');

Закольцовывайте мини-пакеты и классифицируйте изображения с помощью modelPredictions функции, перечисленной в конце примера.

predictions = modelPredictions(dlnet,mbqTest,classes);

Оцените точность классификации.

YTest = imdsValidation.Labels;
accuracy = mean(predictions == YTest)
accuracy = 0.9530

Функция градиентов модели

The modelGradients функция принимает dlnetwork dlnet объектамини-пакет входных данных dlX с соответствующими метками Y и возвращает градиенты потерь относительно настраиваемых параметров в dlnet, состояние сети и потери. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,state,loss] = modelGradients(dlnet,dlX,Y)

[dlYPred,state] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));

end

Функция предсказаний модели

The modelPredictions функция принимает dlnetwork dlnet объекта, а minibatchqueue входных данных mbq, и сетевых классов, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функция для поиска предсказанного класса с самым высоким счетом.

function predictions = modelPredictions(dlnet,mbq,classes)

predictions = [];

while hasdata(mbq)
    
    dlXTest = next(mbq);
    dlYPred = predict(dlnet,dlXTest);
    
    YPred = onehotdecode(dlYPred,classes,1)';
    
    predictions = [predictions; YPred];
end

end

Функция мини-пакетной предварительной обработки

The preprocessMiniBatch функция предварительно обрабатывает мини-пакет предикторов и меток с помощью следующих шагов:

  1. Предварительно обработайте изображения с помощью preprocessMiniBatchPredictors функция.

  2. Извлеките данные метки из входящего массива ячеек и сгруппируйте в категориальный массив по второму измерению.

  3. Однократное кодирование категориальных меток в числовые массивы. Кодирование в первую размерность создает закодированный массив, который совпадает с формой выходного сигнала сети.

function [X,Y] = preprocessMiniBatch(XCell,YCell)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate.
Y = cat(2,YCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,1);

end

Функция предварительной обработки мини-пакетных предикторов

The preprocessMiniBatchPredictors функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из массива входа ячеек и конкатенации в числовой массив. Для входов полутонового цвета, конкатенация по четвертому измерению добавляет третье измерение каждому изображению, чтобы использовать в качестве размерности синглтонного канала.

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});

end

См. также

| | | | | | | | |

Похожие темы