Обучите сверточную нейронную сеть для регрессии

Этот пример показывает, как соответствовать модели регрессии использование сверточных нейронных сетей, чтобы предсказать углы вращения рукописных цифр.

Сверточные нейронные сети (CNNs или ConvNets) являются особыми инструментами для глубокого обучения и особенно подходят для анализа данных изображения. Например, можно использовать CNNs, чтобы классифицировать изображения. Чтобы предсказать текущие данные, такие как углы и расстояния, можно включать слой регрессии в конце сети.

Пример создает сверточную архитектуру нейронной сети, обучает сеть и использует обучивший сеть, чтобы предсказать углы вращаемых рукописных цифр. Эти прогнозы полезны для оптического распознавания символов.

Опционально, можно использовать imrotate (Image Processing Toolbox™), чтобы вращать изображения и boxplot (Statistics and Machine Learning Toolbox™), чтобы создать остаточную диаграмму.

Загрузка данных

Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в градусах), которыми вращается каждое изображение.

Загрузите изображения обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData и digitTest4DArrayData. Выходные параметры YTrain и YValidation являются углами поворота в градусах. Наборы данных обучения и валидации каждый содержит 5 000 изображений.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Отобразите 20 случайных учебных изображений с помощью imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
    drawnow
end

Проверяйте нормализацию данных

При обучении нейронных сетей это часто помогает убедиться, что данные нормированы на всех этапах сети. Нормализация помогает стабилизироваться и ускорить сетевое обучение с помощью спуска градиента. Если ваши данные плохо масштабируются, то потеря может стать NaN, и сетевые параметры могут отличаться во время обучения. Распространенные способы нормировать данные включают перемасштабирование данных так, чтобы его область значений стала [0,1] или так, чтобы это имело среднее значение нулевого и стандартного отклонения одного. Можно нормировать следующие данные:

  • Входные данные. Нормируйте предикторы, прежде чем вы введете их к сети. В этом примере входные изображения уже нормированы к области значений [0,1].

  • Слой выходные параметры. Можно нормировать выходные параметры каждого сверточного и полносвязный слой при помощи пакетного слоя нормализации.

  • Ответы. Если вы используете пакетные слои нормализации, чтобы нормировать слой выходные параметры в конце сети, то прогнозы сети нормированы, когда обучение запускается. Если ответ имеет совсем другую шкалу из этих прогнозов, то сетевое обучение может не сходиться. Если ваш ответ плохо масштабируется, то попытайтесь нормировать его и смотрите, улучшается ли сетевое обучение. Если вы нормируете ответ перед обучением, то необходимо преобразовать прогнозы обучившего сеть, чтобы получить прогнозы исходного ответа.

Постройте распределение ответа. Ответ (угол поворота в градусах) приблизительно равномерно распределен между-45 и 45, который работает хорошо, не нуждаясь в нормализации. В проблемах классификации выходные параметры являются вероятностями класса, которые всегда нормируются.

figure
histogram(YTrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

В целом данные не должны быть точно нормированы. Однако, если вы обучаете сеть в этом примере, чтобы предсказать 100*YTrain или YTrain+500 вместо YTrain, затем потеря становится NaN, и сетевые параметры отличаются, когда обучение запускается. Эти результаты происходят даже при том, что единственной разницей между сетью, предсказывающей да + b и сетью, предсказывая Y, является простое перемасштабирование весов и смещения итогового полносвязного слоя.

Если распределение входа или ответ являются очень неравномерными или скошенными, можно также выполнить нелинейные преобразования (например, беря логарифмы) к данным прежде, чем обучить сеть.

Создайте сетевые слои

Чтобы решить проблему регрессии, создайте слои сети и включайте слой регрессии в конце сети.

Первый слой задает размер и тип входных данных. Входные изображения 28 28 1. Создайте входной слой изображений, одного размера как учебные изображения.

Средние слои сети define базовая архитектура сети, где большая часть вычисления и изучения происходят.

Последние слои задают размер и тип выходных данных. Для проблем регрессии полносвязный слой должен предшествовать слою регрессии в конце сети. Создайте полностью связанный выходной слой размера 1 и слой регрессии.

Объедините все слои вместе в массиве Layer.

layers = [
    imageInputLayer([28 28 1])

    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    averagePooling2dLayer(2,'Stride',2)

    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    averagePooling2dLayer(2,'Stride',2)
  
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    dropoutLayer(0.2)
    fullyConnectedLayer(1)
    regressionLayer];

Обучение сети

Создайте сетевые опции обучения. Обучайтесь в течение 30 эпох. Установите начальную букву, изучают уровень 0,001 и понижают темп обучения после 20 эпох. Контролируйте сетевую точность во время обучения путем определения данных о валидации и частоты валидации. Программное обеспечение обучает сеть на данных тренировки и вычисляет точность на данные о валидации равномерно во время обучения. Данные о валидации не используются, чтобы обновить сетевые веса. Включите учебный график прогресса и выключите командное окно вывод.

miniBatchSize  = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',30, ...
    'InitialLearnRate',1e-3, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',20, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'Verbose',false);

Создайте сеть с помощью trainNetwork. Эта команда использует совместимый графический процессор при наличии. В противном случае trainNetwork использует центральный процессор. CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0, или выше требуется для обучения на графическом процессоре.

net = trainNetwork(XTrain,YTrain,layers,options);

Исследуйте детали сетевой архитектуры, содержавшейся в свойстве Layers net.

net.Layers
ans = 
  18x1 Layer array with layers:

     1   'imageinput'         Image Input           28x28x1 images with 'zerocenter' normalization
     2   'conv_1'             Convolution           8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'        Batch Normalization   Batch normalization with 8 channels
     4   'relu_1'             ReLU                  ReLU
     5   'avgpool2d_1'        Average Pooling       2x2 average pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'             Convolution           16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'        Batch Normalization   Batch normalization with 16 channels
     8   'relu_2'             ReLU                  ReLU
     9   'avgpool2d_2'        Average Pooling       2x2 average pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'             Convolution           32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'        Batch Normalization   Batch normalization with 32 channels
    12   'relu_3'             ReLU                  ReLU
    13   'conv_4'             Convolution           32 3x3x32 convolutions with stride [1  1] and padding 'same'
    14   'batchnorm_4'        Batch Normalization   Batch normalization with 32 channels
    15   'relu_4'             ReLU                  ReLU
    16   'dropout'            Dropout               20% dropout
    17   'fc'                 Fully Connected       1 fully connected layer
    18   'regressionoutput'   Regression Output     mean-squared-error with response 'Response'

Тестирование сети

Проверьте производительность сети путем оценки точности на данных о валидации.

Используйте predict, чтобы предсказать углы вращения изображений валидации.

YPredicted = predict(net,XValidation);

Оцените производительность

Оцените производительность модели путем вычисления:

  1. Процент прогнозов в приемлемом допуске на погрешность

  2. Среднеквадратичная ошибка (RMSE) предсказанных и фактических углов вращения

Вычислите ошибку прогноза между предсказанными и фактическими углами вращения.

predictionError = YValidation - YPredicted;

Вычислите количество прогнозов в приемлемом допуске на погрешность от истинных углов. Установите порог, чтобы быть 10 градусами. Вычислите процент прогнозов в этом пороге.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);

accuracy = numCorrect/numValidationImages
accuracy = 0.9684

Используйте среднеквадратичную ошибку (RMSE), чтобы измерить различия между предсказанными и фактическими углами вращения.

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse = single
    4.5336

Отобразите диаграмму невязок для каждого класса цифры

Функция boxplot требует матрицы, где каждый столбец соответствует невязкам для каждого класса цифры.

Группы данных о валидации отображают классами 0-9 цифры с 500 примерами каждого. Используйте reshape, чтобы сгруппировать невязки классом цифры.

residualMatrix = reshape(predictionError,500,10);

Каждый столбец residualMatrix соответствует невязкам каждой цифры. Создайте остаточную диаграмму для каждой цифры с помощью boxplot (Statistics and Machine Learning Toolbox).

figure
boxplot(residualMatrix,...
    'Labels',{'0','1','2','3','4','5','6','7','8','9'})
xlabel('Digit Class')
ylabel('Degrees Error')
title('Residuals')

Классы цифры с самой высокой точностью имеют среднее значение близко к нулю и небольшому отклонению.

Правильные вращения цифры

Можно использовать функции из Image Processing Toolbox, чтобы выправить цифры и отобразить их вместе. Вращайте 49 демонстрационных цифр согласно их предсказанным углам вращения с помощью imrotate (Image Processing Toolbox).

idx = randperm(numValidationImages,49);
for i = 1:numel(idx)
    image = XValidation(:,:,:,idx(i));
    predictedAngle = YPredicted(idx(i));  
    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end

Отобразите исходные цифры с их исправленными вращениями. Можно использовать montage (Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')

Смотрите также

|

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте