В этом примере показано, как преобразовать обученную классификационную сеть в регрессионую сеть.
Предварительно обученные сети классификации изображений были обучены на более чем миллионе изображений и могут классифицировать изображения в 1000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети узнали представления богатых функций для широкой области значений изображений. Сеть принимает изображение как вход, а затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.
Передача обучения обычно используется в применениях глубокого обучения. Можно взять предварительно обученную сеть и использовать ее как начальная точка для изучения новой задачи. В этом примере показано, как взять предварительно обученную классификационную сеть и переобучить ее для задач регрессии.
Пример загружает предварительно обученную архитектуру сверточной нейронной сети для классификации, заменяет слои для классификации и переобучает сеть, чтобы предсказать углы повернутых рукописных цифр. Опционально можно использовать imrotate
(Image Processing Toolbox™), чтобы исправить вращение изображения с помощью предсказанных значений.
Загрузите предварительно обученную сеть из вспомогательного файла digitsNet.mat
. Этот файл содержит сеть классификации, которая классифицирует рукописные цифры.
load digitsNet
layers = net.Layers
layers = 15×1 Layer array with layers: 1 'imageinput' Image Input 28×28×1 images with 'zerocenter' normalization 2 'conv_1' Convolution 8 3×3×1 convolutions with stride [1 1] and padding 'same' 3 'batchnorm_1' Batch Normalization Batch normalization with 8 channels 4 'relu_1' ReLU ReLU 5 'maxpool_1' Max Pooling 2×2 max pooling with stride [2 2] and padding [0 0 0 0] 6 'conv_2' Convolution 16 3×3×8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'maxpool_2' Max Pooling 2×2 max pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' Convolution 32 3×3×16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'fc' Fully Connected 10 fully connected layer 14 'softmax' Softmax softmax 15 'classoutput' Classification Output crossentropyex with '0' and 9 other classes
Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в степенях), на которые поворачивается каждое изображение.
Загрузите изображения для обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData
и digitTest4DArrayData
. Выходные выходы YTrain
и YValidation
являются углами поворота в степенях. Каждый набор обучающих и валидационных данных содержит 5000 изображений.
[XTrain,~,YTrain] = digitTrain4DArrayData; [XValidation,~,YValidation] = digitTest4DArrayData;
Отобразите 20 случайных обучающих изображений с помощью imshow
.
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) end
Сверточные слои сети извлекают изображение, функции последний выучиваемый слой и конечный слой классификации используют для классификации входа изображения. Эти два слоя, 'fc'
и 'classoutput'
в digitsNet
, содержат информацию о том, как объединить функции, которые сеть извлекает в вероятности классов, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть для регрессии, замените эти два слоя новыми слоями, адаптированными к задаче.
Замените конечный полносвязный слой, слой softmax и выходной слой классификации на полносвязный слой размера 1 (количество откликов) и регрессионный слой.
numResponses = 1; layers = [ layers(1:12) fullyConnectedLayer(numResponses) regressionLayer];
Теперь сеть готова к переобучению на новых данных. Опционально можно «заморозить» веса более ранних слоев в сети, установив скорости обучения в этих слоях на нуль. Во время обучения, trainNetwork
не обновляет параметры замороженных слоев. Поскольку градиенты замороженных слоев не нужно вычислять, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение. Если новый набор данных является маленьким, то замораживание ранее слоев сети может также предотвратить сверхподбор кривой этих слоев к новому набору данных.
Используйте вспомогательную функцию freezeWeights
установить скорости обучения на нуле в первых 12 слоях.
layers(1:12) = freezeWeights(layers(1:12));
Создайте опции обучения. Установите начальный темп обучения равным 0,001. Отслеживайте точность сети во время обучения путем определения данных валидации. Включите график процесса обучения и отключите вывод командного окна.
options = trainingOptions('sgdm',... 'InitialLearnRate',0.001, ... 'ValidationData',{XValidation,YValidation},... 'Plots','training-progress',... 'Verbose',false);
Создайте сеть с помощью trainNetwork
. Эта команда использует совместимый графический процессор, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). В противном случае trainNetwork
использует центральный процессор.
net = trainNetwork(XTrain,YTrain,layers,options);
Проверьте эффективность сети путем оценки точности данных валидации.
Использование predict
для предсказания углов поворота изображений валидации.
YPred = predict(net,XValidation);
Оцените эффективность модели путем вычисления:
Процент предсказаний в пределах допустимого запаса по ошибке
Среднеквадратичная ошибка (RMSE) предсказанного и фактического углов поворота
Вычислите ошибку предсказания между предсказанным и фактическим углами поворота.
predictionError = YValidation - YPred;
Вычислим количество предсказаний в пределах допустимого запаса по ошибке из истинных углов. Установите порог равным 10 степеням. Вычислите процент предсказаний в пределах этого порога.
thr = 10; numCorrect = sum(abs(predictionError) < thr); numImagesValidation = numel(YValidation); accuracy = numCorrect/numImagesValidation
accuracy = 0.7532
Используйте среднеквадратическую ошибку (RMSE), чтобы измерить различия между предсказанным и фактическим углами поворота.
rmse = sqrt(mean(predictionError.^2))
rmse = single
9.0270
Можно использовать функции из Image Processing Toolbox, чтобы выпрямить цифры и отобразить их вместе. Поверните 49 цифр выборки согласно их предсказанным углам поворота с помощью imrotate
(Image Processing Toolbox).
idx = randperm(numImagesValidation,49); for i = 1:numel(idx) I = XValidation(:,:,:,idx(i)); Y = YPred(idx(i)); XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop'); end
Отображение исходных цифр с исправленными поворотами. Использование montage
(Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.
figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title('Original') subplot(1,2,2) montage(XValidationCorrected) title('Corrected')
classificationLayer
| regressionLayer