exponenta event banner

sgdmupdate

Обновление параметров с помощью стохастического градиентного спуска с импульсом (SGDM)

Описание

Обновление сетевых обучаемых параметров в пользовательском учебном цикле с использованием алгоритма стохастического градиентного спуска с импульсом (SGDM).

Примечание

Эта функция применяет алгоритм оптимизации SGDM для обновления параметров сети в пользовательских контурах обучения, в которых используются сети, определенные как dlnetwork объекты или функции модели. При необходимости обучения сети, определенной как Layer массив или как LayerGraph, используйте следующие функции:

пример

[dlnet,vel] = sgdmupdate(dlnet,grad,vel) обновляет обучаемые параметры сети dlnet с использованием алгоритма SGDM. Используйте этот синтаксис в учебном цикле для итеративного обновления сети, определенной как dlnetwork объект.

пример

[params,vel] = sgdmupdate(params,grad,vel) обновляет обучаемые параметры в params с использованием алгоритма SGDM. Используйте этот синтаксис в учебном цикле для итеративного обновления обучаемых параметров сети, определенной с помощью функций.

пример

[___] = sgdmupdate(___learnRate,momentum) также указывает значения, используемые для глобальной скорости обучения и импульса, в дополнение к входным аргументам в предыдущих синтаксисах.

Примеры

свернуть все

Выполнение одного шага обновления SGDM с глобальной скоростью обучения 0.05 и импульс 0.95.

Создайте параметры и градиенты параметров в виде числовых массивов.

params = rand(3,3,4);
grad = ones(3,3,4);

Инициализируйте скорости параметров для первой итерации.

vel = [];

Укажите пользовательские значения для глобальной скорости обучения и импульса.

learnRate = 0.05;
momentum = 0.95;

Обновление обучаемых параметров с помощью sgdmupdate.

[params,vel] = sgdmupdate(params,grad,vel,learnRate,momentum);

Использовать sgdmupdate для обучения сети с использованием алгоритма SGDM.

Загрузка данных обучения

Загрузите данные обучения по цифрам.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Определение сети

Определите сетевую архитектуру и укажите среднее значение изображения с помощью 'Mean' в слое ввода изображения.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создать dlnetwork объект из графа слоев.

dlnet = dlnetwork(lgraph);

Определение функции градиентов модели

Создание вспомогательной функции modelGradients, перечисленных в конце примера. Функция принимает dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Yи возвращает потери и градиенты потерь относительно обучаемых параметров в dlnet.

Укажите параметры обучения

Укажите параметры для использования во время обучения.

miniBatchSize = 128;
numEpochs = 20;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Обучение на GPU, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

executionEnvironment = "auto";

Визуализация хода обучения на графике.

plots = "training-progress";

Железнодорожная сеть

Обучение модели с помощью пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Обновление параметров сети с помощью sgdmupdate функция. В конце каждой эпохи отобразите ход обучения.

Инициализируйте график хода обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализируйте параметр скорости.

vel = [];

Обучение сети.

iteration = 0;
start = tic;

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to a dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to a gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients helper function.
        [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet,vel] = sgdmupdate(dlnet,gradients,vel);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Тестирование сети

Проверьте точность классификации модели путем сравнения прогнозов на тестовом наборе с истинными метками.

[XTest, YTest] = digitTest4DArrayData;

Преобразование данных в dlarray с форматом размера 'SSCB'. Для прогнозирования GPU также преобразуйте данные в gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Классификация изображений с помощью dlnetwork объект, используйте predict и найти классы с самыми высокими баллами.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Оцените точность классификации.

accuracy = mean(YPred==YTest)
accuracy = 0.9916

Функция градиентов модели

modelGradients функция помощника принимает dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Yи возвращает потери и градиенты потерь относительно обучаемых параметров в dlnet. Для автоматического вычисления градиентов используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)

    dlYPred = forward(dlnet,dlX);
    
    loss = crossentropy(dlYPred,Y);
    
    gradients = dlgradient(loss,dlnet.Learnables);

end

Входные аргументы

свернуть все

Сеть, указанная как dlnetwork объект.

Функция обновляет dlnet.Learnables имущества dlnetwork объект. dlnet.Learnables - таблица с тремя переменными:

  • Layer - имя слоя, указанное как строковый скаляр.

  • Parameter - имя параметра, указанное как строковый скаляр.

  • Value - значение параметра, указанное как массив ячеек, содержащий dlarray.

Входной аргумент grad должна быть таблица той же формы, что и dlnet.Learnables.

Сетевые обучаемые параметры, указанные как dlarray, числовой массив, массив ячеек, структуру или таблицу.

При указании params в качестве таблицы она должна содержать следующие три переменные.

  • Layer - имя слоя, указанное как строковый скаляр.

  • Parameter - имя параметра, указанное как строковый скаляр.

  • Value - значение параметра, указанное как массив ячеек, содержащий dlarray.

Можно указать params в качестве контейнера обучаемых параметров для сети с использованием массива ячеек, структуры или таблицы или вложенных массивов или структур ячеек. Доступные для изучения параметры в массиве, структуре или таблице ячейки должны быть dlarray или числовые значения типа данных double или single.

Входной аргумент grad должны иметь точно такой же тип данных, порядок и поля (для структур) или переменные (для таблиц), как params.

Типы данных: single | double | struct | table | cell

Градиенты потерь, указанные как dlarray, числовой массив, массив ячеек, структуру или таблицу.

Точная форма grad зависит от входной сети или обучаемых параметров. В следующей таблице показан требуемый формат для grad для возможных входных данных sgdmupdate.

ВходОбучаемые параметрыГрадиенты
dlnetСтол dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. Таблица с тем же типом данных, переменными и порядком, что и dlnet.Learnables. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого обучаемого параметра.
paramsdlarraydlarray с тем же типом данных и порядком, что и params
Числовой массивЧисловой массив с тем же типом данных и порядком, что и params
Массив ячеекМассив ячеек с теми же типами данных, структурой и порядком, что и params
СтруктураСтруктура с теми же типами данных, полями и порядком, что и params
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray.Таблица с теми же типами данных, переменными и порядком, что и params. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого обучаемого параметра.

Вы можете получить grad от вызова до dlfeval вычисляет функцию, содержащую вызов dlgradient. Дополнительные сведения см. в разделе Использование автоматической дифференциации в инструменте глубокого обучения.

Скорости параметров, указанные как пустой массив, a dlarray, числовой массив, массив ячеек, структуру или таблицу.

Точная форма vel зависит от входной сети или обучаемых параметров. В следующей таблице показан требуемый формат для vel для возможных входных данных sgdmpdate.

ВходОбучаемые параметрыСкорости
dlnetСтол dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. Таблица с тем же типом данных, переменными и порядком, что и dlnet.Learnables. vel должен иметь Value переменная, состоящая из массивов ячеек, которые содержат скорость каждого обучаемого параметра.
paramsdlarraydlarray с тем же типом данных и порядком, что и params
Числовой массивЧисловой массив с тем же типом данных и порядком, что и params
Массив ячеекМассив ячеек с теми же типами данных, структурой и порядком, что и params
СтруктураСтруктура с теми же типами данных, полями и порядком, что и params
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray.Таблица с теми же типами данных, переменными и порядком, что и params. vel должен иметь Value переменная, состоящая из массивов ячеек, которые содержат скорость каждого обучаемого параметра.

При указании vel как пустой массив, функция не предполагает предыдущих скоростей и выполняется так же, как для первого обновления в серии итераций. Для итеративного обновления обучаемых параметров используйте vel вывод предыдущего вызова sgdmupdate в качестве vel вход.

Скорость обучения, заданная как положительный скаляр. Значение по умолчанию learnRate является 0.01.

Если параметры сети указаны как dlnetwork объект, скорость обучения для каждого параметра является глобальной скоростью обучения, умноженной на соответствующее свойство коэффициента скорости обучения, определенное в сетевых уровнях.

Импульс, заданный как положительный скаляр между 0 и 1. Значение по умолчанию momentum является 0.9.

Выходные аргументы

свернуть все

Сеть, возвращенная как dlnetwork объект.

Функция обновляет dlnet.Learnables имущества dlnetwork объект.

Обновленные сетевые обучаемые параметры, возвращенные в виде dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные обучаемые параметры сети.

Обновленные скорости параметров, возвращаемые в виде dlarray, числовой массив, массив ячеек, структуру или таблицу.

Подробнее

свернуть все

Стохастический градиентный спуск с импульсом

Функция использует алгоритм стохастического градиентного спуска с импульсом для обновления обучаемых параметров. Для получения дополнительной информации см. определение стохастического градиентного спуска с алгоритмом импульса в разделе Стохастический градиентный спуск на trainingOptions справочная страница.

Расширенные возможности

Представлен в R2019b