Задайте опции обучения в пользовательском цикле обучения

Для большинства задач можно управлять деталями алгоритма настройки, используя trainingOptions и trainNetwork функций. Если trainingOptions функция не предоставляет опции, необходимые для вашей задачи (для примера, пользовательское расписание скорости обучения), тогда можно задать свой собственный пользовательский цикл обучения с помощью dlnetwork объект. A dlnetwork объект позволяет вам обучать сеть, заданную как график слоев, используя автоматическую дифференциацию.

Чтобы задать те же опции, что и trainingOptions, используйте эти примеры в качестве руководства:

Опция обученияtrainingOptions АргументПример
Решатель Адама Адаптивная оценка момента (ADAM)
Решатель RMSProp Среднеквадратичное распространение корня (RMSProp)
Решатель SGDM Стохастический градиентный спуск с моментом (SGDM)
Темп обучения'InitialLearnRate'Темп обучения
Изучение расписания ставок Кусочное расписание скорости обучения
Процесс обучения'Plots'Графики
Подробный выход Подробный выход
Размер мини-пакета'MiniBatchSize'Размер мини-пакета
Количество эпох'MaxEpochs'Количество эпох
Валидация Валидация
L2 регуляризация'L2Regularization'L2 регуляризация
Градиентные усечения Градиентные Усечения
Один центральный процессор или GPU'ExecutionEnvironment'Обучение на одном центральном процессоре или графическом процессоре
Контрольные точки'CheckpointPath'Контрольные точки

Опции решателя

Чтобы задать решатель, используйте adamupdate, rmspropupdate, и sgdmupdate функции для шага обновления в цикле обучения. Чтобы реализовать свой собственный пользовательский решатель, обновите настраиваемые параметры с помощью dlupdate функция.

Адаптивная оценка момента (ADAM)

Чтобы обновить параметры сети с помощью Adam, используйте adamupdate функция. Задайте градиентный распад и квадратные коэффициенты градиента распада с помощью соответствующих входных параметров.

Среднеквадратичное распространение корня (RMSProp)

Чтобы обновить параметры сети с помощью RMSProp, используйте rmspropupdate функция. Задайте значение смещения знаменателя (эпсилон) с помощью соответствующего входного параметра.

Стохастический градиентный спуск с моментом (SGDM)

Чтобы обновить параметры сети с помощью SGDM, используйте sgdmupdate функция. Задайте импульс с помощью соответствующего входного параметра.

Темп обучения

Чтобы задать скорость обучения, используйте входные параметры скорости обучения adamupdate, rmspropupdate, и sgdmupdate функций.

Чтобы легко настроить скорость обучения или использовать ее для пользовательских расписаний скорости обучения, установите начальную скорость обучения перед пользовательским циклом обучения.

learnRate = 0.01;

Кусочное расписание скорости обучения

Чтобы автоматически сбросить скорость обучения во время обучения с помощью кусочно-графика скорости обучения, умножьте скорость обучения на заданный коэффициент падения после заданного интервала.

Чтобы легко задать кусочное расписание скорости обучения, создайте переменные learnRate, learnRateSchedule, learnRateDropFactor, и learnRateDropPeriod, где learnRate - начальная скорость обучения, learnRateScedule содержит либо "piecewise" или "none", learnRateDropFactor является скаляром в области значений [0, 1], который задает коэффициент для падения скорости обучения и learnRateDropPeriod является положительным целым числом, которое задает, сколько эпох между падением скорости обучения.

learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;

Внутри цикла обучения в конце каждой эпохи сбрасывайте скорость обучения, когда learnRateSchedule опция "piecewise" и текущее число эпохи кратно learnRateDropPeriod. Установите новую скорость обучения в продукт скорости обучения и коэффициента падения скорости обучения.

if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0
    learnRate = learnRate * learnRateDropFactor;
end

Графики

Чтобы построить график потерь и точности обучения во время обучения, вычислите мини-потери пакета и точность или квадратичную невязку (RMSE) в функции градиентов модели и постройте график с помощью анимированной линии.

Чтобы легко указать, что график должен быть включен или выключен, создайте переменную plots который содержит либо "training-progress" или "none". Чтобы также построить графики метрик валидации, используйте те же опции validationData и validationFrequency описан в разделе Валидация.

plots = "training-progress";

validationData = {XValidation, YValidation};
validationFrequency = 50;

Перед обучением инициализируйте анимированные линии с помощью animatedline функция. Для задач классификации создайте график для точности обучения и потерь обучения. Также инициализируйте анимированные линии для метрик валидации, когда заданы данные валидации.

if plots == "training-progress"
    figure
    subplot(2,1,1)
    lineAccuracyTrain = animatedline;
    ylabel("Accuracy")
	
    subplot(2,1,2)
    lineLossTrain = animatedline;
    xlabel("Iteration")
    ylabel("Loss")

    if ~isempty(validationData)
        subplot(2,1,1)
        lineAccuracyValidation = animatedline;

        subplot(2,1,2)
        lineLossValidation = animatedline;
    end
end

Для задач регрессии скорректируйте код путем изменения имен переменных и меток так, чтобы он инициализировал графики для RMSE обучения и валидации вместо точности обучения и валидации.

Внутри цикла обучения в конце итерации обновляйте график так, чтобы он включал соответствующие метрики для сети. Для задач классификации добавьте точки, соответствующие точности мини-пакета и потере мини-пакета. Если данные валидации непусты, а текущая итерация является либо 1, либо кратной опции частоты валидации, то также добавьте точки для данных валидации.

if plots == "training-progress"
    addpoints(lineAccuracyTrain,iteration,accuracyTrain)
    addpoints(lineLossTrain,iteration,lossTrain)

    if ~isempty(validationData) && (iteration == 1 || mod(iteration,validationFrequency) == 0)
        addpoints(lineAccuracyValidation,iteration,accuracyValidation)
        addpoints(lineLossValidation,iteration,lossValidation)
    end
end
где accuracyTrain и lossTrain соответствуют мини-пакетной точности и потерям, вычисленным в функции градиентов модели. Для задач регрессии используйте мини-пакетные потери RMSE вместо мини-пакетных точностей.

Совет

The addpoints функция требует, чтобы точки данных имели тип double. Чтобы извлечь числовые данные из dlarray объекты, использовать extractdata функция. Чтобы собрать данные из графический процессор, используйте gather функция.

Чтобы узнать, как вычислить метрики валидации, смотрите Валидацию.

Подробный выход

Чтобы отобразить потери и точность обучения во время обучения в подробной таблице, вычислите мини-потери пакета и точность (для задач классификации) или RMSE (для задач регрессии) в функции градиентов модели и отобразите их с помощью disp функция.

Чтобы легко указать, что подробная таблица должна быть включена или отключена, создайте переменные verbose и verboseFrequency, где verbose является true или false и verbosefrequency задает, сколько итераций между печатью подробного выхода. Чтобы отобразить метрики валидации, используйте те же опции validationData и validationFrequency описан в разделе Валидация.

verbose = true
verboseFrequency = 50;

validationData = {XValidation, YValidation};
validationFrequency = 50;

Перед обучением отобразите подробные заголовки выходной таблицы и инициализируйте таймер с помощью tic функция.

disp("|======================================================================================================================|")
disp("|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |")
disp("|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |")
disp("|======================================================================================================================|")

start = tic;

Для задач регрессии настройте код так, чтобы он отображал обучение и валидацию RMSE вместо точности обучения и валидации.

Внутри цикла обучения, в конце итерации, печатайте подробный выход, когда verbose опция true и это либо первая итерация, либо число итерации кратно verboseFrequency.

if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0
    D = duration(0,0,toc(start),'Format','hh:mm:ss');

    if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 
        accuracyValidation = "";
        lossValidation = "";
    end

    disp("| " + ...
        pad(epoch,7,'left') + " | " + ...
        pad(iteration,11,'left') + " | " + ...
        pad(D,14,'left') + " | " + ...
        pad(accuracyTrain,12,'left') + " | " + ...
        pad(accuracyValidation,12,'left') + " | " + ...
        pad(lossTrain,12,'left') + " | " + ...
        pad(lossValidation,12,'left') + " | " + ...
        pad(learnRate,15,'left') + " |")
end

Для задач регрессии настройте код так, чтобы он отображал обучение и валидацию RMSE вместо точности обучения и валидации.

Когда обучение закончено, распечатайте последнюю границу подробной таблицы.

disp("|======================================================================================================================|")

Чтобы узнать, как вычислить метрики валидации, смотрите Валидацию.

Размер мини-пакета

Установка размера мини-пакета зависит от формата данных или типа используемого datastore.

Чтобы легко задать размер мини-пакета, создайте переменную miniBatchSize.

miniBatchSize = 128;

Для данных в datastore перед обучением установите ReadSize свойство datastore для размера мини-пакета.

imds.ReadSize = miniBatchSize;

Для данных в дополненном datastore перед обучением установите MiniBatchSize свойство datastore для размера мини-пакета.

augimds.MiniBatchSize = miniBatchSize;

Для данных в памяти во время обучения в начале каждой итерации считывайте наблюдения непосредственно из массива.

idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize);
X = XTrain(:,:,:,idx);

Количество эпох

Задайте максимальное количество эпох для обучения во внешней for цикл цикла обучения.

Чтобы легко задать максимальное количество эпох, создайте переменную maxEpochs который содержит максимальное количество эпох.

maxEpochs = 30;

Во внешней for цикл цикла обучения, задайте цикл в области значений 1, 2,..., maxEpochs.

for epoch = 1:maxEpochs
    ...
end

Валидация

Чтобы подтвердить свою сеть во время обучения, отложите набор валидации и оцените, насколько хорошо сеть работает с этими данными.

Чтобы легко задать опции валидации, создайте переменные validationData и validationFrequency, где validationData содержит данные валидации или пуст и validationFrequency задает, сколько итераций между проверкой сети.

validationData = {XValidation,YValidation};
validationFrequency = 50;

Во время цикла обучения, после обновления параметров сети, проверяйте, насколько хорошо сеть работает на задержанном наборе валидации, используя predict функция. Проверьте сеть, только когда данные валидации заданы, и это либо первая итерация, либо текущая итерация является произведением validationFrequency опция.

if iteration == 1 || mod(iteration,validationFrequency) == 0
    dlYPredValidation = predict(dlnet,dlXValidation);
    lossValidation = crossentropy(softmax(dlYPredValidation), YValidation);

    [~,idx] = max(dlYPredValidation);
    labelsPredValidation = classNames(idx);

    accuracyValidation = mean(labelsPredValidation == labelsValidation);
end
Здесь, YValidation является фиктивной переменной, соответствующей меткам в classNames. Чтобы вычислить точность, преобразуйте YValidation в массив меток.

Для задач регрессии настройте код так, чтобы он вычислял RMSE валидации вместо точности валидации.

Ранняя остановка

Чтобы остановить обучение раньше, когда потеря при задержке валидации перестает уменьшаться, используйте флаг, чтобы вырваться из циклов обучения.

Чтобы легко указать терпение валидации (количество раз, когда потеря валидации может быть больше или равной ранее наименьшей потере до остановки сетевого обучения), создайте переменную validationPatience.

validationPatience = 5;

Перед обучением инициализируйте переменные earlyStop и validationLosses, где earlyStop является флагом, чтобы остановить обучение раньше и validationLosses содержит потери, которые нужно сравнить. Инициализируйте флаг ранней остановки с false и массив потерь для валидации с inf.

earlyStop = false;
if isfinite(validationPatience)
    validationLosses = inf(1,validationPatience);
end

Внутри цикла обучения, в цикле по мини-пакетам, добавьте earlyStop флаг для условия цикла.

while hasdata(ds) && ~earlyStop
    ...
end

Во время шага валидации добавьте новые потери валидации к массиву validationLosses. Если первый элемент массива является наименьшим, то установите earlyStop флаг в true. В противном случае удалите первый элемент.

if isfinite(validationPatience)
    validationLosses = [validationLosses validationLoss];
    if min(validationLosses) == validationLosses(1)
        earlyStop = true;
    else
        validationLosses(1) = [];
    end
end

L2 регуляризация

Чтобы применить L2 регуляризацию к весам, используйте dlupdate функция.

Чтобы легко задать коэффициент регуляризации L2, создайте переменную l2Regularization который содержит L2 коэффициент регуляризации.

l2Regularization = 0.0001;

Во время обучения, после вычисления градиентов модели, для каждого из весовых параметров добавьте продукт L2 коэффициента регуляризации и веса к вычисленным градиентам, используя dlupdate функция. Чтобы обновить только параметры веса, извлеките параметры с именем "Weights".

idx = dlnet.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

После добавления параметра регуляризации L2 к градиентам обновите параметры сети.

Градиентные Усечения

Чтобы отсечь градиенты, используйте dlupdate функция.

Чтобы легко задать опции усечения градиента, создайте переменные gradientThresholdMethod и gradientThreshold, где gradientThresholdMethod содержит "global-l2norm", "l2norm", или "absolute-value", и gradientThreshold - положительная скалярная величина, содержащая порог или inf.

gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;

Создайте функции с именем thresholdGlobalL2Norm, thresholdL2Norm, и thresholdAbsoluteValue которые применяют "global-l2norm", "l2norm", и "absolute-value" пороговые методы, соответственно.

Для "global-l2norm" опция, функция работает со всеми градиентами модели.

function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold)

globalL2Norm = 0;
for i = 1:numel(gradients)
    globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2);
end
globalL2Norm = sqrt(globalL2Norm);

if globalL2Norm > gradientThreshold
    normScale = gradientThreshold / globalL2Norm;
    for i = 1:numel(gradients)
        gradients{i} = gradients{i} * normScale;
    end
end

end

Для "l2norm" и "absolute-value" опции, функции работают с каждым градиентом независимо.

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold)

gradients(gradients > gradientThreshold) = gradientThreshold;
gradients(gradients < -gradientThreshold) = -gradientThreshold;

end

Во время обучения, после вычисления градиентов модели, примените соответствующий метод усечения градиентов к градиентам, используя dlupdate функция. Потому что "global-l2norm" опция требует всех градиентов модели, применить thresholdGlobalL2Norm функция непосредственно к градиентам. Для "l2norm" и "absolute-value" опции, обновляйте градиенты независимо, используя dlupdate функция.

switch gradientThresholdMethod
    case "global-l2norm"
        gradients = thresholdGlobalL2Norm(gradients, gradientThreshold);
    case "l2norm"
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
    case "absolute-value"
        gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients);
end

После применения операции градиентного порога обновите параметры сети.

Обучение на одном центральном процессоре или графическом процессоре

По умолчанию программа выполняет вычисления, используя только центральный процессор. Чтобы обучиться на одном графическом процессоре, преобразуйте данные в gpuArray объекты. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

Чтобы легко задать окружение выполнения, создайте переменную executionEnvironment который содержит либо "cpu", "gpu", или "auto".

executionEnvironment = "auto"

Во время обучения, после чтения мини-пакета, проверьте опцию окружения выполнения и преобразуйте данные в gpuArray при необходимости. canUseGPU проверка функций на пригодные для использования графические процессоры.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Контрольные точки

Чтобы сохранить сети контрольных точек во время обучения, сохраните сеть с помощью save функция.

Чтобы легко определить, следует ли включать контрольные точки, создайте переменную checkpointPath содержит папку для сетей контрольных точек или пуста.

checkpointPath = fullfile(tempdir,"checkpoints");

Если папка контрольных точек не существует, перед обучением создайте папку контрольных точек.

if ~exist(checkpointPath,"dir")
    mkdir(checkpointPath)
end

Во время обучения, в конце эпохи, сохраните сеть в MAT файла. Укажите имя файла, содержащее текущий номер итерации, дату и время.

if ~isempty(checkpointPath)
    D = datestr(now,'yyyy_mm_dd__HH_MM_SS');
    filename = "dlnet_checkpoint__" + iteration + "__" + D + ".mat";
    save(filename,"dlnet")
end
где dlnet является dlnetwork объект, который будет сохранен.

См. также

| | | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте