exponenta event banner

dlupdate

Обновление параметров с помощью пользовательской функции

Описание

пример

dlnet = dlupdate(fun,dlnet) обновляет обучаемые параметры dlnetwork объект dlnet путем оценки функции fun с каждым обучаемым параметром в качестве входных данных. fun является дескриптором функции, который принимает один массив параметров в качестве входного аргумента и возвращает обновленный массив параметров.

params = dlupdate(fun,params) обновляет обучаемые параметры в params путем оценки функции fun с каждым обучаемым параметром в качестве входных данных.

[___] = dlupdate(fun,___A1,...,An) также указывает дополнительные входные аргументы, в дополнение к входным аргументам в предыдущих синтаксисах, когда fun является дескриптором функции для функции, которая требует n+1 входные значения.

[___,X1,...,Xm] = dlupdate(fun,___) возвращает несколько выходов X1,...,Xm когда fun - дескриптор функции для функции, возвращающей m+1 выходные значения.

Примеры

свернуть все

Выполните регуляризацию L1 для структуры градиентов параметров.

Создайте образец входных данных.

dlX = dlarray(rand(100,100,3),'SSC');

Инициализируйте обучаемые параметры для операции свертки.

params.Weights = dlarray(rand(10,10,3,50));
params.Bias = dlarray(rand(50,1));

Расчет градиентов для операции свертки с помощью вспомогательной функции convGradients, определяется в конце этого примера.

gradients = dlfeval(@convGradients,dlX,params);

Определите коэффициент регуляризации.

L1Factor = 0.001;

Создайте анонимную функцию, регулирующую градиенты. Используя анонимную функцию для передачи скалярной константы функции, можно избежать необходимости расширения значения константы до того же размера и структуры, что и переменная параметра.

L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);

Использовать dlupdate для применения функции регуляризации к каждому из градиентов.

gradients = dlupdate(L1Regularizer,gradients,params);

Градиенты в grads теперь регуляризованы в соответствии с функцией L1Regularizer.

convGradients Функция

convGradients вспомогательная функция принимает обучаемые параметры операции свертки и мини-пакет входных данных dlXи возвращает градиенты относительно обучаемых параметров.

function gradients = convGradients(dlX,params)
dlY = dlconv(dlX,params.Weights,params.Bias);
dlY = sum(dlY,'all');
gradients = dlgradient(dlY,params);
end

Использовать dlupdate для обучения сети с помощью пользовательской функции обновления, реализующей алгоритм стохастического градиентного спуска (без импульса).

Загрузка данных обучения

Загрузите данные обучения по цифрам.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Определение сети

Определите сетевую архитектуру и укажите среднее значение изображения с помощью 'Mean' в слое ввода изображения.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создать dlnetwork объект из графа слоев.

dlnet = dlnetwork(lgraph);

Определение функции градиентов модели

Создание вспомогательной функции modelGradients, перечисленных в конце этого примера. Функция принимает dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Yи возвращает потери и градиенты потерь относительно обучаемых параметров в dlnet.

Определение функции стохастического градиентного спуска

Создание вспомогательной функции sgdFunction, перечисленных в конце этого примера. Функция принимает param и paramGradient, обучаемый параметр и градиент потерь по отношению к этому параметру, соответственно, и возвращает обновленный параметр с использованием алгоритма стохастического градиентного спуска, выраженного как

θl+1=θ-α∇E (startl)

где l - число итераций, α > 0 - скорость обучения ,

Укажите параметры обучения

Укажите параметры для использования во время обучения.

miniBatchSize = 128;
numEpochs = 30;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Укажите скорость обучения.

learnRate = 0.01;

Обучение на GPU, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

executionEnvironment = "auto";

Визуализация хода обучения на графике.

plots = "training-progress";

Железнодорожная сеть

Обучение модели с помощью пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Обновление параметров сети путем вызова dlupdate с функцией sgdFunction определено в конце этого примера. В конце каждой эпохи отобразите ход обучения.

Инициализируйте график хода обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Обучение сети.

iteration = 0;
start = tic;

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to a gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients helper function.
        [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the SGD algorithm defined in
        % the sgdFunction helper function.
        updateFcn = @(dlnet,gradients) sgdFunction(dlnet,gradients,learnRate);
        dlnet = dlupdate(updateFcn,dlnet,gradients);
        
        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end

Тестовая сеть

Проверьте точность классификации модели путем сравнения прогнозов на тестовом наборе с истинными метками.

[XTest, YTest] = digitTest4DArrayData;

Преобразование данных в dlarray с форматом размера 'SSCB'. Для прогнозирования GPU также преобразуйте данные в gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

Классификация изображений с помощью dlnetwork объект, используйте predict и найти классы с самыми высокими баллами.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Оцените точность классификации.

accuracy = mean(YPred==YTest)
accuracy = 0.9386

Функция градиентов модели

Вспомогательная функция modelGradients принимает dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Yи возвращает потери и градиенты потерь относительно обучаемых параметров в dlnet. Для автоматического вычисления градиентов используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)

dlYPred = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);

gradients = dlgradient(loss,dlnet.Learnables);

end

Функция стохастического градиентного спуска

Вспомогательная функция sgdFunction принимает обучаемый параметр parameter, градиенты этого параметра относительно потерь gradientи скорость обучения learnRateи возвращает обновленный параметр с использованием алгоритма стохастического градиентного спуска, выраженного как

θl+1=θ-α∇E (startl)

где l - число итераций, α > 0 - скорость обучения ,

function parameter = sgdFunction(parameter,gradient,learnRate)

parameter = parameter - learnRate .* gradient;

end

Входные аргументы

свернуть все

Функция для применения к обучаемым параметрам, заданным как дескриптор функции.

dlupate оценивает fun с каждым сетевым обучаемым параметром в качестве входных данных. fun оценивается столько раз, сколько существует массивов обучаемых параметров в dlnet или params.

Сеть, указанная как dlnetwork объект.

Функция обновляет dlnet.Learnables имущества dlnetwork объект. dlnet.Learnables - таблица с тремя переменными:

  • Layer - имя слоя, указанное как строковый скаляр.

  • Parameter - имя параметра, указанное как строковый скаляр.

  • Value - значение параметра, указанное как массив ячеек, содержащий dlarray.

Сетевые обучаемые параметры, указанные как dlarray, числовой массив, массив ячеек, структуру или таблицу.

При указании params в качестве таблицы она должна содержать следующие три переменные.

  • Layer - имя слоя, указанное как строковый скаляр.

  • Parameter - имя параметра, указанное как строковый скаляр.

  • Value - значение параметра, указанное как массив ячеек, содержащий dlarray.

Можно указать params в качестве контейнера обучаемых параметров для сети с использованием массива ячеек, структуры или таблицы или вложенных массивов или структур ячеек. Доступные для изучения параметры в массиве, структуре или таблице ячейки должны быть dlarray или числовые значения типа данных double или single.

Входной аргумент grad должны иметь точно такой же тип данных, порядок и поля (для структур) или переменные (для таблиц), как params.

Типы данных: single | double | struct | table | cell

Дополнительные входные аргументы для fun, указано как dlarray объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value переменная.

Точная форма A1,...,An зависит от входной сети или обучаемых параметров. В следующей таблице показан требуемый формат для A1,...,An для возможных входных данных dlupdate.

ВходОбучаемые параметрыA1,...,An
dlnetСтол dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. Таблица с тем же типом данных, переменными и порядком, что и dlnet.Learnables. A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительные входные аргументы для функции fun применяется к каждому обучаемому параметру.
paramsdlarraydlarray с тем же типом данных и порядком, что и params
Числовой массивЧисловой массив с тем же типом данных и порядком, что и params
Массив ячеекМассив ячеек с теми же типами данных, структурой и порядком, что и params
СтруктураСтруктура с теми же типами данных, полями и порядком, что и params
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray.Таблица с теми же типами данных, переменными и порядком, что и params. A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительный входной аргумент для функции fun применяется к каждому обучаемому параметру.

Выходные аргументы

свернуть все

Сеть, возвращенная как dlnetwork объект.

Функция обновляет dlnet.Learnables имущества dlnetwork объект.

Обновленные сетевые обучаемые параметры, возвращенные в виде dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные обучаемые параметры сети.

Дополнительные выходные аргументы функции fun, где fun - дескриптор функции для функции, которая возвращает несколько выходов, возвращаемых как dlarray объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value переменная.

Точная форма X1,...,Xm зависит от входной сети или обучаемых параметров. В следующей таблице показан возвращенный формат X1,...,Xm для возможных входных данных dlupdate.

ВходОбучаемые параметрыX1,...,Xm
dlnetСтол dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. Таблица с тем же типом данных, переменными и порядком, что и dlnet.Learnables. X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительные выходные аргументы функции fun применяется к каждому обучаемому параметру.
paramsdlarraydlarray с тем же типом данных и порядком, что и params
Числовой массивЧисловой массив с тем же типом данных и порядком, что и params
Массив ячеекМассив ячеек с теми же типами данных, структурой и порядком, что и params
СтруктураСтруктура с теми же типами данных, полями и порядком, что и params
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray.Таблица с одинаковыми типами данных, переменные. и заказ как params. X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительный выходной аргумент функции fun применяется к каждому обучаемому параметру.

Расширенные возможности

Представлен в R2019b