exponenta event banner

Извлечение функций изображения с помощью предварительно обученной сети

В этом примере показано, как извлечь изученные особенности изображения из предварительно обученной сверточной нейронной сети и использовать их для обучения классификатора изображений. Извлечение характеристик - самый простой и быстрый способ использования репрезентативной силы предварительно подготовленных глубоких сетей. Например, можно обучить вспомогательную векторную машину (SVM) с помощью fitcecoc (Статистика и Toolbox™ машинного обучения) по извлеченным функциям. Поскольку извлечение функций требует только одного прохода через данные, это хорошая отправная точка, если у вас нет графического процессора для ускорения обучения сети с помощью.

Загрузить данные

Распакуйте и загрузите образцы изображений как хранилище данных изображений. imageDatastore автоматически помечает изображения на основе имен папок и сохраняет данные в виде ImageDatastore объект. Хранилище данных изображения позволяет хранить большие данные изображения, включая данные, которые не помещаются в память. Разбейте данные на 70% обучающих и 30% тестовых данных.

unzip('MerchData.zip');
imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsTest] = splitEachLabel(imds,0.7,'randomized');

В настоящее время в этом очень небольшом наборе данных имеется 55 обучающих изображений и 20 подтверждающих изображений. Отображение некоторых образцов изображений.

numTrainImages = numel(imdsTrain.Labels);
idx = randperm(numTrainImages,16);
figure
for i = 1:16
    subplot(4,4,i)
    I = readimage(imdsTrain,idx(i));
    imshow(I)
end

Figure contains 16 axes. Axes 1 contains an object of type image. Axes 2 contains an object of type image. Axes 3 contains an object of type image. Axes 4 contains an object of type image. Axes 5 contains an object of type image. Axes 6 contains an object of type image. Axes 7 contains an object of type image. Axes 8 contains an object of type image. Axes 9 contains an object of type image. Axes 10 contains an object of type image. Axes 11 contains an object of type image. Axes 12 contains an object of type image. Axes 13 contains an object of type image. Axes 14 contains an object of type image. Axes 15 contains an object of type image. Axes 16 contains an object of type image.

Загрузить предварительно обученную сеть

Загрузите предварительно подготовленную сеть ResNet-18. Если пакет поддержки Deep Learning Toolbox Model для ResNet-18 Network не установлен, программа предоставляет ссылку для загрузки. ResNet-18 обучается на более чем миллионе изображений и может классифицировать изображения на 1000 категорий объектов, таких как клавиатура, мышь, карандаш и многие животные. В результате модель получила богатые представления элементов для широкого спектра изображений.

net = resnet18
net = 
  DAGNetwork with properties:

         Layers: [71x1 nnet.cnn.layer.Layer]
    Connections: [78x2 table]
     InputNames: {'data'}
    OutputNames: {'ClassificationLayer_predictions'}

Проанализируйте архитектуру сети. Первый слой, изображение ввело слой, требует входных изображений размера 224 на 224 на 3, где 3 количество цветных каналов.

inputSize = net.Layers(1).InputSize;
analyzeNetwork(net)

Извлечь элементы изображения

Сеть требует входных изображений размером 224-на-224-на-3, но изображения в хранилищах данных изображений имеют разные размеры. Чтобы автоматически изменить размер обучающих и тестовых изображений перед их вводом в сеть, создайте хранилища данных дополненных изображений, укажите требуемый размер изображения и используйте эти хранилища данных в качестве входных аргументов для activations.

augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);

Сеть создает иерархическое представление входных изображений. Более глубокие слои содержат элементы более высокого уровня, построенные с использованием элементов более низкого уровня более ранних слоев. Для получения функциональных представлений обучающих и тестовых изображений используйте activations на уровне глобального объединения, 'pool5', в конце сети. Уровень глобального объединения объединяет входные элементы во всех пространственных местоположениях, давая в общей сложности 512 элементов.

layer = 'pool5';
featuresTrain = activations(net,augimdsTrain,layer,'OutputAs','rows');
featuresTest = activations(net,augimdsTest,layer,'OutputAs','rows');

whos featuresTrain
  Name                Size              Bytes  Class     Attributes

  featuresTrain      55x512            112640  single              

Извлеките метки класса из данных обучения и тестирования.

YTrain = imdsTrain.Labels;
YTest = imdsTest.Labels;

Классификатор изображения подгонки

Используйте функции, извлеченные из обучающих изображений, в качестве переменных предиктора и поместите многоклассную машину векторов поддержки (SVM), используя fitcecoc (Статистика и инструментарий машинного обучения).

classifier = fitcecoc(featuresTrain,YTrain);

Классификация тестовых изображений

Классифицируйте тестовые изображения с использованием обученной модели SVM, используя функции, извлеченные из тестовых изображений.

YPred = predict(classifier,featuresTest);

Отображение четырех образцов тестовых изображений с их прогнозируемыми метками.

idx = [1 5 10 15];
figure
for i = 1:numel(idx)
    subplot(2,2,i)
    I = readimage(imdsTest,idx(i));
    label = YPred(idx(i));
    imshow(I)
    title(char(label))
end

Figure contains 4 axes. Axes 1 with title MathWorks Cap contains an object of type image. Axes 2 with title MathWorks Cube contains an object of type image. Axes 3 with title MathWorks Playing Cards contains an object of type image. Axes 4 with title MathWorks Screwdriver contains an object of type image.

Рассчитайте точность классификации на тестовом наборе. Точность - это доля меток, которую сеть предсказывает правильно.

accuracy = mean(YPred == YTest)
accuracy = 1

Классификатор поездов по мелким характеристикам

Можно также извлечь элементы из более раннего уровня сети и обучить классификатор этим элементам. Более ранние слои обычно извлекают меньшее количество более мелких элементов, имеют более высокое пространственное разрешение и большее общее число активаций. Извлеките элементы из 'res3b_relu' слой. Это конечный уровень, который выводит 128 элементы, и активизации имеют пространственный размер 28 на 28.

layer = 'res3b_relu';
featuresTrain = activations(net,augimdsTrain,layer);
featuresTest = activations(net,augimdsTest,layer);
whos featuresTrain
  Name                Size                      Bytes  Class     Attributes

  featuresTrain      28x28x128x55            22077440  single              

Извлеченные элементы, использованные в первой части этого примера, были объединены во все пространственные местоположения на уровне глобального объединения. Чтобы достичь того же результата при извлечении элементов на более ранних слоях, вручную усредните активации по всем пространственным расположениям. Чтобы получить функции в форме N-by-C, где N - количество наблюдений, а C - количество признаков, удалите одиночные размеры и транспонируйте.

featuresTrain = squeeze(mean(featuresTrain,[1 2]))';
featuresTest = squeeze(mean(featuresTest,[1 2]))';
whos featuresTrain
  Name                Size             Bytes  Class     Attributes

  featuresTrain      55x128            28160  single              

Обучение классификатора SVM по более мелким функциям. Рассчитайте точность теста.

classifier = fitcecoc(featuresTrain,YTrain);
YPred = predict(classifier,featuresTest);
accuracy = mean(YPred == YTest)
accuracy = 0.9500

Обе обученные SVM имеют высокую точность. Если точность недостаточно высока с помощью извлечения элементов, попробуйте перенести обучение. Пример см. в разделе Обучение сети глубокого обучения классификации новых изображений. Список и сравнение предварительно обученных сетей см. в разделе Предварительно обученные глубокие нейронные сети.

См. также

| (инструментарий для статистики и машинного обучения)

Связанные темы