Этот пример показывает, как соответствовать модели регрессии использование сверточных нейронных сетей, чтобы предсказать углы вращения рукописных цифр.
Сверточные нейронные сети (CNNs или ConvNets) являются особыми инструментами для глубокого обучения и особенно подходят для анализа данных изображения. Например, можно использовать CNNs, чтобы классифицировать изображения. Чтобы предсказать текущие данные, такие как углы и расстояния, можно включать слой регрессии в конце сети.
Пример создает сверточную архитектуру нейронной сети, обучает сеть и использует обучивший сеть, чтобы предсказать углы вращаемых рукописных цифр. Эти прогнозы полезны для оптического распознавания символов.
Опционально, можно использовать imrotate
(Image Processing Toolbox™), чтобы вращать изображения и boxplot
(Statistics and Machine Learning Toolbox™), чтобы создать остаточную диаграмму.
Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в градусах), которыми вращается каждое изображение.
Загрузите изображения обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData
и digitTest4DArrayData
. Выходные параметры YTrain
и YValidation
являются углами поворота в градусах. Наборы данных обучения и валидации каждый содержит 5 000 изображений.
[XTrain,~,YTrain] = digitTrain4DArrayData; [XValidation,~,YValidation] = digitTest4DArrayData;
Отобразите 20 случайных учебных изображений с помощью imshow
.
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) drawnow end
При обучении нейронных сетей это часто помогает убедиться, что данные нормированы на всех этапах сети. Нормализация помогает стабилизироваться и ускорить сетевое обучение с помощью спуска градиента. Если ваши данные плохо масштабируются, то потеря может стать NaN
, и сетевые параметры могут отличаться во время обучения. Распространенные способы нормировать данные включают перемасштабирование данных так, чтобы его область значений стала [0,1] или так, чтобы это имело среднее значение нулевого и стандартного отклонения одного. Можно нормировать следующие данные:
Входные данные. Нормируйте предикторы, прежде чем вы введете их к сети. В этом примере входные изображения уже нормированы к области значений [0,1].
Слой выходные параметры. Можно нормировать выходные параметры каждого сверточного и полносвязный слой при помощи пакетного слоя нормализации.
Ответы. Если вы используете пакетные слои нормализации, чтобы нормировать слой выходные параметры в конце сети, то прогнозы сети нормированы, когда обучение запускается. Если ответ имеет совсем другую шкалу из этих прогнозов, то сетевое обучение может не сходиться. Если ваш ответ плохо масштабируется, то попытайтесь нормировать его и смотрите, улучшается ли сетевое обучение. Если вы нормируете ответ перед обучением, то необходимо преобразовать прогнозы обучившего сеть, чтобы получить прогнозы исходного ответа.
Постройте распределение ответа. Ответ (угол поворота в градусах) приблизительно равномерно распределен между-45 и 45, который работает хорошо, не нуждаясь в нормализации. В проблемах классификации выходные параметры являются вероятностями класса, которые всегда нормируются.
figure histogram(YTrain) axis tight ylabel('Counts') xlabel('Rotation Angle')
В целом данные не должны быть точно нормированы. Однако, если вы обучаете сеть в этом примере, чтобы предсказать 100*YTrain
или YTrain+500
вместо YTrain
, затем потеря становится NaN
, и сетевые параметры отличаются, когда обучение запускается. Эти результаты происходят даже при том, что единственной разницей между сетью, предсказывающей да + b и сетью, предсказывая Y, является простое перемасштабирование весов и смещения итогового полносвязного слоя.
Если распределение входа или ответ являются очень неравномерными или скошенными, можно также выполнить нелинейные преобразования (например, беря логарифмы) к данным прежде, чем обучить сеть.
Чтобы решить проблему регрессии, создайте слои сети и включайте слой регрессии в конце сети.
Первый слой задает размер и тип входных данных. Входные изображения 28 28 1. Создайте входной слой изображений, одного размера как учебные изображения.
Средние слои сети define базовая архитектура сети, где большая часть вычисления и изучения происходят.
Последние слои задают размер и тип выходных данных. Для проблем регрессии полносвязный слой должен предшествовать слою регрессии в конце сети. Создайте полностью связанный выходной слой размера 1 и слой регрессии.
Объедините все слои вместе в массиве Layer
.
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer dropoutLayer(0.2) fullyConnectedLayer(1) regressionLayer];
Создайте сетевые опции обучения. Обучайтесь в течение 30 эпох. Установите начальную букву, изучают уровень 0,001 и понижают темп обучения после 20 эпох. Контролируйте сетевую точность во время обучения путем определения данных о валидации и частоты валидации. Программное обеспечение обучает сеть на данных тренировки и вычисляет точность на данные о валидации равномерно во время обучения. Данные о валидации не используются, чтобы обновить сетевые веса. Включите учебный график прогресса и выключите командное окно вывод.
miniBatchSize = 128; validationFrequency = floor(numel(YTrain)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... 'MaxEpochs',30, ... 'InitialLearnRate',1e-3, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropFactor',0.1, ... 'LearnRateDropPeriod',20, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'ValidationFrequency',validationFrequency, ... 'Plots','training-progress', ... 'Verbose',false);
Создайте сеть с помощью trainNetwork
. Эта команда использует совместимый графический процессор при наличии. В противном случае trainNetwork
использует центральный процессор. CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0, или выше требуется для обучения на графическом процессоре.
net = trainNetwork(XTrain,YTrain,layers,options);
Исследуйте детали сетевой архитектуры, содержавшейся в свойстве Layers
net
.
net.Layers
ans = 18x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv_1' Convolution 8 3x3x1 convolutions with stride [1 1] and padding 'same' 3 'batchnorm_1' Batch Normalization Batch normalization with 8 channels 4 'relu_1' ReLU ReLU 5 'avgpool2d_1' Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 6 'conv_2' Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'avgpool2d_2' Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' Convolution 32 3x3x16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'conv_4' Convolution 32 3x3x32 convolutions with stride [1 1] and padding 'same' 14 'batchnorm_4' Batch Normalization Batch normalization with 32 channels 15 'relu_4' ReLU ReLU 16 'dropout' Dropout 20% dropout 17 'fc' Fully Connected 1 fully connected layer 18 'regressionoutput' Regression Output mean-squared-error with response 'Response'
Проверьте производительность сети путем оценки точности на данных о валидации.
Используйте predict
, чтобы предсказать углы вращения изображений валидации.
YPredicted = predict(net,XValidation);
Оцените производительность
Оцените производительность модели путем вычисления:
Процент прогнозов в приемлемом допуске на погрешность
Среднеквадратичная ошибка (RMSE) предсказанных и фактических углов вращения
Вычислите ошибку прогноза между предсказанными и фактическими углами вращения.
predictionError = YValidation - YPredicted;
Вычислите количество прогнозов в приемлемом допуске на погрешность от истинных углов. Установите порог, чтобы быть 10 градусами. Вычислите процент прогнозов в этом пороге.
thr = 10; numCorrect = sum(abs(predictionError) < thr); numValidationImages = numel(YValidation); accuracy = numCorrect/numValidationImages
accuracy = 0.9684
Используйте среднеквадратичную ошибку (RMSE), чтобы измерить различия между предсказанными и фактическими углами вращения.
squares = predictionError.^2; rmse = sqrt(mean(squares))
rmse = single
4.5336
Отобразите диаграмму невязок для каждого класса цифры
Функция boxplot
требует матрицы, где каждый столбец соответствует невязкам для каждого класса цифры.
Группы данных о валидации отображают классами 0-9 цифры с 500 примерами каждого. Используйте reshape
, чтобы сгруппировать невязки классом цифры.
residualMatrix = reshape(predictionError,500,10);
Каждый столбец residualMatrix
соответствует невязкам каждой цифры. Создайте остаточную диаграмму для каждой цифры с помощью boxplot
(Statistics and Machine Learning Toolbox).
figure boxplot(residualMatrix,... 'Labels',{'0','1','2','3','4','5','6','7','8','9'}) xlabel('Digit Class') ylabel('Degrees Error') title('Residuals')
Классы цифры с самой высокой точностью имеют среднее значение близко к нулю и небольшому отклонению.
Можно использовать функции из Image Processing Toolbox, чтобы выправить цифры и отобразить их вместе. Вращайте 49 демонстрационных цифр согласно их предсказанным углам вращения с помощью imrotate
(Image Processing Toolbox).
idx = randperm(numValidationImages,49); for i = 1:numel(idx) image = XValidation(:,:,:,idx(i)); predictedAngle = YPredicted(idx(i)); imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop'); end
Отобразите исходные цифры с их исправленными вращениями. Можно использовать montage
(Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.
figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title('Original') subplot(1,2,2) montage(imagesRotated) title('Corrected')
classificationLayer
| regressionLayer