Для большинства задач можно управлять деталями алгоритма настройки, используя trainingOptions
и trainNetwork
функций. Если trainingOptions
функция не предоставляет опции, необходимые для вашей задачи (для примера, пользовательское расписание скорости обучения), тогда можно задать свой собственный пользовательский цикл обучения с помощью dlnetwork
объект. A dlnetwork
объект позволяет вам обучать сеть, заданную как график слоев, используя автоматическую дифференциацию.
Чтобы задать те же опции, что и trainingOptions
, используйте эти примеры в качестве руководства:
Опция обучения | trainingOptions Аргумент | Пример |
---|---|---|
Решатель Адама | Адаптивная оценка момента (ADAM) | |
Решатель RMSProp | Среднеквадратичное распространение корня (RMSProp) | |
Решатель SGDM | Стохастический градиентный спуск с моментом (SGDM) | |
Темп обучения | 'InitialLearnRate' | Темп обучения |
Изучение расписания ставок | Кусочное расписание скорости обучения | |
Процесс обучения | 'Plots' | Графики |
Подробный выход | Подробный выход | |
Размер мини-пакета | 'MiniBatchSize' | Размер мини-пакета |
Количество эпох | 'MaxEpochs' | Количество эпох |
Валидация | Валидация | |
L2 регуляризация | 'L2Regularization' | L2 регуляризация |
Градиентные усечения | Градиентные Усечения | |
Один центральный процессор или GPU | 'ExecutionEnvironment' | Обучение на одном центральном процессоре или графическом процессоре |
Контрольные точки | 'CheckpointPath' | Контрольные точки |
Чтобы задать решатель, используйте adamupdate
, rmspropupdate
, и sgdmupdate
функции для шага обновления в цикле обучения. Чтобы реализовать свой собственный пользовательский решатель, обновите настраиваемые параметры с помощью dlupdate
функция.
Чтобы обновить параметры сети с помощью Adam, используйте adamupdate
функция. Задайте градиентный распад и квадратные коэффициенты градиента распада с помощью соответствующих входных параметров.
Чтобы обновить параметры сети с помощью RMSProp, используйте rmspropupdate
функция. Задайте значение смещения знаменателя (эпсилон) с помощью соответствующего входного параметра.
Чтобы обновить параметры сети с помощью SGDM, используйте sgdmupdate
функция. Задайте импульс с помощью соответствующего входного параметра.
Чтобы задать скорость обучения, используйте входные параметры скорости обучения adamupdate
, rmspropupdate
, и sgdmupdate
функций.
Чтобы легко настроить скорость обучения или использовать ее для пользовательских расписаний скорости обучения, установите начальную скорость обучения перед пользовательским циклом обучения.
learnRate = 0.01;
Чтобы автоматически сбросить скорость обучения во время обучения с помощью кусочно-графика скорости обучения, умножьте скорость обучения на заданный коэффициент падения после заданного интервала.
Чтобы легко задать кусочное расписание скорости обучения, создайте переменные learnRate
, learnRateSchedule
, learnRateDropFactor
, и learnRateDropPeriod
, где learnRate
- начальная скорость обучения, learnRateScedule
содержит либо "piecewise"
или "none"
, learnRateDropFactor
является скаляром в области значений [0, 1], который задает коэффициент для падения скорости обучения и learnRateDropPeriod
является положительным целым числом, которое задает, сколько эпох между падением скорости обучения.
learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;
Внутри цикла обучения в конце каждой эпохи сбрасывайте скорость обучения, когда learnRateSchedule
опция "piecewise"
и текущее число эпохи кратно learnRateDropPeriod
. Установите новую скорость обучения в продукт скорости обучения и коэффициента падения скорости обучения.
if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0 learnRate = learnRate * learnRateDropFactor; end
Чтобы построить график потерь и точности обучения во время обучения, вычислите мини-потери пакета и точность или квадратичную невязку (RMSE) в функции градиентов модели и постройте график с помощью анимированной линии.
Чтобы легко указать, что график должен быть включен или выключен, создайте переменную plots
который содержит либо "training-progress"
или "none"
. Чтобы также построить графики метрик валидации, используйте те же опции validationData
и validationFrequency
описан в разделе Валидация.
plots = "training-progress";
validationData = {XValidation, YValidation};
validationFrequency = 50;
Перед обучением инициализируйте анимированные линии с помощью animatedline
функция. Для задач классификации создайте график для точности обучения и потерь обучения. Также инициализируйте анимированные линии для метрик валидации, когда заданы данные валидации.
if plots == "training-progress" figure subplot(2,1,1) lineAccuracyTrain = animatedline; ylabel("Accuracy") subplot(2,1,2) lineLossTrain = animatedline; xlabel("Iteration") ylabel("Loss") if ~isempty(validationData) subplot(2,1,1) lineAccuracyValidation = animatedline; subplot(2,1,2) lineLossValidation = animatedline; end end
Для задач регрессии скорректируйте код путем изменения имен переменных и меток так, чтобы он инициализировал графики для RMSE обучения и валидации вместо точности обучения и валидации.
Внутри цикла обучения в конце итерации обновляйте график так, чтобы он включал соответствующие метрики для сети. Для задач классификации добавьте точки, соответствующие точности мини-пакета и потере мини-пакета. Если данные валидации непусты, а текущая итерация является либо 1, либо кратной опции частоты валидации, то также добавьте точки для данных валидации.
if plots == "training-progress" addpoints(lineAccuracyTrain,iteration,accuracyTrain) addpoints(lineLossTrain,iteration,lossTrain) if ~isempty(validationData) && (iteration == 1 || mod(iteration,validationFrequency) == 0) addpoints(lineAccuracyValidation,iteration,accuracyValidation) addpoints(lineLossValidation,iteration,lossValidation) end end
accuracyTrain
и lossTrain
соответствуют мини-пакетной точности и потерям, вычисленным в функции градиентов модели. Для задач регрессии используйте мини-пакетные потери RMSE вместо мини-пакетных точностей.Совет
The addpoints
функция требует, чтобы точки данных имели тип double
. Чтобы извлечь числовые данные из dlarray
объекты, использовать extractdata
функция. Чтобы собрать данные из графический процессор, используйте gather
функция.
Чтобы узнать, как вычислить метрики валидации, смотрите Валидацию.
Чтобы отобразить потери и точность обучения во время обучения в подробной таблице, вычислите мини-потери пакета и точность (для задач классификации) или RMSE (для задач регрессии) в функции градиентов модели и отобразите их с помощью disp
функция.
Чтобы легко указать, что подробная таблица должна быть включена или отключена, создайте переменные verbose
и verboseFrequency
, где verbose
является true
или false
и verbosefrequency
задает, сколько итераций между печатью подробного выхода. Чтобы отобразить метрики валидации, используйте те же опции validationData
и validationFrequency
описан в разделе Валидация.
verbose = true verboseFrequency = 50; validationData = {XValidation, YValidation}; validationFrequency = 50;
Перед обучением отобразите подробные заголовки выходной таблицы и инициализируйте таймер с помощью tic
функция.
disp("|======================================================================================================================|") disp("| Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning |") disp("| | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate |") disp("|======================================================================================================================|") start = tic;
Для задач регрессии настройте код так, чтобы он отображал обучение и валидацию RMSE вместо точности обучения и валидации.
Внутри цикла обучения, в конце итерации, печатайте подробный выход, когда verbose
опция true
и это либо первая итерация, либо число итерации кратно verboseFrequency
.
if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0 D = duration(0,0,toc(start),'Format','hh:mm:ss'); if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 accuracyValidation = ""; lossValidation = ""; end disp("| " + ... pad(epoch,7,'left') + " | " + ... pad(iteration,11,'left') + " | " + ... pad(D,14,'left') + " | " + ... pad(accuracyTrain,12,'left') + " | " + ... pad(accuracyValidation,12,'left') + " | " + ... pad(lossTrain,12,'left') + " | " + ... pad(lossValidation,12,'left') + " | " + ... pad(learnRate,15,'left') + " |") end
Для задач регрессии настройте код так, чтобы он отображал обучение и валидацию RMSE вместо точности обучения и валидации.
Когда обучение закончено, распечатайте последнюю границу подробной таблицы.
disp("|======================================================================================================================|")
Чтобы узнать, как вычислить метрики валидации, смотрите Валидацию.
Установка размера мини-пакета зависит от формата данных или типа используемого datastore.
Чтобы легко задать размер мини-пакета, создайте переменную miniBatchSize
.
miniBatchSize = 128;
Для данных в datastore перед обучением установите ReadSize
свойство datastore для размера мини-пакета.
imds.ReadSize = miniBatchSize;
Для данных в дополненном datastore перед обучением установите MiniBatchSize
свойство datastore для размера мини-пакета.
augimds.MiniBatchSize = miniBatchSize;
Для данных в памяти во время обучения в начале каждой итерации считывайте наблюдения непосредственно из массива.
idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize); X = XTrain(:,:,:,idx);
Задайте максимальное количество эпох для обучения во внешней for
цикл цикла обучения.
Чтобы легко задать максимальное количество эпох, создайте переменную maxEpochs
который содержит максимальное количество эпох.
maxEpochs = 30;
Во внешней for
цикл цикла обучения, задайте цикл в области значений 1, 2,..., maxEpochs
.
for epoch = 1:maxEpochs ... end
Чтобы подтвердить свою сеть во время обучения, отложите набор валидации и оцените, насколько хорошо сеть работает с этими данными.
Чтобы легко задать опции валидации, создайте переменные validationData
и validationFrequency
, где validationData
содержит данные валидации или пуст и validationFrequency
задает, сколько итераций между проверкой сети.
validationData = {XValidation,YValidation}; validationFrequency = 50;
Во время цикла обучения, после обновления параметров сети, проверяйте, насколько хорошо сеть работает на задержанном наборе валидации, используя predict
функция. Проверьте сеть, только когда данные валидации заданы, и это либо первая итерация, либо текущая итерация является произведением validationFrequency
опция.
if iteration == 1 || mod(iteration,validationFrequency) == 0 dlYPredValidation = predict(dlnet,dlXValidation); lossValidation = crossentropy(softmax(dlYPredValidation), YValidation); [~,idx] = max(dlYPredValidation); labelsPredValidation = classNames(idx); accuracyValidation = mean(labelsPredValidation == labelsValidation); end
YValidation
является фиктивной переменной, соответствующей меткам в classNames
. Чтобы вычислить точность, преобразуйте YValidation
в массив меток.Для задач регрессии настройте код так, чтобы он вычислял RMSE валидации вместо точности валидации.
Чтобы остановить обучение раньше, когда потеря при задержке валидации перестает уменьшаться, используйте флаг, чтобы вырваться из циклов обучения.
Чтобы легко указать терпение валидации (количество раз, когда потеря валидации может быть больше или равной ранее наименьшей потере до остановки сетевого обучения), создайте переменную validationPatience
.
validationPatience = 5;
Перед обучением инициализируйте переменные earlyStop
и validationLosses
, где earlyStop
является флагом, чтобы остановить обучение раньше и validationLosses
содержит потери, которые нужно сравнить. Инициализируйте флаг ранней остановки с false
и массив потерь для валидации с inf
.
earlyStop = false; if isfinite(validationPatience) validationLosses = inf(1,validationPatience); end
Внутри цикла обучения, в цикле по мини-пакетам, добавьте earlyStop
флаг для условия цикла.
while hasdata(ds) && ~earlyStop ... end
Во время шага валидации добавьте новые потери валидации к массиву validationLosses
. Если первый элемент массива является наименьшим, то установите earlyStop
флаг в true
. В противном случае удалите первый элемент.
if isfinite(validationPatience) validationLosses = [validationLosses validationLoss]; if min(validationLosses) == validationLosses(1) earlyStop = true; else validationLosses(1) = []; end end
Чтобы применить L2 регуляризацию к весам, используйте dlupdate
функция.
Чтобы легко задать коэффициент регуляризации L2, создайте переменную l2Regularization
который содержит L2 коэффициент регуляризации.
l2Regularization = 0.0001;
Во время обучения, после вычисления градиентов модели, для каждого из весовых параметров добавьте продукт L2 коэффициента регуляризации и веса к вычисленным градиентам, используя dlupdate
функция. Чтобы обновить только параметры веса, извлеките параметры с именем "Weights"
.
idx = dlnet.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));
После добавления параметра регуляризации L2 к градиентам обновите параметры сети.
Чтобы отсечь градиенты, используйте dlupdate
функция.
Чтобы легко задать опции усечения градиента, создайте переменные gradientThresholdMethod
и gradientThreshold
, где gradientThresholdMethod
содержит "global-l2norm"
, "l2norm"
, или "absolute-value"
, и gradientThreshold
- положительная скалярная величина, содержащая порог или inf
.
gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;
Создайте функции с именем thresholdGlobalL2Norm
, thresholdL2Norm
, и thresholdAbsoluteValue
которые применяют "global-l2norm"
, "l2norm"
, и "absolute-value"
пороговые методы, соответственно.
Для "global-l2norm"
опция, функция работает со всеми градиентами модели.
function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold) globalL2Norm = 0; for i = 1:numel(gradients) globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2); end globalL2Norm = sqrt(globalL2Norm); if globalL2Norm > gradientThreshold normScale = gradientThreshold / globalL2Norm; for i = 1:numel(gradients) gradients{i} = gradients{i} * normScale; end end end
Для "l2norm"
и "absolute-value"
опции, функции работают с каждым градиентом независимо.
function gradients = thresholdL2Norm(gradients,gradientThreshold) gradientNorm = sqrt(sum(gradients(:).^2)); if gradientNorm > gradientThreshold gradients = gradients * (gradientThreshold / gradientNorm); end end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold) gradients(gradients > gradientThreshold) = gradientThreshold; gradients(gradients < -gradientThreshold) = -gradientThreshold; end
Во время обучения, после вычисления градиентов модели, примените соответствующий метод усечения градиентов к градиентам, используя dlupdate
функция. Потому что "global-l2norm"
опция требует всех градиентов модели, применить thresholdGlobalL2Norm
функция непосредственно к градиентам. Для "l2norm"
и "absolute-value"
опции, обновляйте градиенты независимо, используя dlupdate
функция.
switch gradientThresholdMethod case "global-l2norm" gradients = thresholdGlobalL2Norm(gradients, gradientThreshold); case "l2norm" gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients); case "absolute-value" gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients); end
После применения операции градиентного порога обновите параметры сети.
По умолчанию программа выполняет вычисления, используя только центральный процессор. Чтобы обучиться на одном графическом процессоре, преобразуйте данные в gpuArray
объекты. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
Чтобы легко задать окружение выполнения, создайте переменную executionEnvironment
который содержит либо "cpu"
, "gpu"
, или "auto"
.
executionEnvironment = "auto"
Во время обучения, после чтения мини-пакета, проверьте опцию окружения выполнения и преобразуйте данные в gpuArray
при необходимости. canUseGPU
проверка функций на пригодные для использования графические процессоры.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end
Чтобы сохранить сети контрольных точек во время обучения, сохраните сеть с помощью save
функция.
Чтобы легко определить, следует ли включать контрольные точки, создайте переменную checkpointPath
содержит папку для сетей контрольных точек или пуста.
checkpointPath = fullfile(tempdir,"checkpoints");
Если папка контрольных точек не существует, перед обучением создайте папку контрольных точек.
if ~exist(checkpointPath,"dir") mkdir(checkpointPath) end
Во время обучения, в конце эпохи, сохраните сеть в MAT файла. Укажите имя файла, содержащее текущий номер итерации, дату и время.
if ~isempty(checkpointPath) D = datestr(now,'yyyy_mm_dd__HH_MM_SS'); filename = "dlnet_checkpoint__" + iteration + "__" + D + ".mat"; save(filename,"dlnet") end
dlnet
является dlnetwork
объект, который будет сохранен.adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| dlupdate
| rmspropupdate
| sgdmupdate