Обучите сеть Используя пользовательский учебный цикл

В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с пользовательским расписанием темпа обучения.

Если trainingOptions не предоставляет возможности, в которых вы нуждаетесь (например, пользовательское расписание темпа обучения), затем можно задать собственный учебный цикл с помощью автоматического дифференцирования.

Этот пример обучает сеть, чтобы классифицировать рукописные цифры с основанным на времени расписанием темпа обучения затухания: для каждой итерации решатель использует темп обучения, данный ρt=ρ01+kt, где t является номером итерации, ρ0 начальный темп обучения, и k является затуханием.

Загрузите обучающие данные

Загрузите данные о цифрах.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Сеть Define

Задайте сеть и задайте среднее изображение с помощью 'Mean' опция в изображении ввела слой.

layers = [
    imageInputLayer([28 28 1], 'Name', 'input', 'Mean', mean(XTrain,4))
    convolution2dLayer(5, 20, 'Name', 'conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv2')
    reluLayer('Name', 'relu2')
    convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv3')
    reluLayer('Name', 'relu3')
    fullyConnectedLayer(numClasses, 'Name', 'fc')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект из графика слоя.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [8×1 nnet.cnn.layer.Layer]
    Connections: [7×2 table]
     Learnables: [8×3 table]
          State: [0×3 table]

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в конце примера, который берет dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствием маркирует Y и возвращает градиенты потери относительно learnable параметров в dlnet и соответствующая потеря.

Задайте опции обучения

Задайте опции обучения.

velocity = [];
numEpochs = 20;
miniBatchSize = 128;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);
initialLearnRate = 0.01;
momentum = 0.9;
decay = 0.01;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.

executionEnvironment = "auto";

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. В конце каждой эпохи отобразите прогресс обучения.

Для каждого мини-пакета:

  • Преобразуйте метки в фиктивные переменные.

  • Преобразуйте данные в dlarray объекты с базовым одним типом и указывают, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Для обучения графического процессора преобразуйте в gpuArray объекты.

  • Оцените градиенты модели и потерю с помощью dlfeval и modelGradients функция.

  • Определите темп обучения для основанного на времени расписания темпа обучения затухания.

  • Обновите сетевые параметры с помощью sgdmupdate функция.

Инициализируйте график процесса обучения.

plots = "training-progress";
if plots == "training-progress"
    figure
    lineLossTrain = animatedline;
    xlabel("Iteration")
    ylabel("Loss")
end

Обучите сеть.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet.Learnables, velocity] = sgdmupdate(dlnet.Learnables, gradients, velocity, learnRate, momentum);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Тестовая модель

Протестируйте точность классификации модели путем сравнения прогнозов на наборе тестов с истинными метками.

[XTest, YTest] = digitTest4DArrayData;

Преобразуйте данные в dlarray объект с форматом размерности 'SSCB'. Для прогноза графического процессора также преобразуйте данные в gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Классифицировать изображения с помощью dlnetwork объект, используйте predict функционируйте и найдите классы с самыми высокими баллами.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Оцените точность классификации.

accuracy = mean(YPred==YTest)
accuracy = 0.9780

Функция градиентов модели

modelGradients функционируйте берет dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствием маркирует Y и возвращает градиенты потери относительно learnable параметров в dlnet и соответствующая потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)

dlYPred = forward(dlnet,dlX);
dlYPred = softmax(dlYPred);

loss = crossentropy(dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);

end

Смотрите также

| | | | | |

Похожие темы