В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с пользовательским расписанием скорости обучения.
Если trainingOptions
не предоставляет необходимые опции (для примера, пользовательское расписание скорости обучения), тогда можно задать свой собственный пользовательский цикл обучения с помощью автоматической дифференциации.
Этот пример обучает сеть классифицировать рукописные цифры с основанным на времени расписанием скорости обучения с распадом: для каждой итерации решатель использует скорость обучения, заданную как , где t - число итерации, является начальной скоростью обучения, и k является распадом.
Загрузите данные цифр в виде datastore изображений с помощью imageDatastore
и укажите папку, содержащую данные изображения.
dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset'); imds = imageDatastore(dataFolder, ... 'IncludeSubfolders',true, .... 'LabelSource','foldernames');
Разделите данные на наборы для обучения и валидации. Отложите 10% данных для валидации с помощью splitEachLabel
функция.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,'randomize');
Сеть, используемая в этом примере, требует изображений входа размера 28 на 28 на 1. Чтобы автоматически изменить размер обучающих изображений, используйте дополненный image datastore. Задайте дополнительные операции увеличения для выполнения на обучающих изображениях: случайным образом переведите изображения до 5 пикселей в горизонтальной и вертикальной осях. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.
inputSize = [28 28 1]; pixelRange = [-5 5]; imageAugmenter = imageDataAugmenter( ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшего увеличения данных, используйте хранилище datastore с дополненными изображениями, не задавая никаких дополнительных операций предварительной обработки.
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Определите количество классов в обучающих данных.
classes = categories(imdsTrain.Labels); numClasses = numel(classes);
Определите сеть для классификации изображений.
layers = [ imageInputLayer(inputSize,'Normalization','none','Name','input') convolution2dLayer(5,20,'Name','conv1') batchNormalizationLayer('Name','bn1') reluLayer('Name','relu1') convolution2dLayer(3,20,'Padding','same','Name','conv2') batchNormalizationLayer('Name','bn2') reluLayer('Name','relu2') convolution2dLayer(3,20,'Padding','same','Name','conv3') batchNormalizationLayer('Name','bn3') reluLayer('Name','relu3') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','softmax')]; lgraph = layerGraph(layers);
Создайте dlnetwork
объект из графика слоев.
dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'input'} OutputNames: {'softmax'}
Создайте функцию modelGradients
, перечисленный в конце примера, который принимает dlnetwork
объект, мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно настраиваемых параметров в сети и соответствующих потерь.
Обучайте на десять эпох с мини-партией размером 128.
numEpochs = 10; miniBatchSize = 128;
Задайте опции для оптимизации SGDM. Задайте начальную скорость обучения 0,01 с распадом 0,01 и импульсом 0,9.
initialLearnRate = 0.01; decay = 0.01; momentum = 0.9;
Создайте minibatchqueue
объект, который обрабатывает и управляет мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch
(определено в конце этого примера), чтобы преобразовать метки в переменные с кодировкой с одним горячим контактом.
Форматируйте данные изображения с помощью меток размерностей 'SSCB'
(пространственный, пространственный, канальный, пакетный). По умолчанию в minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
. Не добавляйте формат к меткам классов.
Обучите на графическом процессоре, если он доступен. По умолчанию в minibatchqueue
объект преобразует каждый выход в gpuArray
при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
mbq = minibatchqueue(augimdsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',@preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB',''});
Инициализируйте график процесса обучения.
figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Инициализируйте параметр скорости для решателя SGDM.
velocity = [];
Обучите сеть с помощью пользовательского цикла обучения. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. Для каждого мини-пакета:
Оцените градиенты модели, состояние и потери с помощью dlfeval
и modelGradients
функционирует и обновляет состояние сети.
Определите скорость обучения для основанного на времени расписания скорости обучения с распадом.
Обновляйте параметры сети с помощью sgdmupdate
функция.
Отображение процесса обучения.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX, dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function and update the network state. [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY); dlnet.State = state; % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,loss) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end
Протестируйте классификационную точность модели путем сравнения предсказаний на наборе валидации с истинными метками.
После обучения создание предсказаний по новым данным не требует меток. Создание minibatchqueue
объект, содержащий только предикторы тестовых данных:
Чтобы игнорировать метки для проверки, установите количество выходов мини-очереди пакетов равным 1.
Укажите тот же размер мини-пакета, что и для обучения.
Предварительно обработайте предикторы, используя preprocessMiniBatchPredictors
функции, перечисленной в конце примера.
Для одинарного выхода datastore задайте формат пакета 'SSCB'
(пространственный, пространственный, канальный, пакетный).
numOutputs = 1; mbqTest = minibatchqueue(augimdsValidation,numOutputs, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn',@preprocessMiniBatchPredictors, ... 'MiniBatchFormat','SSCB');
Закольцовывайте мини-пакеты и классифицируйте изображения с помощью modelPredictions
функции, перечисленной в конце примера.
predictions = modelPredictions(dlnet,mbqTest,classes);
Оцените точность классификации.
YTest = imdsValidation.Labels; accuracy = mean(predictions == YTest)
accuracy = 0.9530
The modelGradients
функция принимает dlnetwork
dlnet объекта
мини-пакет входных данных dlX
с соответствующими метками Y
и возвращает градиенты потерь относительно настраиваемых параметров в dlnet
, состояние сети и потери. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
function [gradients,state,loss] = modelGradients(dlnet,dlX,Y) [dlYPred,state] = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); loss = double(gather(extractdata(loss))); end
The modelPredictions
функция принимает dlnetwork
dlnet объекта
, а minibatchqueue
входных данных mbq
, и сетевых классов, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue
объект. Функция использует onehotdecode
функция для поиска предсказанного класса с самым высоким счетом.
function predictions = modelPredictions(dlnet,mbq,classes) predictions = []; while hasdata(mbq) dlXTest = next(mbq); dlYPred = predict(dlnet,dlXTest); YPred = onehotdecode(dlYPred,classes,1)'; predictions = [predictions; YPred]; end end
The preprocessMiniBatch
функция предварительно обрабатывает мини-пакет предикторов и меток с помощью следующих шагов:
Предварительно обработайте изображения с помощью preprocessMiniBatchPredictors
функция.
Извлеките данные метки из входящего массива ячеек и сгруппируйте в категориальный массив по второму измерению.
Однократное кодирование категориальных меток в числовые массивы. Кодирование в первую размерность создает закодированный массив, который совпадает с формой выходного сигнала сети.
function [X,Y] = preprocessMiniBatch(XCell,YCell) % Preprocess predictors. X = preprocessMiniBatchPredictors(XCell); % Extract label data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,1); end
The preprocessMiniBatchPredictors
функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из массива входа ячеек и конкатенации в числовой массив. Для входов полутонового цвета, конкатенация по четвертому измерению добавляет третье измерение каждому изображению, чтобы использовать в качестве размерности синглтонного канала.
function X = preprocessMiniBatchPredictors(XCell) % Concatenate. X = cat(4,XCell{1:end}); end
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| forward
| minibatchqueue
| onehotdecode
| onehotencode
| predict