Обучите сверточную нейронную сеть регрессии

В этом примере показано, как подбирать модель регрессии использование сверточных нейронных сетей, чтобы предсказать углы вращения рукописных цифр.

Сверточные нейронные сети (CNNs или ConvNets) являются особыми инструментами для глубокого обучения и особенно подходят для анализа данных изображения. Например, можно использовать CNNs, чтобы классифицировать изображения. Чтобы предсказать текущие данные, такие как углы и расстояния, можно включать слой регрессии в конце сети.

Пример создает архитектуру сверточной нейронной сети, обучает сеть и использует обучивший сеть, чтобы предсказать углы вращаемых рукописных цифр. Эти предсказания полезны для оптического распознавания символов.

Опционально, можно использовать imrotate (Image Processing Toolbox™), чтобы вращать изображения и boxplot (Statistics and Machine Learning Toolbox™), чтобы создать остаточную диаграмму.

Загрузка данных

Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в градусах), которыми вращается каждое изображение.

Загрузите изображения обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData и digitTest4DArrayData. Выходные параметры YTrain и YValidation углы поворота в градусах. Наборы данных обучения и валидации каждый содержит 5 000 изображений.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Отобразите 20 случайных учебных изображений с помощью imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

Проверяйте нормализацию данных

При обучении нейронных сетей это часто помогает убедиться, что данные нормированы на всех этапах сети. Нормализация помогает стабилизироваться и ускорить сетевое обучение с помощью градиентного спуска. Если ваши данные плохо масштабируются, то потеря может стать NaN и сетевые параметры могут отличаться во время обучения. Распространенные способы нормировать данные включают перемасштабирование данных так, чтобы его область значений стала [0,1] или так, чтобы это имело среднее значение нулевого и стандартного отклонения одного. Можно нормировать следующие данные:

  • Входные данные. Нормируйте предикторы, прежде чем вы введете их к сети. В этом примере входные изображения уже нормированы к области значений [0,1].

  • Слой выходные параметры. Можно нормировать выходные параметры каждого сверточного и полносвязный слой при помощи слоя нормализации партии.

  • Ответы. Если вы используете слои нормализации партии., чтобы нормировать слой выходные параметры в конце сети, то предсказания сети нормированы, когда обучение запускается. Если ответ имеет совсем другую шкалу из этих предсказаний, то сетевое обучение может не сходиться. Если ваш ответ плохо масштабируется, то попытайтесь нормировать его и смотрите, улучшается ли сетевое обучение. Если вы нормируете ответ перед обучением, то необходимо преобразовать предсказания обучившего сеть, чтобы получить предсказания исходного ответа.

Постройте распределение ответа. Ответ (угол поворота в градусах) приблизительно равномерно распределен между-45 и 45, который работает хорошо, не нуждаясь в нормализации. В проблемах классификации выходные параметры являются вероятностями класса, которые всегда нормируются.

figure
histogram(YTrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

Figure contains an axes object. The axes object contains an object of type histogram.

В общем случае данные не должны быть точно нормированы. Однако, если вы обучаете сеть в этом примере, чтобы предсказать 100*YTrain или YTrain+500 вместо YTrain, затем потеря становится NaN и сетевые параметры отличаются, когда обучение запускается. Эти результаты происходят даже при том, что единственной разницей между сетью, предсказывающей да + b и сетью, предсказывая Y, является простое перемасштабирование весов и смещения итогового полносвязного слоя.

Если распределение входа или ответ являются очень неравномерными или скошенными, можно также выполнить нелинейные преобразования (например, беря логарифмы) к данным прежде, чем обучить сеть.

Создайте слоя сети

Чтобы решить задачу регрессии, создайте слои сети и включайте слой регрессии в конце сети.

Первый слой задает размер и тип входных данных. Входные изображения 28 28 1. Создайте входной слой изображений одного размера с учебными изображениями.

Средние слои сети define базовая архитектура сети, где большая часть расчета и изучения происходят.

Последние слои задают размер и тип выходных данных. Для проблем регрессии полносвязный слой должен предшествовать слою регрессии в конце сети. Создайте полностью связанный выходной слой размера 1 и слой регрессии.

Объедините все слои вместе в Layer массив.

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    dropoutLayer(0.2)
    fullyConnectedLayer(1)
    regressionLayer];

Обучение сети

Создайте сетевые опции обучения. Обучайтесь в течение 30 эпох. Установите начальную букву, изучают уровень 0,001 и понижают скорость обучения после 20 эпох. Контролируйте сетевую точность во время обучения путем определения данных о валидации и частоты валидации. Программное обеспечение обучает сеть на обучающих данных и вычисляет точность на данные о валидации равномерно во время обучения. Данные о валидации не используются, чтобы обновить сетевые веса. Включите график процесса обучения и выключите командное окно выход.

miniBatchSize  = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',30, ...
    'InitialLearnRate',1e-3, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',20, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'Verbose',false);

Создайте сеть с помощью trainNetwork. Эта команда использует совместимый графический процессор при наличии. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). В противном случае, trainNetwork использует центральный процессор.

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (25-Aug-2021 07:24:47) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 10 objects of type patch, text, line. Axes object 2 contains 10 objects of type patch, text, line.

Исследуйте детали сетевой архитектуры, содержавшейся в Layers свойство net.

net.Layers
ans = 
  18x1 Layer array with layers:

     1   'imageinput'         Image Input           28x28x1 images with 'zerocenter' normalization
     2   'conv_1'             Convolution           8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'        Batch Normalization   Batch normalization with 8 channels
     4   'relu_1'             ReLU                  ReLU
     5   'avgpool2d_1'        Average Pooling       2x2 average pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'             Convolution           16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'        Batch Normalization   Batch normalization with 16 channels
     8   'relu_2'             ReLU                  ReLU
     9   'avgpool2d_2'        Average Pooling       2x2 average pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'             Convolution           32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'        Batch Normalization   Batch normalization with 32 channels
    12   'relu_3'             ReLU                  ReLU
    13   'conv_4'             Convolution           32 3x3x32 convolutions with stride [1  1] and padding 'same'
    14   'batchnorm_4'        Batch Normalization   Batch normalization with 32 channels
    15   'relu_4'             ReLU                  ReLU
    16   'dropout'            Dropout               20% dropout
    17   'fc'                 Fully Connected       1 fully connected layer
    18   'regressionoutput'   Regression Output     mean-squared-error with response 'Response'

Тестирование сети

Проверьте производительность сети путем оценки точности на данных о валидации.

Используйте predict предсказать углы вращения изображений валидации.

YPredicted = predict(net,XValidation);

Оцените эффективность

Оцените эффективность модели путем вычисления:

  1. Процент предсказаний в приемлемом допуске на погрешность

  2. Среднеквадратичная ошибка (RMSE) предсказанных и фактических углов вращения

Вычислите ошибку предсказания между предсказанными и фактическими углами вращения.

predictionError = YValidation - YPredicted;

Вычислите количество предсказаний в приемлемом допуске на погрешность от истинных углов. Установите порог, чтобы быть 10 градусами. Вычислите процент предсказаний в этом пороге.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);

accuracy = numCorrect/numValidationImages
accuracy = 0.9700

Используйте среднеквадратичную ошибку (RMSE), чтобы измерить различия между предсказанными и фактическими углами вращения.

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse = single
    4.5296

Визуализируйте предсказания

Визуализируйте предсказания в графике рассеивания. Постройте ожидаемые значения против истинных значений.

figure
scatter(YPredicted,YValidation,'+')
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],'r--')

Figure contains an axes object. The axes object contains 2 objects of type scatter, line.

Правильные вращения цифры

Можно использовать функции из Image Processing Toolbox, чтобы выправить цифры и отобразить их вместе. Вращайте 49 демонстрационных цифр согласно их предсказанным углам вращения с помощью imrotate (Image Processing Toolbox).

idx = randperm(numValidationImages,49);
for i = 1:numel(idx)
    image = XValidation(:,:,:,idx(i));
    predictedAngle = YPredicted(idx(i));  
    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end

Отобразите исходные цифры с их откорректированными вращениями. Можно использовать montage (Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')

Figure contains 2 axes objects. Axes object 1 with title Original contains an object of type image. Axes object 2 with title Corrected contains an object of type image.

Смотрите также

|

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте