В этом примере показов, как обучить нейронную сеть для глубокого обучения с несколькими выходами, которые предсказывают и метки, и углы поворотов рукописных цифр.
Чтобы обучить сеть с несколькими выходами, необходимо обучить сеть с помощью пользовательского цикла обучения.
The digitTrain4DArrayData
функция загружает изображения, их метки цифр и углы поворота от вертикали. Создайте arrayDatastore
объект для изображений, меток и углов, а затем используйте combine
функция, чтобы создать один datastore, который содержит все обучающие данные. Извлеките имена классов и количество недискретных ответов.
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsYTrain = arrayDatastore(YTrain);
dsAnglesTrain = arrayDatastore(anglesTrain);
dsTrain = combine(dsXTrain,dsYTrain,dsAnglesTrain);
classNames = categories(YTrain);
numClasses = numel(classNames);
numObservations = numel(YTrain);
Просмотрите некоторые изображения из обучающих данных.
idx = randperm(numObservations,64); I = imtile(XTrain(:,:,:,idx)); figure imshow(I)
Задайте следующую сеть, которая предсказывает и метки, и углы поворота.
Блок свертки-batchnorm-ReLU с 16 фильтрами 5 на 5.
Два блока свертки-batchnorm-ReLU каждый с 32 фильтрами 3 на 3.
Пропускает соединение вокруг предыдущих двух блоков, содержащее блок свертки-batchnorm-ReLU с 32 свертками 1 на 1.
Объедините пропущенное соединение с помощью сложения.
Для вывода классификации выход ветвь с полностью связанной операцией размера 10 (количество классов) и операцией softmax.
Для вывода регрессии выход ветвь с полностью связанной операцией размера 1 (количество откликов).
Определите основной блок слоев как график слоев.
layers = [ imageInputLayer([28 28 1],'Normalization','none','Name','in') convolution2dLayer(5,16,'Padding','same','Name','conv1') batchNormalizationLayer('Name','bn1') reluLayer('Name','relu1') convolution2dLayer(3,32,'Padding','same','Stride',2,'Name','conv2') batchNormalizationLayer('Name','bn2') reluLayer('Name','relu2') convolution2dLayer(3,32,'Padding','same','Name','conv3') batchNormalizationLayer('Name','bn3') reluLayer('Name','relu4') additionLayer(2,'Name','addition') fullyConnectedLayer(numClasses,'Name','fc1') softmaxLayer('Name','softmax')]; lgraph = layerGraph(layers);
Добавьте пропущенное соединение.
layers = [ convolution2dLayer(1,32,'Stride',2,'Name','convSkip') batchNormalizationLayer('Name','bnSkip') reluLayer('Name','reluSkip')]; lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'relu1','convSkip'); lgraph = connectLayers(lgraph,'reluSkip','addition/in2');
Добавьте полносвязного слоя для регрессии.
layers = fullyConnectedLayer(1,'Name','fc2'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'addition','fc2');
Просмотр графика слоев на графике.
figure plot(lgraph)
Создайте dlnetwork
объект из графика слоев.
dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [17×1 nnet.cnn.layer.Layer] Connections: [17×2 table] Learnables: [20×3 table] State: [8×3 table] InputNames: {'in'} OutputNames: {'softmax' 'fc2'}
Создайте функцию modelGradients
, перечисленный в конце примера, который принимает как вход, dlnetwork
dlnet объекта
мини-пакет входных данных dlX
с соответствующими целями T1
и T2
содержит метки и углы, соответственно, и возвращает градиенты потерь относительно настраиваемых параметров, обновленного состояния сети и соответствующих потерь.
Задайте опции обучения. Обучите 30 эпох, используя мини-партию размером 128.
numEpochs = 30; miniBatchSize = 128;
Визуализируйте процесс обучения на графике.
plots = "training-progress";
Использование minibatchqueue
для обработки и управления мини-пакетами изображений. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch
(определено в конце этого примера), чтобы закодировать метки классов с одним «горячим» кодом.
Форматируйте данные изображения с помощью меток размерностей 'SSCB'
(пространственный, пространственный, канальный, пакетный). По умолчанию в minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
. Не добавляйте формат к меткам классов или углам.
Обучите на графическом процессоре, если он доступен. По умолчанию в minibatchqueue
объект преобразует каждый выход в gpuArray
при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
mbq = minibatchqueue(dsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessData,... 'MiniBatchFormat',{'SSCB','',''});
Обучите модель с помощью пользовательского цикла обучения. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. В конце каждой итерации отобразите процесс обучения. Для каждого мини-пакета:
Оцените градиенты модели и потери с помощью dlfeval
и modelGradients
функция.
Обновляйте параметры сети с помощью adamupdate
функция.
Инициализируйте график процесса обучения.
if plots == "training-progress" figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Инициализируйте параметры для Adam.
trailingAvg = []; trailingAvgSq = [];
Обучите модель.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq) % Loop over mini-batches while hasdata(mbq) iteration = iteration + 1; [dlX,dlY1,dlY2] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function. [gradients,state,loss] = dlfeval(@modelGradients, dlnet, dlX, dlY1, dlY2); dlnet.State = state; % Update the network parameters using the Adam optimizer. [dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ... trailingAvg,trailingAvgSq,iteration); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Протестируйте классификационную точность модели путем сравнения предсказаний на тестовом наборе с истинными метками и углами. Управление набором тестовых данных с помощью minibatchqueue
объект с той же настройкой, что и обучающие данные.
[XTest,Y1Test,anglesTest] = digitTest4DArrayData; dsXTest = arrayDatastore(XTest,'IterationDimension',4); dsYTest = arrayDatastore(Y1Test); dsAnglesTest = arrayDatastore(anglesTest); dsTest = combine(dsXTest,dsYTest,dsAnglesTest); mbqTest = minibatchqueue(dsTest,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessData,... 'MiniBatchFormat',{'SSCB','',''});
Чтобы предсказать метки и углы данных валидации, закольцовывайте мини-пакеты и используйте predict
функция. Сохраните предсказанные классы и углы. Сравните предсказанные и истинные классы и углы и сохраните результаты.
classesPredictions = []; anglesPredictions = []; classCorr = []; angleDiff = []; % Loop over mini-batches. while hasdata(mbqTest) % Read mini-batch of data. [dlXTest,dlY1Test,dlY2Test] = next(mbqTest); % Make predictions using the predict function. [dlY1Pred,dlY2Pred] = predict(dlnet,dlXTest,'Outputs',["softmax" "fc2"]); % Determine predicted classes. Y1PredBatch = onehotdecode(dlY1Pred,classNames,1); classesPredictions = [classesPredictions Y1PredBatch]; % Dermine predicted angles Y2PredBatch = extractdata(dlY2Pred); anglesPredictions = [anglesPredictions Y2PredBatch]; % Compare predicted and true classes Y1Test = onehotdecode(dlY1Test,classNames,1); classCorr = [classCorr Y1PredBatch == Y1Test]; % Compare predicted and true angles angleDiffBatch = Y2PredBatch - dlY2Test; angleDiff = [angleDiff extractdata(gather(angleDiffBatch))]; end
Оцените точность классификации.
accuracy = mean(classCorr)
accuracy = 0.9814
Оцените точность регрессии.
angleRMSE = sqrt(mean(angleDiff.^2))
angleRMSE = single
7.7431
Просмотрите некоторые изображения с их предсказаниями. Отображение прогнозируемых углов красного цвета и правильных меток зеленого цвета.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = anglesPredictions(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = anglesTest(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(classesPredictions(idx(i))); title("Label: " + label) end
The modelGradients
function, принимает как вход, dlnetwork
dlnet объекта
мини-пакет входных данных dlX
с соответствующими целями T1
и T2
содержит метки и углы, соответственно, и возвращает градиенты потерь относительно настраиваемых параметров, обновленного состояния сети и соответствующих потерь.
function [gradients,state,loss] = modelGradients(dlnet,dlX,T1,T2) [dlY1,dlY2,state] = forward(dlnet,dlX,'Outputs',["softmax" "fc2"]); lossLabels = crossentropy(dlY1,T1); lossAngles = mse(dlY2,T2); loss = lossLabels + 0.1*lossAngles; gradients = dlgradient(loss,dlnet.Learnables); end
The preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из входящего массива ячеек и соедините в числовой массив. Конкатенация данных изображения по четвертому измерению добавляет третье измерение к каждому изображению, которое используется в качестве размерности одинарного канала.
Извлеките данные о метках и углах из входящих массивов ячеек и соедините вдоль второго измерения в категориальный массив и числовой массив, соответственно.
Однократное кодирование категориальных меток в числовые массивы. Кодирование в первую размерность создает закодированный массив, который совпадает с формой выходного сигнала сети.
function [X,Y,angle] = preprocessData(XCell,YCell,angleCell) % Extract image data from cell and concatenate X = cat(4,XCell{:}); % Extract label data from cell and concatenate Y = cat(2,YCell{:}); % Extract angle data from cell and concatenate angle = cat(2,angleCell{:}); % One-hot encode labels Y = onehotencode(Y,1); end
batchNormalizationLayer
| convolution2dLayer
| dlarray
| dlfeval
| dlgradient
| fullyConnectedLayer
| minibatchqueue
| onehotdecode
| onehotencode
| reluLayer
| sgdmupdate
| softmaxLayer