exponenta event banner

Преобразовать классификационную сеть в регрессионную сеть

В этом примере показано, как преобразовать обученную классификационную сеть в регрессионную.

Предварительно подготовленные сети классификации изображений были обучены более чем миллиону изображений и могут классифицировать изображения на 1000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети изучили богатые представления функций для широкого спектра изображений. Сеть принимает изображение в качестве входного, а затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Transfer learning обычно используется в приложениях для глубокого обучения. Предварительно подготовленную сеть можно использовать в качестве отправной точки для изучения новой задачи. В этом примере показано, как взять предварительно подготовленную классификационную сеть и переобучить ее для выполнения регрессионных задач.

Пример загружает предварительно обученную архитектуру сверточной нейронной сети для классификации, заменяет слои для классификации и перепрофилирует сеть для прогнозирования углов повернутых рукописных цифр. Дополнительно можно использовать imrotate( Toolbox™ обработки изображения) для коррекции вращений изображения с использованием прогнозируемых значений.

Загрузить предварительно обученную сеть

Загрузка предварительно обученной сети из поддерживающего файла digitsNet.mat. Этот файл содержит классификационную сеть, которая классифицирует рукописные цифры.

load digitsNet
layers = net.Layers
layers = 
  15×1 Layer array with layers:

     1   'imageinput'    Image Input             28×28×1 images with 'zerocenter' normalization
     2   'conv_1'        Convolution             8 3×3×1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'   Batch Normalization     Batch normalization with 8 channels
     4   'relu_1'        ReLU                    ReLU
     5   'maxpool_1'     Max Pooling             2×2 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'        Convolution             16 3×3×8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'   Batch Normalization     Batch normalization with 16 channels
     8   'relu_2'        ReLU                    ReLU
     9   'maxpool_2'     Max Pooling             2×2 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'        Convolution             32 3×3×16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'   Batch Normalization     Batch normalization with 32 channels
    12   'relu_3'        ReLU                    ReLU
    13   'fc'            Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   crossentropyex with '0' and 9 other classes

Загрузить данные

Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в градусах), на которые поворачивается каждое изображение.

Загрузка образов обучения и проверки в виде массивов 4-D с помощью digitTrain4DArrayData и digitTest4DArrayData. Продукция YTrain и YValidation - углы поворота в градусах. Каждый набор данных обучения и проверки содержит 5000 изображений.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Отображение 20 случайных обучающих изображений с помощью imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Заменить конечные слои

Сверточные уровни сетевого извлеченного изображения отличаются тем, что последний обучаемый уровень и конечный уровень классификации используют для классификации входного изображения. Эти два слоя, 'fc' и 'classoutput' в digitsNet, содержат информацию о том, как объединить функции, извлекаемые сетью, в вероятности классов, значение потерь и прогнозируемые метки. Для переподготовки заранее обученной сети для регрессии замените эти два слоя на новые, адаптированные к задаче.

Замените конечный полностью подключенный уровень, уровень softmax и уровень классификационного выхода на полностью подключенный уровень размера 1 (количество откликов) и уровень регрессии.

numResponses = 1;
layers = [
    layers(1:12)
    fullyConnectedLayer(numResponses)
    regressionLayer];

Заморозить начальные слои

Сеть теперь готова переквалифицироваться на новые данные. Дополнительно можно «заморозить» веса более ранних слоев в сети, установив нулевые показатели обучения в этих слоях. Во время обучения, trainNetwork не обновляет параметры замороженных слоев. Поскольку градиенты замороженных слоев не нуждаются в вычислении, замораживание весов многих начальных слоев может значительно ускорить тренировку сети. Если новый набор данных мал, то замораживание более ранних сетевых уровней также может предотвратить перенастройку этих уровней в новый набор данных.

Использовать вспомогательную функцию freezeWeights установка нулевых показателей обучения на первых 12 уровнях.

layers(1:12) = freezeWeights(layers(1:12));

Железнодорожная сеть

Создайте параметры сетевого обучения. Установите начальную скорость обучения на 0,001. Контроль точности сети во время обучения путем указания данных проверки. Включите график хода обучения и выключите вывод в командное окно.

options = trainingOptions('sgdm',...
    'InitialLearnRate',0.001, ...
    'ValidationData',{XValidation,YValidation},...
    'Plots','training-progress',...
    'Verbose',false);

Создание сети с помощью trainNetwork. Эта команда использует совместимый графический процессор, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox). В противном случае trainNetwork использует ЦП.

net = trainNetwork(XTrain,YTrain,layers,options);

Тестовая сеть

Проверка производительности сети путем оценки точности данных проверки.

Использовать predict для прогнозирования углов поворота изображений проверки.

YPred = predict(net,XValidation);

Оцените производительность модели, рассчитав:

  1. Процент прогнозов в пределах допустимого запаса ошибок

  2. Среднеквадратическая погрешность (RMSE) прогнозируемого и фактического углов поворота

Вычислите ошибку прогнозирования между прогнозируемым и фактическим углами поворота.

predictionError = YValidation - YPred;

Вычислите количество прогнозов в пределах допустимого запаса погрешности от истинных углов. Установите порог равным 10 градусам. Вычислите процент прогнозов в пределах этого порога.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numImagesValidation = numel(YValidation);

accuracy = numCorrect/numImagesValidation
accuracy = 0.7532

Используйте среднеквадратическую ошибку (RMSE) для измерения разности между прогнозируемым и фактическим углами поворота.

 rmse = sqrt(mean(predictionError.^2))
rmse = single
    9.0270

Правильное вращение цифр

Для выпрямления цифр и их отображения можно использовать функции панели инструментов обработки изображений. Поверните 49 цифр образца в соответствии с прогнозируемыми углами поворота с помощью imrotate(Панель инструментов обработки изображений).

idx = randperm(numImagesValidation,49);
for i = 1:numel(idx)
    I = XValidation(:,:,:,idx(i));
    Y = YPred(idx(i));  
    XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop');
end

Отображение исходных цифр с исправленными поворотами. Использовать montage(Панель инструментов обработки изображений) для отображения цифр в одном изображении.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(XValidationCorrected)
title('Corrected')

См. также

|

Связанные темы