В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с помощью пользовательского графика обучения.
Если trainingOptions не предоставляет необходимых опций (например, пользовательский график обучения), то можно определить собственный пользовательский цикл обучения с помощью автоматического дифференцирования.
Этот пример обучает сеть классифицировать рукописные цифры по расписанию скорости обучения затуханию, основанному на времени: для каждой итерации решатель использует скорость обучения, заданную, где t - итераций, α0 - начальная скорость обучения, а k - затухание.
Загрузите данные цифр как хранилище данных изображения с помощью imageDatastore и укажите папку, содержащую данные изображения.
dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset'); imds = imageDatastore(dataFolder, ... 'IncludeSubfolders',true, .... 'LabelSource','foldernames');
Разбейте данные на обучающие и проверочные наборы. Отложите 10% данных для проверки с помощью splitEachLabel функция.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,'randomize');Сеть, используемая в этом примере, требует входных изображений размера 28 на 28 на 1. Для автоматического изменения размеров учебных изображений используйте хранилище данных дополненного изображения. Укажите дополнительные операции увеличения, выполняемые на обучающих изображениях: случайное перемещение изображений до 5 пикселей в горизонтальной и вертикальной осях. Увеличение объема данных помогает предотвратить переоборудование сети и запоминание точных деталей обучающих изображений.
inputSize = [28 28 1]; pixelRange = [-5 5]; imageAugmenter = imageDataAugmenter( ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
Чтобы автоматически изменять размер изображений проверки без дальнейшего увеличения данных, используйте хранилище данных дополненного изображения без указания дополнительных операций предварительной обработки.
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Определите количество занятий в данных обучения.
classes = categories(imdsTrain.Labels); numClasses = numel(classes);
Определите сеть для классификации изображений.
layers = [
imageInputLayer(inputSize,'Normalization','none','Name','input')
convolution2dLayer(5,20,'Name','conv1')
batchNormalizationLayer('Name','bn1')
reluLayer('Name','relu1')
convolution2dLayer(3,20,'Padding','same','Name','conv2')
batchNormalizationLayer('Name','bn2')
reluLayer('Name','relu2')
convolution2dLayer(3,20,'Padding','same','Name','conv3')
batchNormalizationLayer('Name','bn3')
reluLayer('Name','relu3')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);Создать dlnetwork объект из графа слоев.
dlnet = dlnetwork(lgraph)
dlnet =
dlnetwork with properties:
Layers: [12×1 nnet.cnn.layer.Layer]
Connections: [11×2 table]
Learnables: [14×3 table]
State: [6×3 table]
InputNames: {'input'}
OutputNames: {'softmax'}
Создание функции modelGradients, перечисленных в конце примера, который принимает dlnetwork объект, мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно обучаемых параметров в сети и соответствующих потерь.
Поезд на десять эпох с размером мини-партии 128.
numEpochs = 10; miniBatchSize = 128;
Укажите параметры оптимизации SGDM. Укажите начальную скорость обучения 0,01 с распадом 0,01, а импульс 0,9.
initialLearnRate = 0.01; decay = 0.01; momentum = 0.9;
Создать minibatchqueue объект, обрабатывающий и управляющий мини-партиями изображений во время обучения. Для каждой мини-партии:
Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определяется в конце этого примера) для преобразования меток в однокодированные переменные.
Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат к меткам класса.
Обучение на GPU, если он доступен. По умолчанию minibatchqueue объект преобразует каждый вывод в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).
mbq = minibatchqueue(augimdsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',@preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB',''});
Инициализируйте график хода обучения.
figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Инициализируйте параметр скорости для решателя SGDM.
velocity = [];
Обучение сети с использованием пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Для каждой мини-партии:
Оцените градиенты модели, состояние и потери с помощью dlfeval и modelGradients и обновить состояние сети.
Определите скорость обучения для графика скорости обучения с учетом времени.
Обновление параметров сети с помощью sgdmupdate функция.
Просмотрите ход обучения.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX, dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function and update the network state. [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY); dlnet.State = state; % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,loss) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end

Проверьте точность классификации модели путем сравнения прогнозов на наборе проверки с истинными метками.
После обучения составление прогнозов по новым данным не требует меток. Создать minibatchqueue объект, содержащий только предикторы тестовых данных:
Чтобы игнорировать метки для тестирования, установите число выходов мини-пакетной очереди равным 1.
Укажите размер мини-партии, используемый для обучения.
Предварительная обработка предикторов с помощью preprocessMiniBatchPredictors функция, перечисленная в конце примера.
Для одиночного вывода хранилища данных укажите формат мини-пакета. 'SSCB' (пространственный, пространственный, канальный, пакетный).
numOutputs = 1; mbqTest = minibatchqueue(augimdsValidation,numOutputs, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn',@preprocessMiniBatchPredictors, ... 'MiniBatchFormat','SSCB');
Закольцовывать мини-пакеты и классифицировать изображения с помощью modelPredictions функция, перечисленная в конце примера.
predictions = modelPredictions(dlnet,mbqTest,classes);
Оцените точность классификации.
YTest = imdsValidation.Labels; accuracy = mean(predictions == YTest)
accuracy = 0.9530
modelGradients функция принимает dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствующими метками Y и возвращает градиенты потерь относительно обучаемых параметров в dlnet, состояние сети и потери. Для автоматического вычисления градиентов используйте dlgradient функция.
function [gradients,state,loss] = modelGradients(dlnet,dlX,Y) [dlYPred,state] = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); loss = double(gather(extractdata(loss))); end
modelPredictions функция принимает dlnetwork объект dlnet, a minibatchqueue входных данных mbqи сетевые классы, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функция для поиска прогнозируемого класса с наивысшим баллом.
function predictions = modelPredictions(dlnet,mbq,classes) predictions = []; while hasdata(mbq) dlXTest = next(mbq); dlYPred = predict(dlnet,dlXTest); YPred = onehotdecode(dlYPred,classes,1)'; predictions = [predictions; YPred]; end end
preprocessMiniBatch функция предварительно обрабатывает мини-пакет предикторов и меток, используя следующие шаги:
Предварительная обработка изображений с помощью preprocessMiniBatchPredictors функция.
Извлеките данные метки из входящего массива ячеек и объедините их в категориальный массив вдоль второго измерения.
Одноконтурное кодирование категориальных меток в числовые массивы. Кодирование в первом измерении создает закодированный массив, который соответствует форме сетевого вывода.
function [X,Y] = preprocessMiniBatch(XCell,YCell) % Preprocess predictors. X = preprocessMiniBatchPredictors(XCell); % Extract label data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,1); end
preprocessMiniBatchPredictors функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из массива входных ячеек и конкатенации в числовой массив. Для ввода в оттенках серого при конкатенации над четвертым размером к каждому изображению добавляется третий размер для использования в качестве размера одиночного канала.
function X = preprocessMiniBatchPredictors(XCell) % Concatenate. X = cat(4,XCell{1:end}); end
adamupdate | dlarray | dlfeval | dlgradient | dlnetwork | forward | minibatchqueue | onehotdecode | onehotencode | predict