Визуализация классификаций изображений с помощью максимальных и минимальных активирующих изображений

В этом примере показано, как использовать набор данных, чтобы выяснить, что активирует каналы глубокой нейронной сети. Это позволяет понять, как работает нейронная сеть, и диагностировать потенциальные проблемы с обучающими данными набором.

Этот пример охватывает ряд простых методов визуализации, используя передачу GoogLeNet, полученную на наборе данных о пищевых продуктах. Глядя на изображения, которые максимально или минимально активируют классификатор, можно узнать, почему нейронная сеть получает неправильные классификации.

Загрузка и предварительная обработка данных

Загрузите изображения как image datastore. Этот небольшой набор данных содержит в общей сложности 978 наблюдений с 9 классами пищи.

Разделите эти данные на наборы для обучения, валидации и тестирования, чтобы подготовиться к передаче обучения с помощью GoogLeNet. Отобразите выбор изображений из набора данных.

rng default
dataDir = fullfile(tempdir,"Food Dataset");
url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip";

if ~exist(dataDir,"dir")
    mkdir(dataDir);
end

downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset...
This can take several minutes to download...
Download finished...
Unzipping file...
Unzipping finished...
Done.
imds = imageDatastore(dataDir, ...
    "IncludeSubfolders",true,"LabelSource","foldernames");
[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.6,0.2);

rnd = randperm(numel(imds.Files),9);
for i = 1:numel(rnd)
subplot(3,3,i)
imshow(imread(imds.Files{rnd(i)}))
label = imds.Labels(rnd(i));
title(label,"Interpreter","none")
end

Обучите сеть классифицировать изображения продуктов питания

Используйте предварительно обученную сеть GoogLeNet и обучите ее снова, чтобы классифицировать 9 видов пищи. Если у вас нет установленного пакета поддержки Deep Learning Toolbox™ Model для GoogLeNet Network, то программное обеспечение предоставляет ссылку загрузки.

Чтобы попробовать другую предварительно обученную сеть, откройте этот пример в MATLAB ® и выберите другую сеть, такую как squeezenet, сеть, которая даже быстрее, чем googlenet. Список всех доступных сетей см. в разделе «Предварительно обученные глубокие нейронные сети».

net = googlenet;

Первый элемент Layers свойство сети является входным слоем изображения. Этот слой требует входа изображений размера 224 224 3, где 3 количество цветовых каналов.

inputSize = net.Layers(1).InputSize;

Сетевая архитектура

Сверточные слои сети извлекают изображение, функции последний выучиваемый слой и конечный слой классификации используют для классификации входа изображения. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержат информацию о том, как объединить функции, которые сеть извлекает в вероятности классов, значение потерь и предсказанные метки. Чтобы обучить предварительно обученную сеть классификации новых изображений, замените эти два слоя новыми слоями, адаптированными к новому набору данных.

Извлеките график слоев из обученной сети.

lgraph = layerGraph(net);

В большинстве сетей последний слой с усвояемыми весами является полносвязным слоем. Замените этот полностью соединенный слой новым полностью соединенным слоем с количеством выходов, равным количеству классов в новом наборе данных (9, в этом примере).

numClasses = numel(categories(imdsTrain.Labels));

newfclayer = fullyConnectedLayer(numClasses,...
    'Name','new_fc',...
    'WeightLearnRateFactor',10,...
    'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,net.Layers(end-2).Name,newfclayer);

Слой классификации задает выходные классы сети. Замените слой классификации новым слоем без меток классов. trainNetwork автоматически устанавливает выходные классы слоя во время обучения.

newclasslayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,net.Layers(end).Name,newclasslayer);

Обучите сеть

Сеть требует изображений входа размера 224 на 224 на 3, но у изображений в изображении datastore есть различные размеры. Используйте хранилище данных дополненных изображений, чтобы автоматически изменить размер обучающих изображений. Задайте дополнительные операции увеличения для выполнения на обучающих изображениях: случайным образом разверните обучающие изображения вдоль вертикальной оси, случайным образом перемещайте их до 30 пикселей и масштабируйте до 10% по горизонтали и вертикали. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшего увеличения данных, используйте хранилище datastore с дополненными изображениями, не задавая никаких дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Задайте опции обучения. Задайте InitialLearnRate к небольшому значению для замедления обучения в перенесенных слоях, которые еще не заморожены. На предыдущем шаге вы увеличили коэффициенты скорости обучения для последнего обучаемого слоя, чтобы ускорить обучение в новых конечных слоях. Эта комбинация настроек скорости обучения приводит к быстрому обучению в новых слоях, более медленному обучению в средних слоях и отсутствию обучения в более ранних, замороженных слоях.

Укажите количество эпох для обучения. При выполнении передачи обучения вам не нужно тренироваться на столько эпох. Эпоха - это полный цикл обучения на целом наборе обучающих данных. Укажите размер мини-пакета и данные валидации. Вычислите точность валидации один раз в эпоху.

miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',4, ...
    'InitialLearnRate',3e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучите сеть с помощью обучающих данных. По умолчанию trainNetwork использует графический процессор, если он доступен. Для этого требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). В противном случае trainNetwork использует центральный процессор. Можно также задать окружение выполнения с помощью 'ExecutionEnvironment' Аргумент пары "имя-значение" из trainingOptions. Поскольку этот набор данных является маленьким, обучение происходит быстро. Если запустить этот пример и обучить сеть самостоятельно, вы получите различные результаты и неправильные классификации, вызванные случайностью, связанной с процессом обучения.

net = trainNetwork(augimdsTrain,lgraph,options);

Классификация тестовых изображений

Классификация тестовых изображений с помощью тонкой настройки сети и вычисление точности классификации.

augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);
[predictedClasses,predictedScores] = classify(net,augimdsTest);

accuracy = mean(predictedClasses == imdsTest.Labels)
accuracy = 0.8418

Матрица неточностей для тестового набора

Постройте матрицу неточностей предсказаний тестового набора. Это подчеркивает, какие именно классы вызывают большинство проблем для сети.

figure;
confusionchart(imdsTest.Labels,predictedClasses,'Normalization',"row-normalized");

Матрица неточностей показывает, что сеть имеет плохую эффективность для некоторых классов, таких как греческий салат, сашими, хот-дог и суши. Эти классы недостаточно представлены в наборе данных, который может повлиять на эффективность сети. Исследуйте один из этих классов, чтобы лучше понять, почему сеть борется.

figure();
histogram(imdsValidation.Labels);
ax = gca();
ax.XAxis.TickLabelInterpreter = "none";

Изучение классификаций

Исследуйте сетевую классификацию для класса суши.

Суши больше всего нравится суши

Во-первых, найдите, какие изображения суши наиболее сильно активируют сеть для класса суши. Это отвечает на вопрос «Какие изображения, по мнению сети, наиболее похожи на суши?».

Постройте максимально активирующие изображения, это входные изображения, которые сильно активируют «суши» нейрон полносвязного слоя. Этот рисунок показывает 4 лучшие изображения в нисходящем счете класса.

chosenClass = "sushi";
classIdx = find(net.Layers(end).Classes == chosenClass);

numImgsToShow = 4;

[sortedScores,imgIdx] = findMaxActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);

figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)

Визуализация сигналов для класса суши

Сеть смотрит на правильную вещь для суши? Максимально активирующие для сети изображения класса суши все выглядят похожими друг на друга - множество круглых форм сгруппированы тесно.

Сеть хорошо классифицирует такие виды суши. Однако, чтобы убедиться, что это правда, и лучше понять, почему сеть принимает свои решения, используйте метод визуализации, такой как Grad-CAM. Для получения дополнительной информации об использовании Grad-CAM, смотрите Grad-CAM Раскрывает Почему За Решения Глубокого Обучения.

Считайте первое измененное изображение из дополненного datastore, затем постройте график визуализации Grad-CAM с помощью gradCAM.

imageNumber = 1;

observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Карта Grad-CAM подтверждает, что сеть фокусируется на суши в изображении. Однако можно также увидеть, что сеть смотрит на части пластины и таблицы.

Второе изображение имеет кластер суши слева и одинокий суши справа. Чтобы увидеть, на чем ориентирована сеть, прочитайте второе изображение и постройте график Grad-CAM.

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Сеть классифицирует это изображение как суши, потому что видит группу суши. Однако способен ли он классифицировать один суши самостоятельно? Протестируйте это, посмотрев на фотографию всего одного суши.

img = imread(strcat(tempdir,"Food Dataset/sushi/sushi_18.jpg"));
img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true);

[label,score] = classify(net,img);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Сеть способна правильно классифицировать эти одинокие суши. Однако GradCAM показывает, что сеть фокусируется на верхней части суши и кластере огурцов, а не на целом кусочке вместе.

Запустите метод визуализации Grad-CAM на одиноком суши, который не содержит никаких сложенных небольших кусочков ингредиентов.

img = imread("crop__sushi34-copy.jpg");
img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true);

[label,score] = classify(net,img);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (score: "+ max(score)+")")

В этом случае метод визуализации подчеркивает, почему сеть работает плохо. Он неправильно классифицирует изображение суши как гамбургер.

Чтобы решить этот вопрос, необходимо накормить сеть большим количеством изображений одиноких суши в процессе обучения.

Суши Меньше Всего Нравится Суши

Теперь найдите, какие изображения суши активируют сеть для класса суши меньше всего. Это отвечает на вопрос «Какие изображения, по мнению сети, менее похожи на суши?».

Это полезно, потому что находит изображения, на которых плохо работает сеть, и дает некоторое представление о своем решении.

chosenClass = "sushi";
numImgsToShow = 9;

[sortedScores,imgIdx] = findMinActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);

figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)

Исследуйте Суши неправильно классифицированный как Сасими

Почему сеть классифицирует суши как сасими? Сеть классифицирует 3 из 9 изображений как сашими. Некоторые из этих изображений, например изображения 4 и 9, на самом деле содержат сашими, что означает, что сеть на самом деле не неправильно классифицирует их. Эти изображения неправильно маркированы.

Чтобы увидеть, на чем ориентирована сеть, запустите метод Grad-CAM на одном из этих изображений.

imageNumber = 4;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

Как и ожидалось, сеть фокусируется на сасими вместо суши.

Исследуйте суши неправильно классифицированные как пицца

Почему сеть классифицирует суши как пиццу? Сеть классифицирует четыре изображения как пиццу вместо суши. Рассмотрим изображение 1, это изображение имеет красочное покрытие, которое может запутать сеть.

Чтобы увидеть, на какую часть изображения смотрит сеть, запустите метод Grad-CAM на одном из этих изображений.

imageNumber = 1;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

Сеть сильно фокусируется на верхних точках. Чтобы помочь сети отличить пиццу от суши с топпингами, добавьте больше обучающих изображений суши с топпингами. Сеть также фокусируется на диске. Это может быть, так как сеть научилась связывать определенные продукты с определенными типами тарелок, что также подсвечивается при взгляде на изображения суши. Чтобы улучшить эффективность сети, обучайте, используя больше примеров еды на разных типах тарелок.

Исследуйте суши неправильно классифицированный как гамбургер

Почему сеть классифицирует суши как гамбургер? Чтобы увидеть, на чем ориентирована сеть, запустите метод Grad-CAM на неправильно классифицированном изображении.

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

Сеть фокусируется на цветке в изображении. Красочный фиолетовый цветок и коричневый стебель перепутали сеть, идентифицируя это изображение как гамбургер.

Расследуйте суши, неправильно классифицированные как французские фри

Почему сеть классифицирует суши как картофель фри? Сеть классифицирует 3-е изображение как картофель фри вместо суши. Этот специфический суши имеет желтый топпинг, и сеть может связать этот цвет с картофелем фри.

Запустите Grad-CAM на этом изображении.

imageNumber = 3;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")","Interpreter","none")

Сети фокусируются на желтом суши, классифицируя его как картофель фри. Как и с гамбургером, необычный цвет заставил сеть неправильно классифицировать суши.

Чтобы помочь сети в этом конкретном случае, обучите ее большему количеству изображений желтой пищи, которая не является картофелем фри.

Заключения

Исследование точек данных, которые вызывают большие или маленькие счета класса, и точек данных, которые сеть классифицирует уверенно, но неправильно, является простым методом, который может предоставить полезное представление о том, как функционирует обученная сеть. В случае набора данных о пищевых продуктах в этом примере подчеркивалось, что:

  • Тестовые данные содержат несколько изображений с неправильными истинными метками, такими как «сашими», который на самом деле является «суши». Данные также содержат неполные метки, такие как изображения, которые содержат как суши, так и сасими.

  • Сеть считает «суши» «множественными, кластерными, кругловидными вещами». Однако он должен уметь различать и одиноких суши.

  • Любые суши или сашими с топингами или необычными цветами путают сеть. Чтобы решить эту проблему, данные должны иметь более широкое разнообразие суши и сасими.

  • Для повышения эффективности сети необходимо видеть больше изображения из недостаточно представленных классов.

Вспомогательные функции

function downloadExampleFoodImagesData(url,dataDir)
% Download the Example Food Image data set, containing 978 images of
% different types of food split into 9 classes.

% Copyright 2019 The MathWorks, Inc.

fileName = "ExampleFoodImageDataset.zip";
fileFullPath = fullfile(dataDir,fileName);

% Download the .zip file into a temporary directory.
if ~exist(fileFullPath,"file")
    fprintf("Downloading MathWorks Example Food Image dataset...\n");
    fprintf("This can take several minutes to download...\n");
    websave(fileFullPath,url);
    fprintf("Download finished...\n");
else
    fprintf("Skipping download, file already exists...\n");
end

% Unzip the file.
%
% Check if the file has already been unzipped by checking for the presence
% of one of the class directories.
exampleFolderFullPath = fullfile(dataDir,"pizza");
if ~exist(exampleFolderFullPath,"dir")
    fprintf("Unzipping file...\n");
    unzip(fileFullPath,dataDir);
    fprintf("Unzipping finished...\n");
else
    fprintf("Skipping unzipping, file already unzipped...\n");
end
fprintf("Done.\n");

end

function [sortedScores,imgIdx] = findMaxActivatingImages(imds,className,predictedScores,numImgsToShow)
% Find the predicted scores of the chosen class on all the images of the chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores);

% Sort the scores in descending order
[sortedScores,idx] = sort(scoresForChosenClass,'descend');

% Return the indices of only the first few
imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));

end

function [sortedScores,imgIdx] = findMinActivatingImages(imds,className,predictedScores,numImgsToShow)
% Find the predicted scores of the chosen class on all the images of the chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores);

% Sort the scores in ascending order
[sortedScores,idx] = sort(scoresForChosenClass,'ascend');

% Return the indices of only the first few
imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));

end

function [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores)
% Find the index of className (e.g. "sushi" is the 9th class)
uniqueClasses = unique(imds.Labels);
chosenClassIdx = string(uniqueClasses) == className;

% Find the indices in imageDatastore that are images of label "className"
% (e.g. find all images of class sushi)
imgsOfClassIdxs = find(imds.Labels == className);

% Find the predicted scores of the chosen class on all the images of the
% chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
scoresForChosenClass = predictedScores(imgsOfClassIdxs,chosenClassIdx);
end

function plotImages(imds,imgIdx,sortedScores,predictedClasses,numImgsToShow)

for i=1:numImgsToShow
    score = sortedScores(i);
    sortedImgIdx = imgIdx(i);
    predClass = predictedClasses(sortedImgIdx); 
    correctClass = imds.Labels(sortedImgIdx);
        
    imgPath = imds.Files{sortedImgIdx};
    
    if predClass == correctClass
        color = "\color{green}";
    else
        color = "\color{red}";
    end
    
    predClassTitle = strrep(string(predClass),'_',' ');
    correctClassTitle = strrep(string(correctClass),'_',' ');
    
    subplot(3,ceil(numImgsToShow./3),i)
    imshow(imread(imgPath));
    title("Predicted: " + color + predClassTitle + "\newline\color{black}Score: " + num2str(score) + "\newlineGround truth: " + correctClassTitle);
end

end

function plotGradCAM(img,gradcamMap,alpha)

subplot(1,2,1)
imshow(img);

h = subplot(1,2,2);
imshow(img)
hold on;
imagesc(gradcamMap,'AlphaData',alpha);

originalSize2 = get(h,'Position');

colormap jet
colorbar

set(h,'Position',originalSize2);
hold off;
end

См. также

| | | | | | | |

Похожие примеры

Подробнее о