Обучите сверточную нейронную сеть для регрессии

Этот пример показывает аппроксимацию регрессионной модели с помощью сверточных нейронных сетей для предсказания углов поворота рукописных цифр.

Сверточные нейронные сети (CNNs, или ConvNets) являются особыми инструментами для глубокого обучения и особенно подходят для анализа данных изображений. Для примера можно использовать CNNs для классификации изображений. Чтобы предсказать непрерывные данные, такие как углы и расстояния, можно включить регрессионый слой в конец сети.

Пример создает архитектуру сверточной нейронной сети, обучает сеть и использует обученную сеть для предсказания углов повернутых рукописных цифр. Эти предсказания полезны для оптического распознавания символов.

Опционально можно использовать imrotate (Image Processing Toolbox™), чтобы повернуть изображения, и boxplot (Statistics and Machine Learning Toolbox™), чтобы создать остаточный прямоугольный график.

Загрузка данных

Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в степенях), на которые поворачивается каждое изображение.

Загрузите изображения для обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData и digitTest4DArrayData. Выходные выходы YTrain и YValidation являются углами поворота в степенях. Каждый набор обучающих и валидационных данных содержит 5000 изображений.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Отобразите 20 случайных обучающих изображений с помощью imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Проверяйте нормализацию данных

При обучении нейронных сетей часто помогает убедиться, что ваши данные нормализованы на всех этапах сети. Нормализация помогает стабилизировать и ускорить обучение сети с помощью градиентного спуска. Если ваши данные плохо масштабированы, то потеря может стать NaN и параметры сети могут различаться во время обучения. Общие способы нормализации данных включают в себя перемасштабирование данных так, чтобы их область значений стала [0,1] или чтобы он имел среднее значение нуля и стандартное отклонение единицы. Можно нормализовать следующие данные:

  • Входные данные. Нормализуйте предикторы перед тем, как вы вводите их в сеть. В этом примере входные изображения уже нормированы к области значений [0,1].

  • Выходы слоя. Можно нормализовать выходы каждого сверточного и полносвязного слоя с помощью слоя нормализации партии ..

  • Ответы. Если вы используете нормализацию партии . слои, чтобы нормализовать слой, выходы в конце сети, то предсказания сети нормализуются, когда начинается обучение. Если ответ имеет очень отличную шкалу от этих предсказаний, то сетевое обучение может не сходиться. Если ваш ответ плохо масштабирован, попробуйте нормализовать его и проверьте, улучшается ли сетевое обучение. Если вы нормализуете ответ перед обучением, то вы должны преобразовать предсказания обученной сети, чтобы получить предсказания исходного ответа.

Постройте график распределения отклика. Реакция (угол поворота в степенях) приблизительно равномерно распределена между -45 и 45, что хорошо работает, не нуждаясь в нормализации. В задачах классификации выходы являются вероятностями классов, которые всегда нормализуются.

figure
histogram(YTrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

В целом данные не должны быть точно нормированы. Однако, если вы обучаете сеть в этом примере прогнозировать 100*YTrain или YTrain+500 вместо YTrain, затем потеря становится NaN и параметры сети различаются, когда начинается обучение. Эти результаты происходят, хотя единственное различие между сетью, предсказывающей aY + b, и сетью, предсказывающей Y, является простым перемасштабированием весов и смещений конечного полносвязного слоя.

Если распределение входа или отклика очень неравномерно или искривлено, можно также выполнить нелинейные преобразования (для примера, взятие логарифмов) к данным перед обучением сети.

Создание слоев сети

Чтобы решить задачу регрессии, создайте слои сети и включите регрессионый слой в конец сети.

Первый слой определяет размер и тип входных данных. Изображения входа составляют 28 на 28 на 1. Создайте вход изображений того же размера, что и обучающие изображения.

Средние слои сети определяют основную архитектуру сети, где происходит большая часть расчетов и обучения.

Конечные слои определяют размер и тип выхода данных. Для регрессионных задач полностью соединенный слой должен предшествовать регрессионому слою в конце сети. Создайте полностью связанный выходной слой размера 1 и регрессионый слой.

Объедините все слои в Layer массив.

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    dropoutLayer(0.2)
    fullyConnectedLayer(1)
    regressionLayer];

Обучите сеть

Создайте опции обучения. Обучайте на 30 эпох. Установите начальный темп обучения равный 0,001 и опустите скорость обучения после 20 эпох. Отслеживайте точность сети во время обучения путем определения данных валидации и частоты валидации. Программное обеспечение обучает сеть по обучающим данным и вычисляет точность по данным валидации через регулярные интервалы во время обучения. Данные валидации не используются для обновления весов сети. Включите график процесса обучения и отключите вывод командного окна.

miniBatchSize  = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',30, ...
    'InitialLearnRate',1e-3, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',20, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'Verbose',false);

Создайте сеть с помощью trainNetwork. Эта команда использует совместимый графический процессор, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). В противном случае trainNetwork использует центральный процессор.

net = trainNetwork(XTrain,YTrain,layers,options);

Исследуйте детали сетевой архитектуры, содержащиеся в Layers свойство net.

net.Layers
ans = 
  18×1 Layer array with layers:

     1   'imageinput'         Image Input           28×28×1 images with 'zerocenter' normalization
     2   'conv_1'             Convolution           8 3×3×1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'        Batch Normalization   Batch normalization with 8 channels
     4   'relu_1'             ReLU                  ReLU
     5   'avgpool2d_1'        Average Pooling       2×2 average pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'             Convolution           16 3×3×8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'        Batch Normalization   Batch normalization with 16 channels
     8   'relu_2'             ReLU                  ReLU
     9   'avgpool2d_2'        Average Pooling       2×2 average pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'             Convolution           32 3×3×16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'        Batch Normalization   Batch normalization with 32 channels
    12   'relu_3'             ReLU                  ReLU
    13   'conv_4'             Convolution           32 3×3×32 convolutions with stride [1  1] and padding 'same'
    14   'batchnorm_4'        Batch Normalization   Batch normalization with 32 channels
    15   'relu_4'             ReLU                  ReLU
    16   'dropout'            Dropout               20% dropout
    17   'fc'                 Fully Connected       1 fully connected layer
    18   'regressionoutput'   Regression Output     mean-squared-error with response 'Response'

Тестирование сети

Проверьте эффективность сети путем оценки точности данных валидации.

Использование predict для предсказания углов поворота изображений валидации.

YPredicted = predict(net,XValidation);

Оценка эффективности

Оцените эффективность модели путем вычисления:

  1. Процент предсказаний в пределах допустимого запаса по ошибке

  2. Среднеквадратичная ошибка (RMSE) предсказанного и фактического углов поворота

Вычислите ошибку предсказания между предсказанным и фактическим углами поворота.

predictionError = YValidation - YPredicted;

Вычислим количество предсказаний в пределах допустимого запаса по ошибке из истинных углов. Установите порог равным 10 степеням. Вычислите процент предсказаний в пределах этого порога.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);

accuracy = numCorrect/numValidationImages
accuracy = 0.9690

Используйте среднеквадратическую ошибку (RMSE), чтобы измерить различия между предсказанным и фактическим углами поворота.

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse = single
    4.6062

Визуализация предсказаний

Визуализируйте предсказания на графике поля точек. Постройте график предсказанных значений относительно истинных значений.

figure
scatter(YPredicted,YValidation,'+')
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],'r--')

Правильное вращение цифр

Можно использовать функции из Image Processing Toolbox, чтобы выпрямить цифры и отобразить их вместе. Поверните 49 цифр выборки согласно их предсказанным углам поворота с помощью imrotate (Image Processing Toolbox).

idx = randperm(numValidationImages,49);
for i = 1:numel(idx)
    image = XValidation(:,:,:,idx(i));
    predictedAngle = YPredicted(idx(i));  
    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end

Отображение исходных цифр с исправленными поворотами. Можно использовать montage (Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')

См. также

|

Похожие темы