В этом примере показано, как преобразовать обученную сеть классификации в сеть регрессии.
Предварительно обученные сети классификации изображений были обучены на более чем миллионе изображений и могут классифицировать изображения в 1 000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети изучили богатые представления функции для широкого спектра изображений. Сеть берет изображение в качестве входа, и затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.
Передача обучения обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве начальной точки, чтобы изучить новую задачу. В этом примере показано, как взять предварительно обученную сеть классификации и переобучить ее для задач регрессии.
Пример загружает предварительно обученную архитектуру сверточной нейронной сети для классификации, заменяет слои для классификации и переобучает сеть, чтобы предсказать углы вращаемых рукописных цифр. Опционально, можно использовать imrotate
(Image Processing Toolbox™), чтобы откорректировать повороты изображения с помощью ожидаемых значений.
Загрузите предварительно обученную сеть от вспомогательного файла digitsNet.mat
. Этот файл содержит сеть классификации, которая классифицирует рукописные цифры.
load digitsNet
layers = net.Layers
layers = 15x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv_1' Convolution 8 3x3x1 convolutions with stride [1 1] and padding 'same' 3 'batchnorm_1' Batch Normalization Batch normalization with 8 channels 4 'relu_1' ReLU ReLU 5 'maxpool_1' Max Pooling 2x2 max pooling with stride [2 2] and padding [0 0 0 0] 6 'conv_2' Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'maxpool_2' Max Pooling 2x2 max pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' Convolution 32 3x3x16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'fc' Fully Connected 10 fully connected layer 14 'softmax' Softmax softmax 15 'classoutput' Classification Output crossentropyex with '0' and 9 other classes
Набор данных содержит синтетические изображения рукописных цифр вместе с соответствующими углами (в градусах), которыми вращается каждое изображение.
Загрузите изображения обучения и валидации как 4-D массивы с помощью digitTrain4DArrayData
и digitTest4DArrayData
. Выходные параметры YTrain
и YValidation
углы поворота в градусах. Наборы данных обучения и валидации каждый содержит 5 000 изображений.
[XTrain,~,YTrain] = digitTrain4DArrayData; [XValidation,~,YValidation] = digitTest4DArrayData;
Отобразите 20 случайных учебных изображений с помощью imshow
.
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) end
Сверточные слои сетевого извлечения отображают функции что последний learnable слой и итоговое использование слоя классификации, чтобы классифицировать входное изображение. Эти два слоя, 'fc'
и 'classoutput'
в digitsNet
, содержите информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть для регрессии, замените эти два слоя на новые слои, адаптированные к задаче.
Замените итоговый полносвязный слой, softmax слой и классификацию выходной слой с полносвязным слоем размера 1 (количество ответов) и слой регрессии.
numResponses = 1; layers = [ layers(1:12) fullyConnectedLayer(numResponses) regressionLayer];
Сеть теперь готова быть переобученной на новых данных. Опционально, можно "заморозить" веса более ранних слоев в сети, установив скорости обучения в тех слоях обнулить. Во время обучения, trainNetwork
не обновляет параметры блокированных слоев. Поскольку градиенты блокированных слоев не должны быть вычислены, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение. Если новый набор данных мал, то замораживание более ранних слоев сети может также препятствовать тому, чтобы те слои сверхсоответствовали к новому набору данных.
Используйте функцию поддержки freezeWeights
обнулять скорости обучения в первых 12 слоях.
layers(1:12) = freezeWeights(layers(1:12));
Создайте сетевые опции обучения. Установите начальную букву, изучают уровень 0,001. Контролируйте сетевую точность во время обучения путем определения данных о валидации. Включите график процесса обучения и выключите командное окно выход.
options = trainingOptions('sgdm',... 'InitialLearnRate',0.001, ... 'ValidationData',{XValidation,YValidation},... 'Plots','training-progress',... 'Verbose',false);
Создайте сеть с помощью trainNetwork
. Эта команда использует совместимый графический процессор при наличии. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). В противном случае, trainNetwork
использует центральный процессор.
net = trainNetwork(XTrain,YTrain,layers,options);
Проверьте производительность сети путем оценки точности на данных о валидации.
Используйте predict
предсказать углы вращения изображений валидации.
YPred = predict(net,XValidation);
Оцените эффективность модели путем вычисления:
Процент предсказаний в приемлемом допуске на погрешность
Среднеквадратичная ошибка (RMSE) предсказанных и фактических углов вращения
Вычислите ошибку предсказания между предсказанными и фактическими углами вращения.
predictionError = YValidation - YPred;
Вычислите количество предсказаний в приемлемом допуске на погрешность от истинных углов. Установите порог, чтобы быть 10 градусами. Вычислите процент предсказаний в этом пороге.
thr = 10; numCorrect = sum(abs(predictionError) < thr); numImagesValidation = numel(YValidation); accuracy = numCorrect/numImagesValidation
accuracy = 0.7532
Используйте среднеквадратичную ошибку (RMSE), чтобы измерить различия между предсказанными и фактическими углами вращения.
rmse = sqrt(mean(predictionError.^2))
rmse = single
9.0271
Можно использовать функции из Image Processing Toolbox, чтобы выправить цифры и отобразить их вместе. Вращайте 49 демонстрационных цифр согласно их предсказанным углам вращения с помощью imrotate
(Image Processing Toolbox).
idx = randperm(numImagesValidation,49); for i = 1:numel(idx) I = XValidation(:,:,:,idx(i)); Y = YPred(idx(i)); XValidationCorrected(:,:,:,i) = imrotate(I,Y,'bicubic','crop'); end
Отобразите исходные цифры с их откорректированными вращениями. Используйте montage
(Image Processing Toolbox), чтобы отобразить цифры вместе в одном изображении.
figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title('Original') subplot(1,2,2) montage(XValidationCorrected) title('Corrected')
regressionLayer
| classificationLayer