Задайте опции обучения в пользовательском учебном цикле

Для большинства задач можно управлять учебными деталями алгоритма с помощью trainingOptions и trainNetwork функции. Если trainingOptions функция не предоставляет возможности, в которых вы нуждаетесь для своей задачи (например, пользовательское изучают расписание уровня), затем можно задать собственный учебный цикл с помощью автоматического дифференцирования.

Задавать те же опции как trainingOptions, используйте эти примеры в качестве руководства:

Опция обученияtrainingOptions АргументПример
Решатель Адама Адаптивная оценка момента (ADAM)
Решатель RMSProp Среднеквадратичное распространение (RMSProp)
Решатель SGDM Стохастический градиентный спуск с импульсом (SGDM)
Изучите уровень'InitialLearnRate'Изучите уровень
Изучите расписание уровня Кусочный изучают расписание уровня
Процесс обучения'Plots'Графики
Многословный выход Многословный Выход
Мини-пакетный размер'MiniBatchSize'Мини-пакетный размер
Номер эпох'MaxEpochs'Номер эпох
Валидация Валидация
Регуляризация L2'L2Regularization'Регуляризация L2
Усечение градиента Усечение градиента
Одно обучение центрального процессора или графического процессора'ExecutionEnvironment'Одно обучение центрального процессора или графического процессора
Контрольные точки'CheckpointPath'Контрольные точки

Опции решателя

Чтобы задать решатель, используйте adamupdate, rmspropupdate, и sgdmupdate функции для обновления продвигаются в ваш учебный цикл. Чтобы реализовать ваш собственный решатель, обновите learnable параметры с помощью dlupdate функция.

Адаптивная оценка момента (ADAM)

Чтобы обновить ваши сетевые параметры с помощью Адама, используйте adamupdate функция. Задайте затухание градиента и факторы затухания градиента в квадрате с помощью соответствующих входных параметров.

Среднеквадратичное распространение (RMSProp)

Чтобы обновить ваши сетевые параметры с помощью RMSProp, используйте rmspropupdate функция. Задайте смещение знаменателя (эпсилон) значение с помощью соответствующего входного параметра.

Стохастический градиентный спуск с импульсом (SGDM)

Чтобы обновить ваши сетевые параметры с помощью SGDM, используйте sgdmupdate функция. Задайте импульс с помощью соответствующего входного параметра.

Изучите уровень

Чтобы задать изучить уровень, используйте изучить входные параметры уровня adamupdate, rmspropupdate, и sgdmupdate функции.

Чтобы легко настроить изучить уровень или использовать, это в пользовательском изучает расписания уровня, установило начальную букву, изучают уровень перед пользовательским учебным циклом.

learnRate = 0.01;

Кусочный изучают расписание уровня

Чтобы автоматически пропустить изучить уровень во время обучения с помощью кусочного изучают расписание уровня, умножают изучить уровень на данный фактор отбрасывания после заданного интервала.

Чтобы легко задать кусочное изучают расписание уровня, создают переменные learnRate, learnRateSchedule, learnRateDropFactor, и learnRateDropPeriod, где learnRate начальная буква, изучают уровень, learnRateScedule содержит любой "piecewise" или "none", learnRateDropFactor скаляр в области значений [0, 1], который задает фактор для отбрасывания темпа обучения и learnRateDropPeriod положительное целое число, которое задает сколько эпох между отбрасыванием изучить уровня.

learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;

В учебном цикле, в конце каждой эпохи, пропускают изучить уровень когда learnRateSchedule опцией является "piecewise" и текущий номер эпохи является кратным learnRateDropPeriod. Установите новое, изучают уровень продукту изучить уровня и изучить фактора отбрасывания уровня.

if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0
    learnRate = learnRate * learnRateDropFactor;
end

Графики

Чтобы построить учебную потерю и точность во время обучения, вычислите мини-пакетную потерю, и или точность или среднеквадратическая ошибка (RMSE) в градиентах модели функционируют и строят их использующий анимированную линию.

Чтобы легко указать, что график должен идти или прочь, создайте переменную plots это содержит любой "training-progress" или "none". Чтобы также построить метрики валидации, используйте те же опции validationData и validationFrequency описанный в Валидации.

plots = "training-progress";

validationData = {XValidation, YValidation};
validationFrequency = 50;

Перед обучением инициализируйте анимированные линии с помощью animatedline функция. Для классификации задачи создают график для учебной точности и учебной потери. Также инициализируйте анимированные линии для метрик валидации, когда данные о валидации будут заданы.

if plots == "training-progress"
    figure
    subplot(2,1,1)
    lineAccuracyTrain = animatedline;
    ylabel("Accuracy")
	
    subplot(2,1,2)
    lineLossTrain = animatedline;
    xlabel("Iteration")
    ylabel("Loss")

    if ~isempty(validationData)
        subplot(1,2,1)
        lineAccuracyValidation = animatedline;

        subplot(1,2,2)
        lineLossValidation = animatedline;
    end
end

Для задач регрессии настройте код путем изменения имен переменных и меток так, чтобы он инициализировал графики для обучения и валидации RMSE вместо точности обучения и валидации.

В учебном цикле, в конце итерации, обновляют график так, чтобы это включало соответствующие метрики для сети. Для задач классификации добавьте точки, соответствующие мини-пакетной точности и мини-пакетной потере. Если данные о валидации непусты, и текущая итерация или 1 или кратное опции частоты валидации, то также добавляют точки для данных о валидации.

if plots == "training-progress"
    addpoints(lineAccuracyTrain,iteration,accuracyTrain)
    addpoints(lineLossTrain,iteration,lossTrain)

    if ~isempty(validationData) && (iteration == 1 || mod(iteration,validationFrequency) == 0)
        addpoints(lineAccuracyValidation,iteration,accuracyValidation)
        addpoints(lineLossValidation,iteration,lossValidation)
    end
end
где accuracyTrain и lossTrain соответствуйте мини-пакетной точности и потере, вычисленной в функции градиентов модели. Для задач регрессии используйте мини-пакетные потери RMSE вместо мини-пакетной точности.

Совет

addpoints функция требует, чтобы точки данных, чтобы иметь ввели double. Извлекать числовые данные из dlarray объекты, используйте extractdata функция. Чтобы собрать данные от графического процессора, используйте gather функция.

Чтобы изучить, как вычислить метрики валидации, смотрите Валидацию.

Многословный Выход

Чтобы отобразить учебную потерю и точность во время обучения в многословной таблице, вычислите мини-пакетную потерю и любого, точность (для задач классификации) или RMSE (для задач регрессии) в градиентах модели функционирует и отображает их использующий disp функция.

Чтобы легко указать, что многословная таблица должна идти или прочь, создайте переменные verbose и verboseFrequency, где verbose true или false и verbosefrequency задает сколько итераций между печатью многословного выхода. Чтобы отобразить метрики валидации, используйте те же опции validationData и validationFrequency описанный в Валидации.

verbose = true
verboseFrequency = 50;

validationData = {XValidation, YValidation};
validationFrequency = 50;

Перед обучением, отображение многословные выходные табличные заголовки и инициализируют таймер с помощью tic функция.

disp("|======================================================================================================================|")
disp("|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |")
disp("|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |")
disp("|======================================================================================================================|")

start = tic;

Для задач регрессии настройте код так, чтобы он отобразил обучение и валидацию RMSE вместо точности обучения и валидации.

В учебном цикле, в конце итерации, распечатывают многословный выход когда verbose опцией является true и это - или первая итерация или номер итерации, является кратным verboseFrequency.

if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0
    D = duration(0,0,toc(start),'Format','hh:mm:ss');

    if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 
        accuracyValidation = "";
        lossValidation = "";
    end

    disp("| " + ...
        pad(epoch,7,'left') + " | " + ...
        pad(iteration,11,'left') + " | " + ...
        pad(D,14,'left') + " | " + ...
        pad(accuracyTrain,12,'left') + " | " + ...
        pad(accuracyValidation,12,'left') + " | " + ...
        pad(lossTrain,12,'left') + " | " + ...
        pad(lossValidation,12,'left') + " | " + ...
        pad(learnRate,15,'left') + " |")
end

Для задач регрессии настройте код так, чтобы он отобразил обучение и валидацию RMSE вместо точности обучения и валидации.

Когда обучение будет закончено, распечатайте последнюю границу многословной таблицы.

disp("|======================================================================================================================|")

Чтобы изучить, как вычислить метрики валидации, смотрите Валидацию.

Мини-пакетный размер

Установка мини-пакетного размера зависит от формата данных или типа используемого datastore.

Чтобы легко задать мини-пакетный размер, создайте переменную miniBatchSize.

miniBatchSize = 128;

Для данных в datastore изображений, перед обучением, устанавливает ReadSize свойство datastore к мини-пакетному размеру.

imds.ReadSize = miniBatchSize;

Для данных в увеличенном datastore изображений, перед обучением, устанавливает MiniBatchSize свойство datastore к мини-пакетному размеру.

augimds.MiniBatchSize = miniBatchSize;

Для данных в оперативной памяти, во время обучения в начале каждой итерации, читают наблюдения непосредственно из массива.

idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize);
X = XTrain(:,:,:,idx);

Номер эпох

Задайте максимальное количество эпох для обучения во внешнем for цикл учебного цикла.

Чтобы легко задать максимальное количество эпох, создайте переменную maxEpochs это содержит максимальное количество эпох.

maxEpochs = 30;

Во внешнем for цикл учебного цикла, задайте, чтобы циклично выполниться в области значений 1, 2, …, maxEpochs.

for epoch = 1:maxEpochs
    ...
end

Валидация

Чтобы подтвердить вашу сеть во время обучения, отложите протянутый набор валидации и оцените, как хорошо сеть выполняет на тех данных.

Чтобы легко задать опции валидации, создайте переменные validationData и validationFrequency, где validationData содержит данные о валидации или пуст и validationFrequency задает сколько итераций между проверкой сети.

validationData = {XValidation,YValidation};
validationFrequency = 50;

Во время учебного цикла, после обновления сетевых параметров, тест, как хорошо сеть выполняет на протянутом наборе валидации с помощью predict функция. Подтвердите сеть только, когда данные о валидации заданы, и это - или первая итерация или текущая итерация, является кратным verboseFrequency опция.

if iteration == 1 || mod(iteration,verboseFrequency) == 0
    dlYPredValidation = predict(dlnet,dlXValidation);
    lossValidation = crossentropy(softmax(dlYPredValidation), YValidation);

    [~,idx] = max(dlYPredValidation);
    labelsPredValidation = classNames(idx);

    accuracyValidation = mean(labelsPredValidation == labelsValidation);
end
Здесь, YValidation фиктивная переменная, соответствующая меткам в classNames. Чтобы вычислить точность, преобразуйте YValidation к массиву меток.

Для задач регрессии настройте код так, чтобы он вычислил валидацию RMSE вместо точности валидации.

Рано остановка

Чтобы остановить обучение рано, когда потеря на протянутой валидации прекращает уменьшаться, используйте, устанавливает флаг убегать из учебных циклов.

Чтобы легко задать терпение валидации (число раз, что потеря валидации может быть больше, чем или равняться ранее самой маленькой потере перед сетевыми учебными остановками), создайте переменную validationPatience.

validationPatience = 5;

Перед обучением инициализируйте переменные earlyStop и validationLosses, где earlyStop флаг должен остановить обучение рано и validationLosses содержит потери, чтобы выдержать сравнение. Инициализируйте ранний флаг остановки с false и массив потерь валидации с inf.

earlyStop = false;
if isfinite(validationPatience)
    validationLosses = inf(1,validationPatience);
end

В учебном цикле, в цикле по мини-пакетам, добавляет earlyStop отметьте к условию цикла.

while hasdata(ds) && ~earlyStop
    ...
end

Во время шага валидации добавьте новую потерю валидации для массива validationLosses. Если первый элемент массива является самым маленьким, то установленный earlyStop отметьте к true. В противном случае удалите первый элемент.

if isfinite(validationPatience)
    validationLosses = [validationLosses validationLoss];
    if min(validationLosses) == validationLosses(1)
        earlyStop = true;
    else
        validationLosses(1) = [];
    end
end

Регуляризация L2

Чтобы применить регуляризацию L2 к весам, используйте dlupdate функция.

Чтобы легко задать фактор регуляризации L2, создайте переменную l2Regularization это содержит фактор регуляризации L2.

l2Regularization = 0.0001;

Во время обучения, после вычисления градиентов модели, для каждого из параметров веса, добавляет продукт фактора регуляризации L2 и весов к вычисленным градиентам с помощью dlupdate функция. Чтобы обновить только параметры веса, извлеките параметры с именем "Weights".

idx = dlnet.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

После добавления параметра регуляризации L2 к градиентам обновите сетевые параметры.

Усечение градиента

Чтобы отсечь градиенты, используйте dlupdate функция.

Чтобы легко задать опции усечения градиента, создайте переменные gradientThresholdMethod и gradientThreshold, где gradientThresholdMethod содержит "global-l2norm", "l2norm", или "absolute-value", и gradientThreshold положительная скалярная величина, содержащая порог или inf.

gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;

Создайте функции с именем thresholdGlobalL2Norm, thresholdL2Norm, и thresholdAbsoluteValue это применяет "global-l2norm", "l2norm", и "absolute-value" пороговые методы, соответственно.

Для "global-l2norm" опция, функция работает со всеми градиентами модели.

function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold)

globalL2Norm = 0;
for i = 1:numel(gradients)
    globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2);
end
globalL2Norm = sqrt(globalL2Norm);

if globalL2Norm > gradientThreshold
    normScale = gradientThreshold / globalL2Norm;
    for i = 1:numel(gradients)
        gradients{i} = gradients{i} * normScale;
    end
end

end

Для "l2norm" и "absolute-value" опции, функции работают с каждым градиентом независимо.

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold)

gradients(gradients > gradientThreshold) = gradientThreshold;
gradients(gradients < -gradientThreshold) = -gradientThreshold;

end

Во время обучения, после вычисления градиентов модели, применяют соответствующий метод усечения градиента к градиентам с помощью dlupdate функция. Поскольку "global-l2norm" опция требует всех градиентов модели, примените thresholdGlobalL2Norm функционируйте непосредственно к градиентам. Для "l2norm" и "absolute-value" опции, обновите градиенты независимо с помощью dlupdate функция.

switch gradientThresholdMethod
    case "global-l2norm"
        gradients = thresholdGlobalL2Norm(gradients, gradientThreshold);
    case "l2norm"
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
    case "absolute-value"
        gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients);
end

После применения пороговой операции градиента обновите сетевые параметры.

Одно обучение центрального процессора или графического процессора

Программное обеспечение, по умолчанию, выполняет вычисления с помощью одного центральный процессор. Обучайтесь на одном графическом процессоре, преобразуйте в данные к gpuArray объекты. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.

Чтобы легко задать среду выполнения, создайте переменную executionEnvironment это содержит любой "cpu", "gpu", или "auto".

executionEnvironment = "auto"

Во время обучения, после чтения мини-пакета, проверяют опцию среды выполнения и преобразуют данные в gpuArray при необходимости. canUseGPU функционируйте проверки на применимые графические процессоры.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Контрольные точки

Сохранить сети контрольной точки во время учебного сохранения сеть с помощью save функция.

Чтобы легко задать, должны ли контрольные точки быть включены, создайте переменную checkpointPath содержит папку для сетей контрольной точки или пуст.

checkpointPath = fullfile(tempdir,"checkpoints");

Если папка контрольной точки не существует, то перед обучением, создайте папку контрольной точки.

if ~exist(checkpointPath,"dir")
    mkdir(checkpointPath)
end

Во время обучения, в конце эпохи, сохраняют сеть в файле MAT. Задайте имя файла, содержащее текущий номер итерации, дату, и время.

if ~isempty(checkpointPath)
    D = datestr(now,'yyyy_mm_dd__HH_MM_SS');
    filename = "dlnet_checkpoint__" + iteration + "__" + D + ".mat";
    save(filename,"dlnet")
end
где dlnet dlnetwork объект быть сохраненным.

Смотрите также

| | | | | | |

Похожие темы