dlupdate

Обновите параметры с помощью пользовательской функции

Описание

пример

dlnet = dlupdate(fun,dlnet) обновляет настраиваемые параметры dlnetwork объект dlnet путем оценки функционального fun с каждым настраиваемым параметром как вход. fun указатель на функцию к функции, которая берет один массив параметров в качестве входного параметра и возвращает обновленный массив параметров.

params = dlupdate(fun,params) обновляет настраиваемые параметры в params путем оценки функционального fun с каждым настраиваемым параметром как вход.

[___] = dlupdate(fun,___A1,...,An) также задает дополнительные входные параметры, в дополнение к входным параметрам в предыдущих синтаксисах, когда fun указатель на функцию к функции, которая требует n+1 входные значения.

[___,X1,...,Xm] = dlupdate(fun,___) возвращает несколько выходных параметров X1,...,Xm когда fun указатель на функцию к функции, которая возвращает m+1 выходные значения.

Примеры

свернуть все

Выполните регуляризацию L1 на структуре градиентов параметра.

Создайте демонстрационные входные данные.

dlX = dlarray(rand(100,100,3),'SSC');

Инициализируйте настраиваемые параметры для операции свертки.

params.Weights = dlarray(rand(10,10,3,50));
params.Bias = dlarray(rand(50,1));

Вычислите градиенты для операции свертки с помощью функции помощника convGradients, заданный в конце этого примера.

gradients = dlfeval(@convGradients,dlX,params);

Задайте фактор регуляризации.

L1Factor = 0.001;

Создайте анонимную функцию, которая упорядочивает градиенты. При помощи анонимной функции, чтобы передать скалярную константу функции, можно избежать необходимости расширять постоянное значение до того же размера и структуры как переменная параметра.

L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);

Используйте dlupdate чтобы применить регуляризацию функционируют к каждому из градиентов.

gradients = dlupdate(L1Regularizer,gradients,params);

Градиенты в grads теперь упорядочены согласно функциональному L1Regularizer.

convGradients Функция

convGradients функция помощника берет настраиваемые параметры операции свертки и мини-пакет входных данных dlX, и возвращает градиенты относительно настраиваемых параметров.

function gradients = convGradients(dlX,params)
dlY = dlconv(dlX,params.Weights,params.Bias);
dlY = sum(dlY,'all');
gradients = dlgradient(dlY,params);
end

Используйте dlupdate обучать сеть с помощью пользовательской функции обновления, которая реализует стохастический алгоритм градиентного спуска (без импульса).

Загрузите обучающие данные

Загрузите обучающие данные цифр.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Задайте сеть

Задайте сетевую архитектуру и задайте среднее значение изображений с помощью 'Mean' опция в изображении ввела слой.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект от графика слоев.

dlnet = dlnetwork(lgraph);

Функция градиентов модели Define

Создайте функцию помощника modelGradients, перечисленный в конце этого примера. Функция берет dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствием маркирует Y, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet.

Задайте стохастическую функцию градиентного спуска

Создайте функцию помощника sgdFunction, перечисленный в конце этого примера. Функция берет param и paramGradient, настраиваемый параметр и градиент потери относительно того параметра, соответственно, и возвращают обновленный параметр с помощью стохастического алгоритма градиентного спуска, описанного как

θl+1=θ-αE(θl)

где l номер итерации, α>0 скорость обучения, θ вектор параметра, и E(θ) функция потерь.

Задайте опции обучения

Задайте опции, чтобы использовать во время обучения.

miniBatchSize = 128;
numEpochs = 30;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Задайте скорость обучения.

learnRate = 0.01;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

executionEnvironment = "auto";

Визуализируйте процесс обучения в графике.

plots = "training-progress";

Обучение сети

Обучите модель с помощью пользовательского учебного цикла. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. Обновите сетевые параметры путем вызова dlupdate с функциональным sgdFunction заданный в конце этого примера. В конце каждой эпохи отобразите прогресс обучения.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Обучите сеть.

iteration = 0;
start = tic;

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to a gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients helper function.
        [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the SGD algorithm defined in
        % the sgdFunction helper function.
        updateFcn = @(dlnet,gradients) sgdFunction(dlnet,gradients,learnRate);
        dlnet = dlupdate(updateFcn,dlnet,gradients);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Тестирование сети

Протестируйте точность классификации модели путем сравнения предсказаний на наборе тестов с истинными метками.

[XTest, YTest] = digitTest4DArrayData;

Преобразуйте данные в dlarray с форматом размерности 'SSCB'. Для предсказания графического процессора также преобразуйте данные в gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Классифицировать изображения с помощью dlnetwork объект, используйте predict функционируйте и найдите классы с самыми высокими баллами.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Оцените точность классификации.

accuracy = mean(YPred==YTest)
accuracy = 0.9386

Функция градиентов модели

Функция помощника modelGradients берет dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствием маркирует Y, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)

dlYPred = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);

gradients = dlgradient(loss,dlnet.Learnables);

end

Стохастическая функция градиентного спуска

Функция помощника sgdFunction берет настраиваемый параметр parameter, градиенты того параметра относительно потери gradient, и скорость обучения learnRate, и возвращает обновленный параметр с помощью стохастического алгоритма градиентного спуска, описанного как

θl+1=θ-αE(θl)

где l номер итерации, α>0 скорость обучения, θ вектор параметра, и E(θ) функция потерь.

function parameter = sgdFunction(parameter,gradient,learnRate)

parameter = parameter - learnRate .* gradient;

end

Входные параметры

свернуть все

Функция, чтобы примениться к настраиваемым параметрам в виде указателя на функцию.

dlupdate оценивает fun с каждым сетевым настраиваемым параметром как вход. fun оценивается так же много раз, как существуют массивы настраиваемых параметров в dlnet или params.

Сеть в виде dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект. dlnet.Learnables таблица с тремя переменными:

  • Layer — Имя слоя в виде строкового скаляра.

  • Parameter — Название параметра в виде строкового скаляра.

  • Value — Значение параметра в виде массива ячеек, содержащего dlarray.

Сетевые настраиваемые параметры в виде dlarray, числовой массив, массив ячеек, структура или таблица.

Если вы задаете params как таблица, это должно содержать следующие три переменные.

  • Layer — Имя слоя в виде строкового скаляра.

  • Parameter — Название параметра в виде строкового скаляра.

  • Value — Значение параметра в виде массива ячеек, содержащего dlarray.

Можно задать params как контейнер настраиваемых параметров для вашей сети с помощью массива ячеек, структуры, или таблицы, или вложенных массивов ячеек или структур. Настраиваемыми параметрами в массиве ячеек, структуре или таблице должен быть dlarray или числовые значения типа данных double или single.

Входной параметр A1,...,An должен быть обеспечен точно совпадающим типом данных, упорядоченным расположением и полями (для структур) или переменные (для таблиц) как params.

Типы данных: single | double | struct | table | cell

Дополнительные входные параметры к funВ виде dlarray объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value переменная.

Точная форма A1,...,An зависит от входной сети или настраиваемых параметров. Следующая таблица показывает требуемый формат для A1,...,An для возможных входных параметров к dlupdate.

Входной параметрНастраиваемые параметрыA1,...,An
dlnetТаблица dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр как dlarray. Таблица с совпадающим типом данных, переменными, и заказывающий как dlnet.Learnables. A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительные входные параметры для функционального fun применяться к каждому настраиваемому параметру.
paramsdlarraydlarray с совпадающим типом данных и заказывающий как params.
Числовой массивЧисловой массив с совпадающим типом данных и заказывающий как params.
CellArrayМассив ячеек с совпадающими типами данных, структурой, и заказывающий как params.
СтруктураСтруктура с совпадающими типами данных, полями, и заказывающий как params.
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр как dlarray.Таблица с совпадающими типами данных, переменными и заказывающий как params. A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительный входной параметр для функционального fun применяться к каждому настраиваемому параметру.

Выходные аргументы

свернуть все

Сеть, возвращенная как dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект.

Обновленные сетевые настраиваемые параметры, возвращенные как dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные настраиваемые параметры сети.

Дополнительные выходные аргументы от функционального fun, где fun указатель на функцию к функции, которая возвращает несколько выходных параметров, возвращенных как dlarray объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value переменная.

Точная форма X1,...,Xm зависит от входной сети или настраиваемых параметров. Следующая таблица показывает возвращенный формат X1,...,Xm для возможных входных параметров к dlupdate.

Входной параметрНастраиваемые параметрыX1,...,Xm
dlnetТаблица dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр как dlarray. Таблица с совпадающим типом данных, переменными, и заказывающий как dlnet.Learnables. X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительные выходные аргументы функционального fun примененный каждый настраиваемый параметр.
paramsdlarraydlarray с совпадающим типом данных и заказывающий как params.
Числовой массивЧисловой массив с совпадающим типом данных и заказывающий как params.
CellArrayМассив ячеек с совпадающими типами данных, структурой, и заказывающий как params.
СтруктураСтруктура с совпадающими типами данных, полями, и заказывающий как params.
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр как dlarray.Таблица с совпадающими типами данных, переменными. и упорядоченное расположение как params. X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительный выходной аргумент функционального fun примененный каждый настраиваемый параметр.

Расширенные возможности

Введенный в R2019b