Преобразуйте классификационную сеть в регрессионую

В этом примере показано, как преобразовать обученную классификационную сеть в регрессионую сеть.

Предварительно обученные сети классификации изображений были обучены на более чем миллионе изображений и могут классифицировать изображения в 1000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети узнали представления богатых функций для широкой области значений изображений. Сеть принимает изображение как вход, а затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Передача обучения обычно используется в применениях глубокого обучения. Можно взять предварительно обученную сеть и использовать ее как начальная точка для изучения новой задачи. В этом примере показано, как взять предварительно обученную классификационную сеть и переобучить ее для задач регрессии.

Пример загружает предварительно обученную архитектуру сверточной нейронной сети для классификации, заменяет слои для классификации и переобучает сеть, чтобы предсказать углы повернутых рукописных цифр. Опционально можно использовать imrotate (Image Processing Toolbox™), чтобы исправить вращение изображения с помощью предсказанных значений.

Загрузка предварительно обученной сети

Загрузите предварительно обученную сеть из вспомогательного файла digitsNet.mat. Этот файл содержит сеть классификации, которая классифицирует рукописные цифры.

load digitsNet
layers = net.Layers
layers = 
  15×1 Layer array with layers:

     1   'imageinput'    Image Input             28×28×1 images with 'zerocenter' normalization
     2   'conv_1'        Convolution             8 3×3×1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'   Batch Normalization     Batch normalization with 8 channels
     4   'relu_1'        ReLU                    ReLU
     5   'maxpool_1'     Max Pooling             2×2 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'        Convolution             16 3×3×8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'   Batch Normalization     Batch normalization with 16 channels
     8   'relu_2'        ReLU                    ReLU
     9   'maxpool_2'     Max Pooling             2×2 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'        Convolution             32 3×3×16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'   Batch Normalization     Batch normalization with 32 channels
    12   'relu_3'        ReLU                    ReLU
    13   'fc'            Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   crossentropyex with '0' and 9 other classes

Загрузка данных

Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в степенях), на которые поворачивается каждое изображение.

Загрузите изображения для обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData и digitTest4DArrayData. Выходные выходы YTrain и YValidation являются углами поворота в степенях. Каждый набор обучающих и валидационных данных содержит 5000 изображений.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Отобразите 20 случайных обучающих изображений с помощью imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Замена конечных слоев

Сверточные слои сети извлекают изображение, функции последний выучиваемый слой и конечный слой классификации используют для классификации входа изображения. Эти два слоя, 'fc' и 'classoutput' в digitsNet, содержат информацию о том, как объединить функции, которые сеть извлекает в вероятности классов, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть для регрессии, замените эти два слоя новыми слоями, адаптированными к задаче.

Замените конечный полносвязный слой, слой softmax и выходной слой классификации на полносвязный слой размера 1 (количество откликов) и регрессионный слой.

numResponses = 1;
layers = [
    layers(1:12)
    fullyConnectedLayer(numResponses)
    regressionLayer];

Замораживание начальных слоев

Теперь сеть готова к переобучению на новых данных. Опционально можно «заморозить» веса более ранних слоев в сети, установив скорости обучения в этих слоях на нуль. Во время обучения, trainNetwork не обновляет параметры замороженных слоев. Поскольку градиенты замороженных слоев не нужно вычислять, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение. Если новый набор данных является маленьким, то замораживание ранее слоев сети может также предотвратить сверхподбор кривой этих слоев к новому набору данных.

Используйте вспомогательную функцию freezeWeights установить скорости обучения на нуле в первых 12 слоях.

layers(1:12) = freezeWeights(layers(1:12));

Обучите сеть

Создайте опции обучения. Установите начальный темп обучения равным 0,001. Отслеживайте точность сети во время обучения путем определения данных валидации. Включите график процесса обучения и отключите вывод командного окна.

options = trainingOptions('sgdm',...
    'InitialLearnRate',0.001, ...
    'ValidationData',{XValidation,YValidation},...
    'Plots','training-progress',...
    'Verbose',false);

Создайте сеть с помощью trainNetwork. Эта команда использует совместимый графический процессор, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). В противном случае trainNetwork использует центральный процессор.

net = trainNetwork(XTrain,YTrain,layers,options);

Тестирование сети

Проверьте эффективность сети путем оценки точности данных валидации.

Использование predict для предсказания углов поворота изображений валидации.

YPred = predict(net,XValidation);

Оцените эффективность модели путем вычисления:

  1. Процент предсказаний в пределах допустимого запаса по ошибке

  2. Среднеквадратичная ошибка (RMSE) предсказанного и фактического углов поворота

Вычислите ошибку предсказания между предсказанным и фактическим углами поворота.

predictionError = YValidation - YPred;

Вычислим количество предсказаний в пределах допустимого запаса по ошибке из истинных углов. Установите порог равным 10 степеням. Вычислите процент предсказаний в пределах этого порога.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numImagesValidation = numel(YValidation);

accuracy = numCorrect/numImagesValidation
accuracy = 0.7532

Используйте среднеквадратическую ошибку (RMSE), чтобы измерить различия между предсказанным и фактическим углами поворота.

 rmse = sqrt(mean(predictionError.^2))
rmse = single
    9.0270

Правильное вращение цифр

Можно использовать функции из Image Processing Toolbox, чтобы выпрямить цифры и отобразить их вместе. Поверните 49 цифр выборки согласно их предсказанным углам поворота с помощью imrotate (Image Processing Toolbox).

idx = randperm(numImagesValidation,49);
for i = 1:numel(idx)
    I = XValidation(:,:,:,idx(i));
    Y = YPred(idx(i));  
    XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop');
end

Отображение исходных цифр с исправленными поворотами. Использование montage (Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(XValidationCorrected)
title('Corrected')

См. также

|

Похожие темы