Функция градиентов модели Define для пользовательского учебного цикла

Когда вы обучаете модель глубокого обучения с пользовательским учебным циклом, программное обеспечение минимизирует потерю относительно настраиваемых параметров. Чтобы минимизировать потерю, программное обеспечение использует градиенты потери относительно настраиваемых параметров. Чтобы вычислить эти градиенты с помощью автоматического дифференцирования, необходимо задать функцию градиентов модели.

Для примера, показывающего, как обучить модель глубокого обучения с dlnetwork возразите, смотрите, Обучат сеть Используя Пользовательский Учебный Цикл. Для примера, показывающего, как к обучению модель глубокого обучения, заданная как функция, смотрите, Обучат сеть Используя Функцию Модели.

Создайте функцию градиентов модели для модели, заданной как dlnetwork Объект

Если вам задали модель глубокого обучения как dlnetwork объект, затем создайте функцию градиентов модели, которая берет dlnetwork возразите, как введено.

Для модели, заданной как dlnetwork возразите, создайте функцию формы gradients = modelGradients(dlnet,dlX,T), где dlnet сеть, dlX сетевой вход, T содержит цели и gradients содержит возвращенные градиенты. Опционально, можно передать дополнительные аргументы функции градиентов (например, если функция потерь запрашивает дополнительную информацию), или возвратите дополнительные аргументы (например, метрики для графического вывода процесса обучения).

Например, эта функция возвращает градиенты и потерю перекрестной энтропии для заданного dlnetwork объект dlnet, входные данные dlX, и цели T.

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

Создайте функцию градиентов модели для модели, заданной как функция

Если вам задали модель глубокого обучения как функцию, то создаете функцию градиентов модели, которая берет настраиваемые параметры модели, как введено.

Для определенной функцией модели создайте функцию формы gradients = modelGradients(parameters,dlX,T), где parameters содержит настраиваемые параметры, dlX вход модели, T содержит цели и gradients содержит возвращенные градиенты. Опционально, можно передать дополнительные аргументы функции градиентов (например, если функция потерь запрашивает дополнительную информацию), или возвратите дополнительные аргументы (например, метрики для графического вывода процесса обучения).

Например, эта функция возвращает градиенты и потерю перекрестной энтропии для функции модели глубокого обучения model с заданными настраиваемыми параметрами parameters, входные данные dlX, и цели T.

function [gradients, loss] = modelGradients(parameters, dlX, T)

    % Forward data through the model function.
    dlY = model(parameters,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,parameters);

end

Выполните функцию градиентов модели

Чтобы оценить градиенты модели с помощью автоматического дифференцирования, используйте dlfeval функция, которая выполняет функцию с автоматическим включенным дифференцированием. Для первого входа dlfeval, передайте определенный функцией указатель функции градиентов модели. Для следующих входных параметров передайте необходимые переменные для функции градиентов модели. Для выходных параметров dlfeval функционируйте, задайте те же выходные параметры как функция градиентов модели.

Например, выполните функцию градиентов модели modelGradients с dlnetwork объект dlnet, входные данные dlX, и цели T, и возвратите градиенты модели и потерю.

[gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

Точно так же выполните функцию градиентов модели modelGradients использование функции модели с настраиваемыми параметрами, заданными структурой parameters, входные данные dlX, и цели T, и возвратите градиенты модели и потерю.

[gradients, loss] = dlfeval(@modelGradients,parameters,dlX,T);

Обновите настраиваемые параметры Используя градиенты

Чтобы обновить настраиваемые параметры с помощью градиентов, можно использовать следующие функции.

ФункцияОписание
adamupdateОбновите параметры с помощью адаптивной оценки момента (Адам)
rmspropupdateОбновите параметры с помощью корневого среднеквадратического распространения (RMSProp)
sgdmupdateОбновите параметры с помощью стохастического градиентного спуска с импульсом (SGDM)
dlupdateОбновите параметры с помощью пользовательской функции

Например, обновите настраиваемые параметры dlnetwork объект dlnet использование adamupdate функция.

[dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
gradients выход функции градиентов модели и trailingAvg, trailingAvgSq, и iteration гиперпараметры, требуемые adamupdate функция.

Точно так же обновите настраиваемые параметры для функционального parameters модели использование adamupdate функция.

[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
gradients выход функции градиентов модели и trailingAvg, trailingAvgSq, и iteration гиперпараметры, требуемые adamupdate функция.

Используйте функцию градиентов модели в пользовательском учебном цикле

Когда обучение модель глубокого обучения использование пользовательского учебного цикла, оцените градиенты модели и обновите настраиваемые параметры для каждого мини-пакета.

Этот фрагмент кода показывает пример использования dlfeval и adamupdate функции в пользовательском учебном цикле.

iteration = 0;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;

        % Prepare mini-batch.
        % ...

        % Evaluate model gradients.
        [gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

        % Update learnable parameters.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAverageSq,iteration);

    end
end

Для примера, показывающего, как обучить модель глубокого обучения с dlnetwork возразите, смотрите, Обучат сеть Используя Пользовательский Учебный Цикл. Для примера, показывающего, как к обучению модель глубокого обучения, заданная как функция, смотрите, Обучат сеть Используя Функцию Модели.

Отладьте функции градиентов модели

Если реализация функции градиентов модели имеет проблему, то вызов dlfeval может выдать ошибку. Иногда, когда вы используете dlfeval функция, это не ясно, какая строка кода выдает ошибку. Чтобы помочь определить местоположение ошибки, можно попробовать следующее.

Вызовите функцию градиентов модели непосредственно

Попытайтесь вызвать функцию градиентов модели непосредственно (то есть, не используя dlfeval функция) со сгенерированными входными параметрами ожидаемых размеров. Если какая-либо из линий кода выдает ошибку, то сообщение об ошибке обеспечивает дополнительную деталь. Обратите внимание на то, что, когда вы не используете dlfeval функция, любые вызовы dlgradient функционируйте выдают ошибку.

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

[gradients, loss] = modelGradients(dlnet,dlX,T);

Запустите код градиентов модели вручную

Запуститесь код в градиентах модели функционируют вручную со сгенерированными входными параметрами ожидаемых размеров и смотрят выход и любые выданные сообщения об ошибке.

Например, рассмотрите следующую функцию градиентов модели.

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

Проверяйте функцию градиентов модели путем выполнения следующего кода.

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

% Check forward pass.
dlY = forward(dlnet,dlX);

% Check loss calculation.
loss = crossentropy(dlX,T)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте