Преобразуйте сеть классификации в сеть регрессии

В этом примере показано, как преобразовать обученную сеть классификации в сеть регрессии.

Предварительно обученные сети классификации изображений были обучены на более чем миллионе изображений и могут классифицировать изображения в 1 000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети изучили богатые представления функции для широкого спектра изображений. Сеть берет изображение в качестве входа, и затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Передача обучения обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве начальной точки, чтобы изучить новую задачу. В этом примере показано, как взять предварительно обученную сеть классификации и переобучить ее для задач регрессии.

Пример загружает предварительно обученную архитектуру сверточной нейронной сети для классификации, заменяет слои для классификации и переобучает сеть, чтобы предсказать углы вращаемых рукописных цифр. Опционально, можно использовать imrotate (Image Processing Toolbox™), чтобы откорректировать повороты изображения с помощью ожидаемых значений.

Загрузите предварительно обученную сеть

Загрузите предварительно обученную сеть от вспомогательного файла digitsNet.mat. Этот файл содержит сеть классификации, которая классифицирует рукописные цифры.

load digitsNet
layers = net.Layers
layers = 
  15x1 Layer array with layers:

     1   'imageinput'    Image Input             28x28x1 images with 'zerocenter' normalization
     2   'conv_1'        Convolution             8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'   Batch Normalization     Batch normalization with 8 channels
     4   'relu_1'        ReLU                    ReLU
     5   'maxpool_1'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'        Convolution             16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'   Batch Normalization     Batch normalization with 16 channels
     8   'relu_2'        ReLU                    ReLU
     9   'maxpool_2'     Max Pooling             2x2 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'        Convolution             32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'   Batch Normalization     Batch normalization with 32 channels
    12   'relu_3'        ReLU                    ReLU
    13   'fc'            Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   crossentropyex with '0' and 9 other classes

Загрузка данных

Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в градусах), которыми вращается каждое изображение.

Загрузите изображения обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData и digitTest4DArrayData. Выходные параметры YTrain и YValidation углы поворота в градусах. Наборы данных обучения и валидации каждый содержит 5 000 изображений.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Отобразите 20 случайных учебных изображений с помощью imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

Замените последние слои

Сверточные слои сетевого извлечения отображают функции что последний learnable слой и итоговое использование слоя классификации, чтобы классифицировать входное изображение. Эти два слоя, 'fc' и 'classoutput' в digitsNet, содержите информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть для регрессии, замените эти два слоя на новые слои, адаптированные к задаче.

Замените итоговый полносвязный слой, softmax слой и классификацию выходной слой с полносвязным слоем размера 1 (количество ответов) и слой регрессии.

numResponses = 1;
layers = [
    layers(1:12)
    fullyConnectedLayer(numResponses)
    regressionLayer];

Заморозьте начальные слои

Сеть теперь готова быть переобученной на новых данных. Опционально, можно "заморозить" веса более ранних слоев в сети, установив скорости обучения в тех слоях обнулить. Во время обучения, trainNetwork не обновляет параметры блокированных слоев. Поскольку градиенты блокированных слоев не должны быть вычислены, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение. Если новый набор данных мал, то замораживание более ранних слоев сети может также препятствовать тому, чтобы те слои сверхсоответствовали к новому набору данных.

Используйте функцию поддержки freezeWeights обнулять скорости обучения в первых 12 слоях.

layers(1:12) = freezeWeights(layers(1:12));

Обучение сети

Создайте сетевые опции обучения. Установите начальную букву, изучают уровень 0,001. Контролируйте сетевую точность во время обучения путем определения данных о валидации. Включите график процесса обучения и выключите командное окно выход.

options = trainingOptions('sgdm',...
    'InitialLearnRate',0.001, ...
    'ValidationData',{XValidation,YValidation},...
    'Plots','training-progress',...
    'Verbose',false);

Создайте сеть с помощью trainNetwork. Эта команда использует совместимый графический процессор при наличии. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). В противном случае, trainNetwork использует центральный процессор.

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (25-Aug-2021 07:32:47) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 10 objects of type patch, text, line. Axes object 2 contains 10 objects of type patch, text, line.

Тестирование сети

Проверьте производительность сети путем оценки точности на данных о валидации.

Используйте predict предсказать углы вращения изображений валидации.

YPred = predict(net,XValidation);

Оцените эффективность модели путем вычисления:

  1. Процент предсказаний в приемлемом допуске на погрешность

  2. Среднеквадратичная ошибка (RMSE) предсказанных и фактических углов вращения

Вычислите ошибку предсказания между предсказанными и фактическими углами вращения.

predictionError = YValidation - YPred;

Вычислите количество предсказаний в приемлемом допуске на погрешность от истинных углов. Установите порог, чтобы быть 10 градусами. Вычислите процент предсказаний в этом пороге.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numImagesValidation = numel(YValidation);

accuracy = numCorrect/numImagesValidation
accuracy = 0.7532

Используйте среднеквадратичную ошибку (RMSE), чтобы измерить различия между предсказанными и фактическими углами вращения.

 rmse = sqrt(mean(predictionError.^2))
rmse = single
    9.0271

Правильные вращения цифры

Можно использовать функции из Image Processing Toolbox, чтобы выправить цифры и отобразить их вместе. Вращайте 49 демонстрационных цифр согласно их предсказанным углам вращения с помощью imrotate (Image Processing Toolbox).

idx = randperm(numImagesValidation,49);
for i = 1:numel(idx)
    I = XValidation(:,:,:,idx(i));
    Y = YPred(idx(i));  
    XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop');
end

Отобразите исходные цифры с их откорректированными вращениями. Используйте montage (Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(XValidationCorrected)
title('Corrected')

Figure contains 2 axes objects. Axes object 1 with title Original contains an object of type image. Axes object 2 with title Corrected contains an object of type image.

Смотрите также

|

Похожие темы