adamupdate

Обновите параметры с помощью адаптивной оценки момента (Адам)

Описание

Обновите сетевые настраиваемые параметры в пользовательском цикле обучения с помощью алгоритма адаптивной оценки момента (Adam).

Примечание

Эта функция применяет алгоритм оптимизации Адама, чтобы обновить сетевые параметры в пользовательских циклах обучения, которые используют сети, определенные как dlnetwork объекты или функции модели. Если вы хотите обучить сеть, определенную как Layer массив или как LayerGraph, используйте следующие функции:

пример

[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration) обновляет настраиваемые параметры сети dlnet использование алгоритма Адама. Используйте этот синтаксис в цикле обучения, чтобы итерационно обновить сеть, заданную как dlnetwork объект.

пример

[params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration) обновляет настраиваемые параметры в params использование алгоритма Адама. Используйте этот синтаксис в цикле обучения, чтобы итерационно обновить настраиваемые параметры сети, заданной с помощью функций.

пример

[___] = adamupdate(___learnRate,gradDecay,sqGradDecay,epsilon) также задает значения для глобальной скорости обучения, градиентного распада, квадратного градиента и малого постоянного эпсилона, в дополнение к входным параметрам в предыдущих синтаксисах.

Примеры

свернуть все

Выполните один шаг обновления адаптивной оценки момента с глобальной скоростью обучения 0.05, коэффициент градиентного распада 0.75, и квадратный коэффициент градиентного распада 0.95.

Создайте параметры и градиенты параметров как числовые массивы.

params = rand(3,3,4);
grad = ones(3,3,4);

Инициализируйте счетчик итерации, средний градиент и средний квадратный градиент для первой итерации.

iteration = 1;
averageGrad = [];
averageSqGrad = [];

Задайте пользовательские значения для глобальной скорости обучения, фактора градиентного распада и квадратного фактора градиентного распада.

learnRate = 0.05;
gradDecay = 0.75;
sqGradDecay = 0.95;

Обновляйте настраиваемые параметры с помощью adamupdate.

[params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);

Обновите счетчик итерации.

iteration = iteration + 1;

Использование adamupdate обучить сеть с помощью алгоритма Адама.

Загрузка обучающих данных

Загрузите обучающие данные цифр.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Определение сети

Определите сеть и задайте среднее значение изображения с помощью 'Mean' опция на входном слое изображения.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект из графика слоев.

dlnet = dlnetwork(lgraph);

Задайте функцию градиентов модели

Создайте вспомогательную функцию modelGradients, перечисленный в конце примера. Функция принимает dlnetwork dlnet объекта и мини-пакет входных данных dlX с соответствующими метками Y, и возвращает потери и градиенты потерь относительно настраиваемых параметров в dlnet.

Настройка опций обучения

Задайте опции, которые будут использоваться во время обучения.

miniBatchSize = 128;
numEpochs = 20;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Обучите на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

executionEnvironment = "auto";

Визуализируйте процесс обучения на графике.

plots = "training-progress";

Обучите сеть

Обучите модель с помощью пользовательского цикла обучения. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. Обновляйте параметры сети с помощью adamupdate функция. В конце каждой эпохи отобразите процесс обучения.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализируйте средние градиенты и квадратные средние градиенты.

averageGrad = [];
averageSqGrad = [];

Обучите сеть.

iteration = 0;
start = tic;

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to a dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to a gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients helper function.
        [grad,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the Adam optimizer.
        [dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Тестирование сети

Протестируйте классификационную точность модели путем сравнения предсказаний на тестовом наборе с истинными метками.

[XTest, YTest] = digitTest4DArrayData;

Преобразуйте данные в dlarray с форматом размерности 'SSCB'. Для предсказания графический процессор также преобразуйте данные в gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Чтобы классифицировать изображения с помощью dlnetwork объект, используйте predict функция и найти классы с самыми высокими счетами.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Оцените точность классификации.

accuracy = mean(YPred==YTest)
accuracy = 0.9896

Функция градиентов модели

The modelGradients Функция helper принимает dlnetwork dlnet объекта и мини-пакет входных данных dlX с соответствующими метками Y, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)

    dlYPred = forward(dlnet,dlX);
    
    loss = crossentropy(dlYPred,Y);
    
    gradients = dlgradient(loss,dlnet.Learnables);
    
end

Входные параметры

свернуть все

Сеть, заданная как dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект. dlnet.Learnables - таблица с тремя переменными:

  • Layer - Имя слоя, заданное как строковый скаляр.

  • Parameter - Имя параметра, заданное как строковый скаляр.

  • Value - Значение параметра, заданное как массив ячеек, содержащий dlarray.

Входной параметр grad должна быть таблицей той же формы, что и dlnet.Learnables.

Сетевые настраиваемые параметры, заданная как dlarray, числовой массив, массив ячеек, структура или таблица.

Если вы задаете params в качестве таблицы она должна содержать следующие три переменные:

  • Layer - Имя слоя, заданное как строковый скаляр.

  • Parameter - Имя параметра, заданное как строковый скаляр.

  • Value - Значение параметра, заданное как массив ячеек, содержащий dlarray.

Можно задать params как контейнер настраиваемых параметров для сети с помощью массива ячеек, структуры или таблицы или вложенных массивов ячеек или структур. Настраиваемые параметры в массиве ячеек, структуре или таблице должны быть dlarray или числовые значения типа данных double или single.

Входной параметр grad должны быть снабжены точно совпадающим типом данных, упорядоченным расположением и полями (для структур) или переменными (для таблиц), как params.

Типы данных: single | double | struct | table | cell

Градиенты потерь, заданные как dlarray, числовой массив, массив ячеек, структура или таблица.

Точная форма grad зависит от входа сети или настраиваемых параметров. В следующей таблице показан необходимый формат для grad для возможных входов в adamupdate.

ВходНастраиваемые параметрыГрадиенты
dlnetТабличные dlnet.Learnables содержащие Layer, Parameter, и Value переменные. The Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray. Таблица с совпадающим типом данных, переменными и порядком, что и dlnet.Learnables. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого настраиваемого параметра.
paramsdlarraydlarray с совпадающим типом данных и порядком, что и params
Числовой массивЧисловой массив с совпадающим типом данных и порядком, что и params
Массив ячеекМассив ячеек с совпадающими типами данных, структурой и порядком, как params
СтруктураСтруктура с совпадающими типами данных, полями и порядками, что и params
Таблица с Layer, Parameter, и Value переменные. The Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray.Таблица с совпадающими типами данных, переменными и порядками, что и params. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого настраиваемого параметра.

Можно получить grad из вызова в dlfeval которая оценивает функцию, содержащую вызов на dlgradient. Для получения дополнительной информации смотрите Использование автоматической дифференциации в Deep Learning Toolbox.

Скользящее среднее значение градиентов параметра, заданное как пустой массив, a dlarray, числовой массив, массив ячеек, структура или таблица.

Точная форма averageGrad зависит от входа сети или настраиваемых параметров. В следующей таблице показан необходимый формат для averageGrad для возможных входов в adamupdate.

ВходНастраиваемые параметрыСредние градиенты
dlnetТабличные dlnet.Learnables содержащие Layer, Parameter, и Value переменные. The Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray. Таблица с совпадающим типом данных, переменными и порядком, что и dlnet.Learnables. averageGrad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат средний градиент каждого настраиваемого параметра.
paramsdlarraydlarray с совпадающим типом данных и порядком, что и params
Числовой массивЧисловой массив с совпадающим типом данных и порядком, что и params
Массив ячеекМассив ячеек с совпадающими типами данных, структурой и порядком, как params
СтруктураСтруктура с совпадающими типами данных, полями и порядками, что и params
Таблица с Layer, Parameter, и Value переменные. The Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray.Таблица с совпадающими типами данных, переменными и порядками, что и params. averageGrad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат средний градиент каждого настраиваемого параметра.

Если вы задаете averageGrad и averageSqGrad как пустые массивы, функция не принимает предыдущих градиентов и запускается так же, как и для первого обновления в серии итераций. Чтобы итерационно обновить настраиваемые параметры, используйте averageGrad выход предыдущего вызова на adamupdate как averageGrad вход.

Скользящее среднее значение градиентов квадратов параметра, заданное как пустой массив, a dlarray, числовой массив, массив ячеек, структура или таблица.

Точная форма averageSqGrad зависит от входа сети или настраиваемых параметров. В следующей таблице показан необходимый формат для averageSqGrad для возможных входов в adamupdate.

ВходНастраиваемые параметрыСредние квадратные градиенты
dlnetТабличные dlnet.Learnables содержащие Layer, Parameter, и Value переменные. The Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray. Таблица с совпадающим типом данных, переменными и порядком, что и dlnet.Learnables. averageSqGrad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат средний квадратный градиент каждого настраиваемого параметра.
paramsdlarraydlarray с совпадающим типом данных и порядком, что и params
Числовой массивЧисловой массив с совпадающим типом данных и порядком, что и params
Массив ячеекМассив ячеек с совпадающими типами данных, структурой и порядком, как params
СтруктураСтруктура с совпадающими типами данных, полями и порядками, что и params
Таблица с Layer, Parameter, и Value переменные. The Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray.Таблица с совпадающими типами данных, переменными и порядком, что и params. averageSqGrad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат средний квадратный градиент каждого настраиваемого параметра.

Если вы задаете averageGrad и averageSqGrad как пустые массивы, функция не принимает предыдущих градиентов и запускается так же, как и для первого обновления в серии итераций. Чтобы итерационно обновить настраиваемые параметры, используйте averageSqGrad выход предыдущего вызова на adamupdate как averageSqGrad вход.

Число итерации, заданное как положительное целое число. Для первого вызова в adamupdate, используйте значение 1. Необходимо увеличить iteration по 1 для каждого последующего вызова в серии вызовов, чтобы adamupdate. Алгоритм Адама использует это значение для исправления смещения в скользящих средних значениях в начале набора итераций.

Глобальная скорость обучения, заданная как положительная скалярная величина. Значение по умолчанию learnRate является 0.001.

Если вы задаете сетевые параметры как dlnetwork, скорость обучения для каждого параметра является глобальной скоростью обучения, умноженной на соответствующее свойство коэффициента скорости обучения, заданное в слоях сети.

Коэффициент распада градиента, заданный как положительная скалярная величина между 0 и 1. Значение по умолчанию gradDecay является 0.9.

Квадратный коэффициент распада градиента, заданный как положительная скалярная величина между 0 и 1. Значение по умолчанию sqGradDecay является 0.999.

Малая константа для предотвращения ошибок деления на нули, заданная как положительная скалярная величина. Значение по умолчанию epsilon является 1e-8.

Выходные аргументы

свернуть все

Сеть, возвращается как dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект.

Обновлённые сетевые настраиваемые параметры, возвращенный как dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные настраиваемые параметры сети.

Обновленное скользящее среднее значение градиентов параметра, возвращенное как dlarray, числовой массив, массив ячеек, структура или таблица.

Обновлённое среднее значение квадратов градиентов параметра, возвращенная как dlarray, числовой массив, массив ячеек, структура или таблица.

Подробнее о

свернуть все

Адам

Функция использует алгоритм адаптивной оценки момента (Адам), чтобы обновить настраиваемые параметры. Для получения дополнительной информации смотрите определение алгоритма Адама под Стохастическим градиентным спуском на trainingOptions страница с описанием.

Расширенные возможности

Введенный в R2019b
Для просмотра документации необходимо авторизоваться на сайте