В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с помощью и изображения и входных данных функции.
Загрузите изображения цифр XTrain
, маркирует YTrain
, и по часовой стрелке углы поворота anglesTrain
. Создайте arrayDatastore
объект для изображений, меток и углов, и затем использует combine
функция, чтобы сделать один datastore, который содержит все обучающие данные. Извлеките имена классов и высоту, ширину, количество каналов и количество недискретных ответов.
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsAnglesTrain = arrayDatastore(anglesTrain);
dsYTrain = arrayDatastore(YTrain);
dsTrain = combine(dsXTrain,dsAnglesTrain,dsYTrain);
classes = categories(YTrain);
[h,w,c,numObservations] = size(XTrain);
Отобразите 20 случайных учебных изображений.
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) title("Angle: " + anglesTrain(idx(i))) end
Задайте размер входного изображения, количество функций каждого наблюдения, количество классов, и размер и количество фильтров слоя свертки.
imageInputSize = [h w c]; numFeatures = 1; numClasses = numel(classes); filterSize = 5; numFilters = 16;
Чтобы создать сеть с двумя входными слоями, необходимо задать сеть в двух частях и соединить их, например, при помощи слоя конкатенации.
Задайте первую часть сети. Задайте слои классификации изображений и включайте слой конкатенации перед последним полносвязным слоем.
layers = [ imageInputLayer(imageInputSize,'Normalization','none','Name','images') convolution2dLayer(filterSize,numFilters,'Name','conv') reluLayer('Name','relu') fullyConnectedLayer(50,'Name','fc1') concatenationLayer(1,2,'Name','concat') fullyConnectedLayer(numClasses,'Name','fc2') softmaxLayer('Name','softmax')];
Преобразуйте слои в график слоев.
lgraph = layerGraph(layers);
Для второй части сети добавьте входной слой функции и соедините его со вторым входом слоя конкатенации.
featInput = featureInputLayer(numFeatures,'Name','features'); lgraph = addLayers(lgraph, featInput); lgraph = connectLayers(lgraph, 'features', 'concat/in2');
Визуализируйте сеть.
figure plot(lgraph)
Создайте dlnetwork
объект.
dlnet = dlnetwork(lgraph);
Когда вы используете функции predict
и forward
на dlnetwork
объект, входные параметры должны совпадать с распоряжением, данным InputNames
свойство dlnetwork
объект. Смотрите имя и порядок входных слоев.
dlnet.InputNames
ans = 1×2 cell
{'images'} {'features'}
modelGradients
функция, перечисленная в разделе Model Gradients Function примера, берет в качестве входа dlnetwork
объект dlnet
, мини-пакет входных данных изображения dlX1
, мини-пакет входа показывает данные dlX2
, и соответствие маркирует dlY
, и возвращает градиенты потери относительно настраиваемых параметров в dlnet
, сетевое состояние и потеря.
Обучайтесь с мини-пакетным размером 128 в течение 15 эпох.
numEpochs = 15; miniBatchSize = 128;
Задайте опции для оптимизации SGDM. Задайте начальную скорость обучения 0,01 с затуханием 0,01 и импульсом 0,9.
learnRate = 0.01; decay = 0.01; momentum = 0.9;
Чтобы контролировать процесс обучения, можно построить учебную потерю после каждой итерации. Создайте переменную plots
это содержит "training-progress"
. Если вы не хотите строить процесс обучения, то установленный это значение к "none"
.
plots = "training-progress";
Обучите модель с помощью пользовательского учебного цикла. Инициализируйте скоростной параметр для решателя SGDM.
velocity = [];
Используйте minibatchqueue
обработать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessData
(заданный в конце этого примера) к одногорячему кодируют метки класса.
По умолчанию, minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
. Формат изображения с размерностью маркирует 'SSCB'
(пространственный, пространственный, канал, пакет), и углы с размерностью маркирует 'CB'
(образуйте канал, пакет). Не добавляйте формат в метки класса.
Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue
объект преобразует каждый выход в gpuArray
если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).
mbq = minibatchqueue(dsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB','CB',''});
В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. В конце каждой эпохи отобразите прогресс обучения. Для каждого мини-пакета:
Оцените градиенты модели, состояние и потерю с помощью dlfeval
и modelGradients
функционируйте и обновите сетевое состояние.
Обновите сетевые параметры с помощью sgdmupdate
функция.
Инициализируйте график процесса обучения.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Обучите модель.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq) % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX1,dlX2,dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function and update the network state. [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX1,dlX2,dlY); dlnet.State = state; % Update the network parameters using the SGDM optimizer. [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum); if plots == "training-progress" % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); %completionPercentage = round(iteration/numIterations*100,0); title("Epoch: " + epoch + ", Elapsed: " + string(D)); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) drawnow end end end
Протестируйте точность классификации модели путем сравнения предсказаний на наборе тестов с истинными метками. Протестируйте точность классификации модели путем сравнения предсказаний на наборе тестов с истинными метками и углами. Управляйте набором тестовых данных с помощью minibatchqueue
объект с той же установкой как обучающие данные.
[XTest,YTest,anglesTest] = digitTest4DArrayData; dsXTest = arrayDatastore(XTest,'IterationDimension',4); dsAnglesTest = arrayDatastore(anglesTest); dsYTest = arrayDatastore(YTest); dsTest = combine(dsXTest,dsAnglesTest,dsYTest); mbqTest = minibatchqueue(dsTest,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB','CB',''});
Цикл по мини-пакетам и классифицирует изображения с помощью modelPredictions
функция, перечисленная в конце примера.
[predictions,predCorr] = modelPredictions(dlnet,mbqTest,classes);
Оцените точность классификации.
accuracy = mean(predCorr)
accuracy = 0.9818
Просмотрите некоторые изображения с их предсказаниями.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) label = string(predictions(idx(i))); title("Predicted Label: " + label) end
modelGradients
функционируйте берет в качестве входа dlnetwork
объект dlnet
, мини-пакет входных данных изображения dlX1
, мини-пакет входа показывает данные dlX2
, и соответствующий маркирует Y
, и возвращает градиенты потери относительно настраиваемых параметров в dlnet
, сетевое состояние и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
Когда вы используете forward
функция на dlnetwork
объект, входные параметры должны совпадать с распоряжением, данным InputNames
свойство dlnetwork
объект.
function [gradients,state,loss] = modelGradients(dlnet,dlX1,dlX2,Y) [dlYPred,state] = forward(dlnet,dlX1,dlX2); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables); end
modelPredictions
функционируйте берет в качестве входа dlnetwork
объект dlnet
, minibatchqueue
из входных данных mbq
, и сетевые классы, и вычисляют предсказания модели путем итерации по всем данным в minibatchqueue
объект. Функция использует onehotdecode
функционируйте, чтобы найти предсказанный класс с самым высоким счетом и затем сравнить предсказание с истинной меткой. Функция возвращает предсказания и вектор из единиц и нулей, который представляет правильные и неправильные предсказания.
function [classesPredictions,classCorr] = modelPredictions(dlnet,mbq,classes) classesPredictions = []; classCorr = []; while hasdata(mbq) [dlX1,dlX2,dlY] = next(mbq); % Make predictions using the model function. dlYPred = predict(dlnet,dlX1,dlX2); % Determine predicted classes. YPredBatch = onehotdecode(dlYPred,classes,1); classesPredictions = [classesPredictions YPredBatch]; % Compare predicted and true classes. Y = onehotdecode(dlY,classes,1); classCorr = [classCorr YPredBatch == Y]; end end
preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из массива входящей ячейки и конкатенируйте в числовой массив. Конкатенация данных изображения по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использоваться в качестве одноэлементной размерности канала.
Извлеките метку и угловые данные из массивов входящей ячейки и конкатенируйте вдоль второго измерения в категориальный массив и числовой массив, соответственно.
Одногорячий кодируют категориальные метки в числовые массивы. Кодирование в первую размерность производит закодированный массив, который совпадает с формой сетевого выхода.
function [X,angle,Y] = preprocessMiniBatch(XCell,angleCell,YCell) % Extract image data from cell and concatenate. X = cat(4,XCell{:}); % Extract angle data from cell and concatenate. angle = cat(2,angleCell{:}); % Extract label data from cell and concatenate. Y = cat(2,YCell{:}); % One-hot encode labels. Y = onehotencode(Y,1); end
dlnetwork
| dlfeval
| dlarray
| fullyConnectedLayer
| Deep Network Designer | featureInputLayer
| minibatchqueue
| onehotencode
| onehotdecode