exponenta event banner

dlfeval

Оценка модели глубокого обучения для пользовательских циклов обучения

Описание

Использовать dlfeval оценка пользовательских моделей глубокого обучения для пользовательских циклов обучения.

Совет

Для выполнения большинства задач глубокого обучения можно использовать предварительно подготовленную сеть и адаптировать ее к собственным данным. Пример, показывающий, как использовать transfer learning для переподготовки сверточной нейронной сети для классификации нового набора изображений, см. в разделе Train Deep Learning Network to Classify New Images. Кроме того, можно создавать и обучать сети с нуля с помощью layerGraph объекты с trainNetwork и trainingOptions функции.

Если trainingOptions функция не предоставляет возможности обучения, необходимые для выполнения задачи, после чего можно создать индивидуальный цикл обучения с помощью автоматического дифференцирования. Дополнительные сведения см. в разделе Определение сети глубокого обучения для пользовательских циклов обучения.

пример

[y1,...,yk] = dlfeval(fun,x1,...,xn) оценивает функцию массива глубокого обучения fun на входных аргументах x1,...,xn. Функции, переданные dlfeval может содержать вызовы dlgradient, которые вычисляют градиенты из входных данных x1,...,xn с помощью автоматического дифференцирования.

Примеры

свернуть все

Функция Розенброка является стандартной тестовой функцией для оптимизации. rosenbrock.m вспомогательная функция вычисляет значение функции и использует автоматическое дифференцирование для вычисления ее градиента.

type rosenbrock.m
function [y,dydx] = rosenbrock(x)

y = 100*(x(2) - x(1).^2).^2 + (1 - x(1)).^2;
dydx = dlgradient(y,x);

end

Чтобы оценить функцию Розенброка и его градиент в точке [–1,2], создайте dlarray точки, а затем вызова dlfeval на дескрипторе функции @rosenbrock.

x0 = dlarray([-1,2]);
[fval,gradval] = dlfeval(@rosenbrock,x0)
fval = 
  1x1 dlarray

   104

gradval = 
  1x2 dlarray

   396   200

Либо определите функцию Розенброка как функцию двух входов, x1 и x2.

type rosenbrock2.m
function [y,dydx1,dydx2] = rosenbrock2(x1,x2)

y = 100*(x2 - x1.^2).^2 + (1 - x1).^2;
[dydx1,dydx2] = dlgradient(y,x1,x2);

end

Звонить dlfeval оценить rosenbrock2 на двух dlarray аргументы, представляющие входные данные –1 и 2.

x1 = dlarray(-1);
x2 = dlarray(2);
[fval,dydx1,dydx2] = dlfeval(@rosenbrock2,x1,x2)
fval = 
  1x1 dlarray

   104

dydx1 = 
  1x1 dlarray

   396

dydx2 = 
  1x1 dlarray

   200

Постройте график градиента функции Розенброка для нескольких точек единичного квадрата. Сначала инициализируйте массивы, представляющие точки оценки и выходные данные функции.

[X1 X2] = meshgrid(linspace(0,1,10));
X1 = dlarray(X1(:));
X2 = dlarray(X2(:));
Y = dlarray(zeros(size(X1)));
DYDX1 = Y;
DYDX2 = Y;

Вычислите функцию в цикле. Постройте график результата с помощью quiver.

for i = 1:length(X1)
    [Y(i),DYDX1(i),DYDX2(i)] = dlfeval(@rosenbrock2,X1(i),X2(i));
end
quiver(extractdata(X1),extractdata(X2),extractdata(DYDX1),extractdata(DYDX2))
xlabel('x1')
ylabel('x2')

Figure contains an axes. The axes contains an object of type quiver.

Входные аргументы

свернуть все

Вычисляемая функция, заданная как дескриптор функции. Если fun включает dlgradient позвоните, затем dlfeval оценивает градиент с помощью автоматического дифференцирования. В этой оценке градиента каждый аргумент dlgradient вызов должен быть dlarray или массив ячеек, структуру или таблицу, содержащую dlarray. Число входных аргументов для dlfeval должно совпадать с количеством входных аргументов для fun.

Пример: @rosenbrock

Типы данных: function_handle

Аргументы функции, указанные как любой тип данных MATLAB или dlnetwork объект.

Входной аргумент xj это переменная дифференциации в dlgradient вызов должен быть отслеживаемым dlarray или массив ячеек, структуру или таблицу, содержащую трассированный dlarray. Дополнительная переменная, такая как гиперпараметр или массив данных констант, не обязательно должна быть dlarray.

Чтобы оценить градиенты для глубокого обучения, вы можете предоставить dlnetwork объект в качестве аргумента функции и оценка прямого прохода сети внутри fun.

Пример: dlarray([1 2;3 4])

Типы данных: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | logical | char | string | struct | table | cell | function_handle | categorical | datetime | duration | calendarDuration | fi

Выходные аргументы

свернуть все

Вывод функции, возвращаемый как любой тип данных. Если выходные данные являются результатом dlgradient вызов, выходные данные являются dlarray.

Совет

Представлен в R2019b