exponenta event banner

Определение функции градиентов модели для пользовательского цикла обучения

При обучении модели глубокого обучения с помощью пользовательского цикла обучения программное обеспечение минимизирует потери по отношению к обучаемым параметрам. Для минимизации потерь программное обеспечение использует градиенты потерь относительно обучаемых параметров. Чтобы рассчитать эти градиенты с помощью автоматического дифференцирования, необходимо определить функцию градиентов модели.

Пример обучения модели глубокого обучения с помощью dlnetwork см. раздел Обучение сети с использованием пользовательского цикла обучения. Пример обучения модели глубокого обучения, определенной как функция, см. в разделе Обучение сети с использованием функции модели.

Создание функции градиентов модели для модели, определенной как dlnetwork Объект

Если у вас есть модель глубокого обучения, определенная как dlnetwork объект, затем создать функцию градиентов модели, которая принимает dlnetwork объект в качестве входных данных.

Для модели, указанной как dlnetwork объект, создайте функцию формы gradients = modelGradients(dlnet,dlX,T), где dlnet - сеть, dlX - сетевой вход, T содержит целевые объекты, и gradients содержит возвращенные градиенты. Дополнительно можно передать дополнительные аргументы функции градиентов (например, если функция потерь требует дополнительной информации) или вернуть дополнительные аргументы (например, метрики для построения графика хода обучения).

Например, эта функция возвращает градиенты и потери перекрестной энтропии для указанного dlnetwork объект dlnet, входные данные dlX, и цели T.

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

Создание функции градиентов модели для модели, определенной как функция

Если модель глубокого обучения определена как функция, создайте функцию градиентов модели, которая принимает обучаемые параметры модели в качестве входных данных.

Для модели, указанной как функция, создайте функцию формы gradients = modelGradients(parameters,dlX,T), где parameters содержит обучаемые параметры, dlX - входной сигнал модели, T содержит целевые объекты, и gradients содержит возвращенные градиенты. Дополнительно можно передать дополнительные аргументы функции градиентов (например, если функция потерь требует дополнительной информации) или вернуть дополнительные аргументы (например, метрики для построения графика хода обучения).

Например, эта функция возвращает градиенты и потери перекрестной энтропии для функции модели глубокого обучения model с указанными обучаемыми параметрами parameters, входные данные dlX, и цели T.

function [gradients, loss] = modelGradients(parameters, dlX, T)

    % Forward data through the model function.
    dlY = model(parameters,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,parameters);

end

Функция вычисления градиентов модели

Чтобы оценить градиенты модели с помощью автоматического дифференцирования, используйте dlfeval функция, которая оценивает функцию с включенным автоматическим дифференцированием. Для первого ввода dlfevalпередайте функцию градиентов модели, заданную как дескриптор функции. Для следующих входных данных передайте требуемые переменные для функции градиентов модели. Для выходных данных dlfeval укажите те же выходные данные, что и функция градиентов модели.

Например, вычислить функцию градиентов модели modelGradients с dlnetwork объект dlnet, входные данные dlX, и цели Tи вернуть градиенты модели и потери.

[gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

Аналогично, оцените функцию градиентов модели modelGradients использование функции модели с обучаемыми параметрами, заданными структурой parameters, входные данные dlX, и цели Tи вернуть градиенты модели и потери.

[gradients, loss] = dlfeval(@modelGradients,parameters,dlX,T);

Обновление обучаемых параметров с помощью градиентов

Для обновления обучаемых параметров с помощью градиентов можно использовать следующие функции.

ФункцияОписание
adamupdateОбновление параметров с использованием адаптивной оценки момента (Адам)
rmspropupdateОбновление параметров с использованием среднеквадратичного распространения корня (RMSProp)
sgdmupdateОбновление параметров с помощью стохастического градиентного спуска с импульсом (SGDM)
dlupdateОбновление параметров с помощью пользовательской функции

Например, обновите обучаемые параметры dlnetwork объект dlnet с использованием adamupdate функция.

[dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
gradients является выводом функции градиентов модели, и trailingAvg, trailingAvgSq, и iteration - гиперпараметры, требуемые adamupdate функция.

Аналогично, обновите обучаемые параметры для функции модели parameters с использованием adamupdate функция.

[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
gradients является выводом функции градиентов модели, и trailingAvg, trailingAvgSq, и iteration - гиперпараметры, требуемые adamupdate функция.

Использование функции градиентов модели в пользовательском учебном цикле

При обучении модели глубокого обучения с использованием пользовательского цикла обучения оцените градиенты модели и обновите обучаемые параметры для каждого мини-пакета.

Этот фрагмент кода показывает пример использования dlfeval и adamupdate функции в индивидуальном цикле обучения.

iteration = 0;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;

        % Prepare mini-batch.
        % ...

        % Evaluate model gradients.
        [gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

        % Update learnable parameters.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAverageSq,iteration);

    end
end

Пример обучения модели глубокого обучения с помощью dlnetwork см. раздел Обучение сети с использованием пользовательского цикла обучения. Пример обучения модели глубокого обучения, определенной как функция, см. в разделе Обучение сети с использованием функции модели.

Функции градиентов отладочной модели

Если реализация функции градиентов модели имеет проблему, то вызов dlfeval может вызвать ошибку. Иногда, когда вы используете dlfeval функция, непонятно, какая строка кода выбрасывает ошибку. Чтобы найти ошибку, попробуйте выполнить следующие действия.

Функция прямого вызова градиентов модели

Попробуйте вызвать функцию градиентов модели напрямую (то есть без использования dlfeval функция) с сгенерированными входами ожидаемых размеров. Если какая-либо из строк кода вызывает ошибку, сообщение об ошибке предоставляет дополнительную информацию. Обратите внимание, что если вы не используете dlfeval функция, любые вызовы dlgradient выдать ошибку.

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

[gradients, loss] = modelGradients(dlnet,dlX,T);

Выполнить код градиентов модели вручную

Запустите код в функции градиентов модели вручную с сгенерированными входами ожидаемых размеров и проверьте выходные данные и любые сообщения об ошибках.

Например, рассмотрим следующую функцию градиентов модели.

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlY,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

Проверьте функцию градиентов модели, выполнив следующий код.

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

% Check forward pass.
dlY = forward(dlnet,dlX);

% Check loss calculation.
loss = crossentropy(dlX,T)

Связанные темы