Обновление параметров с помощью пользовательской функции
обновляет обучаемые параметры dlnet = dlupdate(fun,dlnet)dlnetwork объект dlnet путем оценки функции fun с каждым обучаемым параметром в качестве входных данных. fun является дескриптором функции, который принимает один массив параметров в качестве входного аргумента и возвращает обновленный массив параметров.
dlupdateВыполните регуляризацию L1 для структуры градиентов параметров.
Создайте образец входных данных.
dlX = dlarray(rand(100,100,3),'SSC');Инициализируйте обучаемые параметры для операции свертки.
params.Weights = dlarray(rand(10,10,3,50)); params.Bias = dlarray(rand(50,1));
Расчет градиентов для операции свертки с помощью вспомогательной функции convGradients, определяется в конце этого примера.
gradients = dlfeval(@convGradients,dlX,params);
Определите коэффициент регуляризации.
L1Factor = 0.001;
Создайте анонимную функцию, регулирующую градиенты. Используя анонимную функцию для передачи скалярной константы функции, можно избежать необходимости расширения значения константы до того же размера и структуры, что и переменная параметра.
L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);
Использовать dlupdate для применения функции регуляризации к каждому из градиентов.
gradients = dlupdate(L1Regularizer,gradients,params);
Градиенты в grads теперь регуляризованы в соответствии с функцией L1Regularizer.
convGradients Функция
convGradients вспомогательная функция принимает обучаемые параметры операции свертки и мини-пакет входных данных dlXи возвращает градиенты относительно обучаемых параметров.
function gradients = convGradients(dlX,params) dlY = dlconv(dlX,params.Weights,params.Bias); dlY = sum(dlY,'all'); gradients = dlgradient(dlY,params); end
dlupdate Обучение сети с помощью пользовательской функции обновленияИспользовать dlupdate для обучения сети с помощью пользовательской функции обновления, реализующей алгоритм стохастического градиентного спуска (без импульса).
Загрузка данных обучения
Загрузите данные обучения по цифрам.
[XTrain,YTrain] = digitTrain4DArrayData; classes = categories(YTrain); numClasses = numel(classes);
Определение сети
Определите сетевую архитектуру и укажите среднее значение изображения с помощью 'Mean' в слое ввода изображения.
layers = [
imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
convolution2dLayer(5,20,'Name','conv1')
reluLayer('Name', 'relu1')
convolution2dLayer(3,20,'Padding',1,'Name','conv2')
reluLayer('Name','relu2')
convolution2dLayer(3,20,'Padding',1,'Name','conv3')
reluLayer('Name','relu3')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);Создать dlnetwork объект из графа слоев.
dlnet = dlnetwork(lgraph);
Определение функции градиентов модели
Создание вспомогательной функции modelGradients, перечисленных в конце этого примера. Функция принимает dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Yи возвращает потери и градиенты потерь относительно обучаемых параметров в dlnet.
Определение функции стохастического градиентного спуска
Создание вспомогательной функции sgdFunction, перечисленных в конце этого примера. Функция принимает param и paramGradient, обучаемый параметр и градиент потерь по отношению к этому параметру, соответственно, и возвращает обновленный параметр с использованием алгоритма стохастического градиентного спуска, выраженного как
)
где - число итераций, 0 - скорость обучения ,
Укажите параметры обучения
Укажите параметры для использования во время обучения.
miniBatchSize = 128; numEpochs = 30; numObservations = numel(YTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Укажите скорость обучения.
learnRate = 0.01;
Обучение на GPU, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).
executionEnvironment = "auto";Визуализация хода обучения на графике.
plots = "training-progress";Железнодорожная сеть
Обучение модели с помощью пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Обновление параметров сети путем вызова dlupdate с функцией sgdFunction определено в конце этого примера. В конце каждой эпохи отобразите ход обучения.
Инициализируйте график хода обучения.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Обучение сети.
iteration = 0; start = tic; for epoch = 1:numEpochs % Shuffle data. idx = randperm(numel(YTrain)); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(idx); for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = zeros(numClasses, miniBatchSize, 'single'); for c = 1:numClasses Y(c,YTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to dlarray. dlX = dlarray(single(X),'SSCB'); % If training on a GPU, then convert data to a gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients helper function. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y); % Update the network parameters using the SGD algorithm defined in % the sgdFunction helper function. updateFcn = @(dlnet,gradients) sgdFunction(dlnet,gradients,learnRate); dlnet = dlupdate(updateFcn,dlnet,gradients); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end

Тестовая сеть
Проверьте точность классификации модели путем сравнения прогнозов на тестовом наборе с истинными метками.
[XTest, YTest] = digitTest4DArrayData;
Преобразование данных в dlarray с форматом размера 'SSCB'. Для прогнозирования GPU также преобразуйте данные в gpuArray.
dlXTest = dlarray(XTest,'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlXTest = gpuArray(dlXTest); end
Классификация изображений с помощью dlnetwork объект, используйте predict и найти классы с самыми высокими баллами.
dlYPred = predict(dlnet,dlXTest); [~,idx] = max(extractdata(dlYPred),[],1); YPred = classes(idx);
Оцените точность классификации.
accuracy = mean(YPred==YTest)
accuracy = 0.9386
Функция градиентов модели
Вспомогательная функция modelGradients принимает dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Yи возвращает потери и градиенты потерь относительно обучаемых параметров в dlnet. Для автоматического вычисления градиентов используйте dlgradient функция.
function [gradients,loss] = modelGradients(dlnet,dlX,Y) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
Функция стохастического градиентного спуска
Вспомогательная функция sgdFunction принимает обучаемый параметр parameter, градиенты этого параметра относительно потерь gradientи скорость обучения learnRateи возвращает обновленный параметр с использованием алгоритма стохастического градиентного спуска, выраженного как
)
где - число итераций, 0 - скорость обучения ,
function parameter = sgdFunction(parameter,gradient,learnRate) parameter = parameter - learnRate .* gradient; end
fun - Функция для примененияФункция для применения к обучаемым параметрам, заданным как дескриптор функции.
dlupate оценивает fun с каждым сетевым обучаемым параметром в качестве входных данных. fun оценивается столько раз, сколько существует массивов обучаемых параметров в dlnet или params.
dlnet - Сетьdlnetwork объектСеть, указанная как dlnetwork объект.
Функция обновляет dlnet.Learnables имущества dlnetwork объект. dlnet.Learnables - таблица с тремя переменными:
Layer - имя слоя, указанное как строковый скаляр.
Parameter - имя параметра, указанное как строковый скаляр.
Value - значение параметра, указанное как массив ячеек, содержащий dlarray.
params - Сетевые обучаемые параметрыdlarray | числовой массив | массив ячеек | структура | таблицаСетевые обучаемые параметры, указанные как dlarray, числовой массив, массив ячеек, структуру или таблицу.
При указании params в качестве таблицы она должна содержать следующие три переменные.
Layer - имя слоя, указанное как строковый скаляр.
Parameter - имя параметра, указанное как строковый скаляр.
Value - значение параметра, указанное как массив ячеек, содержащий dlarray.
Можно указать params в качестве контейнера обучаемых параметров для сети с использованием массива ячеек, структуры или таблицы или вложенных массивов или структур ячеек. Доступные для изучения параметры в массиве, структуре или таблице ячейки должны быть dlarray или числовые значения типа данных double или single.
Входной аргумент grad должны иметь точно такой же тип данных, порядок и поля (для структур) или переменные (для таблиц), как params.
Типы данных: single | double | struct | table | cell
A1,...,An - Дополнительные входные аргументыdlarray | числовой массив | массив ячеек | структура | таблицаДополнительные входные аргументы для fun, указано как dlarray объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value переменная.
Точная форма A1,...,An зависит от входной сети или обучаемых параметров. В следующей таблице показан требуемый формат для A1,...,An для возможных входных данных dlupdate.
| Вход | Обучаемые параметры | A1,...,An |
|---|---|---|
dlnet | Стол dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. | Таблица с тем же типом данных, переменными и порядком, что и dlnet.Learnables. A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительные входные аргументы для функции fun применяется к каждому обучаемому параметру. |
params | dlarray | dlarray с тем же типом данных и порядком, что и params
|
| Числовой массив | Числовой массив с тем же типом данных и порядком, что и params
| |
| Массив ячеек | Массив ячеек с теми же типами данных, структурой и порядком, что и params | |
| Структура | Структура с теми же типами данных, полями и порядком, что и params | |
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. | Таблица с теми же типами данных, переменными и порядком, что и params. A1,...,An должен иметь Value переменная, состоящая из массивов ячеек, которые содержат дополнительный входной аргумент для функции fun применяется к каждому обучаемому параметру. |
dlnet - Обновленная сетьdlnetwork объектСеть, возвращенная как dlnetwork объект.
Функция обновляет dlnet.Learnables имущества dlnetwork объект.
params - Обновленные сетевые обучаемые параметрыdlarray | числовой массив | массив ячеек | структура | таблицаОбновленные сетевые обучаемые параметры, возвращенные в виде dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные обучаемые параметры сети.
X1,...,Xm - Дополнительные выходные аргументыdlarray | числовой массив | массив ячеек | структура | таблицаДополнительные выходные аргументы функции fun, где fun - дескриптор функции для функции, которая возвращает несколько выходов, возвращаемых как dlarray объекты, числовые массивы, массивы ячеек, структуры или таблицы с Value переменная.
Точная форма X1,...,Xm зависит от входной сети или обучаемых параметров. В следующей таблице показан возвращенный формат X1,...,Xm для возможных входных данных dlupdate.
| Вход | Обучаемые параметры | X1,...,Xm |
|---|---|---|
dlnet | Стол dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. | Таблица с тем же типом данных, переменными и порядком, что и dlnet.Learnables. X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительные выходные аргументы функции fun применяется к каждому обучаемому параметру. |
params | dlarray | dlarray с тем же типом данных и порядком, что и params
|
| Числовой массив | Числовой массив с тем же типом данных и порядком, что и params
| |
| Массив ячеек | Массив ячеек с теми же типами данных, структурой и порядком, что и params | |
| Структура | Структура с теми же типами данных, полями и порядком, что и params
| |
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. | Таблица с одинаковыми типами данных, переменные. и заказ как params. X1,...,Xm имеет Value переменная, состоящая из массивов ячеек, которые содержат дополнительный выходной аргумент функции fun применяется к каждому обучаемому параметру. |
Примечания и ограничения по использованию:
Если хотя бы один из следующих входных аргументов является gpuArray или dlarray с базовыми данными типа gpuArray, эта функция выполняется на GPU.
params
A1,...,An
Дополнительные сведения см. в разделе Запуск функций MATLAB на графическом процессоре (панель инструментов параллельных вычислений).
adamupdate | dlarray | dlfeval | dlgradient | dlnetwork | rmspropupdate | sgdmupdate
Имеется измененная версия этого примера. Открыть этот пример с помощью изменений?
1. Если смысл перевода понятен, то лучше оставьте как есть и не придирайтесь к словам, синонимам и тому подобному. О вкусах не спорим.
2. Не дополняйте перевод комментариями “от себя”. В исправлении не должно появляться дополнительных смыслов и комментариев, отсутствующих в оригинале. Такие правки не получится интегрировать в алгоритме автоматического перевода.
3. Сохраняйте структуру оригинального текста - например, не разбивайте одно предложение на два.
4. Не имеет смысла однотипное исправление перевода какого-то термина во всех предложениях. Исправляйте только в одном месте. Когда Вашу правку одобрят, это исправление будет алгоритмически распространено и на другие части документации.
5. По иным вопросам, например если надо исправить заблокированное для перевода слово, обратитесь к редакторам через форму технической поддержки.