Для большинства задач можно управлять учебными деталями алгоритма с помощью trainingOptions
и trainNetwork
функции. Если trainingOptions
функция не предоставляет возможности, в которых вы нуждаетесь для своей задачи (например, пользовательское изучают расписание уровня), затем можно задать собственный учебный цикл с помощью автоматического дифференцирования.
Задавать те же опции как trainingOptions
, используйте эти примеры в качестве руководства:
Опция обучения | trainingOptions Аргумент | Пример |
---|---|---|
Решатель Адама | Адаптивная оценка момента (ADAM) | |
Решатель RMSProp | Среднеквадратичное распространение (RMSProp) | |
Решатель SGDM | Стохастический градиентный спуск с импульсом (SGDM) | |
Изучите уровень | 'InitialLearnRate' | Изучите уровень |
Изучите расписание уровня | Кусочный изучают расписание уровня | |
Процесс обучения | 'Plots' | Графики |
Многословный выход | Многословный Выход | |
Мини-пакетный размер | 'MiniBatchSize' | Мини-пакетный размер |
Номер эпох | 'MaxEpochs' | Номер эпох |
Валидация | Валидация | |
Регуляризация L2 | 'L2Regularization' | Регуляризация L2 |
Усечение градиента | Усечение градиента | |
Одно обучение центрального процессора или графического процессора | 'ExecutionEnvironment' | Одно обучение центрального процессора или графического процессора |
Контрольные точки | 'CheckpointPath' | Контрольные точки |
Чтобы задать решатель, используйте adamupdate
, rmspropupdate
, и sgdmupdate
функции для обновления продвигаются в ваш учебный цикл. Чтобы реализовать ваш собственный решатель, обновите learnable параметры с помощью dlupdate
функция.
Чтобы обновить ваши сетевые параметры с помощью Адама, используйте adamupdate
функция. Задайте затухание градиента и факторы затухания градиента в квадрате с помощью соответствующих входных параметров.
Чтобы обновить ваши сетевые параметры с помощью RMSProp, используйте rmspropupdate
функция. Задайте смещение знаменателя (эпсилон) значение с помощью соответствующего входного параметра.
Чтобы обновить ваши сетевые параметры с помощью SGDM, используйте sgdmupdate
функция. Задайте импульс с помощью соответствующего входного параметра.
Чтобы задать изучить уровень, используйте изучить входные параметры уровня adamupdate
, rmspropupdate
, и sgdmupdate
функции.
Чтобы легко настроить изучить уровень или использовать, это в пользовательском изучает расписания уровня, установило начальную букву, изучают уровень перед пользовательским учебным циклом.
learnRate = 0.01;
Чтобы автоматически пропустить изучить уровень во время обучения с помощью кусочного изучают расписание уровня, умножают изучить уровень на данный фактор отбрасывания после заданного интервала.
Чтобы легко задать кусочное изучают расписание уровня, создают переменные learnRate
, learnRateSchedule
, learnRateDropFactor
, и learnRateDropPeriod
, где learnRate
начальная буква, изучают уровень, learnRateScedule
содержит любой "piecewise"
или "none"
, learnRateDropFactor
скаляр в области значений [0, 1], который задает фактор для отбрасывания темпа обучения и learnRateDropPeriod
положительное целое число, которое задает сколько эпох между отбрасыванием изучить уровня.
learnRate = 0.01;
learnRateSchedule = "piecewise"
learnRateDropPeriod = 10;
learnRateDropFactor = 0.1;
В учебном цикле, в конце каждой эпохи, пропускают изучить уровень когда learnRateSchedule
опцией является "piecewise"
и текущий номер эпохи является кратным learnRateDropPeriod
. Установите новое, изучают уровень продукту изучить уровня и изучить фактора отбрасывания уровня.
if learnRateSchedule == "piecewise" && mod(epoch,learnRateDropPeriod) == 0 learnRate = learnRate * learnRateDropFactor; end
Чтобы построить учебную потерю и точность во время обучения, вычислите мини-пакетную потерю, и или точность или среднеквадратическая ошибка (RMSE) в градиентах модели функционируют и строят их использующий анимированную линию.
Чтобы легко указать, что график должен идти или прочь, создайте переменную plots
это содержит любой "training-progress"
или "none"
. Чтобы также построить метрики валидации, используйте те же опции validationData
и validationFrequency
описанный в Валидации.
plots = "training-progress";
validationData = {XValidation, YValidation};
validationFrequency = 50;
Перед обучением инициализируйте анимированные линии с помощью animatedline
функция. Для классификации задачи создают график для учебной точности и учебной потери. Также инициализируйте анимированные линии для метрик валидации, когда данные о валидации будут заданы.
if plots == "training-progress" figure subplot(2,1,1) lineAccuracyTrain = animatedline; ylabel("Accuracy") subplot(2,1,2) lineLossTrain = animatedline; xlabel("Iteration") ylabel("Loss") if ~isempty(validationData) subplot(1,2,1) lineAccuracyValidation = animatedline; subplot(1,2,2) lineLossValidation = animatedline; end end
Для задач регрессии настройте код путем изменения имен переменных и меток так, чтобы он инициализировал графики для обучения и валидации RMSE вместо точности обучения и валидации.
В учебном цикле, в конце итерации, обновляют график так, чтобы это включало соответствующие метрики для сети. Для задач классификации добавьте точки, соответствующие мини-пакетной точности и мини-пакетной потере. Если данные о валидации непусты, и текущая итерация или 1 или кратное опции частоты валидации, то также добавляют точки для данных о валидации.
if plots == "training-progress" addpoints(lineAccuracyTrain,iteration,accuracyTrain) addpoints(lineLossTrain,iteration,lossTrain) if ~isempty(validationData) && (iteration == 1 || mod(iteration,validationFrequency) == 0) addpoints(lineAccuracyValidation,iteration,accuracyValidation) addpoints(lineLossValidation,iteration,lossValidation) end end
accuracyTrain
и lossTrain
соответствуйте мини-пакетной точности и потере, вычисленной в функции градиентов модели. Для задач регрессии используйте мини-пакетные потери RMSE вместо мини-пакетной точности.addpoints
функция требует, чтобы точки данных, чтобы иметь ввели double
. Извлекать числовые данные из dlarray
объекты, используйте extractdata
функция. Чтобы собрать данные от графического процессора, используйте gather
функция.
Чтобы изучить, как вычислить метрики валидации, смотрите Валидацию.
Чтобы отобразить учебную потерю и точность во время обучения в многословной таблице, вычислите мини-пакетную потерю и любого, точность (для задач классификации) или RMSE (для задач регрессии) в градиентах модели функционирует и отображает их использующий disp
функция.
Чтобы легко указать, что многословная таблица должна идти или прочь, создайте переменные verbose
и verboseFrequency
, где verbose
true
или false
и verbosefrequency
задает сколько итераций между печатью многословного выхода. Чтобы отобразить метрики валидации, используйте те же опции validationData
и validationFrequency
описанный в Валидации.
verbose = true verboseFrequency = 50; validationData = {XValidation, YValidation}; validationFrequency = 50;
Перед обучением, отображение многословные выходные табличные заголовки и инициализируют таймер с помощью tic
функция.
disp("|======================================================================================================================|") disp("| Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning |") disp("| | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate |") disp("|======================================================================================================================|") start = tic;
Для задач регрессии настройте код так, чтобы он отобразил обучение и валидацию RMSE вместо точности обучения и валидации.
В учебном цикле, в конце итерации, распечатывают многословный выход когда verbose
опцией является true
и это - или первая итерация или номер итерации, является кратным verboseFrequency
.
if verbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0 D = duration(0,0,toc(start),'Format','hh:mm:ss'); if isempty(validationData) || mod(iteration,validationFrequency) ~= 0 accuracyValidation = ""; lossValidation = ""; end disp("| " + ... pad(epoch,7,'left') + " | " + ... pad(iteration,11,'left') + " | " + ... pad(D,14,'left') + " | " + ... pad(accuracyTrain,12,'left') + " | " + ... pad(accuracyValidation,12,'left') + " | " + ... pad(lossTrain,12,'left') + " | " + ... pad(lossValidation,12,'left') + " | " + ... pad(learnRate,15,'left') + " |") end
Для задач регрессии настройте код так, чтобы он отобразил обучение и валидацию RMSE вместо точности обучения и валидации.
Когда обучение будет закончено, распечатайте последнюю границу многословной таблицы.
disp("|======================================================================================================================|")
Чтобы изучить, как вычислить метрики валидации, смотрите Валидацию.
Установка мини-пакетного размера зависит от формата данных или типа используемого datastore.
Чтобы легко задать мини-пакетный размер, создайте переменную miniBatchSize
.
miniBatchSize = 128;
Для данных в datastore изображений, перед обучением, устанавливает ReadSize
свойство datastore к мини-пакетному размеру.
imds.ReadSize = miniBatchSize;
Для данных в увеличенном datastore изображений, перед обучением, устанавливает MiniBatchSize
свойство datastore к мини-пакетному размеру.
augimds.MiniBatchSize = miniBatchSize;
Для данных в оперативной памяти, во время обучения в начале каждой итерации, читают наблюдения непосредственно из массива.
idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize); X = XTrain(:,:,:,idx);
Задайте максимальное количество эпох для обучения во внешнем for
цикл учебного цикла.
Чтобы легко задать максимальное количество эпох, создайте переменную maxEpochs
это содержит максимальное количество эпох.
maxEpochs = 30;
Во внешнем for
цикл учебного цикла, задайте, чтобы циклично выполниться в области значений 1, 2, …, maxEpochs
.
for epoch = 1:maxEpochs ... end
Чтобы подтвердить вашу сеть во время обучения, отложите протянутый набор валидации и оцените, как хорошо сеть выполняет на тех данных.
Чтобы легко задать опции валидации, создайте переменные validationData
и validationFrequency
, где validationData
содержит данные о валидации или пуст и validationFrequency
задает сколько итераций между проверкой сети.
validationData = {XValidation,YValidation}; validationFrequency = 50;
Во время учебного цикла, после обновления сетевых параметров, тест, как хорошо сеть выполняет на протянутом наборе валидации с помощью predict
функция. Подтвердите сеть только, когда данные о валидации заданы, и это - или первая итерация или текущая итерация, является кратным verboseFrequency
опция.
if iteration == 1 || mod(iteration,verboseFrequency) == 0 dlYPredValidation = predict(dlnet,dlXValidation); lossValidation = crossentropy(softmax(dlYPredValidation), YValidation); [~,idx] = max(dlYPredValidation); labelsPredValidation = classNames(idx); accuracyValidation = mean(labelsPredValidation == labelsValidation); end
YValidation
фиктивная переменная, соответствующая меткам в classNames
. Чтобы вычислить точность, преобразуйте YValidation
к массиву меток.Для задач регрессии настройте код так, чтобы он вычислил валидацию RMSE вместо точности валидации.
Чтобы остановить обучение рано, когда потеря на протянутой валидации прекращает уменьшаться, используйте, устанавливает флаг убегать из учебных циклов.
Чтобы легко задать терпение валидации (число раз, что потеря валидации может быть больше, чем или равняться ранее самой маленькой потере перед сетевыми учебными остановками), создайте переменную validationPatience
.
validationPatience = 5;
Перед обучением инициализируйте переменные earlyStop
и validationLosses
, где earlyStop
флаг должен остановить обучение рано и validationLosses
содержит потери, чтобы выдержать сравнение. Инициализируйте ранний флаг остановки с false
и массив потерь валидации с inf
.
earlyStop = false; if isfinite(validationPatience) validationLosses = inf(1,validationPatience); end
В учебном цикле, в цикле по мини-пакетам, добавляет earlyStop
отметьте к условию цикла.
while hasdata(ds) && ~earlyStop ... end
Во время шага валидации добавьте новую потерю валидации для массива validationLosses
. Если первый элемент массива является самым маленьким, то установленный earlyStop
отметьте к true
. В противном случае удалите первый элемент.
if isfinite(validationPatience) validationLosses = [validationLosses validationLoss]; if min(validationLosses) == validationLosses(1) earlyStop = true; else validationLosses(1) = []; end end
Чтобы применить регуляризацию L2 к весам, используйте dlupdate
функция.
Чтобы легко задать фактор регуляризации L2, создайте переменную l2Regularization
это содержит фактор регуляризации L2.
l2Regularization = 0.0001;
Во время обучения, после вычисления градиентов модели, для каждого из параметров веса, добавляет продукт фактора регуляризации L2 и весов к вычисленным градиентам с помощью dlupdate
функция. Чтобы обновить только параметры веса, извлеките параметры с именем "Weights"
.
idx = dlnet.Learnables.Parameter == "Weights";
gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));
После добавления параметра регуляризации L2 к градиентам обновите сетевые параметры.
Чтобы отсечь градиенты, используйте dlupdate
функция.
Чтобы легко задать опции усечения градиента, создайте переменные gradientThresholdMethod
и gradientThreshold
, где gradientThresholdMethod
содержит "global-l2norm"
, "l2norm"
, или "absolute-value"
, и gradientThreshold
положительная скалярная величина, содержащая порог или inf
.
gradientThresholdMethod = "global-l2norm";
gradientThreshold = 2;
Создайте функции с именем thresholdGlobalL2Norm
, thresholdL2Norm
, и thresholdAbsoluteValue
это применяет "global-l2norm"
, "l2norm"
, и "absolute-value"
пороговые методы, соответственно.
Для "global-l2norm"
опция, функция работает со всеми градиентами модели.
function gradients = thresholdGlobalL2Norm(gradients,gradientThreshold) globalL2Norm = 0; for i = 1:numel(gradients) globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2); end globalL2Norm = sqrt(globalL2Norm); if globalL2Norm > gradientThreshold normScale = gradientThreshold / globalL2Norm; for i = 1:numel(gradients) gradients{i} = gradients{i} * normScale; end end end
Для "l2norm"
и "absolute-value"
опции, функции работают с каждым градиентом независимо.
function gradients = thresholdL2Norm(gradients,gradientThreshold) gradientNorm = sqrt(sum(gradients(:).^2)); if gradientNorm > gradientThreshold gradients = gradients * (gradientThreshold / gradientNorm); end end
function gradients = thresholdAbsoluteValue(gradients,gradientThreshold) gradients(gradients > gradientThreshold) = gradientThreshold; gradients(gradients < -gradientThreshold) = -gradientThreshold; end
Во время обучения, после вычисления градиентов модели, применяют соответствующий метод усечения градиента к градиентам с помощью dlupdate
функция. Поскольку "global-l2norm"
опция требует всех градиентов модели, примените thresholdGlobalL2Norm
функционируйте непосредственно к градиентам. Для "l2norm"
и "absolute-value"
опции, обновите градиенты независимо с помощью dlupdate
функция.
switch gradientThresholdMethod case "global-l2norm" gradients = thresholdGlobalL2Norm(gradients, gradientThreshold); case "l2norm" gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients); case "absolute-value" gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients); end
После применения пороговой операции градиента обновите сетевые параметры.
Программное обеспечение, по умолчанию, выполняет вычисления с помощью одного центральный процессор. Обучайтесь на одном графическом процессоре, преобразуйте в данные к gpuArray
объекты. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.
Чтобы легко задать среду выполнения, создайте переменную executionEnvironment
это содержит любой "cpu"
, "gpu"
, или "auto"
.
executionEnvironment = "auto"
Во время обучения, после чтения мини-пакета, проверяют опцию среды выполнения и преобразуют данные в gpuArray
при необходимости. canUseGPU
функционируйте проверки на применимые графические процессоры.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end
Сохранить сети контрольной точки во время учебного сохранения сеть с помощью save
функция.
Чтобы легко задать, должны ли контрольные точки быть включены, создайте переменную checkpointPath
содержит папку для сетей контрольной точки или пуст.
checkpointPath = fullfile(tempdir,"checkpoints");
Если папка контрольной точки не существует, то перед обучением, создайте папку контрольной точки.
if ~exist(checkpointPath,"dir") mkdir(checkpointPath) end
Во время обучения, в конце эпохи, сохраняют сеть в файле MAT. Задайте имя файла, содержащее текущий номер итерации, дату, и время.
if ~isempty(checkpointPath) D = datestr(now,'yyyy_mm_dd__HH_MM_SS'); filename = "dlnet_checkpoint__" + iteration + "__" + D + ".mat"; save(filename,"dlnet") end
dlnet
dlnetwork
объект быть сохраненным.adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| dlupdate
| rmspropupdate
| sgdmupdate