exponenta event banner

Конволюционная нейронная сеть поезда для регрессии

В этом примере показано, как подогнать регрессионную модель с использованием сверточных нейронных сетей для прогнозирования углов поворота рукописных цифр.

Сверточные нейронные сети (CNN, или ConvNets) являются важными инструментами для глубокого обучения и особенно подходят для анализа данных изображения. Например, для классификации изображений можно использовать CNN. Для прогнозирования непрерывных данных, таких как углы и расстояния, можно включить уровень регрессии в конце сети.

Пример строит архитектуру сверточной нейронной сети, обучает сеть и использует обученную сеть для прогнозирования углов повернутых рукописных цифр. Эти прогнозы полезны для оптического распознавания символов.

Дополнительно можно использовать imrotate( Toolbox™ обработки изображений) для поворота изображений и boxplot (Статистика и Toolbox™ машинного обучения) для создания графика остаточного поля.

Загрузить данные

Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в градусах), на которые поворачивается каждое изображение.

Загрузка образов обучения и проверки в виде массивов 4-D с помощью digitTrain4DArrayData и digitTest4DArrayData. Продукция YTrain и YValidation - углы поворота в градусах. Каждый набор данных обучения и проверки содержит 5000 изображений.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Отображение 20 случайных обучающих изображений с помощью imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Проверка нормализации данных

При обучении нейронных сетей это часто помогает убедиться, что ваши данные нормализованы на всех этапах сети. Нормализация помогает стабилизировать и ускорить обучение сети с использованием градиентного спуска. Если ваши данные плохо масштабированы, то потеря может стать NaN и параметры сети могут расходиться во время обучения. Общие способы нормализации данных включают в себя масштабирование данных так, чтобы их диапазон стал равным [0,1] или чтобы они имели среднее значение нуля и стандартное отклонение единицы. Можно нормализовать следующие данные:

  • Входные данные. Нормализуйте предикторы перед вводом их в сеть. В этом примере входные изображения уже нормализованы к диапазону [0,1].

  • Выходы слоев. Можно нормализовать выходы каждого сверточного и полностью соединенного уровня с помощью уровня пакетной нормализации.

  • Ответы. При использовании уровней пакетной нормализации для нормализации выходов уровня в конце сети прогнозы сети нормализуются при запуске обучения. Если масштаб ответа очень отличается от этих прогнозов, то обучение сети может не сойтись. Если ваш ответ плохо масштабирован, попробуйте нормализовать его и проверьте, улучшилось ли обучение сети. Если вы нормализуете ответ перед тренировкой, то вы должны преобразовать прогнозы обученной сети, чтобы получить прогнозы исходного ответа.

Постройте график распределения ответа. Отклик (угол поворота в градусах) примерно равномерно распределён между -45 и 45, что хорошо работает без необходимости нормализации. В проблемах классификации выходами являются вероятности классов, которые всегда нормализуются.

figure
histogram(YTrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

В общем случае данные не обязательно должны быть точно нормализованы. Однако при обучении сети в этом примере прогнозированию 100*YTrain или YTrain+500 вместо YTrain, то потеря становится NaN и параметры сети расходятся при начале обучения. Эти результаты имеют место, даже если единственное различие между предсказанием сети aY + b и предсказанием сети Y заключается в простом масштабировании весов и смещений конечного полностью соединенного слоя.

Если распределение входных данных или откликов очень неравномерно или искажено, перед обучением сети можно также выполнить нелинейные преобразования (например, брать логарифмы) в данные.

Создание сетевых слоев

Чтобы решить проблему регрессии, создайте слои сети и включите уровень регрессии в конце сети.

Первый слой определяет размер и тип входных данных. Входные изображения составляют 28 на 28 на 1. Создайте слой ввода изображения того же размера, что и учебные изображения.

Средние уровни сети определяют основную архитектуру сети, где происходит большая часть вычислений и обучения.

Конечные слои определяют размер и тип выходных данных. Для проблем регрессии полностью подключенный уровень должен предшествовать уровню регрессии в конце сети. Создайте полностью подключенный выходной слой размера 1 и регрессионный слой.

Объединение всех слоев в Layer массив.

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    dropoutLayer(0.2)
    fullyConnectedLayer(1)
    regressionLayer];

Железнодорожная сеть

Создайте параметры сетевого обучения. Поезд на 30 эпох. Установите начальную скорость обучения на 0,001 и понизите скорость обучения после 20 периодов. Контроль точности сети во время обучения путем указания данных проверки и частоты проверки. Программное обеспечение обучает сеть данным обучения и вычисляет точность данных проверки через регулярные интервалы во время обучения. Данные проверки не используются для обновления весов сети. Включите график хода обучения и выключите вывод в командное окно.

miniBatchSize  = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',30, ...
    'InitialLearnRate',1e-3, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',20, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'Verbose',false);

Создание сети с помощью trainNetwork. Эта команда использует совместимый графический процессор, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox). В противном случае trainNetwork использует ЦП.

net = trainNetwork(XTrain,YTrain,layers,options);

Проверьте детали сетевой архитектуры, содержащейся в Layers имущество net.

net.Layers
ans = 
  18×1 Layer array with layers:

     1   'imageinput'         Image Input           28×28×1 images with 'zerocenter' normalization
     2   'conv_1'             Convolution           8 3×3×1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'        Batch Normalization   Batch normalization with 8 channels
     4   'relu_1'             ReLU                  ReLU
     5   'avgpool2d_1'        Average Pooling       2×2 average pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'             Convolution           16 3×3×8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'        Batch Normalization   Batch normalization with 16 channels
     8   'relu_2'             ReLU                  ReLU
     9   'avgpool2d_2'        Average Pooling       2×2 average pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'             Convolution           32 3×3×16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'        Batch Normalization   Batch normalization with 32 channels
    12   'relu_3'             ReLU                  ReLU
    13   'conv_4'             Convolution           32 3×3×32 convolutions with stride [1  1] and padding 'same'
    14   'batchnorm_4'        Batch Normalization   Batch normalization with 32 channels
    15   'relu_4'             ReLU                  ReLU
    16   'dropout'            Dropout               20% dropout
    17   'fc'                 Fully Connected       1 fully connected layer
    18   'regressionoutput'   Regression Output     mean-squared-error with response 'Response'

Тестовая сеть

Проверка производительности сети путем оценки точности данных проверки.

Использовать predict для прогнозирования углов поворота изображений проверки.

YPredicted = predict(net,XValidation);

Оценка производительности

Оцените производительность модели, рассчитав:

  1. Процент прогнозов в пределах допустимого запаса ошибок

  2. Среднеквадратическая погрешность (RMSE) прогнозируемого и фактического углов поворота

Вычислите ошибку прогнозирования между прогнозируемым и фактическим углами поворота.

predictionError = YValidation - YPredicted;

Вычислите количество прогнозов в пределах допустимого запаса погрешности от истинных углов. Установите порог равным 10 градусам. Вычислите процент прогнозов в пределах этого порога.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);

accuracy = numCorrect/numValidationImages
accuracy = 0.9690

Используйте среднеквадратическую ошибку (RMSE) для измерения разности между прогнозируемым и фактическим углами поворота.

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse = single
    4.6062

Визуализация прогнозов

Визуализация прогнозов на графике рассеяния. Постройте график прогнозируемых значений относительно истинных значений.

figure
scatter(YPredicted,YValidation,'+')
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],'r--')

Правильное вращение цифр

Для выпрямления цифр и их отображения можно использовать функции панели инструментов обработки изображений. Поверните 49 цифр образца в соответствии с прогнозируемыми углами поворота с помощью imrotate(Панель инструментов обработки изображений).

idx = randperm(numValidationImages,49);
for i = 1:numel(idx)
    image = XValidation(:,:,:,idx(i));
    predictedAngle = YPredicted(idx(i));  
    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end

Отображение исходных цифр с исправленными поворотами. Вы можете использовать montage(Панель инструментов обработки изображений) для отображения цифр в одном изображении.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')

См. также

|

Связанные темы