exponenta event banner

Укажите параметры обучения в индивидуальном цикле обучения

Для большинства задач можно управлять подробностями алгоритма обучения с помощью trainingOptions и trainNetwork функции. Если trainingOptions функция не предоставляет опции, необходимые для выполнения задачи (например, индивидуальный график обучения), после чего можно определить собственный индивидуальный цикл обучения с помощью dlnetwork объект. A dlnetwork объект позволяет обучить сеть, заданную как график слоев, с помощью автоматического дифференцирования.

Чтобы задать те же параметры, что и trainingOptionsиспользуйте следующие примеры в качестве руководства:

Вариант обученияtrainingOptions АргументПример
Решатель Адама Оценка адаптивного момента (ADAM)
Решатель RMSProp Среднеквадратичное распространение (RMSProp)
Решатель SGDM Стохастический градиентный спуск с импульсом (SGDM)
Скорость обучения'InitialLearnRate'Скорость обучения
Узнать расписание тарифов График скорости обучения по кусочкам
Ход обучения'Plots'Сюжеты
Подробный вывод Подробный вывод
Размер мини-партии'MiniBatchSize'Размер мини-пакета
Число эпох'MaxEpochs'Число эпох
Проверка Проверка
L2 регуляризация'L2Regularization'L2 Регуляризация
Отсечение градиента Отсечение градиента
Обучение одному процессору или графическому процессору'ExecutionEnvironment'Обучение одному процессору или графическому процессору
Контрольно-пропускные пункты'CheckpointPath'Контрольно-пропускные пункты

Параметры решателя

Чтобы указать решатель, используйте adamupdate, rmspropupdate, и sgdmupdate функции для шага обновления в учебном цикле. Для реализации собственного пользовательского решателя обновите обучаемые параметры с помощью dlupdate функция.

Оценка адаптивного момента (ADAM)

Чтобы обновить параметры сети с помощью Adam, используйте adamupdate функция. Задайте коэффициент градиентного затухания и коэффициент градиентного затухания в квадрате, используя соответствующие входные аргументы.

Среднеквадратичное распространение (RMSProp)

Для обновления параметров сети с помощью RMSProp используйте rmspropupdate функция. Задайте значение смещения знаменателя (epsilon) с помощью соответствующего входного аргумента.

Стохастический градиентный спуск с импульсом (SGDM)

Для обновления параметров сети с помощью SGDM используйте sgdmupdate функция. Задайте импульс с помощью соответствующего входного аргумента.

Скорость обучения

Чтобы указать скорость обучения, используйте входные аргументы скорости обучения adamupdate, rmspropupdate, и sgdmupdate функции.

Чтобы легко настроить скорость обучения или использовать ее для пользовательских графиков скорости обучения, задайте начальную скорость обучения перед пользовательским циклом обучения.

learnRate = 0.01;

График скорости обучения по кусочкам

Чтобы автоматически сбросить скорость обучения во время обучения, используя кусочный график скорости обучения, умножьте скорость обучения на заданный коэффициент после указанного интервала.

Чтобы легко задать кусочно-обучающий график, создайте переменные learnRate, learnRateSchedule, learnRateDropFactor, и learnRateDropPeriod, где learnRate - начальная скорость обучения, learnRateScedule содержит либо "piecewise" или "none", learnRateDropFactor - скаляр в диапазоне [0, 1], который определяет коэффициент для снижения скорости обучения, и learnRateDropPeriod является положительным целым числом, которое указывает, сколько эпох между отбрасыванием скорости обучения.

learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;

Внутри учебного цикла, в конце каждой эпохи, сбрасывайте скорость обучения, когда learnRateSchedule опция - "piecewise" и текущее число эпох кратно learnRateDropPeriod. Установите новую скорость обучения на произведение скорости обучения и коэффициента падения скорости обучения.

if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0
    learnRate = learnRate * learnRateDropFactor;
end

Сюжеты

Для построения графика потерь и точности во время обучения вычислите потери мини-партии и либо точность, либо среднеквадратичную ошибку (RMSE) в функции градиентов модели и постройте график с использованием анимированной линии.

Чтобы легко указать, что график должен быть включен или выключен, создайте переменную plots который содержит либо "training-progress" или "none". Для печати метрик проверки используйте те же параметры validationData и validationFrequency описано в разделе Проверка.

plots = "training-progress";

validationData = {XValidation, YValidation};
validationFrequency = 50;

Перед началом обучения инициализируйте анимированные строки с помощью animatedline функция. Для задач классификации создайте график точности обучения и потерь обучения. Также инициализируйте анимированные строки для метрик проверки, когда указаны данные проверки.

if plots == "training-progress"
    figure
    subplot(2,1,1)
    lineAccuracyTrain = animatedline;
    ylabel("Accuracy")
	
    subplot(2,1,2)
    lineLossTrain = animatedline;
    xlabel("Iteration")
    ylabel("Loss")

    if ~isempty(validationData)
        subplot(2,1,1)
        lineAccuracyValidation = animatedline;

        subplot(2,1,2)
        lineLossValidation = animatedline;
    end
end

Для задач регрессии скорректируйте код, изменив имена переменных и метки так, чтобы он инициализировал графики для RMSE обучения и проверки вместо точности обучения и проверки.

Внутри учебного цикла в конце итерации обновите график так, чтобы он включал соответствующие метрики для сети. Для задач классификации добавьте точки, соответствующие точности мини-партии и потере мини-партии. Если данные проверки не являются пустыми, а текущая итерация является либо 1, либо кратной опции частоты проверки, то также добавьте точки для данных проверки.

if plots == "training-progress"
    addpoints(lineAccuracyTrain,iteration,accuracyTrain)
    addpoints(lineLossTrain,iteration,lossTrain)

    if ~isempty(validationData) && (iteration == 1 || mod(iteration,validationFrequency) == 0)
        addpoints(lineAccuracyValidation,iteration,accuracyValidation)
        addpoints(lineLossValidation,iteration,lossValidation)
    end
end
где accuracyTrain и lossTrain соответствуют точности и потерям мини-партии, рассчитанным в функции градиентов модели. Для задач регрессии используйте потери RMSE мини-партии вместо точности мини-партии.

Совет

addpoints функция требует, чтобы точки данных имели тип double. Извлечение числовых данных из dlarray объекты, используйте extractdata функция. Для сбора данных из графического процессора используйте gather функция.

Сведения о вычислении метрик проверки см. в разделе Проверка.

Подробный вывод

Чтобы отобразить потери и точность обучения во время обучения в подробной таблице, вычислите потери мини-партии и точность (для задач классификации) или RMSE (для задач регрессии) в функции градиентов модели и просмотрите их с помощью disp функция.

Чтобы легко указать, что подробная таблица должна быть включена или выключена, создайте переменные verbose и verboseFrequency, где verbose является true или false и verbosefrequency указывает количество итераций между печатью подробных выходных данных. Для отображения метрик проверки используйте те же параметры validationData и validationFrequency описано в разделе Проверка.

verbose = true
verboseFrequency = 50;

validationData = {XValidation, YValidation};
validationFrequency = 50;

Перед обучением просмотрите подробные заголовки выходных таблиц и инициализируйте таймер с помощью tic функция.

disp("|======================================================================================================================|")
disp("|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |")
disp("|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |")
disp("|======================================================================================================================|")

start = tic;

Для задач регрессии настройте код так, чтобы он отображал RMSE обучения и проверки вместо точности обучения и проверки.

Внутри учебного цикла в конце итерации распечатайте подробные выходные данные, когда verbose опция - true и это либо первая итерация, либо номер итерации кратен verboseFrequency.

if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0
    D = duration(0,0,toc(start),'Format','hh:mm:ss');

    if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 
        accuracyValidation = "";
        lossValidation = "";
    end

    disp("| " + ...
        pad(epoch,7,'left') + " | " + ...
        pad(iteration,11,'left') + " | " + ...
        pad(D,14,'left') + " | " + ...
        pad(accuracyTrain,12,'left') + " | " + ...
        pad(accuracyValidation,12,'left') + " | " + ...
        pad(lossTrain,12,'left') + " | " + ...
        pad(lossValidation,12,'left') + " | " + ...
        pad(learnRate,15,'left') + " |")
end

Для задач регрессии настройте код так, чтобы он отображал RMSE обучения и проверки вместо точности обучения и проверки.

По окончании обучения напечатайте последнюю границу подробной таблицы.

disp("|======================================================================================================================|")

Сведения о вычислении метрик проверки см. в разделе Проверка.

Размер мини-пакета

Установка размера мини-пакета зависит от формата данных или типа используемого хранилища данных.

Для упрощения определения размера мини-партии создайте переменную. miniBatchSize.

miniBatchSize = 128;

Для данных в хранилище данных образа перед обучением установите ReadSize свойство хранилища данных в соответствии с размером мини-пакета.

imds.ReadSize = miniBatchSize;

Для данных в хранилище данных дополненного изображения перед обучением установите MiniBatchSize свойство хранилища данных в соответствии с размером мини-пакета.

augimds.MiniBatchSize = miniBatchSize;

Для данных в памяти во время обучения в начале каждой итерации считывайте наблюдения непосредственно из массива.

idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize);
X = XTrain(:,:,:,idx);

Число эпох

Укажите максимальное количество эпох для обучения во внешнем for цикл учебного цикла.

Чтобы легко задать максимальное количество эпох, создайте переменную maxEpochs который содержит максимальное количество эпох.

maxEpochs = 30;

Во внешней for цикл учебного цикла, указать цикл в диапазоне 1, 2 ,...,maxEpochs.

for epoch = 1:maxEpochs
    ...
end

Проверка

Чтобы проверить сеть во время обучения, отложите набор проверки и оцените, насколько хорошо сеть работает с этими данными.

Чтобы легко задать параметры проверки, создайте переменные validationData и validationFrequency, где validationData содержит данные проверки или пуст и validationFrequency указывает количество итераций между проверкой сети.

validationData = {XValidation,YValidation};
validationFrequency = 50;

Во время цикла обучения, после обновления параметров сети, проверьте, насколько хорошо сеть работает с набором проверки с удержанием с помощью predict функция. Проверка сети выполняется только в том случае, если данные проверки указаны и являются либо первой итерацией, либо текущая итерация кратна validationFrequency вариант.

if iteration == 1 || mod(iteration,validationFrequency) == 0
    dlYPredValidation = predict(dlnet,dlXValidation);
    lossValidation = crossentropy(softmax(dlYPredValidation), YValidation);

    [~,idx] = max(dlYPredValidation);
    labelsPredValidation = classNames(idx);

    accuracyValidation = mean(labelsPredValidation == labelsValidation);
end
Здесь, YValidation является фиктивной переменной, соответствующей меткам в classNames. Чтобы рассчитать точность, преобразуйте YValidation в массив меток.

Для задач регрессии настройте код так, чтобы он рассчитывал RMSE проверки вместо точности проверки.

Ранняя остановка

Чтобы остановить тренировку раньше, когда потеря при проверке с удержанием перестает уменьшаться, используйте флаг, чтобы вырваться из учебных циклов.

Чтобы легко определить терпение проверки (количество раз, когда потеря проверки может быть больше или равна ранее наименьшей потере перед остановкой обучения сети), создайте переменную validationPatience.

validationPatience = 5;

Перед обучением инициализируйте переменные earlyStop и validationLosses, где earlyStop является флагом, чтобы остановить тренировку рано и validationLosses содержит сравниваемые потери. Инициализировать флаг ранней остановки с помощью false и массив потерь проверки с inf.

earlyStop = false;
if isfinite(validationPatience)
    validationLosses = inf(1,validationPatience);
end

Внутри учебного цикла, в цикле над мини-партиями, добавить earlyStop установите флажок для условия цикла.

while hasdata(ds) && ~earlyStop
    ...
end

На этапе проверки добавьте новую потерю проверки в массив validationLosses. Если первый элемент массива является наименьшим, установите значение earlyStop флаг для true. В противном случае удалите первый элемент.

if isfinite(validationPatience)
    validationLosses = [validationLosses validationLoss];
    if min(validationLosses) == validationLosses(1)
        earlyStop = true;
    else
        validationLosses(1) = [];
    end
end

L2 Регуляризация

Чтобы применить регуляризацию L2 к весам, используйте dlupdate функция.

Чтобы легко определить коэффициент регуляризации L2, создайте переменную. l2Regularization который содержит коэффициент регуляризации L2.

l2Regularization = 0.0001;

Во время обучения после вычисления градиентов модели для каждого из весовых параметров добавьте произведение коэффициента регуляризации L2 и весов к вычисленным градиентам с помощью dlupdate функция. Чтобы обновить только весовые параметры, извлеките параметры с именем "Weights".

idx = dlnet.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

После добавления параметра регуляризации L2 к градиентам обновите параметры сети.

Отсечение градиента

Чтобы обрезать градиенты, используйте dlupdate функция.

Чтобы легко задать параметры подрезки градиента, создайте переменные gradientThresholdMethod и gradientThreshold, где gradientThresholdMethod содержит "global-l2norm", "l2norm", или "absolute-value", и gradientThreshold является положительным скаляром, содержащим порог или inf.

gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;

Создание функций с именем thresholdGlobalL2Norm, thresholdL2Norm, и thresholdAbsoluteValue которые применяют "global-l2norm", "l2norm", и "absolute-value" пороговые методы соответственно.

Для "global-l2norm" функция работает со всеми градиентами модели.

function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold)

globalL2Norm = 0;
for i = 1:numel(gradients)
    globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2);
end
globalL2Norm = sqrt(globalL2Norm);

if globalL2Norm > gradientThreshold
    normScale = gradientThreshold / globalL2Norm;
    for i = 1:numel(gradients)
        gradients{i} = gradients{i} * normScale;
    end
end

end

Для "l2norm" и "absolute-value" опции, функции работают для каждого градиента независимо.

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold)

gradients(gradients > gradientThreshold) = gradientThreshold;
gradients(gradients < -gradientThreshold) = -gradientThreshold;

end

Во время обучения после вычисления градиентов модели примените соответствующий метод отсечения градиента к градиентам с помощью dlupdate функция. Потому что "global-l2norm" для опции требуются все градиенты модели, примените thresholdGlobalL2Norm непосредственно к градиентам. Для "l2norm" и "absolute-value" , обновить градиенты независимо с помощью dlupdate функция.

switch gradientThresholdMethod
    case "global-l2norm"
        gradients = thresholdGlobalL2Norm(gradients, gradientThreshold);
    case "l2norm"
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
    case "absolute-value"
        gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients);
end

После применения порогового значения градиента обновите параметры сети.

Обучение одному процессору или графическому процессору

По умолчанию программа выполняет вычисления только с помощью ЦП. Для обучения на одном GPU преобразуйте данные в gpuArray объекты. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

Чтобы легко определить среду выполнения, создайте переменную executionEnvironment который содержит либо "cpu", "gpu", или "auto".

executionEnvironment = "auto"

Во время обучения после чтения мини-пакета проверьте параметр среды выполнения и преобразуйте данные в gpuArray при необходимости. canUseGPU проверка функций на наличие используемых графических процессоров.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Контрольно-пропускные пункты

Для сохранения сетей контрольных точек во время обучения сохраните сеть с помощью save функция.

Чтобы легко указать, следует ли включать контрольные точки, создайте переменную checkpointPath содержит папку для сетей контрольных точек или является пустой.

checkpointPath = fullfile(tempdir,"checkpoints");

Если папка контрольных точек не существует, то перед началом обучения создайте папку контрольных точек.

if ~exist(checkpointPath,"dir")
    mkdir(checkpointPath)
end

Во время обучения, в конце эпохи, сохраните сеть в MAT-файле. Укажите имя файла, содержащего текущий номер итерации, дату и время.

if ~isempty(checkpointPath)
    D = datestr(now,'yyyy_mm_dd__HH_MM_SS');
    filename = "dlnet_checkpoint__" + iteration + "__" + D + ".mat";
    save(filename,"dlnet")
end
где dlnet является dlnetwork объект для сохранения.

См. также

| | | | | | |

Связанные темы