Визуализируйте классификации изображений Используя максимальные и минимальные изображения активации

В этом примере показано, как использовать набор данных, чтобы узнать то, что активирует каналы глубокой нейронной сети. Это позволяет вам изучать, как нейронная сеть работает, и диагностируйте потенциальные проблемы с обучающим набором данных.

Этот пример покрывает много простых методов визуализации, с помощью GoogLeNet, изученного передаче на продовольственном наборе данных. Путем рассмотрения изображений, которые максимально или минимально активируют классификатор, можно обнаружить, почему нейронная сеть понимает классификации превратно.

Загрузите и предварительно обработайте данные

Загрузите изображения как datastore изображений. Этот небольшой набор данных содержит в общей сложности 978 наблюдений с 9 классами еды.

Разделите эти данные в обучение, валидацию и наборы тестов, чтобы подготовить к использованию передачи обучения GoogLeNet. Отобразите выбор изображений от набора данных.

rng default
dataDir = fullfile(tempdir,"Food Dataset");
url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip";

if ~exist(dataDir,"dir")
    mkdir(dataDir);
end

downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset...
This can take several minutes to download...
Download finished...
Unzipping file...
Unzipping finished...
Done.
imds = imageDatastore(dataDir, ...
    "IncludeSubfolders",true,"LabelSource","foldernames");
[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.6,0.2);

rnd = randperm(numel(imds.Files),9);
for i = 1:numel(rnd)
subplot(3,3,i)
imshow(imread(imds.Files{rnd(i)}))
label = imds.Labels(rnd(i));
title(label,"Interpreter","none")
end

Обучите сеть, чтобы классифицировать продовольственные изображения

Используйте предварительно обученную сеть GoogLeNet и обучите ее снова классифицировать 9 типов еды. Если у вас нет Модели Deep Learning Toolbox™ для пакета Сетевой поддержки GoogLeNet установленной, то программное обеспечение обеспечивает ссылку на загрузку.

Чтобы попробовать различную предварительно обученную сеть, откройте этот пример в MATLAB® и выберите различную сеть, такую как squeezenet, сеть, которая является четной быстрее, чем googlenet. Для списка всех доступных сетей смотрите Предварительно обученные Глубокие нейронные сети.

net = googlenet;

Первый элемент Layers свойство сети является входным слоем изображений. Этот слой требует входных изображений размера 224 224 3, где 3 количество цветовых каналов.

inputSize = net.Layers(1).InputSize;

Сетевая архитектура

Сверточные слои сетевого извлечения отображают функции что последний learnable слой и итоговое использование слоя классификации, чтобы классифицировать входное изображение. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержите информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы обучить предварительно обученную сеть классифицировать новые изображения, замените эти два слоя на новые слои, адаптированные к новому набору данных.

Извлеките график слоев из обучившего сеть.

lgraph = layerGraph(net);

В большинстве сетей последний слой с learnable весами является полносвязным слоем. Замените этот полносвязный слой на новый полносвязный слой с количеством выходных параметров, равных количеству классов в новом наборе данных (9 в этом примере).

numClasses = numel(categories(imdsTrain.Labels));

newfclayer = fullyConnectedLayer(numClasses,...
    'Name','new_fc',...
    'WeightLearnRateFactor',10,...
    'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,net.Layers(end-2).Name,newfclayer);

Слой классификации задает выходные классы сети. Замените слой классификации на новый без меток класса. trainNetwork автоматически устанавливает выходные классы слоя в учебное время.

newclasslayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,net.Layers(end).Name,newclasslayer);

Обучение сети

Сеть требует входных изображений размера 224 224 3, но изображения в datastore изображений имеют различные размеры. Используйте увеличенный datastore изображений, чтобы автоматически изменить размер учебных изображений. Задайте дополнительные операции увеличения, чтобы выполнить на учебных изображениях: случайным образом инвертируйте учебные изображения вдоль вертикальной оси, случайным образом переведите их до 30 пикселей и масштабируйте их до 10% горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшее увеличение данных, используйте увеличенный datastore изображений, не задавая дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Задайте опции обучения. Установите InitialLearnRate к маленькому значению, чтобы замедлить изучение в переданных слоях, которые уже не замораживаются. На предыдущем шаге вы увеличили факторы скорости обучения для последнего learnable слоя, чтобы ускорить изучение в новых последних слоях. Эта комбинация настроек скорости обучения приводит к быстрому изучению в новых слоях, медленнее учась в средних слоях и никакое изучение в ранее, блокированные слои.

Задайте номер эпох, чтобы обучаться для. При использовании обучение с переносом вы не должны обучаться для как много эпох. Эпоха является полным учебным циклом на целом обучающем наборе данных. Задайте мини-пакетный размер и данные о валидации. Вычислите точность валидации однажды в эпоху.

miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',4, ...
    'InitialLearnRate',3e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучите сеть с помощью обучающих данных. По умолчанию, trainNetwork использует графический процессор, если вы доступны. Это требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). В противном случае, trainNetwork использует центральный процессор. Можно также задать среду выполнения при помощи 'ExecutionEnvironment' аргумент пары "имя-значение" trainingOptions. Поскольку этот набор данных мал, обучение быстро. Если вы запустите этот пример и обучите сеть сами, вы получите различные результаты и misclassifications, вызванный случайностью, вовлеченной в учебный процесс.

net = trainNetwork(augimdsTrain,lgraph,options);

Классифицируйте тестовые изображения

Классифицируйте тестовые изображения с помощью подстроенной сети и вычислите точность классификации.

augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);
[predictedClasses,predictedScores] = classify(net,augimdsTest);

accuracy = mean(predictedClasses == imdsTest.Labels)
accuracy = 0.8418

Матрица беспорядка для набора тестов

Постройте матрицу беспорядка предсказаний набора тестов. Это подсвечивает, какие конкретные классы вызывают большинство проблем для сети.

figure;
confusionchart(imdsTest.Labels,predictedClasses,'Normalization',"row-normalized");

Матрица беспорядка показывает, что сеть имеет низкую производительность для некоторых классов, таких как греческий салат, сашими, хот-дог и суши. Эти классы недостаточно представлены в наборе данных, который может влиять на производительность сети. Исследуйте один из этих классов, чтобы лучше изучить, почему сеть борется.

figure();
histogram(imdsValidation.Labels);
ax = gca();
ax.XAxis.TickLabelInterpreter = "none";

Исследуйте классификации

Исследуйте сетевую классификацию для класса суши.

Суши больше всего как суши

Во-первых, найдите, какие изображения суши наиболее строго активируют сеть для класса суши. Это отвечает на вопрос, "Какие изображения делает сеть, думают, являются самыми подобными суши?".

Постройте максимально активирующиеся изображения, это входные изображения, которые строго активируют нейрон "суши" полносвязного слоя. Этот рисунок показывает лучшие 4 изображения в убывающем счете класса.

chosenClass = "sushi";
classIdx = find(net.Layers(end).Classes == chosenClass);

numImgsToShow = 4;

[sortedScores,imgIdx] = findMaxActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);

figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)

Визуализируйте сигналы для класса суши

Сеть смотрит на правильную вещь для суши? Максимально активирующиеся изображения класса суши для сети весь взгляд, похожий друг на друга - много круглых форм, кластеризируемых тесно вместе.

Сеть преуспевает при классификации тех видов суши. Однако, чтобы проверить, что это верно и лучше изучать, почему сеть принимает свои решения, используйте метод визуализации как CAM градиента. Для получения дополнительной информации об использовании CAM градиента смотрите, что CAM градиента Показывает Почему Позади Решений Глубокого обучения.

Считайте первое измененное изображение из увеличенного datastore изображений, затем постройте визуализацию CAM градиента с помощью gradCAM.

imageNumber = 1;

observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Карта CAM градиента подтверждает, что сеть фокусируется на суши в изображении. Однако можно также видеть, что сеть смотрит на части пластины и таблицы.

Второе изображение имеет кластер суши слева и одиноких суши справа. Чтобы видеть что сетевое особое внимание на, считайте второе изображение и постройте CAM градиента.

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Сеть классифицирует это изображение как суши, потому что это видит группу суши. Однако может он, чтобы классифицировать суши самостоятельно? Протестируйте это путем рассмотрения изображения всего суши.

img = imread(strcat(tempdir,"Food Dataset/sushi/sushi_18.jpg"));
img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true);

[label,score] = classify(net,img);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Сеть может классифицировать эти одинокие суши правильно. Однако GradCAM показывает, что сеть фокусируется на верхней части суши и кластере огурца, а не целой части вместе.

Запустите метод визуализации CAM градиента на одиноких суши, которые не содержат сложенных маленьких частей компонентов.

img = imread("crop__sushi34-copy.jpg");
img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true);

[label,score] = classify(net,img);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (score: "+ max(score)+")")

В этом случае метод визуализации подсвечивает, почему сеть выполняет плохо. Это неправильно классифицирует изображение суши как гамбургер.

Чтобы решить эту проблему, необходимо питать сеть большим количеством изображений одиноких суши во время учебного процесса.

Суши меньше всего как суши

Теперь найдите, какие изображения суши активируются, сеть для суши классифицируют наименьшее. Это отвечает на вопрос, "Какие изображения делает сеть, думают, менее подобны суши?".

Это полезно, потому что это находит изображения, на которых сеть выполняет плохо, и это обеспечивает некоторое понимание ее решения.

chosenClass = "sushi";
numImgsToShow = 9;

[sortedScores,imgIdx] = findMinActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);

figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)

Исследуйте суши, неправильно классифицированные как сашими

Почему сеть классифицирует суши как сашими? Сеть классифицирует 3 из 9 изображений как сашими. Некоторые из этих изображений, например, изображения 4 и 9, на самом деле содержите сашими, что означает, что сеть на самом деле не неправильно классифицирует их. Эти изображения являются mislabeled.

Чтобы видеть, на чем фокусируется сеть, запустите метод CAM градиента на одном из этих изображений.

imageNumber = 4;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

Как ожидалось сеть фокусируется на сашими вместо суши.

Исследуйте суши, неправильно классифицированные как пицца

Почему сеть классифицирует суши как пиццу? Сеть классифицирует четыре из изображений как пицца вместо суши. Рассмотрите изображение 1, это изображение имеет красочное положение во главе, которое может сбить с толку сеть.

Чтобы видеть, на который смотрит часть изображения сеть, запустите метод CAM градиента на одном из этих изображений.

imageNumber = 1;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

Сеть строго фокусируется на начинках. Чтобы помочь сети отличить пиццу от суши с начинками, добавьте больше учебных изображений суши с начинками. Сеть также фокусируется на пластине. Это может быть, когда сеть училась сопоставлять определенные продукты с определенными типами пластин, равно как и подсвечено при рассмотрении изображений суши. Чтобы улучшать производительность сети, обучите использование большего количества примеров еды на различных типах пластин.

Исследуйте суши, неправильно классифицированные как гамбургер

Почему сеть классифицирует суши как гамбургер? Чтобы видеть, на чем фокусируется сеть, запустите метод CAM градиента на неправильно классифицированном изображении.

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

Сеть фокусируется на цветке в изображении. Красочный фиолетовый цветок и коричневое основание перепутали сеть в идентификацию этого изображения как гамбургер.

Исследуйте суши, неправильно классифицированные как картофель фри

Почему сеть классифицирует суши как картофель фри? Сеть классифицирует 3-е изображение как картофель фри вместо суши. Эти определенные суши имеют желтое положение во главе, и сетевая сила сопоставляют этот цвет с картофелем фри.

Запустите CAM градиента на этом изображении.

imageNumber = 3;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")","Interpreter","none")

Сети фокусируются на желтых суши, классифицирующих его как картофель фри. Как с гамбургером, необычный цвет заставил сеть неправильно классифицировать суши.

Чтобы помочь сети в этом конкретном случае, обучите его с большим количеством изображений желтых продуктов, которые не являются картофелем фри.

Заключения

При исследовании точек данных, которые дают начало большим или маленьким баллам класса, и точки данных, которые сеть классифицирует уверенно, но неправильно, являются простым методом, который может обеспечить полезное понимание, как функционирует обучивший сеть. В случае продовольственного набора данных этот пример подсветил что:

  • Тестовые данные содержат несколько изображений с неправильными истинными метками, такими как "сашими", которое является на самом деле "суши". Данные также содержат неполные метки, такие как изображения, которые содержат и суши и сашими.

  • Сеть полагает, что "суши" "несколько, кластеризируемые, вещи круглой формы". Однако это должно смочь отличить одинокие суши также.

  • Любые суши или сашими с начинками или необычными цветами путают сеть. Чтобы разрешить эту проблему, данные должны иметь более широкое разнообразие суши и сашими.

  • Чтобы улучшать производительность, сеть должна видеть больше изображений от недостаточно представленных классов.

Функции помощника

function downloadExampleFoodImagesData(url,dataDir)
% Download the Example Food Image data set, containing 978 images of
% different types of food split into 9 classes.

% Copyright 2019 The MathWorks, Inc.

fileName = "ExampleFoodImageDataset.zip";
fileFullPath = fullfile(dataDir,fileName);

% Download the .zip file into a temporary directory.
if ~exist(fileFullPath,"file")
    fprintf("Downloading MathWorks Example Food Image dataset...\n");
    fprintf("This can take several minutes to download...\n");
    websave(fileFullPath,url);
    fprintf("Download finished...\n");
else
    fprintf("Skipping download, file already exists...\n");
end

% Unzip the file.
%
% Check if the file has already been unzipped by checking for the presence
% of one of the class directories.
exampleFolderFullPath = fullfile(dataDir,"pizza");
if ~exist(exampleFolderFullPath,"dir")
    fprintf("Unzipping file...\n");
    unzip(fileFullPath,dataDir);
    fprintf("Unzipping finished...\n");
else
    fprintf("Skipping unzipping, file already unzipped...\n");
end
fprintf("Done.\n");

end

function [sortedScores,imgIdx] = findMaxActivatingImages(imds,className,predictedScores,numImgsToShow)
% Find the predicted scores of the chosen class on all the images of the chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores);

% Sort the scores in descending order
[sortedScores,idx] = sort(scoresForChosenClass,'descend');

% Return the indices of only the first few
imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));

end

function [sortedScores,imgIdx] = findMinActivatingImages(imds,className,predictedScores,numImgsToShow)
% Find the predicted scores of the chosen class on all the images of the chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores);

% Sort the scores in ascending order
[sortedScores,idx] = sort(scoresForChosenClass,'ascend');

% Return the indices of only the first few
imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));

end

function [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores)
% Find the index of className (e.g. "sushi" is the 9th class)
uniqueClasses = unique(imds.Labels);
chosenClassIdx = string(uniqueClasses) == className;

% Find the indices in imageDatastore that are images of label "className"
% (e.g. find all images of class sushi)
imgsOfClassIdxs = find(imds.Labels == className);

% Find the predicted scores of the chosen class on all the images of the
% chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
scoresForChosenClass = predictedScores(imgsOfClassIdxs,chosenClassIdx);
end

function plotImages(imds,imgIdx,sortedScores,predictedClasses,numImgsToShow)

for i=1:numImgsToShow
    score = sortedScores(i);
    sortedImgIdx = imgIdx(i);
    predClass = predictedClasses(sortedImgIdx); 
    correctClass = imds.Labels(sortedImgIdx);
        
    imgPath = imds.Files{sortedImgIdx};
    
    if predClass == correctClass
        color = "\color{green}";
    else
        color = "\color{red}";
    end
    
    predClassTitle = strrep(string(predClass),'_',' ');
    correctClassTitle = strrep(string(correctClass),'_',' ');
    
    subplot(3,ceil(numImgsToShow./3),i)
    imshow(imread(imgPath));
    title("Predicted: " + color + predClassTitle + "\newline\color{black}Score: " + num2str(score) + "\newlineGround truth: " + correctClassTitle);
end

end

function plotGradCAM(img,gradcamMap,alpha)

subplot(1,2,1)
imshow(img);

h = subplot(1,2,2);
imshow(img)
hold on;
imagesc(gradcamMap,'AlphaData',alpha);

originalSize2 = get(h,'Position');

colormap jet
colorbar

set(h,'Position',originalSize2);
hold off;
end

Смотрите также

| | | | | | | |

Связанные примеры

Больше о