Функция градиентов модели Define для пользовательского учебного цикла

Когда обучение модель глубокого обучения с пользовательским учебным циклом, программное обеспечение минимизирует потерю относительно настраиваемых параметров. Чтобы минимизировать потерю, программное обеспечение использует градиенты потери относительно настраиваемых параметров. Чтобы вычислить эти градиенты с помощью автоматического дифференцирования, необходимо задать функцию градиентов модели.

Для примера, показывающего, как обучить модель глубокого обучения с dlnetwork возразите, смотрите, Обучат сеть Используя Пользовательский Учебный Цикл. Для примера, показывающего, как к обучению модель глубокого обучения, заданная как функция, смотрите, Обучат сеть Используя Функцию Модели.

Создайте функцию градиентов модели для моделей, заданных как dlnetwork Объект

Если вам задали модель глубокого обучения как dlnetwork объект, затем создайте функцию градиентов модели, которая берет dlnetwork возразите, как введено.

Для моделей, заданных как dlnetwork возразите, создайте функцию формы gradients = modelGradients(dlnet,dlX,T), где dlnet сеть, dlX содержит входные предикторы, T содержит цели и gradients содержит возвращенные градиенты. Опционально, можно передать дополнительные аргументы функции градиентов (например, если функция потерь запрашивает дополнительную информацию), или возвратите дополнительные аргументы (например, метрики для графического вывода процесса обучения).

Например, эта функция возвращает градиенты и потерю перекрестной энтропии для заданного dlnetwork объект dlnet, входные данные dlX, и цели T.

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

Создайте функцию градиентов модели для моделей, заданных как функция

Если вам задали модель глубокого обучения как функцию формы dlY = model(parameters,dlX), затем создайте функцию формы gradients = modelGradients(parameters,dlX,T), где parameters struct, содержащий настраиваемые параметры, dlX входные предикторы, T цели и gradients возвращенные градиенты. Опционально, можно передать дополнительные аргументы функции градиентов (например, если функция потерь запрашивает дополнительную информацию), или возвратите дополнительные аргументы (например, метрики для графического вывода процесса обучения). Для моделей, заданных как функция, вы не должны передавать сеть как входной параметр.

Например, эта функция возвращает градиенты и потерю перекрестной энтропии для функции модели глубокого обучения model с заданными настраиваемыми параметрами parameters, входные данные dlX, и цели T.

function [gradients, loss] = modelGradients(parameters, dlX, T)

    % Forward data through the model function.
    dlY = model(parameters,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,parameters);

end

Выполните функцию градиентов модели

Чтобы оценить градиенты модели с помощью автоматического дифференцирования, используйте dlfeval функция, которая выполняет функцию с автоматическим включенным дифференцированием. Для первого входа dlfeval, передайте определенный функцией указатель функции градиентов модели и для следующих входных параметров, передайте необходимые переменные для функции градиентов модели. Для выходных параметров dlfeval функционируйте, задайте те же выходные параметры как функция градиентов модели.

Например, чтобы оценить градиенты модели функционируют modelGradients с dlnetwork объект dlnet, входные данные dlX и T, и возвратите градиенты модели и потерю, используйте команду:

[gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

Точно так же, чтобы оценить градиенты модели функционируют modelGradients использование функции модели с настраиваемыми параметрами, заданными struct parameters, входные данные dlX и T, и возвратите градиенты модели и потерю, используйте команду:

[gradients, loss] = dlfeval(@modelGradients,parameters,dlX,T);

Обновите настраиваемые параметры Используя градиенты

Чтобы обновить настраиваемые параметры с помощью градиентов, можно использовать следующие функции:

ФункцияОписание
adamupdateОбновите параметры с помощью адаптивной оценки момента (Адам)
rmspropupdateОбновите параметры с помощью корневого среднеквадратического распространения (RMSProp)
sgdmupdateОбновите параметры с помощью стохастического градиентного спуска с импульсом (SGDM)
dlupdateОбновите параметры с помощью пользовательской функции

Например, чтобы обновить настраиваемые параметры dlnetwork объект dlnet использование adamupdate функция, используйте команду:

[dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
где gradients выход функции градиентов модели и trailingAvg, trailingAvgSq, и iteration гиперпараметры, требуемые adamupdate функция.

Точно так же обновить настраиваемые параметры для функционального parameters модели использование adamupdate функция, используйте команду:

[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
где gradients выход функции градиентов модели и trailingAvg, trailingAvgSq, и iteration гиперпараметры, требуемые adamupdate функция.

Используйте функцию градиентов модели в пользовательском учебном цикле

Когда обучение модель глубокого обучения использование пользовательского учебного цикла, оцените градиенты модели и обновите настраиваемые параметры для каждого мини-пакета.

Этот фрагмент кода показывает пример использования dlfeval и adamupdate функции в пользовательском учебном цикле.

iteration = 0;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;

        % Prepare mini-batch.
        % ...

        % Evaluate model gradients.
        [gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

        % Update learnable parameters.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAverageSq,iteration);

    end
end

Для примера, показывающего, как обучить модель глубокого обучения с dlnetwork возразите, смотрите, Обучат сеть Используя Пользовательский Учебный Цикл. Для примера, показывающего, как к обучению модель глубокого обучения, заданная как функция, смотрите, Обучат сеть Используя Функцию Модели.

Отладка функций градиентов модели

Если существует проблема в реализации функции градиентов модели, вызова dlfeval может выдать ошибку. Иногда, при использовании dlfeval функция, это не ясно, какая строка кода выдает ошибку. Чтобы помочь определить местоположение ошибки, можно попробовать следующее:

Вызовите функцию градиентов модели непосредственно

Попытайтесь вызвать функцию градиентов модели непосредственно (то есть, не используя dlfeval функция) со сгенерированными входными параметрами ожидаемых размеров. Если какая-либо из линий кода выдает ошибку, то это должно быть ясно, который сделал. Обратите внимание на то, что если не использование dlfeval функция, любые вызовы dlgradient функция ожидается к ошибке.

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

[gradients, loss] = modelGradients(dlnet,dlX,T);

Запустите код градиентов модели вручную

Запуститесь код в градиентах модели функционируют вручную со сгенерированными входными параметрами ожидаемых размеров и смотрят выход и любые выданные сообщения об ошибке.

Например, чтобы проверить функцию, определяемую градиентов модели:

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

запустите код:

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

% Check forward pass.
dlY = forward(dlnet,dlX);

% Check loss calculation.
loss = crossentropy(dlX,T)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте