В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с помощью входных данных изображений и функций.
Загрузите изображения цифр XTrain
, метки YTrain
, и углы поворота по часовой стрелке anglesTrain
. Создайте arrayDatastore
объект для изображений, меток и углов, а затем используйте combine
функция, чтобы создать один datastore, который содержит все обучающие данные. Извлеките имена классов и высоту, ширину, количество каналов и количество недискретных откликов.
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsAnglesTrain = arrayDatastore(anglesTrain);
dsYTrain = arrayDatastore(YTrain);
dsTrain = combine(dsXTrain,dsAnglesTrain,dsYTrain);
classes = categories(YTrain);
[h,w,c,numObservations] = size(XTrain);
Отобразите 20 случайных обучающих изображений.
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) title("Angle: " + anglesTrain(idx(i))) end
Задайте размер входа изображения, количество функций каждого наблюдения, количество классов, а также размер и количество фильтров слоя свертки.
imageInputSize = [h w c]; numFeatures = 1; numClasses = numel(classes); filterSize = 5; numFilters = 16;
Чтобы создать сеть с двумя входными слоями, необходимо задать сеть в двух частях и присоединиться к ним, например, с помощью слоя конкатенации.
Определите первую часть сети. Определите слои классификации изображений и включите слой конкатенации перед последним полносвязным слоем.
layers = [ imageInputLayer(imageInputSize,'Normalization','none','Name','images') convolution2dLayer(filterSize,numFilters,'Name','conv') reluLayer('Name','relu') fullyConnectedLayer(50,'Name','fc1') concatenationLayer(1,2,'Name','concat') fullyConnectedLayer(numClasses,'Name','fc2') softmaxLayer('Name','softmax')];
Преобразуйте слои в график слоев.
lgraph = layerGraph(layers);
Для второй части сети добавьте входной слой функции и соедините его со вторым входом слоя конкатенации.
featInput = featureInputLayer(numFeatures,'Name','features'); lgraph = addLayers(lgraph, featInput); lgraph = connectLayers(lgraph, 'features', 'concat/in2');
Визуализация сети.
figure plot(lgraph)
Создайте dlnetwork
объект.
dlnet = dlnetwork(lgraph);
Когда вы используете функции predict
и forward
на dlnetwork
объект, входные параметры должны совпадать с порядком, заданным InputNames
свойство dlnetwork
объект. Проверьте имя и порядок входных слоев.
dlnet.InputNames
ans = 1×2 cell
{'images'} {'features'}
The modelGradients
функция, перечисленная в разделе Model Gradients Function примера, принимает в качестве входных данных dlnetwork
dlnet объекта
мини-пакет входных данных изображения dlX1
мини-пакет входных данных функций dlX2
и соответствующие метки dlY
, и возвращает градиенты потерь относительно настраиваемых параметров в dlnet
, состояние сети и потери.
Train с мини-партией размером 128 на 15 эпох.
numEpochs = 15; miniBatchSize = 128;
Задайте опции для оптимизации SGDM. Задайте начальную скорость обучения 0,01 с распадом 0,01 и импульсом 0,9.
learnRate = 0.01; decay = 0.01; momentum = 0.9;
Чтобы контролировать процесс обучения, можно построить график потерь обучения после каждой итерации. Создайте переменную plots
который содержит "training-progress"
. Если вы не хотите строить график процесса обучения, задайте это значение "none"
.
plots = "training-progress";
Обучите модель с помощью пользовательского цикла обучения. Инициализируйте параметр скорости для решателя SGDM.
velocity = [];
Использование minibatchqueue
обрабатывать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessData
(определено в конце этого примера), чтобы закодировать метки классов с одним «горячим» кодом.
По умолчанию в minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
. Форматируйте изображения с метками размерностей 'SSCB'
(пространственный, пространственный, канал, пакет) и углы с метками размерностей 'CB'
(канал, пакет). Не добавляйте формат к меткам классов.
Обучите на графическом процессоре, если он доступен. По умолчанию в minibatchqueue
объект преобразует каждый выход в gpuArray
при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
mbq = minibatchqueue(dsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB','CB',''});
Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. В конце каждой эпохи отобразите процесс обучения. Для каждого мини-пакета:
Оцените градиенты модели, состояние и потери с помощью dlfeval
и modelGradients
функционировать и обновлять состояние сети.
Обновляйте параметры сети с помощью sgdmupdate
функция.
Инициализируйте график процесса обучения.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Обучите модель.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq) % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX1,dlX2,dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function and update the network state. [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX1,dlX2,dlY); dlnet.State = state; % Update the network parameters using the SGDM optimizer. [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum); if plots == "training-progress" % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); %completionPercentage = round(iteration/numIterations*100,0); title("Epoch: " + epoch + ", Elapsed: " + string(D)); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) drawnow end end end
Протестируйте классификационную точность модели путем сравнения предсказаний на тестовом наборе с истинными метками. Протестируйте классификационную точность модели путем сравнения предсказаний на тестовом наборе с истинными метками и углами. Управление набором тестовых данных с помощью minibatchqueue
объект с той же настройкой, что и обучающие данные.
[XTest,YTest,anglesTest] = digitTest4DArrayData; dsXTest = arrayDatastore(XTest,'IterationDimension',4); dsAnglesTest = arrayDatastore(anglesTest); dsYTest = arrayDatastore(YTest); dsTest = combine(dsXTest,dsAnglesTest,dsYTest); mbqTest = minibatchqueue(dsTest,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB','CB',''});
Закольцовывайте мини-пакеты и классифицируйте изображения с помощью modelPredictions
функции, перечисленной в конце примера.
[predictions,predCorr] = modelPredictions(dlnet,mbqTest,classes);
Оцените точность классификации.
accuracy = mean(predCorr)
accuracy = 0.9818
Просмотрите некоторые изображения с их предсказаниями.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) label = string(predictions(idx(i))); title("Predicted Label: " + label) end
The modelGradients
функция принимает как вход dlnetwork
dlnet объекта
мини-пакет входных данных изображения dlX1
мини-пакет входных данных функций dlX2
, и соответствующие метки Y
, и возвращает градиенты потерь относительно настраиваемых параметров в dlnet
, состояние сети и потери. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
Когда вы используете forward
функция на dlnetwork
объект, входные параметры должны совпадать с порядком, заданным InputNames
свойство dlnetwork
объект.
function [gradients,state,loss] = modelGradients(dlnet,dlX1,dlX2,Y) [dlYPred,state] = forward(dlnet,dlX1,dlX2); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
The modelPredictions
функция принимает как вход dlnetwork
dlnet объекта
, а minibatchqueue
входных данных mbq
, и сетевых классов, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue
объект. Функция использует onehotdecode
функция для поиска предсказанного класса с самым высоким счетом, а затем сравнивает предсказание с истинной меткой. Функция возвращает предсказания и вектор таковых и нулей, который представляет правильные и неправильные предсказания.
function [classesPredictions,classCorr] = modelPredictions(dlnet,mbq,classes) classesPredictions = []; classCorr = []; while hasdata(mbq) [dlX1,dlX2,dlY] = next(mbq); % Make predictions using the model function. dlYPred = predict(dlnet,dlX1,dlX2); % Determine predicted classes. YPredBatch = onehotdecode(dlYPred,classes,1); classesPredictions = [classesPredictions YPredBatch]; % Compare predicted and true classes. Y = onehotdecode(dlY,classes,1); classCorr = [classCorr YPredBatch == Y]; end end
The preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из входящего массива ячеек и соедините в числовой массив. Конкатенация данных изображения по четвертому измерению добавляет третье измерение к каждому изображению, которое используется в качестве размерности одинарного канала.
Извлеките данные о метках и углах из входящих массивов ячеек и соедините вдоль второго измерения в категориальный массив и числовой массив, соответственно.
Однократное кодирование категориальных меток в числовые массивы. Кодирование в первую размерность создает закодированный массив, который совпадает с формой выходного сигнала сети.
function [X,angle,Y] = preprocessMiniBatch(XCell,angleCell,YCell) % Extract image data from cell and concatenate. X = cat(4,XCell{:}); % Extract angle data from cell and concatenate. angle = cat(2,angleCell{:}); % Extract label data from cell and concatenate. Y = cat(2,YCell{:}); % One-hot encode labels. Y = onehotencode(Y,1); end
Deep Network Designer | dlarray
| dlfeval
| dlnetwork
| featureInputLayer
| fullyConnectedLayer
| minibatchqueue
| onehotdecode
| onehotencode