Преобразуйте сеть классификации в сеть регрессии

В этом примере показано, как преобразовать обученную сеть классификации в сеть регрессии.

Предварительно обученные сети классификации изображений были обучены на более чем миллионе изображений и могут классифицировать изображения в 1 000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети изучили богатые представления функции для широкого спектра изображений. Сеть берет изображение в качестве входа, и затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Изучение передачи обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве начальной точки, чтобы изучить новую задачу. В этом примере показано, как взять предварительно обученную сеть классификации и переобучить ее для задач регрессии.

Пример загружает предварительно обученную архитектуру сверточной нейронной сети для классификации, заменяет слои для классификации и переобучает сеть, чтобы предсказать углы вращаемых рукописных цифр. Опционально, можно использовать imrotate (Image Processing Toolbox™), чтобы откорректировать повороты изображения с помощью ожидаемых значений.

Загрузите предварительно обученную сеть

Загрузите предварительно обученную сеть от вспомогательного файла digitsNet.mat. Этот файл содержит сеть классификации, которая классифицирует рукописные цифры.

load digitsNet
layers = net.Layers
layers = 
  15x1 Layer array with layers:

     1   'imageinput'    Image Input             28x28x1 images with 'zerocenter' normalization
     2   'conv_1'        Convolution             8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'   Batch Normalization     Batch normalization with 8 channels
     4   'relu_1'        ReLU                    ReLU
     5   'maxpool_1'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'        Convolution             16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'   Batch Normalization     Batch normalization with 16 channels
     8   'relu_2'        ReLU                    ReLU
     9   'maxpool_2'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'        Convolution             32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'   Batch Normalization     Batch normalization with 32 channels
    12   'relu_3'        ReLU                    ReLU
    13   'fc'            Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   crossentropyex with '0' and 9 other classes

Загрузка данных

Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в градусах), которыми вращается каждое изображение.

Загрузите изображения обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData и digitTest4DArrayData. Выходные параметры YTrain и YValidation углы поворота в градусах. Наборы данных обучения и валидации каждый содержит 5 000 изображений.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Отобразите 20 случайных учебных изображений с помощью imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
    drawnow
end

Замените последние слои

Сверточные слои сетевого извлечения отображают функции что последний learnable слой и итоговое использование слоя классификации, чтобы классифицировать входное изображение. Эти два слоя, 'fc' и 'classoutput' в digitsNet, содержите информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть для регрессии, замените эти два слоя на новые слои, адаптированные к задаче.

Замените итоговый полносвязный слой, softmax слой и классификацию выходной слой с полносвязным слоем размера 1 (количество ответов) и слой регрессии.

numResponses = 1;
layers = [
    layers(1:12)
    fullyConnectedLayer(numResponses)
    regressionLayer];

Заморозьте начальные слои

Сеть теперь готова быть переобученной на новых данных. Опционально, можно "заморозить" веса более ранних слоев в сети, установив темпы обучения в тех слоях обнулить. Во время обучения, trainNetwork не обновляет параметры блокированных слоев. Поскольку градиенты блокированных слоев не должны быть вычислены, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение. Если новый набор данных мал, то замораживание более ранних сетевых слоев может также препятствовать тому, чтобы те слои сверхсоответствовали к новому набору данных.

Используйте функцию поддержки freezeWeights обнулять темпы обучения в первых 12 слоях.

layers(1:12) = freezeWeights(layers(1:12));

Обучение сети

Создайте сетевые опции обучения. Установите начальную букву, изучают уровень 0,001. Контролируйте сетевую точность во время обучения путем определения данных о валидации. Включите график процесса обучения и выключите командное окно выход.

options = trainingOptions('sgdm',...
    'InitialLearnRate',0.001, ...
    'ValidationData',{XValidation,YValidation},...
    'Plots','training-progress',...
    'Verbose',false);

Создайте сеть с помощью trainNetwork. Эта команда использует совместимый графический процессор при наличии. В противном случае, trainNetwork использует центральный процессор. Графический процессор CUDA®-enabled NVIDIA® с вычисляет возможность 3.0, или выше требуется для обучения на графическом процессоре.

net = trainNetwork(XTrain,YTrain,layers,options);

Тестирование сети

Проверьте производительность сети путем оценки точности на данных о валидации.

Используйте predict предсказать углы вращения изображений валидации.

YPred = predict(net,XValidation);

Оцените производительность модели путем вычисления:

  1. Процент прогнозов в приемлемом допуске на погрешность

  2. Среднеквадратичная ошибка (RMSE) предсказанных и фактических углов вращения

Вычислите ошибку прогноза между предсказанными и фактическими углами вращения.

predictionError = YValidation - YPred;

Вычислите количество прогнозов в приемлемом допуске на погрешность от истинных углов. Установите порог, чтобы быть 10 градусами. Вычислите процент прогнозов в этом пороге.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numImagesValidation = numel(YValidation);

accuracy = numCorrect/numImagesValidation
accuracy = 0.7716

Используйте среднеквадратичную ошибку (RMSE), чтобы измерить различия между предсказанными и фактическими углами вращения.

 rmse = sqrt(mean(predictionError.^2))
rmse = single
    8.7786

Правильные вращения цифры

Можно использовать функции из Image Processing Toolbox, чтобы выправить цифры и отобразить их вместе. Вращайте 49 демонстрационных цифр согласно их предсказанным углам вращения с помощью imrotate (Image Processing Toolbox).

idx = randperm(numImagesValidation,49);
for i = 1:numel(idx)
    I = XValidation(:,:,:,idx(i));
    Y = YPred(idx(i));  
    XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop');
end

Отобразите исходные цифры с их откорректированными вращениями. Используйте montage (Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(XValidationCorrected)
title('Corrected')

Смотрите также

|

Похожие темы