rmspropupdate

Обновите параметры с помощью корневого среднеквадратического распространения (RMSProp)

Описание

Обновите сетевые learnable параметры в пользовательском учебном цикле с помощью корневого среднеквадратического распространения (RMSProp) алгоритм.

Примечание

Эта функция применяет алгоритм оптимизации RMSProp, чтобы обновить сетевые параметры в пользовательских учебных циклах, которые используют сети, заданные в качестве dlnetwork объекты или функции модели. Если вы хотите обучить сеть, заданную как Layer массив или как LayerGraph, используйте следующие функции:

пример

[dlnet,averageSqGrad] = rmspropupdate(dlnet,grad,averageSqGrad) обновляет learnable параметры сети dlnet использование алгоритма RMSProp. Используйте этот синтаксис в учебном цикле, чтобы итеративно обновить сеть, заданную как dlnetwork объект.

пример

[params,averageSqGrad] = rmspropupdate(params,grad,averageSqGrad) обновляет learnable параметры в params использование алгоритма RMSProp. Используйте этот синтаксис в учебном цикле, чтобы итеративно обновить learnable параметры сети, заданной с помощью функций.

пример

[___] = rmspropupdate(___learnRate,sqGradDecay,epsilon) также задает значения, чтобы использовать в глобальном темпе обучения, квадратном затухании градиента и маленьком постоянном эпсилоне, в дополнение к входным параметрам в предыдущих синтаксисах.

Примеры

свернуть все

Выполните один корневой среднеквадратический шаг обновления распространения с глобальным темпом обучения 0.05 и градиент в квадрате затухает фактор 0.95.

Создайте параметры и градиенты параметра как числовые массивы.

params = rand(3,3,4);
grad = ones(3,3,4);

Инициализируйте средний градиент в квадрате для первой итерации.

averageSqGrad = [];

Задайте пользовательские значения для глобального темпа обучения и фактора затухания градиента в квадрате.

learnRate = 0.05;
sqGradDecay = 0.95;

Обновите learnable параметры с помощью rmspropupdate.

[params,averageSqGrad] = rmspropupdate(params,grad,averageSqGrad,learnRate,sqGradDecay);

Используйте rmspropupdate обучать сеть с помощью среднеквадратичного распространения (RMSProp) алгоритм.

Загрузите обучающие данные

Загрузите обучающие данные цифр.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Задайте сеть

Задайте сетевую архитектуру и задайте среднее изображение с помощью 'Mean' опция в изображении ввела слой.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект из графика слоя.

dlnet = dlnetwork(lgraph);

Задайте функцию градиентов модели

Создайте функциональный modelGradients, перечисленный в конце примера, который берет dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствием маркирует Y и возвращает потерю и градиенты потери относительно learnable параметров в dlnet.

Задайте опции обучения

Задайте опции, чтобы использовать во время обучения.

miniBatchSize = 128;
numEpochs = 20;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.

executionEnvironment = "auto";

Инициализируйте средние градиенты в квадрате.

averageSqGrad = [];

Инициализируйте график процесса обучения.

plots = "training-progress";
if plots == "training-progress"
    iteration = 1;
    figure
    lineLossTrain = animatedline;
    xlabel("Total Iterations")
    ylabel("Loss")
end

Обучите сеть

Обучите модель с помощью пользовательского учебного цикла. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. Обновите сетевые параметры с помощью rmspropupdate функция. В конце каждой эпохи отобразите прогресс обучения.

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [grad,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the RMSProp optimizer.
        [dlnet,averageSqGrad] = rmspropupdate(dlnet,grad,averageSqGrad);
        
        % Display the training progress.
        if plots == "training-progress"
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Loss During Training: Epoch - " + epoch + "; Iteration - " + i)
            drawnow
            iteration = iteration + 1;
        end
    end
end

Протестируйте сеть

Протестируйте точность классификации модели путем сравнения прогнозов на наборе тестов с истинными метками.

[XTest, YTest] = digitTest4DArrayData;

Преобразуйте данные в dlarray объект с форматом размерности 'SSCB'. Для прогноза графического процессора также преобразуйте данные в gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Классифицировать изображения с помощью dlnetwork объект, используйте predict функционируйте и найдите классы с самыми высокими баллами.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Оцените точность классификации.

accuracy = mean(YPred==YTest)
accuracy = 0.9752

Функция градиентов модели

modelGradients функционируйте берет dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствием маркирует Y и возвращает потерю и градиенты потери относительно learnable параметров в dlnet. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)
    dlYPred = forward(dlnet,dlX);
    dlYPred = softmax(dlYPred);
    
    loss = crossentropy(dlYPred,Y);
    gradients = dlgradient(loss,dlnet.Learnables);
end

Входные параметры

свернуть все

Сеть, заданная как dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект. dlnet.Learnables таблица с тремя переменными:

  • Layer — Имя слоя, заданное как скаляр строки.

  • Parameter — Название параметра, заданное как скаляр строки.

  • Value — Значение параметра, заданного как массив ячеек, содержащий dlarray.

Входной параметр grad должна быть таблица той же формы как dlnet.Learnables.

Сетевые learnable параметры, заданные как dlarray, числовой массив, массив ячеек, структура или таблица.

Если вы задаете params как таблица, это должно содержать следующие три переменные.

  • Layer — Имя слоя, заданное как скаляр строки.

  • Parameter — Название параметра, заданное как скаляр строки.

  • Value — Значение параметра, заданного как массив ячеек, содержащий dlarray.

Можно задать params как контейнер learnable параметров для вашей сети с помощью массива ячеек, структуры, или таблицы, или вложенных массивов ячеек или структур. learnable параметрами в массиве ячеек, структуре или таблице должен быть dlarray или числовые значения типа данных double или single.

Входной параметр grad должен быть обеспечен точно совпадающим типом данных, упорядоченным расположением и полями (для структур) или переменные (для таблиц) как params.

Типы данных: single | double | struct | table | cell

Градиенты потери, заданной как dlarray, числовой массив, массив ячеек, структура или таблица.

Точная форма grad зависит от входа сетевые или learnable параметры. Следующая таблица показывает требуемый формат для grad для возможных входных параметров к rmspropupdate.

Входной параметрПараметры LearnableГрадиенты
dlnetТаблица dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый learnable параметр как dlarray. Таблица с совпадающим типом данных, переменными, и заказывающий как dlnet.Learnables. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого learnable параметра.
paramsdlarraydlarray с совпадающим типом данных и заказывающий как params
Числовой массивЧисловой массив с совпадающим типом данных и заказывающий как params
CellArrayМассив ячеек с совпадающими типами данных, структурой, и заказывающий как params
СтруктураСтруктура с совпадающими типами данных, полями, и заказывающий как params
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый learnable параметр как dlarray.Таблица с совпадающими типами данных, переменными, и заказывающий как params. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого learnable параметра.

Можно получить grad от вызова до dlfeval это выполняет функцию, которая содержит вызов dlgradient. Для получения дополнительной информации смотрите Использование Автоматическое Дифференцирование В Deep Learning Toolbox.

Скользящее среднее значение градиентов параметра в квадрате, заданных как пустой массив, dlarray, числовой массив, массив ячеек, структура или таблица.

Точная форма averageSqGrad зависит от входа сетевые или learnable параметры. Следующая таблица показывает требуемый формат для averageSqGrad для возможных входных параметров к rmspropupdate.

Входной параметрПараметры LearnableСредние градиенты в квадрате
dlnetТаблица dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый learnable параметр как dlarray. Таблица с совпадающим типом данных, переменными, и заказывающий как dlnet.Learnables. averageSqGrad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат средний градиент в квадрате каждого learnable параметра.
paramsdlarraydlarray с совпадающим типом данных и заказывающий как params
Числовой массивЧисловой массив с совпадающим типом данных и заказывающий как params
CellArrayМассив ячеек с совпадающими типами данных, структурой, и заказывающий как params
СтруктураСтруктура с совпадающими типами данных, полями, и заказывающий как params
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый learnable параметр как dlarray.Таблица с совпадающими типами данных, переменными, и заказывающий как params. averageSqGrad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат средний градиент в квадрате каждого learnable параметра.

Если вы задаете averageSqGrad как пустой массив, функция не принимает предыдущих градиентов и запусков таким же образом что касается первого обновления в серии итераций. Чтобы обновить learnable параметры итеративно, используйте averageSqGrad выход предыдущего вызова rmspropupdate как averageSqGrad входной параметр.

Глобальный темп обучения, заданный как положительная скалярная величина. Значение по умолчанию learnRate 0.001.

Если вы задаете сетевые параметры как dlnetwork, темп обучения для каждого параметра является глобальным темпом обучения, умноженным на соответствующее свойство фактора темпа обучения, заданное в сетевых слоях.

Фактор затухания градиента в квадрате, заданный как положительная скалярная величина между 0 и 1. Значение по умолчанию sqGradDecay 0.999.

Маленькая константа для предотвращения делит на нуль ошибки, заданные как положительная скалярная величина. Значение по умолчанию epsilon 1e-8.

Выходные аргументы

свернуть все

Сеть, возвращенная как dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект.

Обновленные сетевые learnable параметры, возвращенные как dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные learnable параметры сети.

Обновленное скользящее среднее значение градиентов параметра в квадрате, возвращенных как dlarray, числовой массив, массив ячеек, структура или таблица.

Больше о

свернуть все

RMSProp

Функция использует корневой среднеквадратический алгоритм распространения, чтобы обновить learnable параметры. Для получения дополнительной информации см. определение алгоритма RMSProp под Стохастическим Градиентным спуском на trainingOptions страница с описанием.

Введенный в R2019b