Обновление параметров с использованием среднеквадратичного распространения корня (RMSProp)
Обновление сетевых обучаемых параметров в пользовательском учебном цикле с использованием алгоритма среднеквадратичного распространения (RMSProp).
Примечание
Эта функция применяет алгоритм оптимизации RMSProp для обновления параметров сети в пользовательских контурах обучения, в которых используются сети, определенные как dlnetwork объекты или функции модели. При необходимости обучения сети, определенной как Layer массив или как LayerGraph, используйте следующие функции:
Создать TrainingOptionsRMSProp с использованием trainingOptions функция.
Используйте TrainingOptionsRMSProp объект с trainNetwork функция.
[ обновляет обучаемые параметры сети dlnet,averageSqGrad] = rmspropupdate(dlnet,grad,averageSqGrad)dlnet с использованием алгоритма RMSProp. Используйте этот синтаксис в учебном цикле для итеративного обновления сети, определенной как dlnetwork объект.
[ обновляет обучаемые параметры в params,averageSqGrad] = rmspropupdate(params,grad,averageSqGrad)params с использованием алгоритма RMSProp. Используйте этот синтаксис в учебном цикле для итеративного обновления обучаемых параметров сети, определенной с помощью функций.
[___] = rmspropupdate(___ также указывает значения для глобальной скорости обучения, спада квадратного градиента и малого постоянного эпсилона в дополнение к входным аргументам в предыдущих синтаксисах. learnRate,sqGradDecay,epsilon)
rmspropupdateВыполнение одного шага обновления среднеквадратичного распространения с глобальной скоростью обучения 0.05 и квадрат градиентного коэффициента распада 0.95.
Создайте параметры и градиенты параметров в виде числовых массивов.
params = rand(3,3,4); grad = ones(3,3,4);
Инициализируйте среднеквадратичный градиент для первой итерации.
averageSqGrad = [];
Задайте пользовательские значения для глобальной скорости обучения и коэффициента градиентного спада в квадрате.
learnRate = 0.05; sqGradDecay = 0.95;
Обновление обучаемых параметров с помощью rmspropupdate.
[params,averageSqGrad] = rmspropupdate(params,grad,averageSqGrad,learnRate,sqGradDecay);
rmspropupdateИспользовать rmspropupdate для обучения сети с использованием алгоритма среднеквадратичного распространения (RMSProp).
Загрузка данных обучения
Загрузите данные обучения по цифрам.
[XTrain,YTrain] = digitTrain4DArrayData; classes = categories(YTrain); numClasses = numel(classes);
Определение сети
Определите сетевую архитектуру и укажите среднее значение изображения с помощью 'Mean' в слое ввода изображения.
layers = [
imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
convolution2dLayer(5,20,'Name','conv1')
reluLayer('Name', 'relu1')
convolution2dLayer(3,20,'Padding',1,'Name','conv2')
reluLayer('Name','relu2')
convolution2dLayer(3,20,'Padding',1,'Name','conv3')
reluLayer('Name','relu3')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);Создать dlnetwork объект из графа слоев.
dlnet = dlnetwork(lgraph);
Определение функции градиентов модели
Создание вспомогательной функции modelGradients, перечисленных в конце примера. Функция принимает dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Yи возвращает потери и градиенты потерь относительно обучаемых параметров в dlnet.
Укажите параметры обучения
Укажите параметры для использования во время обучения.
miniBatchSize = 128; numEpochs = 20; numObservations = numel(YTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Обучение на GPU, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).
executionEnvironment = "auto";Визуализация хода обучения на графике.
plots = "training-progress";Железнодорожная сеть
Обучение модели с помощью пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Обновление параметров сети с помощью rmspropupdate функция. В конце каждой эпохи отобразите ход обучения.
Инициализируйте график хода обучения.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Инициализируйте средние градиенты в квадрате.
averageSqGrad = [];
Обучение сети.
iteration = 0; start = tic; for epoch = 1:numEpochs % Shuffle data. idx = randperm(numel(YTrain)); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(idx); for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = zeros(numClasses, miniBatchSize, 'single'); for c = 1:numClasses Y(c,YTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to a dlarray. dlX = dlarray(single(X),'SSCB'); % If training on a GPU, then convert data to a gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients helper function. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,Y); % Update the network parameters using the RMSProp optimizer. [dlnet,averageSqGrad] = rmspropupdate(dlnet,gradients,averageSqGrad); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end

Тестирование сети
Проверьте точность классификации модели путем сравнения прогнозов на тестовом наборе с истинными метками.
[XTest, YTest] = digitTest4DArrayData;
Преобразование данных в dlarray с форматом измерения 'SSCB'. Для прогнозирования GPU также преобразуйте данные в gpuArray.
dlXTest = dlarray(XTest,'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlXTest = gpuArray(dlXTest); end
Классификация изображений с помощью dlnetwork объект, используйте predict и найти классы с самыми высокими баллами.
dlYPred = predict(dlnet,dlXTest); [~,idx] = max(extractdata(dlYPred),[],1); YPred = classes(idx);
Оцените точность классификации.
accuracy = mean(YPred==YTest)
accuracy = 0.9860
Функция градиентов модели
Вспомогательная функция modelGradients принимает dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Y, и возвращает потери и градиенты потерь относительно обучаемых параметров в dlnet. Для автоматического вычисления градиентов используйте dlgradient функция.
function [gradients,loss] = modelGradients(dlnet,dlX,Y) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
dlnet - Сетьdlnetwork объектСеть, указанная как dlnetwork объект.
Функция обновляет dlnet.Learnables имущества dlnetwork объект. dlnet.Learnables - таблица с тремя переменными:
Layer - имя слоя, указанное как строковый скаляр.
Parameter - имя параметра, указанное как строковый скаляр.
Value - значение параметра, указанное как массив ячеек, содержащий dlarray.
Входной аргумент grad должна быть таблица той же формы, что и dlnet.Learnables.
params - Сетевые обучаемые параметрыdlarray | числовой массив | массив ячеек | структура | таблицаСетевые обучаемые параметры, указанные как dlarray, числовой массив, массив ячеек, структуру или таблицу.
При указании params в качестве таблицы она должна содержать следующие три переменные.
Layer - имя слоя, указанное как строковый скаляр.
Parameter - имя параметра, указанное как строковый скаляр.
Value - значение параметра, указанное как массив ячеек, содержащий dlarray.
Можно указать params в качестве контейнера обучаемых параметров для сети с использованием массива ячеек, структуры или таблицы или вложенных массивов или структур ячеек. Доступные для изучения параметры в массиве, структуре или таблице ячейки должны быть dlarray или числовые значения типа данных double или single.
Входной аргумент grad должны иметь точно такой же тип данных, порядок и поля (для структур) или переменные (для таблиц), как params.
Типы данных: single | double | struct | table | cell
grad - Градиенты потерьdlarray | числовой массив | массив ячеек | структура | таблицаГрадиенты потерь, указанные как dlarray, числовой массив, массив ячеек, структуру или таблицу.
Точная форма grad зависит от входной сети или обучаемых параметров. В следующей таблице показан требуемый формат для grad для возможных входных данных rmspropupdate.
| Вход | Обучаемые параметры | Градиенты |
|---|---|---|
dlnet | Стол dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. | Таблица с тем же типом данных, переменными и порядком, что и dlnet.Learnables. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого обучаемого параметра. |
params | dlarray | dlarray с тем же типом данных и порядком, что и params
|
| Числовой массив | Числовой массив с тем же типом данных и порядком, что и params
| |
| Массив ячеек | Массив ячеек с теми же типами данных, структурой и порядком, что и params | |
| Структура | Структура с теми же типами данных, полями и порядком, что и params | |
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. | Таблица с теми же типами данных, переменными и порядком, что и params. grad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат градиент каждого обучаемого параметра. |
Вы можете получить grad от вызова до dlfeval вычисляет функцию, содержащую вызов dlgradient. Дополнительные сведения см. в разделе Использование автоматической дифференциации в инструменте глубокого обучения.
averageSqGrad - Скользящее среднее квадратичных градиентов параметров[] | dlarray | числовой массив | массив ячеек | структура | таблицаСкользящее среднее квадратичных градиентов параметров, указанных как пустой массив, a dlarray, числовой массив, массив ячеек, структуру или таблицу.
Точная форма averageSqGrad зависит от входной сети или обучаемых параметров. В следующей таблице показан требуемый формат для averageSqGrad для возможных входных данных rmspropupdate.
| Вход | Обучаемые параметры | Средние квадратичные градиенты |
|---|---|---|
dlnet | Стол dlnet.Learnables содержа Layer, Parameter, и Value переменные. Value переменная состоит из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. | Таблица с тем же типом данных, переменными и порядком, что и dlnet.Learnables. averageSqGrad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат среднеквадратичный градиент каждого обучаемого параметра. |
params | dlarray | dlarray с тем же типом данных и порядком, что и params
|
| Числовой массив | Числовой массив с тем же типом данных и порядком, что и params
| |
| Массив ячеек | Массив ячеек с теми же типами данных, структурой и порядком, что и params | |
| Структура | Структура с теми же типами данных, полями и порядком, что и params | |
Таблица с Layer, Parameter, и Value переменные. Value переменная должна состоять из массивов ячеек, которые содержат каждый обучаемый параметр в качестве dlarray. | Таблица с теми же типами данных, переменными и порядком, что и params. averageSqGrad должен иметь Value переменная, состоящая из массивов ячеек, которые содержат среднеквадратичный градиент каждого обучаемого параметра. |
При указании averageSqGrad как пустой массив, функция не принимает предыдущих градиентов и выполняется так же, как для первого обновления в серии итераций. Для итеративного обновления обучаемых параметров используйте averageSqGrad вывод предыдущего вызова rmspropupdate в качестве averageSqGrad вход.
learnRate - Общемировой уровень обучения0.001 (по умолчанию) | положительный скалярГлобальная скорость обучения, заданная как положительный скаляр. Значение по умолчанию learnRate является 0.001.
Если параметры сети указаны как dlnetworkскорость обучения для каждого параметра является глобальной скоростью обучения, умноженной на соответствующее свойство коэффициента скорости обучения, определенное в сетевых уровнях.
sqGradDecay - Коэффициент спада градиента в квадрате0.9 (по умолчанию) | положительный скаляр между 0 и 1.Коэффициент распада градиента в квадрате, заданный как положительный скаляр между 0 и 1. Значение по умолчанию sqGradDecay является 0.9.
epsilon - Малая константа1e-8 (по умолчанию) | положительный скалярМалая константа для предотвращения ошибок деления на ноль, заданная как положительный скаляр. Значение по умолчанию epsilon является 1e-8.
dlnet - Обновленная сетьdlnetwork объектСеть, возвращенная как dlnetwork объект.
Функция обновляет dlnet.Learnables имущества dlnetwork объект.
params - Обновленные сетевые обучаемые параметрыdlarray | числовой массив | массив ячеек | структура | таблицаОбновленные сетевые обучаемые параметры, возвращенные в виде dlarray, числовой массив, массив ячеек, структура или таблица с Value переменная, содержащая обновленные обучаемые параметры сети.
averageSqGrad - Обновлено скользящее среднее квадратичных градиентов параметровdlarray | числовой массив | массив ячеек | структура | таблицаОбновлено скользящее среднее квадратичных градиентов параметров, возвращаемых в виде dlarray, числовой массив, массив ячеек, структуру или таблицу.
Функция использует алгоритм среднеквадратичного распространения корня для обновления обучаемых параметров. Дополнительные сведения см. в определении алгоритма RMSProp в разделе Стохастический градиентный спуск на trainingOptions справочная страница.
rmspropupdate квадрат градиентного коэффициента затухания по умолчанию 0.9В R2020a изменилось поведение
Начиная с R2020a, значение по умолчанию для квадрата градиентного коэффициента затухания в rmspropupdate является 0.9. В предыдущих версиях значением по умолчанию было 0.999. Чтобы воспроизвести предыдущее поведение по умолчанию, используйте один из следующих синтаксисов:
[dlnet,averageSqGrad] = rmspropupdate(dlnet,grad,averageSqGrad,0.001,0.999) [params,averageSqGrad] = rmspropupdate(params,grad,averageSqGrad,0.001,0.999)
Примечания и ограничения по использованию:
Если хотя бы один из следующих входных аргументов является gpuArray или dlarray с базовыми данными типа gpuArray, эта функция выполняется на GPU.
grad
averageSqGrad
params
Дополнительные сведения см. в разделе Запуск функций MATLAB на графическом процессоре (панель инструментов параллельных вычислений).
adamupdate | dlarray | dlfeval | dlgradient | dlnetwork | dlupdate | forward | sgdmupdate
Имеется измененная версия этого примера. Открыть этот пример с помощью изменений?
1. Если смысл перевода понятен, то лучше оставьте как есть и не придирайтесь к словам, синонимам и тому подобному. О вкусах не спорим.
2. Не дополняйте перевод комментариями “от себя”. В исправлении не должно появляться дополнительных смыслов и комментариев, отсутствующих в оригинале. Такие правки не получится интегрировать в алгоритме автоматического перевода.
3. Сохраняйте структуру оригинального текста - например, не разбивайте одно предложение на два.
4. Не имеет смысла однотипное исправление перевода какого-то термина во всех предложениях. Исправляйте только в одном месте. Когда Вашу правку одобрят, это исправление будет алгоритмически распространено и на другие части документации.
5. По иным вопросам, например если надо исправить заблокированное для перевода слово, обратитесь к редакторам через форму технической поддержки.