Извлеките функции изображений Используя предварительно обученную сеть

В этом примере показано, как извлечь изученные функции изображений из предварительно обученной сверточной нейронной сети и использовать те функции, чтобы обучить классификатор изображений. Извлечение признаков является самым легким и самым быстрым способом использовать представительную степень предварительно обученных глубоких сетей. Например, можно обучить машину опорных векторов (SVM) с помощью fitcecoc (Statistics and Machine Learning Toolbox™) на извлеченных функциях. Поскольку извлечение признаков только требует одного прохода через данные, это - хорошая начальная точка, если у вас нет графического процессора, чтобы ускорить сетевое обучение с.

Загрузка данных

Разархивируйте и загрузите демонстрационные изображения как datastore изображений. imageDatastore автоматически помечает изображения на основе имен папок и хранит данные как ImageDatastore объект. Datastore изображений позволяет вам сохранить большие данные изображения, включая данные, которые не умещаются в памяти. Разделите данные в 70%-е обучение и 30% тестовых данных.

unzip('MerchData.zip');
imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsTest] = splitEachLabel(imds,0.7,'randomized');

Существует теперь 55 учебных изображений и 20 изображений валидации в этом очень небольшом наборе данных. Отобразите некоторые демонстрационные изображения.

numTrainImages = numel(imdsTrain.Labels);
idx = randperm(numTrainImages,16);
figure
for i = 1:16
    subplot(4,4,i)
    I = readimage(imdsTrain,idx(i));
    imshow(I)
end

Figure contains 16 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image.

Загрузите предварительно обученную сеть

Загрузите предварительно обученную сеть ResNet-18. Если Модель Deep Learning Toolbox для пакета Сетевой поддержки ResNet-18 не установлена, то программное обеспечение обеспечивает ссылку на загрузку. ResNet-18 обучен больше чем на миллионе изображений и может классифицировать изображения в 1 000 категорий объектов, таких как клавиатура, мышь, карандаш и многие животные. В результате модель изучила богатые представления функции для широкого спектра изображений.

net = resnet18
net = 
  DAGNetwork with properties:

         Layers: [71x1 nnet.cnn.layer.Layer]
    Connections: [78x2 table]
     InputNames: {'data'}
    OutputNames: {'ClassificationLayer_predictions'}

Анализируйте сетевую архитектуру. Первый слой, изображение ввело слой, требует входных изображений размера 224 224 3, где 3 количество цветовых каналов.

inputSize = net.Layers(1).InputSize;
analyzeNetwork(net)

Извлеките функции изображений

Сеть требует входных изображений размера 224 224 3, но изображения в хранилищах данных изображений имеют различные размеры. Чтобы автоматически изменить размер обучения и тестовых изображений, прежде чем они будут введены к сети, создайте увеличенные хранилища данных изображений, задайте желаемый размер изображения и используйте эти хранилища данных в качестве входных параметров к activations.

augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);

Сеть создает иерархическое представление входных изображений. Более глубокие слои содержат высокоуровневые функции, созданное использование функций низшего уровня более ранних слоев. Чтобы получить представления функции обучения и тестовых изображений, используйте activations на глобальном слое объединения, 'pool5', в конце сети. Глобальный слой объединения объединяет входные функции по всем пространственным местоположениям, давая 512 функций всего.

layer = 'pool5';
featuresTrain = activations(net,augimdsTrain,layer,'OutputAs','rows');
featuresTest = activations(net,augimdsTest,layer,'OutputAs','rows');

whos featuresTrain
  Name                Size              Bytes  Class     Attributes

  featuresTrain      55x512            112640  single              

Извлеките метки класса из обучения и тестовых данных.

YTrain = imdsTrain.Labels;
YTest = imdsTest.Labels;

Подходящий классификатор изображений

Используйте функции, извлеченные из учебных изображений как переменные предикторы, и соответствуйте машине опорных векторов (SVM) мультикласса с помощью fitcecoc (Statistics and Machine Learning Toolbox).

classifier = fitcecoc(featuresTrain,YTrain);

Классифицируйте тестовые изображения

Классифицируйте тестовые изображения с помощью обученной модели SVM, использующей функции, извлеченные из тестовых изображений.

YPred = predict(classifier,featuresTest);

Отобразите четыре демонстрационных тестовых изображения с их предсказанными метками.

idx = [1 5 10 15];
figure
for i = 1:numel(idx)
    subplot(2,2,i)
    I = readimage(imdsTest,idx(i));
    label = YPred(idx(i));
    imshow(I)
    title(char(label))
end

Figure contains 4 axes objects. Axes object 1 with title MathWorks Cap contains an object of type image. Axes object 2 with title MathWorks Cube contains an object of type image. Axes object 3 with title MathWorks Playing Cards contains an object of type image. Axes object 4 with title MathWorks Screwdriver contains an object of type image.

Вычислите точность классификации на набор тестов. Точность является частью меток, которые сеть предсказывает правильно.

accuracy = mean(YPred == YTest)
accuracy = 1

Обучите классификатор на более мелких функциях

Можно также извлечь функции из более раннего слоя в сети и обучить классификатор на тех функциях. Более ранние слои обычно извлекают меньше, более мелкие функции, имеют более высокое пространственное разрешение и большее общее количество активаций. Извлеките функции из 'res3b_relu' слой. Это - последний слой, из которого выходные параметры 128 функций и активации имеют пространственный размер 28 28.

layer = 'res3b_relu';
featuresTrain = activations(net,augimdsTrain,layer);
featuresTest = activations(net,augimdsTest,layer);
whos featuresTrain
  Name                Size                      Bytes  Class     Attributes

  featuresTrain      28x28x128x55            22077440  single              

Извлеченные функции, использованные в первой части этого примера, были объединены по всем пространственным местоположениям глобальным слоем объединения. Чтобы достигнуть того же результата при извлечении функций в более ранних слоях, вручную насчитайте активации по всем пространственным местоположениям. Получить функции на форме N-by-C, где N является количеством наблюдений, и C является количеством функций, удалите одноэлементные размерности и транспонируйте.

featuresTrain = squeeze(mean(featuresTrain,[1 2]))';
featuresTest = squeeze(mean(featuresTest,[1 2]))';
whos featuresTrain
  Name                Size             Bytes  Class     Attributes

  featuresTrain      55x128            28160  single              

Обучите классификатор SVM на более мелких функциях. Вычислите тестовую точность.

classifier = fitcecoc(featuresTrain,YTrain);
YPred = predict(classifier,featuresTest);
accuracy = mean(YPred == YTest)
accuracy = 0.9500

Оба обучались, SVMs имеют высокую точность. Если точность высоко достаточно не использует извлечение признаков, то попробуйте передачу обучения вместо этого. Для примера смотрите, Обучают Нейронную сеть для глубокого обучения Классифицировать Новые Изображения. Для списка и сравнения предварительно обученных сетей, смотрите Предварительно обученные Глубокие нейронные сети.

Смотрите также

(Statistics and Machine Learning Toolbox) |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте