В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с помощью пользовательского графика обучения.
Если trainingOptions
не предоставляет необходимых опций (например, пользовательский график обучения), то можно определить собственный пользовательский цикл обучения с помощью автоматического дифференцирования.
Этот пример обучает сеть классифицировать рукописные цифры по расписанию скорости обучения затуханию, основанному на времени: для каждой итерации решатель использует скорость обучения, заданную, где t - α0 - начальная скорость обучения, а k - затухание.
Загрузите данные цифр как хранилище данных изображения с помощью imageDatastore
и укажите папку, содержащую данные изображения.
dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset'); imds = imageDatastore(dataFolder, ... 'IncludeSubfolders',true, .... 'LabelSource','foldernames');
Разбейте данные на обучающие и проверочные наборы. Отложите 10% данных для проверки с помощью splitEachLabel
функция.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,'randomize');
Сеть, используемая в этом примере, требует входных изображений размера 28 на 28 на 1. Для автоматического изменения размеров учебных изображений используйте хранилище данных дополненного изображения. Укажите дополнительные операции увеличения, выполняемые на обучающих изображениях: случайное перемещение изображений до 5 пикселей в горизонтальной и вертикальной осях. Увеличение объема данных помогает предотвратить переоборудование сети и запоминание точных деталей обучающих изображений.
inputSize = [28 28 1]; pixelRange = [-5 5]; imageAugmenter = imageDataAugmenter( ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
Чтобы автоматически изменять размер изображений проверки без дальнейшего увеличения данных, используйте хранилище данных дополненного изображения без указания дополнительных операций предварительной обработки.
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Определите количество занятий в данных обучения.
classes = categories(imdsTrain.Labels); numClasses = numel(classes);
Определите сеть для классификации изображений.
layers = [ imageInputLayer(inputSize,'Normalization','none','Name','input') convolution2dLayer(5,20,'Name','conv1') batchNormalizationLayer('Name','bn1') reluLayer('Name','relu1') convolution2dLayer(3,20,'Padding','same','Name','conv2') batchNormalizationLayer('Name','bn2') reluLayer('Name','relu2') convolution2dLayer(3,20,'Padding','same','Name','conv3') batchNormalizationLayer('Name','bn3') reluLayer('Name','relu3') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','softmax')]; lgraph = layerGraph(layers);
Создать dlnetwork
объект из графа слоев.
dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'input'} OutputNames: {'softmax'}
Создание функции modelGradients
, перечисленных в конце примера, который принимает dlnetwork
объект, мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно обучаемых параметров в сети и соответствующих потерь.
Поезд на десять эпох с размером мини-партии 128.
numEpochs = 10; miniBatchSize = 128;
Укажите параметры оптимизации SGDM. Укажите начальную скорость обучения 0,01 с распадом 0,01, а импульс 0,9.
initialLearnRate = 0.01; decay = 0.01; momentum = 0.9;
Создать minibatchqueue
объект, обрабатывающий и управляющий мини-партиями изображений во время обучения. Для каждой мини-партии:
Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch
(определяется в конце этого примера) для преобразования меток в однокодированные переменные.
Форматирование данных изображения с метками размеров 'SSCB'
(пространственный, пространственный, канальный, пакетный). По умолчанию minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
. Не добавляйте формат к меткам класса.
Обучение на GPU, если он доступен. По умолчанию minibatchqueue
объект преобразует каждый вывод в gpuArray
если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).
mbq = minibatchqueue(augimdsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',@preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB',''});
Инициализируйте график хода обучения.
figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Инициализируйте параметр скорости для решателя SGDM.
velocity = [];
Обучение сети с использованием пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Для каждой мини-партии:
Оцените градиенты модели, состояние и потери с помощью dlfeval
и modelGradients
и обновить состояние сети.
Определите скорость обучения для графика скорости обучения с учетом времени.
Обновление параметров сети с помощью sgdmupdate
функция.
Просмотрите ход обучения.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX, dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function and update the network state. [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY); dlnet.State = state; % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,loss) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end
Проверьте точность классификации модели путем сравнения прогнозов на наборе проверки с истинными метками.
После обучения составление прогнозов по новым данным не требует меток. Создать minibatchqueue
объект, содержащий только предикторы тестовых данных:
Чтобы игнорировать метки для тестирования, установите число выходов мини-пакетной очереди равным 1.
Укажите размер мини-партии, используемый для обучения.
Предварительная обработка предикторов с помощью preprocessMiniBatchPredictors
функция, перечисленная в конце примера.
Для одиночного вывода хранилища данных укажите формат мини-пакета. 'SSCB'
(пространственный, пространственный, канальный, пакетный).
numOutputs = 1; mbqTest = minibatchqueue(augimdsValidation,numOutputs, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn',@preprocessMiniBatchPredictors, ... 'MiniBatchFormat','SSCB');
Закольцовывать мини-пакеты и классифицировать изображения с помощью modelPredictions
функция, перечисленная в конце примера.
predictions = modelPredictions(dlnet,mbqTest,classes);
Оцените точность классификации.
YTest = imdsValidation.Labels; accuracy = mean(predictions == YTest)
accuracy = 0.9530
modelGradients
функция принимает dlnetwork
объект dlnet
, мини-пакет входных данных dlX
с соответствующими метками Y
и возвращает градиенты потерь относительно обучаемых параметров в dlnet
, состояние сети и потери. Для автоматического вычисления градиентов используйте dlgradient
функция.
function [gradients,state,loss] = modelGradients(dlnet,dlX,Y) [dlYPred,state] = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); loss = double(gather(extractdata(loss))); end
modelPredictions
функция принимает dlnetwork
объект dlnet
, a minibatchqueue
входных данных mbq
и сетевые классы, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue
объект. Функция использует onehotdecode
функция для поиска прогнозируемого класса с наивысшим баллом.
function predictions = modelPredictions(dlnet,mbq,classes) predictions = []; while hasdata(mbq) dlXTest = next(mbq); dlYPred = predict(dlnet,dlXTest); YPred = onehotdecode(dlYPred,classes,1)'; predictions = [predictions; YPred]; end end
preprocessMiniBatch
функция предварительно обрабатывает мини-пакет предикторов и меток, используя следующие шаги:
Предварительная обработка изображений с помощью preprocessMiniBatchPredictors
функция.
Извлеките данные метки из входящего массива ячеек и объедините их в категориальный массив вдоль второго измерения.
Одноконтурное кодирование категориальных меток в числовые массивы. Кодирование в первом измерении создает закодированный массив, который соответствует форме сетевого вывода.
function [X,Y] = preprocessMiniBatch(XCell,YCell) % Preprocess predictors. X = preprocessMiniBatchPredictors(XCell); % Extract label data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,1); end
preprocessMiniBatchPredictors
функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из массива входных ячеек и конкатенации в числовой массив. Для ввода в оттенках серого при конкатенации над четвертым размером к каждому изображению добавляется третий размер для использования в качестве размера одиночного канала.
function X = preprocessMiniBatchPredictors(XCell) % Concatenate. X = cat(4,XCell{1:end}); end
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| forward
| minibatchqueue
| onehotdecode
| onehotencode
| predict