sgdmupdate

Обновите параметры с помощью стохастического градиентного спуска с импульсом (SGDM)

Описание

Обновите сетевые настраиваемые параметры в пользовательском учебном цикле с помощью стохастического градиентного спуска с импульсом (SGDM) алгоритм.

Примечание

Эта функция применяет алгоритм оптимизации SGDM, чтобы обновить сетевые параметры в пользовательских учебных циклах, которые используют сети, заданные в качестве dlnetwork объекты или функции модели. Если вы хотите обучить сеть, заданную как Layer массив или как LayerGraph, используйте следующие функции:

пример

[dlnet,vel] = sgdmupdate(dlnet,grad,vel) обновляет настраиваемые параметры сети dlnet использование алгоритма SGDM. Используйте этот синтаксис в учебном цикле, чтобы итеративно обновить сеть, заданную как dlnetwork объект.

пример

[params,vel] = sgdmupdate(params,grad,vel) обновляет настраиваемые параметры в params использование алгоритма SGDM. Используйте этот синтаксис в учебном цикле, чтобы итеративно обновить настраиваемые параметры сети, заданной с помощью функций.

пример

[___] = sgdmupdate(___learnRate,momentum) также задает значения, чтобы использовать для глобальной скорости обучения и импульса, в дополнение к входным параметрам в предыдущих синтаксисах.

Примеры

свернуть все

Выполните один шаг обновления SGDM с глобальной скоростью обучения 0.05 и импульс 0.95.

Создайте параметры и градиенты параметра как числовые массивы.

params = rand(3,3,4);
grad = ones(3,3,4);

Инициализируйте скорости параметра для первой итерации.

vel = [];

Задайте пользовательские значения для глобальной скорости обучения и импульса.

learnRate = 0.05;
momentum = 0.95;

Обновите настраиваемые параметры с помощью sgdmupdate.

[params,vel] = sgdmupdate(params,grad,vel,learnRate,momentum);

Используйте sgdmupdate обучать сеть с помощью алгоритма SGDM.

Загрузите обучающие данные

Загрузите обучающие данные цифр.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Сеть Define

Задайте сетевую архитектуру и задайте среднее значение изображений с помощью 'Mean' опция в изображении ввела слой.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект от графика слоев.

dlnet = dlnetwork(lgraph);

Функция градиентов модели Define

Создайте функцию помощника modelGradients, перечисленный в конце примера. Функция берет dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствием маркирует Y, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet.

Задайте опции обучения

Задайте опции, чтобы использовать во время обучения.

miniBatchSize = 128;
numEpochs = 20;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

executionEnvironment = "auto";

Визуализируйте процесс обучения в графике.

plots = "training-progress";

Обучение сети

Обучите модель с помощью пользовательского учебного цикла. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. Обновите сетевые параметры с помощью sgdmupdate функция. В конце каждой эпохи отобразите прогресс обучения.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализируйте скоростной параметр.

vel = [];

Обучите сеть.

iteration = 0;
start = tic;

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to a dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to a gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients helper function.
        [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet,vel] = sgdmupdate(dlnet,gradients,vel);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Протестируйте сеть

Протестируйте точность классификации модели путем сравнения предсказаний на наборе тестов с истинными метками.

[XTest, YTest] = digitTest4DArrayData;

Преобразуйте данные в dlarray с форматом размерности 'SSCB'. Для предсказания графического процессора также преобразуйте данные в gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Классифицировать изображения с помощью dlnetwork объект, используйте predict функционируйте и найдите классы с самыми высокими баллами.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Оцените точность классификации.

accuracy = mean(YPred==YTest)
accuracy = 0.9916

Функция градиентов модели

modelGradients функция помощника берет dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствием маркирует Y, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)

    dlYPred = forward(dlnet,dlX);
    
    loss = crossentropy(dlYPred,Y);
    
    gradients = dlgradient(loss,dlnet.Learnables);

end

Входные параметры

свернуть все

Сеть в виде dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект. dlnet.Learnables таблица с тремя переменными:

  • Layer — Имя слоя в виде строкового скаляра.

  • Parameter — Название параметра в виде строкового скаляра.

  • Value — Значение параметра в виде массива ячеек, содержащего dlarray.

Входной параметр grad должна быть таблица той же формы как dlnet.Learnables.

Сетевые настраиваемые параметры в виде dlarray, числовой массив, массив ячеек, структура или таблица.

Если вы задаете params как таблица, это должно содержать следующие три переменные.

  • Layer — Имя слоя в виде строкового скаляра.

  • Parameter — Название параметра в виде строкового скаляра.

  • Value — Значение параметра в виде массива ячеек, содержащего dlarray.

Можно задать params как контейнер настраиваемых параметров для вашей сети с помощью массива ячеек, структуры, или таблицы, или вложенных массивов ячеек или структур. Настраиваемыми параметрами в массиве ячеек, структуре или таблице должен быть dlarray или числовые значения типа данных double или single.

Входной параметр grad должен быть обеспечен точно совпадающим типом данных, упорядоченным расположением и полями (для структур) или переменные (для таблиц) как params.

Типы данных: single | double | struct | table | cell

Градиенты потери в виде dlarray, числовой массив, массив ячеек, структура или таблица.

Точная форма grad зависит от входной сети или настраиваемых параметров. Следующая таблица показывает требуемый формат для grad для возможных входных параметров к sgdmupdate.

Входной параметрНастраиваемые параметрыГрадиенты
dlnetТаблица dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр как dlarray. Таблица с совпадающим типом данных, переменными, и заказывающий как dlnet.Learnables. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого настраиваемого параметра.
paramsdlarraydlarray с совпадающим типом данных и заказывающий как params
Числовой массивЧисловой массив с совпадающим типом данных и заказывающий как params
CellArrayМассив ячеек с совпадающими типами данных, структурой, и заказывающий как params
СтруктураСтруктура с совпадающими типами данных, полями, и заказывающий как params
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр как dlarray.Таблица с совпадающими типами данных, переменными, и заказывающий как params. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого настраиваемого параметра.

Можно получить grad от вызова до dlfeval это выполняет функцию, которая содержит вызов dlgradient. Для получения дополнительной информации смотрите Использование Автоматическое Дифференцирование В Deep Learning Toolbox.

Скорости параметра в виде пустого массива, dlarray, числовой массив, массив ячеек, структура или таблица.

Точная форма vel зависит от входной сети или настраиваемых параметров. Следующая таблица показывает требуемый формат для vel для возможных входных параметров к sgdmpdate.

Входной параметрНастраиваемые параметрыСкорости
dlnetТаблица dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр как dlarray. Таблица с совпадающим типом данных, переменными, и заказывающий как dlnet.Learnables. vel должен иметь Value переменная, состоящая из массивов ячеек, которые содержат скорость каждого настраиваемого параметра.
paramsdlarraydlarray с совпадающим типом данных и заказывающий как params
Числовой массивЧисловой массив с совпадающим типом данных и заказывающий как params
CellArrayМассив ячеек с совпадающими типами данных, структурой, и заказывающий как params
СтруктураСтруктура с совпадающими типами данных, полями, и заказывающий как params
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр как dlarray.Таблица с совпадающими типами данных, переменными, и заказывающий как params. vel должен иметь Value переменная, состоящая из массивов ячеек, которые содержат скорость каждого настраиваемого параметра.

Если вы задаете vel как пустой массив, функция не принимает предыдущих скоростей и запусков таким же образом что касается первого обновления в серии итераций. Чтобы обновить настраиваемые параметры итеративно, используйте vel выход предыдущего вызова sgdmupdate как vel входной параметр.

Скорость обучения в виде положительной скалярной величины. Значение по умолчанию learnRate 0.01.

Если вы задаете сетевые параметры как dlnetwork объект, скорость обучения для каждого параметра является глобальной скоростью обучения, умноженной на соответствующее свойство фактора скорости обучения, заданное в слоях сети.

Импульс в виде положительной скалярной величины между 0 и 1. Значение по умолчанию momentum 0.9.

Выходные аргументы

свернуть все

Сеть, возвращенная как dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект.

Обновленные сетевые настраиваемые параметры, возвращенные как dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные настраиваемые параметры сети.

Обновленные скорости параметра, возвращенные как dlarray, числовой массив, массив ячеек, структура или таблица.

Больше о

свернуть все

Стохастический градиентный спуск с импульсом

Функция использует стохастический градиентный спуск с алгоритмом импульса, чтобы обновить настраиваемые параметры. Для получения дополнительной информации см. определение стохастического градиентного спуска с алгоритмом импульса под Стохастическим Градиентным спуском на trainingOptions страница с описанием.

Расширенные возможности

Введенный в R2019b