Обновляйте параметры с помощью пользовательской функции
обновляет настраиваемые параметры dlnet
= dlupdate(fun
,dlnet
)dlnetwork
dlnet объекта
путем оценки функции fun
с каждым настраиваемым параметром в качестве входов. fun
- указатель на функцию функции, который принимает один массив параметров в качестве входного параметра и возвращает обновленный массив параметров.
dlupdate
Выполните L1 регуляризацию структуры градиентов параметров.
Создайте образец входных данных.
dlX = dlarray(rand(100,100,3),'SSC');
Инициализируйте настраиваемые параметры для операции свертки.
params.Weights = dlarray(rand(10,10,3,50)); params.Bias = dlarray(rand(50,1));
Вычислите градиенты для операции свертки с помощью функции helper convGradients
, заданный в конце этого примера.
gradients = dlfeval(@convGradients,dlX,params);
Определите коэффициент регуляризации.
L1Factor = 0.001;
Создайте анонимную функцию, которая регулирует градиенты. При помощи анонимной функции, чтобы передать скаляру константу в функцию, можно избежать необходимости расширения постоянного значения до тех же размера и структуры, что и переменная параметра.
L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);
Использование dlupdate
применить функцию регуляризации к каждому из градиентов.
gradients = dlupdate(L1Regularizer,gradients,params);
Градиенты в grads
теперь регулируются в соответствии с функцией L1Regularizer
.
convGradients
Функция
The convGradients
Функция helper принимает настраиваемые параметры операции свертки и мини-пакет входных данных dlX
, и возвращает градиенты относительно настраиваемых параметров.
function gradients = convGradients(dlX,params) dlY = dlconv(dlX,params.Weights,params.Bias); dlY = sum(dlY,'all'); gradients = dlgradient(dlY,params); end
dlupdate
Обучение сети с помощью пользовательской функции обновленияИспользование dlupdate
обучить сеть с помощью пользовательской функции обновления, которая реализует алгоритм стохастического градиентного спуска (без импульса).
Загрузка обучающих данных
Загрузите обучающие данные цифр.
[XTrain,YTrain] = digitTrain4DArrayData; classes = categories(YTrain); numClasses = numel(classes);
Определите сеть
Определите сетевую архитектуру и задайте среднее значение изображения с помощью 'Mean'
опция на входном слое изображения.
layers = [ imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4)) convolution2dLayer(5,20,'Name','conv1') reluLayer('Name', 'relu1') convolution2dLayer(3,20,'Padding',1,'Name','conv2') reluLayer('Name','relu2') convolution2dLayer(3,20,'Padding',1,'Name','conv3') reluLayer('Name','relu3') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','softmax')]; lgraph = layerGraph(layers);
Создайте dlnetwork
объект из графика слоев.
dlnet = dlnetwork(lgraph);
Задайте функцию градиентов модели
Создайте вспомогательную функцию modelGradients
, перечисленный в конце этого примера. Функция принимает dlnetwork
dlnet объекта
и мини-пакет входных данных dlX
с соответствующими метками Y
, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet
.
Задайте функцию стохастического градиентного спуска
Создайте вспомогательную функцию sgdFunction
, перечисленный в конце этого примера. Функция принимает param
и paramGradient
, настраиваемый параметр и градиент потерь относительно этого параметра, соответственно, и возвращает обновленный параметр, используя алгоритм стохастического градиентного спуска, выраженный как
где - число итерации, является скорость обучения, является вектором параметра, и - функция потерь.
Настройка опций обучения
Задайте опции, которые будут использоваться во время обучения.
miniBatchSize = 128; numEpochs = 30; numObservations = numel(YTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Задайте скорость обучения.
learnRate = 0.01;
Обучите на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
executionEnvironment = "auto";
Визуализируйте процесс обучения на графике.
plots = "training-progress";
Обучите сеть
Обучите модель с помощью пользовательского цикла обучения. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. Обновляйте параметры сети путем вызова dlupdate
с функцией sgdFunction
заданное в конце этого примера. В конце каждой эпохи отобразите процесс обучения.
Инициализируйте график процесса обучения.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Обучите сеть.
iteration = 0; start = tic; for epoch = 1:numEpochs % Shuffle data. idx = randperm(numel(YTrain)); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(idx); for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = zeros(numClasses, miniBatchSize, 'single'); for c = 1:numClasses Y(c,YTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to dlarray. dlX = dlarray(single(X),'SSCB'); % If training on a GPU, then convert data to a gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients helper function. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y); % Update the network parameters using the SGD algorithm defined in % the sgdFunction helper function. updateFcn = @(dlnet,gradients) sgdFunction(dlnet,gradients,learnRate); dlnet = dlupdate(updateFcn,dlnet,gradients); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Тестирование сети
Протестируйте классификационную точность модели путем сравнения предсказаний на тестовом наборе с истинными метками.
[XTest, YTest] = digitTest4DArrayData;
Преобразуйте данные в dlarray
с форматом размерности 'SSCB'
. Для предсказания графический процессор также преобразуйте данные в gpuArray
.
dlXTest = dlarray(XTest,'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlXTest = gpuArray(dlXTest); end
Чтобы классифицировать изображения с помощью dlnetwork
объект, используйте predict
функция и найти классы с самыми высокими счетами.
dlYPred = predict(dlnet,dlXTest); [~,idx] = max(extractdata(dlYPred),[],1); YPred = classes(idx);
Оцените точность классификации.
accuracy = mean(YPred==YTest)
accuracy = 0.9386
Функция градиентов модели
Функция помощника modelGradients
принимает dlnetwork
dlnet объекта
и мини-пакет входных данных dlX
с соответствующими метками Y
, и возвращает потерю и градиенты потери относительно настраиваемых параметров в dlnet
. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
function [gradients,loss] = modelGradients(dlnet,dlX,Y) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
Функция стохастического градиентного спуска
Функция помощника sgdFunction
принимает настраиваемый параметр parameter
, градиенты этого параметра относительно потерь gradient
, и скорость обучения learnRate
, и возвращает обновленный параметр, используя алгоритм стохастического градиентного спуска, выраженный как
где - число итерации, является скорость обучения, является вектором параметра, и - функция потерь.
function parameter = sgdFunction(parameter,gradient,learnRate) parameter = parameter - learnRate .* gradient; end
fun
- Функция для примененияФункция для применения к настраиваемым параметрам, заданная как указатель на функцию.
dlupate
оценивает fun
с каждым сетевым настраиваемым параметром в качестве входа. fun
оценивается столько раз, сколько существует массивов настраиваемых параметров в dlnet
или params
.
dlnet
- Сетьdlnetwork
объектСеть, заданная как dlnetwork
объект.
Функция обновляет dlnet.Learnables
свойство dlnetwork
объект. dlnet.Learnables
- таблица с тремя переменными:
Layer
- Имя слоя, заданное как строковый скаляр.
Parameter
- Имя параметра, заданное как строковый скаляр.
Value
- Значение параметра, заданное как массив ячеек, содержащий dlarray
.
params
- Параметры, учитываемые в сетиdlarray
| числовой массив | массив ячеек | структуру | таблицуСетевые настраиваемые параметры, заданная как dlarray
, числовой массив, массив ячеек, структура или таблица.
Если вы задаете params
как таблица, она должна содержать следующие три переменные.
Layer
- Имя слоя, заданное как строковый скаляр.
Parameter
- Имя параметра, заданное как строковый скаляр.
Value
- Значение параметра, заданное как массив ячеек, содержащий dlarray
.
Можно задать params
как контейнер настраиваемых параметров для сети с помощью массива ячеек, структуры или таблицы или вложенных массивов ячеек или структур. Настраиваемые параметры в массиве ячеек, структуре или таблице должны быть dlarray
или числовые значения типа данных double
или single
.
Входной параметр grad
должны быть снабжены точно совпадающим типом данных, упорядоченным расположением и полями (для структур) или переменными (для таблиц), как params
.
Типы данных: single
| double
| struct
| table
| cell
A1,...,An
- Дополнительные входные параметрыdlarray
| числовой массив | массив ячеек | структуру | таблицуДополнительные входные параметры для fun
, заданный как dlarray
объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value
переменная.
Точная форма A1,...,An
зависит от входа сети или настраиваемых параметров. В следующей таблице показан необходимый формат для A1,...,An
для возможных входов в dlupdate
.
Вход | Настраиваемые параметры | A1,...,An |
---|---|---|
dlnet | Табличные dlnet.Learnables содержащие Layer , Parameter , и Value переменные. The Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray . | Таблица с совпадающим типом данных, переменными и порядком, что и dlnet.Learnables . A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительные входные параметры для функции fun применить к каждому настраиваемому параметру. |
params | dlarray | dlarray с совпадающим типом данных и порядком, что и params
|
Числовой массив | Числовой массив с совпадающим типом данных и порядком, что и params
| |
Массив ячеек | Массив ячеек с совпадающими типами данных, структурой и порядком, как params | |
Структура | Структура с совпадающими типами данных, полями и порядками, что и params | |
Таблица с Layer , Parameter , и Value переменные. The Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray . | Таблица с совпадающими типами данных, переменными и порядком, что и params . A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительный входной параметр для функции fun применить к каждому настраиваемому параметру. |
dlnet
- Обновленная сетьdlnetwork
объектСеть, возвращается как dlnetwork
объект.
Функция обновляет dlnet.Learnables
свойство dlnetwork
объект.
params
- Обновленные настраиваемые параметры сетиdlarray
| числовой массив | массив ячеек | структуру | таблицуОбновлённые сетевые настраиваемые параметры, возвращенный как dlarray
, числовой массив, массив ячеек, структура или таблица с Value
переменная, содержащая обновленные настраиваемые параметры сети.
X1,...,Xm
- Дополнительные выходные аргументыdlarray
| числовой массив | массив ячеек | структуру | таблицуДополнительные выходные аргументы от функции fun
, где fun
- указатель на функцию, который возвращает несколько выходы, возвращаемых как dlarray
объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value
переменная.
Точная форма X1,...,Xm
зависит от входа сети или настраиваемых параметров. В следующей таблице показан возвращенный формат X1,...,Xm
для возможных входов в dlupdate
.
Вход | Настраиваемые параметры | X1,...,Xm |
---|---|---|
dlnet | Табличные dlnet.Learnables содержащие Layer , Parameter , и Value переменные. The Value переменная состоит из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray . | Таблица с совпадающим типом данных, переменными и порядком, что и dlnet.Learnables . X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительные выходные аргументы функции fun применяется к каждому настраиваемому параметру. |
params | dlarray | dlarray с совпадающим типом данных и порядком, что и params
|
Числовой массив | Числовой массив с совпадающим типом данных и порядком, что и params
| |
Массив ячеек | Массив ячеек с совпадающими типами данных, структурой и порядком, как params | |
Структура | Структура с совпадающими типами данных, полями и порядками, что и params
| |
Таблица с Layer , Parameter , и Value переменные. The Value переменная должна состоять из массивов ячеек, которые содержат каждый настраиваемый параметр в виде dlarray . | Таблица с совпадающими типами данных, переменными. и упорядоченное расположение как params . X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительный выходной аргумент функции fun применяется к каждому настраиваемому параметру. |
Указания и ограничения по применению:
Когда по крайней мере один из следующих входных параметров является gpuArray
или dlarray
с базовыми данными типа gpuArray
эта функция выполняется на графическом процессоре.
params
A1,...,An
Для получения дополнительной информации смотрите Запуск функций MATLAB на графическом процессоре (Parallel Computing Toolbox).
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| rmspropupdate
| sgdmupdate
У вас есть измененная версия этого примера. Вы хотите открыть этот пример с вашими правками?
1. Если смысл перевода понятен, то лучше оставьте как есть и не придирайтесь к словам, синонимам и тому подобному. О вкусах не спорим.
2. Не дополняйте перевод комментариями “от себя”. В исправлении не должно появляться дополнительных смыслов и комментариев, отсутствующих в оригинале. Такие правки не получится интегрировать в алгоритме автоматического перевода.
3. Сохраняйте структуру оригинального текста - например, не разбивайте одно предложение на два.
4. Не имеет смысла однотипное исправление перевода какого-то термина во всех предложениях. Исправляйте только в одном месте. Когда Вашу правку одобрят, это исправление будет алгоритмически распространено и на другие части документации.
5. По иным вопросам, например если надо исправить заблокированное для перевода слово, обратитесь к редакторам через форму технической поддержки.