Обновите статистику нормализации партии. в пользовательском учебном цикле

В этом примере показано, как обновить сетевое состояние в пользовательском учебном цикле.

Слой нормализации партии. нормирует каждый входной канал через мини-пакет. Чтобы ускорить обучение сверточных нейронных сетей и уменьшать чувствительность к сетевой инициализации, используйте слои нормализации партии. между сверточными слоями и нелинейностью, такой как слои ReLU.

Во время обучения слои нормализации партии. сначала нормируют активации каждого канала путем вычитания мини-среднего значения партии и деления на мини-пакетное стандартное отклонение. Затем слой переключает вход learnable смещением β и масштабирует его learnable масштабным коэффициентом γ.

Когда сетевое обучение заканчивается, слои нормализации партии. вычисляют среднее значение и отклонение по полному набору обучающих данных, и хранит значения в TrainedMean и TrainedVariance свойства. Когда вы используете обучивший сеть, чтобы сделать предсказания на новых изображениях, слои нормализации партии. используют обученное среднее значение и отклонение вместо мини-среднего значения партии и отклонение, чтобы нормировать активации.

Чтобы вычислить статистику набора данных, слои нормализации партии. отслеживают мини-пакетную статистику при помощи постоянно обновляющегося состояния. Если вы реализуете пользовательский учебный цикл, то необходимо обновить сетевое состояние между мини-пакетами.

Загрузите обучающие данные

digitTrain4DArrayData функционируйте загружает изображения рукописных цифр и их меток цифры. Создайте arrayDatastore объект для изображений и углов, и затем использует combine функция, чтобы сделать один datastore, который содержит все обучающие данные. Извлеките имена классов.

[XTrain,YTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsYTrain = arrayDatastore(YTrain);

dsTrain = combine(dsXTrain,dsYTrain);

classNames = categories(YTrain);
numClasses = numel(classNames);

Сеть Define

Задайте сеть и задайте среднее изображение с помощью 'Mean' опция в изображении ввела слой.

layers = [
    imageInputLayer([28 28 1], 'Name', 'input', 'Mean', mean(XTrain,4))
    convolution2dLayer(5, 20, 'Name', 'conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name', 'relu2')
    convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name', 'relu3')
    fullyConnectedLayer(numClasses, 'Name', 'fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект от графика слоев.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}

Просмотрите сетевое состояние. Каждый слой нормализации партии. имеет TrainedMean параметр и TrainedVariance параметр, содержащий среднее значение набора данных и отклонение, соответственно.

dlnet.State
ans=6×3 table
    Layer        Parameter             Value     
    _____    _________________    _______________

    "bn1"    "TrainedMean"        {1×1×20 single}
    "bn1"    "TrainedVariance"    {1×1×20 single}
    "bn2"    "TrainedMean"        {1×1×20 single}
    "bn2"    "TrainedVariance"    {1×1×20 single}
    "bn3"    "TrainedMean"        {1×1×20 single}
    "bn3"    "TrainedVariance"    {1×1×20 single}

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в конце примера, который берет в качестве входа dlnetwork объект dlnet, и мини-пакет входных данных dlX с соответствием маркирует Y, и возвращает градиенты потери относительно настраиваемых параметров в dlnet и соответствующая потеря.

Задайте опции обучения

Обучайтесь в течение пяти эпох с помощью мини-пакетного размера 128. Для оптимизации SGDM задайте скорость обучения 0,01 и импульс 0,9.

numEpochs = 5;
miniBatchSize = 128;

learnRate = 0.01;
momentum = 0.9;

Визуализируйте процесс обучения в графике.

plots = "training-progress";

Обучите модель

Используйте minibatchqueue обработать и управлять мини-пакетами изображений. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера) к одногорячему кодируют метки класса.

  • Формат данные изображения с размерностью маркирует 'SSCB' (пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат в метки класса.

  • Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue объект преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Обучите модель с помощью пользовательского учебного цикла. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. В конце каждой итерации отобразите прогресс обучения. Для каждого мини-пакета:

  • Оцените градиенты модели, состояние и потерю с помощью dlfeval и modelGradients функционируйте и обновите сетевое состояние.

  • Обновите сетевые параметры с помощью sgdmupdate функция.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализируйте скоростной параметр для решателя SGDM.

velocity = [];

Обучите сеть.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq)
    
    % Loop over mini-batches.
    while hasdata(mbq)
  
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        [dlX,dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        dlnet.State = state;
                
        % Update the network parameters using the SGDM optimizer.
        [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Тестовая модель

Протестируйте точность классификации модели путем сравнения предсказаний на наборе тестов с истинными метками и углами. Управляйте набором тестовых данных с помощью minibatchqueue объект с той же установкой как обучающие данные.

[XTest,YTest] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,'IterationDimension',4);
dsYTest = arrayDatastore(YTest);

dsTest = combine(dsXTest,dsYTest);

mbqTest = minibatchqueue(dsTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Классифицируйте изображения с помощью modelPredictions функция, перечисленная в конце примера. Функция возвращает предсказанные классы и сравнение с истинными значениями.

[classesPredictions,classCorr] = modelPredictions(dlnet,mbqTest,classNames);

Оцените точность классификации.

accuracy = mean(classCorr)
accuracy = 0.9946

Функция градиентов модели

modelGradients функционируйте берет в качестве входа dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствием маркирует Y, и возвращает градиенты потери относительно настраиваемых параметров в dlnet, сетевое состояние и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,state,loss] = modelGradients(dlnet,dlX,Y)

    [dlYPred,state] = forward(dlnet,dlX);
    
    loss = crossentropy(dlYPred,Y);
    gradients = dlgradient(loss,dlnet.Learnables);

end

Функция предсказаний модели

modelPredictions функционируйте берет в качестве входа dlnetwork объект dlnet, minibatchqueue из входных данных mbq, и вычисляет предсказания модели путем итерации всех данных в minibatchqueue. Функция использует onehotdecode функционируйте, чтобы найти предсказанный класс с самым высоким счетом и затем сравнить предсказание с истинным классом. Функция возвращает предсказания и вектор из единиц и нулей, который представляет правильные и неправильные предсказания.

function [classesPredictions,classCorr] = modelPredictions(dlnet,mbq,classes)

    classesPredictions = [];
    classCorr = [];
    
    while hasdata(mbq)
        [dlX,dlY] = next(mbq);
        
        % Make predictions using the model function.
        dlYPred = predict(dlnet,dlX);
        
        % Determine predicted classes.
        YPredBatch = onehotdecode(dlYPred,classes,1);
        classesPredictions = [classesPredictions YPredBatch];
        
        % Compare predicted and true classes
        Y = onehotdecode(dlY,classes,1);
        classCorr = [classCorr YPredBatch == Y];
        
    end

end

Мини-функция предварительной обработки пакета

preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные изображения из массива входящей ячейки и конкатенируйте в числовой массив. Конкатенация данных изображения по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использоваться в качестве одноэлементной размерности канала.

  2. Извлеките данные о метке из массивов входящей ячейки и конкатенируйте в категориальный массив вдоль второго измерения..

  3. Одногорячий кодируют категориальные метки в числовые массивы. Кодирование в первую размерность производит закодированный массив, который совпадает с формой сетевого выхода.

function [X,Y] = preprocessMiniBatch(XCell,YCell)
    
    % Extract image data from cell and concatenate
    X = cat(4,XCell{:});
    % Extract label data from cell and concatenate
    Y = cat(2,YCell{:});

    % One-hot encode labels
    Y = onehotencode(Y,1);
            
end

Смотрите также

| | | | | | | | |

Похожие темы