exponenta event banner

Обновление статистики нормализации пакетов в индивидуальном цикле обучения

В этом примере показано, как обновить состояние сети в индивидуальном цикле обучения.

Уровень нормализации пакета нормализует каждый входной канал в мини-пакете. Чтобы ускорить обучение сверточных нейронных сетей и снизить чувствительность к инициализации сети, используйте пакетные уровни нормализации между сверточными слоями и нелинейностями, такими как уровни ReLU.

Во время обучения слои нормализации партии сначала нормализуют активации каждого канала путем вычитания среднего значения мини-партии и деления на стандартное отклонение мини-партии. Затем слой сдвигает входной сигнал на обучаемое смещение β и масштабирует его на обучаемый масштабный коэффициент γ.

После завершения обучения по сети уровни пакетной нормализации вычисляют среднее значение и отклонение по полному набору обучения и сохраняют значения в TrainedMean и TrainedVariance свойства. При использовании обученной сети для прогнозирования на новых изображениях слои нормализации партий используют среднее и дисперсию, а не среднее и дисперсию мини-партии для нормализации активаций.

Чтобы вычислить статистику набора данных, слои нормализации пакета отслеживают статистику мини-пакета, используя постоянно обновляемое состояние. При реализации пользовательского цикла обучения необходимо обновить состояние сети между мини-пакетами.

Загрузка данных обучения

digitTrain4DArrayData функция загружает изображения рукописных цифр и их цифровые метки. Создание arrayDatastore объект для изображений и углов, а затем используйте combine для создания единого хранилища данных, содержащего все данные обучения. Извлеките имена классов.

[XTrain,YTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsYTrain = arrayDatastore(YTrain);

dsTrain = combine(dsXTrain,dsYTrain);

classNames = categories(YTrain);
numClasses = numel(classNames);

Определение сети

Определите сеть и укажите среднее изображение с помощью 'Mean' в слое ввода изображения.

layers = [
    imageInputLayer([28 28 1], 'Name', 'input', 'Mean', mean(XTrain,4))
    convolution2dLayer(5, 20, 'Name', 'conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name', 'relu2')
    convolution2dLayer(3, 20, 'Padding', 1, 'Name', 'conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name', 'relu3')
    fullyConnectedLayer(numClasses, 'Name', 'fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создать dlnetwork объект из графа слоев.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}

Просмотр состояния сети. Каждый уровень нормализации партий имеет TrainedMean параметр и TrainedVariance параметр, содержащий среднее значение набора данных и дисперсию соответственно.

dlnet.State
ans=6×3 table
    Layer        Parameter             Value     
    _____    _________________    _______________

    "bn1"    "TrainedMean"        {1×1×20 single}
    "bn1"    "TrainedVariance"    {1×1×20 single}
    "bn2"    "TrainedMean"        {1×1×20 single}
    "bn2"    "TrainedVariance"    {1×1×20 single}
    "bn3"    "TrainedMean"        {1×1×20 single}
    "bn3"    "TrainedVariance"    {1×1×20 single}

Определение функции градиентов модели

Создание функции modelGradients, перечисленных в конце примера, который принимает в качестве входных данных dlnetwork объект dlnetи мини-пакет входных данных dlX с соответствующими метками Yи возвращает градиенты потерь относительно обучаемых параметров в dlnet и соответствующие потери.

Укажите параметры обучения

Поезд на пять эпох с использованием мини-партии размером 128. Для оптимизации SGDM укажите скорость обучения 0,01 и импульс 0,9.

numEpochs = 5;
miniBatchSize = 128;

learnRate = 0.01;
momentum = 0.9;

Визуализация хода обучения на графике.

plots = "training-progress";

Модель поезда

Использовать minibatchqueue для обработки и управления мини-партиями изображений. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определяется в конце этого примера) для одноконтактного кодирования меток класса.

  • Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат к меткам класса.

  • Обучение на GPU, если он доступен. По умолчанию minibatchqueue объект преобразует каждый вывод в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Обучение модели с помощью пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. В конце каждой итерации просмотрите ход обучения. Для каждой мини-партии:

  • Оценка градиентов, состояния и потерь модели с помощью dlfeval и modelGradients и обновить состояние сети.

  • Обновление параметров сети с помощью sgdmupdate функция.

Инициализируйте график хода обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализируйте параметр скорости для решателя SGDM.

velocity = [];

Обучение сети.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq)
    
    % Loop over mini-batches.
    while hasdata(mbq)
  
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        [dlX,dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        dlnet.State = state;
                
        % Update the network parameters using the SGDM optimizer.
        [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Тестовая модель

Проверка точности классификации модели путем сравнения прогнозов на тестовом наборе с истинными метками и углами. Управление набором тестовых данных с помощью minibatchqueue объект с той же настройкой, что и данные обучения.

[XTest,YTest] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,'IterationDimension',4);
dsYTest = arrayDatastore(YTest);

dsTest = combine(dsXTest,dsYTest);

mbqTest = minibatchqueue(dsTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Классифицируйте изображения с помощью modelPredictions функция, перечисленная в конце примера. Функция возвращает предсказанные классы и сравнение с истинными значениями.

[classesPredictions,classCorr] = modelPredictions(dlnet,mbqTest,classNames);

Оцените точность классификации.

accuracy = mean(classCorr)
accuracy = 0.9946

Функция градиентов модели

modelGradients функция принимает в качестве входного значения a dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Yи возвращает градиенты потерь относительно обучаемых параметров в dlnet, состояние сети и потери. Для автоматического вычисления градиентов используйте dlgradient функция.

function [gradients,state,loss] = modelGradients(dlnet,dlX,Y)

    [dlYPred,state] = forward(dlnet,dlX);
    
    loss = crossentropy(dlYPred,Y);
    gradients = dlgradient(loss,dlnet.Learnables);

end

Функция прогнозирования модели

modelPredictions функция принимает в качестве входного значения a dlnetwork объект dlnet, a minibatchqueue входных данных mbqи вычисляет прогнозы модели путем итерации всех данных в minibatchqueue. Функция использует onehotdecode функция, чтобы найти прогнозируемый класс с наивысшим баллом, а затем сравнивает прогноз с истинным классом. Функция возвращает предсказания и вектор единиц и нулей, который представляет правильные и неправильные предсказания.

function [classesPredictions,classCorr] = modelPredictions(dlnet,mbq,classes)

    classesPredictions = [];
    classCorr = [];
    
    while hasdata(mbq)
        [dlX,dlY] = next(mbq);
        
        % Make predictions using the model function.
        dlYPred = predict(dlnet,dlX);
        
        % Determine predicted classes.
        YPredBatch = onehotdecode(dlYPred,classes,1);
        classesPredictions = [classesPredictions YPredBatch];
        
        % Compare predicted and true classes
        Y = onehotdecode(dlY,classes,1);
        classCorr = [classCorr YPredBatch == Y];
        
    end

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция выполняет предварительную обработку данных с помощью следующих шагов:

  1. Извлеките данные изображения из входящего массива ячеек и объедините их в числовой массив. Конкатенация данных изображения над четвертым размером добавляет к каждому изображению третий размер, который будет использоваться в качестве размера одиночного канала.

  2. Извлеките данные метки из входящих массивов ячеек и объедините их в категориальный массив вдоль второго измерения.

  3. Одноконтурное кодирование категориальных меток в числовые массивы. Кодирование в первом измерении создает закодированный массив, который соответствует форме сетевого вывода.

function [X,Y] = preprocessMiniBatch(XCell,YCell)
    
    % Extract image data from cell and concatenate
    X = cat(4,XCell{:});
    % Extract label data from cell and concatenate
    Y = cat(2,YCell{:});

    % One-hot encode labels
    Y = onehotencode(Y,1);
            
end

См. также

| | | | | | | | |

Связанные темы