dlupdate

Обновляйте параметры с помощью пользовательской функции

Описание

пример

dlnet = dlupdate(fun,dlnet) обновляет настраиваемые параметры dlnetwork dlnet объекта путем оценки функции fun с каждым настраиваемым параметром в качестве входов. fun - указатель на функцию функции, который принимает один массив параметров в качестве входного параметра и возвращает обновленный массив параметров.

params = dlupdate(fun,params) обновляет настраиваемые параметры в params путем оценки функции fun с каждым настраиваемым параметром в качестве входов.

[___] = dlupdate(fun,___A1,...,An) также задает дополнительные входные параметры, в дополнение к входным параметрам в предыдущих синтаксисах, когда fun является указателем на функцию, которая требует n+1 входные значения.

[___,X1,...,Xm] = dlupdate(fun,___) возвращает несколько выходов X1,...,Xm когда fun - указатель на функцию, который возвращает m+1 выходные значения.

Примеры

свернуть все

Выполните L1 регуляризацию структуры градиентов параметров.

Создайте образец входных данных.

dlX = dlarray(rand(100,100,3),'SSC');

Инициализируйте настраиваемые параметры для операции свертки.

params.Weights = dlarray(rand(10,10,3,50));
params.Bias = dlarray(rand(50,1));

Вычислите градиенты для операции свертки с помощью функции helper convGradients, заданный в конце этого примера.

gradients = dlfeval(@convGradients,dlX,params);

Определите коэффициент регуляризации.

L1Factor = 0.001;

Создайте анонимную функцию, которая регулирует градиенты. При помощи анонимной функции, чтобы передать скаляру константу в функцию, можно избежать необходимости расширения постоянного значения до тех же размера и структуры, что и переменная параметра.

L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);

Использование dlupdate применить функцию регуляризации к каждому из градиентов.

gradients = dlupdate(L1Regularizer,gradients,params);

Градиенты в grads теперь регулируются в соответствии с функцией L1Regularizer.

convGradients Функция

The convGradients Функция helper принимает настраиваемые параметры операции свертки и мини-пакет входных данных dlX, и возвращает градиенты относительно настраиваемых параметров.

function gradients = convGradients(dlX,params)
dlY = dlconv(dlX,params.Weights,params.Bias);
dlY = sum(dlY,'all');
gradients = dlgradient(dlY,params);
end

Использование dlupdate обучить сеть с помощью пользовательской функции обновления, которая реализует алгоритм стохастического градиентного спуска (без импульса).

Загрузка обучающих данных

Загрузите обучающие данные цифр.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Определите сеть

Определите сетевую архитектуру и задайте среднее значение изображения с помощью 'Mean' опция на входном слое изображения.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект из графика слоев.

dlnet = dlnetwork(lgraph);

Задайте функцию градиентов модели

Создайте вспомогательную функцию modelGradients, перечисленный в конце этого примера. Функция принимает dlnetwork dlnet объекта и мини-пакет входных данных dlX с соответствующими метками Y, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet.

Задайте функцию стохастического градиентного спуска

Создайте вспомогательную функцию sgdFunction, перечисленный в конце этого примера. Функция принимает param и paramGradient, настраиваемый параметр и градиент потерь относительно этого параметра, соответственно, и возвращает обновленный параметр, используя алгоритм стохастического градиентного спуска, выраженный как

θl+1=θ-αE(θl)

где l - число итерации, α>0 является скорость обучения, θ является вектором параметра, и E(θ) - функция потерь.

Настройка опций обучения

Задайте опции, которые будут использоваться во время обучения.

miniBatchSize = 128;
numEpochs = 30;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Задайте скорость обучения.

learnRate = 0.01;

Обучите на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

executionEnvironment = "auto";

Визуализируйте процесс обучения на графике.

plots = "training-progress";

Обучите сеть

Обучите модель с помощью пользовательского цикла обучения. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. Обновляйте параметры сети путем вызова dlupdate с функцией sgdFunction заданное в конце этого примера. В конце каждой эпохи отобразите процесс обучения.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Обучите сеть.

iteration = 0;
start = tic;

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to a gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients helper function.
        [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the SGD algorithm defined in
        % the sgdFunction helper function.
        updateFcn = @(dlnet,gradients) sgdFunction(dlnet,gradients,learnRate);
        dlnet = dlupdate(updateFcn,dlnet,gradients);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Тестирование сети

Протестируйте классификационную точность модели путем сравнения предсказаний на тестовом наборе с истинными метками.

[XTest, YTest] = digitTest4DArrayData;

Преобразуйте данные в dlarray с форматом размерности 'SSCB'. Для предсказания графический процессор также преобразуйте данные в gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Чтобы классифицировать изображения с помощью dlnetwork объект, используйте predict функция и найти классы с самыми высокими счетами.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Оцените точность классификации.

accuracy = mean(YPred==YTest)
accuracy = 0.9386

Функция градиентов модели

Функция помощника modelGradients принимает dlnetwork dlnet объекта и мини-пакет входных данных dlX с соответствующими метками Y, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)

dlYPred = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);

gradients = dlgradient(loss,dlnet.Learnables);

end

Функция стохастического градиентного спуска

Функция помощника sgdFunction принимает настраиваемый параметр parameter, градиенты этого параметра относительно потерь gradient, и скорость обучения learnRate, и возвращает обновленный параметр, используя алгоритм стохастического градиентного спуска, выраженный как

θl+1=θ-αE(θl)

где l - число итерации, α>0 является скорость обучения, θ является вектором параметра, и E(θ) - функция потерь.

function parameter = sgdFunction(parameter,gradient,learnRate)

parameter = parameter - learnRate .* gradient;

end

Входные параметры

свернуть все

Функция для применения к настраиваемым параметрам, заданная как указатель на функцию.

dlupate оценивает fun с каждым сетевым настраиваемым параметром в качестве входа. fun оценивается столько раз, сколько существует массивов настраиваемых параметров в dlnet или params.

Сеть, заданная как dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект. dlnet.Learnables - таблица с тремя переменными:

  • Layer - Имя слоя, заданное как строковый скаляр.

  • Parameter - Имя параметра, заданное как строковый скаляр.

  • Value - Значение параметра, заданное как массив ячеек, содержащий dlarray.

Сетевые настраиваемые параметры, заданная как dlarray, числовой массив, массив ячеек, структура или таблица.

Если вы задаете params как таблица, она должна содержать следующие три переменные.

  • Layer - Имя слоя, заданное как строковый скаляр.

  • Parameter - Имя параметра, заданное как строковый скаляр.

  • Value - Значение параметра, заданное как массив ячеек, содержащий dlarray.

Можно задать params как контейнер настраиваемых параметров для сети с помощью массива ячеек, структуры или таблицы или вложенных массивов ячеек или структур. Настраиваемые параметры в массиве ячеек, структуре или таблице должны быть dlarray или числовые значения типа данных double или single.

Входной параметр grad должны быть снабжены точно совпадающим типом данных, упорядоченным расположением и полями (для структур) или переменными (для таблиц), как params.

Типы данных: single | double | struct | table | cell

Дополнительные входные параметры для fun, заданный как dlarray объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value переменная.

Точная форма A1,...,An зависит от входа сети или настраиваемых параметров. В следующей таблице показан необходимый формат для A1,...,An для возможных входов в dlupdate.

ВходНастраиваемые параметрыA1,...,An
dlnetТабличные dlnet.Learnables содержащие Layer, Parameter, и Value переменные. The Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray. Таблица с совпадающим типом данных, переменными и порядком, что и dlnet.Learnables. A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительные входные параметры для функции fun применить к каждому настраиваемому параметру.
paramsdlarraydlarray с совпадающим типом данных и порядком, что и params
Числовой массивЧисловой массив с совпадающим типом данных и порядком, что и params
Массив ячеекМассив ячеек с совпадающими типами данных, структурой и порядком, как params
СтруктураСтруктура с совпадающими типами данных, полями и порядками, что и params
Таблица с Layer, Parameter, и Value переменные. The Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray.Таблица с совпадающими типами данных, переменными и порядком, что и params. A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительный входной параметр для функции fun применить к каждому настраиваемому параметру.

Выходные аргументы

свернуть все

Сеть, возвращается как dlnetwork объект.

Функция обновляет dlnet.Learnables свойство dlnetwork объект.

Обновлённые сетевые настраиваемые параметры, возвращенный как dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные настраиваемые параметры сети.

Дополнительные выходные аргументы от функции fun, где fun - указатель на функцию, который возвращает несколько выходы, возвращаемых как dlarray объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value переменная.

Точная форма X1,...,Xm зависит от входа сети или настраиваемых параметров. В следующей таблице показан возвращенный формат X1,...,Xm для возможных входов в dlupdate.

ВходНастраиваемые параметрыX1,...,Xm
dlnetТабличные dlnet.Learnables содержащие Layer, Parameter, и Value переменные. The Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray. Таблица с совпадающим типом данных, переменными и порядком, что и dlnet.Learnables. X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительные выходные аргументы функции fun применяется к каждому настраиваемому параметру.
paramsdlarraydlarray с совпадающим типом данных и порядком, что и params
Числовой массивЧисловой массив с совпадающим типом данных и порядком, что и params
Массив ячеекМассив ячеек с совпадающими типами данных, структурой и порядком, как params
СтруктураСтруктура с совпадающими типами данных, полями и порядками, что и params
Таблица с Layer, Parameter, и Value переменные. The Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray.Таблица с совпадающими типами данных, переменными. и упорядоченное расположение как params. X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительный выходной аргумент функции fun применяется к каждому настраиваемому параметру.

Расширенные возможности

Введенный в R2019b
Для просмотра документации необходимо авторизоваться на сайте