Задайте функцию градиентов модели для пользовательского цикла обучения

Когда вы обучаете модель глубокого обучения с помощью пользовательского цикла обучения, программное обеспечение минимизирует потери относительно настраиваемых параметров. Чтобы минимизировать потери, программное обеспечение использует градиенты потерь относительно настраиваемых параметров. Чтобы вычислить эти градиенты с помощью автоматической дифференциации, необходимо задать функцию градиентов модели.

Для примера, показывающего, как обучить модель глубокого обучения с dlnetwork объект, см. Train сети с использованием пользовательского цикла обучения. Для примера, показывающего, как обучить модель глубокого обучения, заданную как функция, смотрите Обучите сеть Используя Функцию Модели.

Создайте функцию градиентов модели для модели, заданной как dlnetwork Объект

Если у вас есть модель глубокого обучения, заданная как dlnetwork объект, затем создайте функцию градиентов модели, которая принимает dlnetwork объект как вход.

Для модели, заданной как dlnetwork создайте функцию формы gradients = modelGradients(dlnet,dlX,T), где dlnet является сетью, dlX - входной вход сети, T содержит цели и gradients содержит возвращенные градиенты. Опционально можно передать дополнительные аргументы в функцию gradients (для примера, если функция loss требует дополнительной информации), или вернуть дополнительные аргументы (для примера, метрики для графического изображения процесса обучения).

Для примера эта функция возвращает градиенты и потери перекрестной энтропии для заданного dlnetwork dlnet объекта, входные данные dlX, и цели T.

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

Создайте функцию градиентов модели для модели, заданной как функция

Если у вас есть модель глубокого обучения, заданная как функция, то создайте функцию градиентов модели, которая принимает обучаемые модели параметры как вход.

Для модели, заданной как функция, создайте функцию вида gradients = modelGradients(parameters,dlX,T), где parameters содержит настраиваемые параметры, dlX является входом модели, T содержит цели и gradients содержит возвращенные градиенты. Опционально можно передать дополнительные аргументы в функцию gradients (для примера, если функция loss требует дополнительной информации), или вернуть дополнительные аргументы (для примера, метрики для графического изображения процесса обучения).

Для примера эта функция возвращает градиенты и потери перекрестной энтропии для функции модели глубокого обучения model с заданными настраиваемыми параметрами parameters, входные данные dlX, и цели T.

function [gradients, loss] = modelGradients(parameters, dlX, T)

    % Forward data through the model function.
    dlY = model(parameters,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,parameters);

end

Вычислите функцию градиентов модели

Чтобы вычислить градиенты модели с помощью автоматической дифференциации, используйте dlfeval функция, которая оценивает функцию с включенной автоматической дифференциацией. Для первого входа dlfeval, передайте функцию градиентов модели, заданную как указатель на функцию. Для следующих входов передайте необходимые переменные для функции градиентов модели. Для выходов dlfeval function, задайте те же выходы, что и функция градиентов модели.

Для примера вычислите функцию градиентов модели modelGradients с dlnetwork dlnet объекта, входные данные dlX, и цели T, и возвращает градиенты модели и потери.

[gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

Точно так же вычислите функцию градиентов модели modelGradients использование функции модели с настраиваемыми параметрами, заданными структурой parameters, входные данные dlX, и цели T, и возвращает градиенты модели и потери.

[gradients, loss] = dlfeval(@modelGradients,parameters,dlX,T);

Обновление настраиваемых параметров с использованием градиентов

Чтобы обновить настраиваемые параметры с помощью градиентов, можно использовать следующие функции.

ФункцияОписание
adamupdateОбновите параметры с помощью адаптивной оценки момента (Адам)
rmspropupdateОбновляйте параметры с помощью корневого среднего квадратного распространения (RMSProp)
sgdmupdateОбновляйте параметры с помощью стохастического градиентного спуска с импульсом (SGDM)
dlupdateОбновляйте параметры с помощью пользовательской функции

Например, обновляйте настраиваемые параметры dlnetwork dlnet объекта использование adamupdate функция.

[dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
gradients - выход функции градиентов модели, и trailingAvg, trailingAvgSq, и iteration являются гиперпараметрами, требуемыми adamupdate функция.

Точно так же обновляйте настраиваемые параметры для функции модели parameters использование adamupdate функция.

[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
gradients - выход функции градиентов модели, и trailingAvg, trailingAvgSq, и iteration являются гиперпараметрами, требуемыми adamupdate функция.

Используйте функцию градиентов модели в пользовательском цикле обучения

При обучении модели глубокого обучения с помощью пользовательского цикла обучения оценивайте градиенты модели и обновляйте настраиваемые параметры для каждого мини-пакета.

Этот фрагмент кода показывает пример использования dlfeval и adamupdate функций в пользовательском цикле обучения.

iteration = 0;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;

        % Prepare mini-batch.
        % ...

        % Evaluate model gradients.
        [gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

        % Update learnable parameters.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAverageSq,iteration);

    end
end

Для примера, показывающего, как обучить модель глубокого обучения с dlnetwork объект, см. Train сети с использованием пользовательского цикла обучения. Для примера, показывающего, как обучить модель глубокого обучения, заданную как функция, смотрите Обучите сеть Используя Функцию Модели.

Отладка функций градиентов модели

Если реализация функции градиентов модели имеет проблему, то вызов к dlfeval может выдать ошибку. Иногда, когда вы используете dlfeval функция, не ясно, какая строка кода выдает ошибку. Чтобы помочь найти ошибку, можно попробовать следующее.

Вызов функции градиентов модели непосредственно

Попробуйте вызвать функцию градиентов модели непосредственно (то есть не используя dlfeval функция) с сгенерированными входами ожидаемых размеров. Если какая-либо из строк кода выдает ошибку, то сообщение об ошибке предоставляет дополнительную подробную информацию. Обратите внимание, что, когда вы не используете dlfeval function, любые вызовы dlgradient функция выдает ошибку.

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

[gradients, loss] = modelGradients(dlnet,dlX,T);

Запустите код градиентов модели вручную

Запустите код внутри функции градиентов модели вручную с сгенерированными входами ожидаемых размеров и проверьте выход и все выданные сообщения об ошибке.

Для примера рассмотрим следующую функцию градиентов модели.

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

Проверьте функцию градиентов модели, запустив следующий код.

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

% Check forward pass.
dlY = forward(dlnet,dlX);

% Check loss calculation.
loss = crossentropy(dlX,T)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте