dlgradient

Вычислите градиенты для пользовательских учебных циклов с помощью автоматического дифференцирования

Описание

Используйте dlgradient вычислить производные с помощью автоматического дифференцирования в пользовательских учебных циклах.

Совет

Для большинства задач глубокого обучения можно использовать предварительно обученную сеть и адаптировать ее к собственным данным. Для примера, показывающего, как использовать передачу, учащуюся переобучать сверточную нейронную сеть, чтобы классифицировать новый набор изображений, смотрите, Обучают Нейронную сеть для глубокого обучения Классифицировать Новые Изображения. В качестве альтернативы можно создать и обучить нейронные сети с нуля с помощью layerGraph объекты с trainNetwork и trainingOptions функции.

Если trainingOptions функция не обеспечивает опции обучения, в которых вы нуждаетесь для своей задачи, затем можно создать пользовательский учебный цикл с помощью автоматического дифференцирования. Чтобы узнать больше, смотрите, Задают Пользовательские Учебные Циклы.

пример

[dydx1,...,dydxk] = dlgradient(y,x1,...,xk) возвращает градиенты y относительно переменных x1 через xk.

Вызовите dlgradient из функции, переданной dlfeval. Смотрите вычисляют градиент Используя автоматическое дифференцирование и используют автоматическое дифференцирование в Deep Learning Toolbox.

[dydx1,...,dydxk] = dlgradient(y,x1,...,xk,'RetainData',true) заставляет градиент сохранять промежуточные значения для повторного использования в последующем dlgradient вызовы. Этот синтаксис может сэкономить время, но использует больше памяти. Смотрите Советы.

Примеры

свернуть все

Функция Розенброка является стандартной тестовой функцией для оптимизации. rosenbrock.m функция помощника вычисляет значение функции и использует автоматическое дифференцирование, чтобы вычислить его градиент.

type rosenbrock.m
function [y,dydx] = rosenbrock(x)

y = 100*(x(2) - x(1).^2).^2 + (1 - x(1)).^2;
dydx = dlgradient(y,x);

end

Выполнять функцию Розенброка и ее градиент в точке [–1,2], создайте dlarray из точки и затем вызывают dlfeval на указателе на функцию @rosenbrock.

x0 = dlarray([-1,2]);
[fval,gradval] = dlfeval(@rosenbrock,x0)
fval = 
  1×1 dlarray

   104

gradval = 
  1×2 dlarray

   396   200

В качестве альтернативы задайте функцию Розенброка как функцию двух входных параметров, x1 и x2.

type rosenbrock2.m
function [y,dydx1,dydx2] = rosenbrock2(x1,x2)

y = 100*(x2 - x1.^2).^2 + (1 - x1).^2;
[dydx1,dydx2] = dlgradient(y,x1,x2);

end

Вызовите dlfeval оценивать rosenbrock2 на двух dlarray аргументы, представляющие входные параметры –1 и 2.

x1 = dlarray(-1);
x2 = dlarray(2);
[fval,dydx1,dydx2] = dlfeval(@rosenbrock2,x1,x2)
fval = 
  1×1 dlarray

   104

dydx1 = 
  1×1 dlarray

   396

dydx2 = 
  1×1 dlarray

   200

Постройте градиент функции Розенброка для нескольких точек в модульном квадрате. Во-первых, инициализируйте массивы, представляющие точки оценки и выход функции.

[X1 X2] = meshgrid(linspace(0,1,10));
X1 = dlarray(X1(:));
X2 = dlarray(X2(:));
Y = dlarray(zeros(size(X1)));
DYDX1 = Y;
DYDX2 = Y;

Выполните функцию в цикле. Постройте результат с помощью quiver.

for i = 1:length(X1)
    [Y(i),DYDX1(i),DYDX2(i)] = dlfeval(@rosenbrock2,X1(i),X2(i));
end
quiver(extractdata(X1),extractdata(X2),extractdata(DYDX1),extractdata(DYDX2))
xlabel('x1')
ylabel('x2')

Входные параметры

свернуть все

Переменная, чтобы дифференцироваться, заданный как скалярный dlarray объект. Для дифференцирования, y должна быть прослеженная функция dlarray входные параметры (см. Прослеженный dlarray) и должны состоять из поддерживаемых функций для dlarray (ee Список Функций с Поддержкой dlarray).

Пример: 100*(x(2) - x(1).^2).^2 + (1 - x(1)).^2

Пример: relu(X)

Переменная в функции, заданной как dlarray объект, массив ячеек, структура или таблица, содержащая dlarray объекты или любая комбинация таких аргументов рекурсивно. Например, аргумент может быть массивом ячеек, содержащим массив ячеек, который содержит структуру, содержащую dlarray объекты.

Пример: dlarray([1 2;3 4])

Типы данных: single | double | logical | struct | cell

Индикатор для сохранения данных о трассировке во время вызова функции, заданного как false или true. Когда этим аргументом является false, dlarray сразу отбрасывает производную трассировку после вычисления производной. Когда этим аргументом является true, dlarray сохраняет производную трассировку до конца dlfeval вызов функции, который оценивает dlgradient. true установка полезна только когда dlfeval вызов содержит больше чем один dlgradient вызвать. true установка заставляет программное обеспечение использовать больше памяти, но может сэкономить время когда несколько dlgradient вызовы используют, по крайней мере, часть той же трассировки.

Пример: dydx = dlgradient(y,x,'RetainData',true)

Типы данных: логический

Выходные аргументы

свернуть все

Градиент, возвращенный как dlarray объект, или массив ячеек, структура или таблица, содержащая dlarray объекты или любая комбинация таких аргументов рекурсивно. Размер и тип данных dydx совпадают с теми из связанной входной переменной x.

Больше о

свернуть все

Прослеженный dlarray

Во время расчета функции, dlarray внутренне записывает шаги, сделанные в trace, включая реверсному режиму автоматическое дифференцирование. Трассировка происходит в dlfeval вызвать. Смотрите Автоматический Фон Дифференцирования.

Советы

  • dlgradient не поддерживает производные высшего порядка. Другими словами, вы не можете передать выход dlgradient вызовите в другой dlgradient вызвать.

  • dlgradient вызов должен быть в функции. Чтобы получить числовое значение градиента, необходимо выполнить функцию с помощью dlfeval, и аргументом к функции должен быть dlarray. Смотрите использование автоматическое дифференцирование в Deep Learning Toolbox.

  • Включить правильную оценку градиентов, y аргумент должен использовать только поддерживаемые функции в dlarray. См. Список Функций с Поддержкой dlarray.

  • Если вы устанавливаете 'RetainData' аргумент пары "имя-значение" true, трассировка консервов программного обеспечения на время dlfeval вызов функции вместо того, чтобы сразу стереть трассировку после производного расчета. Это сохранение может вызвать последующий dlgradient вызовите в том же dlfeval вызовите, чтобы быть выполненными быстрее, но использует больше памяти. Например, в обучении соперничающей сети, 'RetainData' установка полезна, потому что эти две сети осуществляют обмен данными и функции во время обучения. Смотрите Обучают Порождающую соперничающую сеть (GAN).

Введенный в R2019b