Визуализируйте классификации изображений Используя максимальные и минимальные изображения активации

В этом примере показано, как использовать набор данных, чтобы узнать то, что активирует каналы глубокой нейронной сети. Это позволяет вам изучать, как работает нейронная сеть, и также диагностируйте потенциальные проблемы с обучающим набором данных.

Этот пример покрывает много простых методов визуализации, с помощью GoogLeNet, изученного передаче на продовольственном наборе данных.

Путем рассмотрения изображений, которые максимально или минимально активируют классификатор, вы изучаете, как диагностировать, почему нейронная сеть получает классификации неправильно с помощью простого метода, базирующегося вокруг музыки класса к различным классам изображений.

Загрузите и предварительно обработайте данные

Загрузите изображения как datastore изображений. Этот небольшой набор данных содержит в общей сложности 978 наблюдений с 9 классами еды.

Разделите эти данные в обучение, валидацию и набор тестов, чтобы подготовиться к передаче обучения на GoogLeNet. Отобразите несколько изображений от набора данных.

rng default
dataDir = fullfile(tempdir, "Food Dataset");
url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip";

if ~exist(dataDir, "dir")
    mkdir(dataDir);
end

downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset...
This can take several minutes to download...
Download finished...
Unzipping file...
Unzipping finished...
Done.
imds = imageDatastore(dataDir, ...
    "IncludeSubfolders", true, "LabelSource", "foldernames");
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds, 0.6, 0.2);


rnd = randperm(numel(imds.Files), 9);
for i = 1:numel(rnd)
subplot(3,3,i)
imshow(imread(imds.Files{rnd(i)}))
label = imds.Labels(rnd(i));
title(label, "Interpreter", "none")
end

Обучите сеть, чтобы классифицировать продовольственные изображения

Используйте предварительно обученную сеть GoogLeNet и обучите ее снова классифицировать 9 классов еды. Если у вас нет Модели Deep Learning Toolbox™ для пакета Сетевой поддержки GoogLeNet установленной, то программное обеспечение обеспечивает ссылку на загрузку.

Чтобы попробовать различную предварительно обученную сеть, откройте этот пример в MATLAB® и выберите различную сеть, такую как squeezenet, сеть, которая является четной быстрее, чем googlenet. Для списка всех доступных сетей смотрите Предварительно обученные сети Загрузки.

net = googlenet;

Первый элемент Layers свойство сети является входным слоем изображений. Этот слой требует входных изображений размера 224 224 3, где 3 количество цветовых каналов.

inputSize = net.Layers(1).InputSize;

Сверточные слои сетевого извлечения отображают функции что последний learnable слой и итоговое использование слоя классификации, чтобы классифицировать входное изображение. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержите информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы обучить снова предварительно обученную сеть классифицировать новые изображения, замените эти два слоя на новые слои, адаптированные к новому набору данных.

Извлеките график слоев из обучившего сеть.

lgraph = layerGraph(net);

В большинстве сетей последний слой с learnable весами является полносвязным слоем. Замените этот полносвязный слой на новый полносвязный слой с количеством выходных параметров, равных количеству классов в новом наборе данных (9 в этом примере).

numClasses = numel(categories(imdsTrain.Labels));

newfclayer = fullyConnectedLayer(numClasses,...
    'Name', 'new_fc',...
    'WeightLearnRateFactor',10,...
    'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph, net.Layers(end-2).Name, newfclayer);

Слой классификации задает выходные классы сети. Замените слой классификации на новый без меток класса. trainNetwork автоматически устанавливает выходные классы слоя в учебное время.

newclasslayer = classificationLayer('Name', 'new_classoutput');
lgraph = replaceLayer(lgraph, net.Layers(end).Name, newclasslayer);

Обучение сети

Сеть требует входных изображений размера 224 224 3, но изображения в datastore изображений имеют различные размеры. Используйте увеличенный datastore изображений, чтобы автоматически изменить размер учебных изображений. Задайте дополнительные операции увеличения, чтобы выполнить на учебных изображениях: случайным образом инвертируйте учебные изображения вдоль вертикальной оси и случайным образом переведите их до 30 пикселей и масштабируйте их до 10% горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшее увеличение данных, используйте увеличенный datastore изображений, не задавая дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Задайте опции обучения. Установите InitialLearnRate к маленькому значению, чтобы замедлить изучение в переданных слоях, которые уже не замораживаются. На предыдущем шаге вы увеличили факторы скорости обучения для последнего learnable слоя, чтобы ускорить изучение в новых последних слоях. Эта комбинация настроек скорости обучения приводит к быстрому изучению в новых слоях, медленнее учась в средних слоях и никакое изучение в ранее, блокированные слои.

Задайте номер эпох, чтобы обучаться для. При использовании обучение с переносом вы не должны обучаться для как много эпох. Эпоха является полным учебным циклом на целом обучающем наборе данных. Задайте мини-пакетный размер и данные о валидации. Вычислите точность валидации однажды в эпоху.

miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',3e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучите сеть с помощью обучающих данных. По умолчанию, trainNetwork использует графический процессор, если вы доступны. Это требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше. В противном случае, trainNetwork использует центральный процессор. Можно также задать среду выполнения при помощи 'ExecutionEnvironment' аргумент пары "имя-значение" trainingOptions. Поскольку этот набор данных мал, обучение быстро. Если вы запустите этот пример и обучите сеть сами, вы получите различные результаты и misclassifications, вызванный случайностью, вовлеченной в учебный процесс.

net = trainNetwork(augimdsTrain,lgraph,options);

Классифицируйте тестовые изображения

Классифицируйте тестовые изображения с помощью подстроенной сети и вычислите точность классификации.

augimdsTest = augmentedImageDatastore(inputSize(1:2), imdsTest);
[predictedClasses, predictedScores] = classify(net, augimdsTest);

accuracy = mean(predictedClasses == imdsTest.Labels)
accuracy = 0.8622

Матрица беспорядка для набора тестов

Постройте матрицу беспорядка предсказаний набора тестов. Это подсвечивает, какие конкретные классы вызывают большинство проблем для сети.

figure;
confusionchart(imdsTest.Labels, predictedClasses, 'Normalization', "row-normalized");

Матрица беспорядка показывает, что наиболее часто сеть неправильно классифицирует некоторые классы, такие как греческий салат, сашими и суши. Чтобы лучше изучить, почему это происходит, запускает некоторые тесты на этих 3 классах, при учете того факта, что набор данных использовал в этом примере underrepresents их.

figure();
histogram(imdsValidation.Labels);
ax = gca();
ax.XAxis.TickLabelInterpreter = "none";

Sushis больше всего как суши

Во-первых, найдите, какие изображения суши наиболее строго активируют сеть для класса суши. Это отвечает на вопрос, "Какие изображения делает сеть, думают, являются самыми подобными суши?".

Для этого постройте максимально активирующиеся изображения или входные изображения, которые строго активируют нейрон "суши" полносвязного слоя. Этот рисунок показывает лучшие 4 изображения в убывающем счете класса.

chosenClass = "sushi";
classIdx = find(net.Layers(end).Classes == chosenClass);

numImgsToShow = 4;

[sortedScores, imgIdx] = findMaxActivatingImages(imdsTest, chosenClass, predictedScores, numImgsToShow);

figure
plotImages(imdsTest, imgIdx, sortedScores, predictedClasses, numImgsToShow)

Визуализируйте сигналы для класса суши

Сеть смотрит на правильную вещь для суши? Максимально активирующиеся изображения класса суши для сети весь взгляд, похожий друг на друга - много круглых форм, кластеризируемых тесно вместе.

Сеть преуспевает при классификации тех видов sushis. Однако, чтобы проверить, что это верно и лучше изучать, почему сеть принимает свои решения, используйте метод визуализации как CAM градиента. Для получения дополнительной информации об использовании CAM градиента смотрите, что CAM градиента Показывает Почему Позади Решений Глубокого обучения.

Чтобы использовать CAM градиента, создайте dlnetwork от сети GoogLeNet. Задайте имя softmax слоя, 'prob'. Задайте имя итогового сверточного слоя, 'inception_5b-output'.

lgraph = layerGraph(net);
lgraph = removeLayers(lgraph, lgraph.Layers(end).Name);
dlnet = dlnetwork(lgraph);
softmaxName = 'prob';
convLayerName = 'inception_5b-output';

Считайте первое измененное изображение из увеличенного datastore изображений, затем постройте визуализацию CAM градиента.

imageNumber = 1;

observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

График подтверждает, что сеть правильно фокусируется на суши вместо пластины или таблицы. Сеть классифицирует это изображение как суши, потому что это видит группу sushis. Однако может он, чтобы классифицировать суши самостоятельно?

Второе изображение имеет кластер суши слева и одиноких суши справа. Чтобы видеть что сетевое особое внимание на, считайте второе изображение и постройте CAM градиента.

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

alpha = 0.5;

figure
plotGradCAM(img, gradcamMap, alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Сеть сопоставляет класс суши с несколькими sushis, сложенными вместе. Протестируйте это путем рассмотрения изображения всего суши.

img = imread(strcat(tempdir,"Food Dataset/sushi/sushi_18.jpg"));
img = imresize(img, net.Layers(1).InputSize(1:2), "Method", "bilinear", "AntiAliasing", true);

[label,score] = classify(net, img);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Сеть может классифицировать эти одинокие суши правильно. Однако GradCAM показывает, что сеть фокусируется больше на кластере огурца в суши, а не целой части вместе.

Запустите метод визуализации CAM градиента на одиноких суши, которые не содержат сложенных маленьких частей компонентов.

img = imread("crop__sushi34-copy.jpg");
img = imresize(img, net.Layers(1).InputSize(1:2), "Method", "bilinear", "AntiAliasing", true);

[label,score] = classify(net, img);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (score: "+ max(score)+")")

В этом случае метод визуализации подсвечивает, почему сеть выполняет плохо. Это неправильно классифицирует изображение суши как гамбургер.

Чтобы решить эту проблему, необходимо питать сеть большим количеством изображений одиноких суши во время учебного процесса.

Sushis меньше всего как Sushis

Теперь найдите, какие изображения суши активируются, сеть для суши классифицируют наименьшее. Это отвечает на вопрос, "Какие изображения делает сеть, думают, менее подобны суши?".

Это полезно, потому что это находит изображения, на которых сеть выполняет плохо, и это обеспечивает некоторое понимание ее решения.

chosenClass = "sushi";
numImgsToShow = 9;

[sortedScores, imgIdx] = findMinActivatingImages(imdsTest, chosenClass, predictedScores, numImgsToShow);

figure
plotImages(imdsTest, imgIdx, sortedScores, predictedClasses, numImgsToShow)

Исследуйте суши, неправильно классифицированные как сашими

Почему сеть классифицирует суши как сашими? Сеть классифицирует 3 из 9 изображений как сашими, соответственно изображения 2, 4, и 9. Два из них, изображения 4 и 9, на самом деле содержите сашими, что означает, что сеть на самом деле не неправильно классифицирует их. Эти изображения являются mislabeled.

Чтобы видеть, на чем фокусируется сеть, запустите метод CAM градиента на одном из этих изображений.

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

Как ожидалось сеть фокусируется на сашими вместо суши.

Исследуйте суши, неправильно классифицированные как пицца

Почему сеть классифицирует суши как пиццу? Сеть классифицирует 4 из 9 изображений как пицца вместо суши. Эти изображения имеют много красочных начинок.

Чтобы видеть, на который смотрит часть изображения сеть, запустите метод CAM градиента на одном из этих изображений.

imageNumber = 5;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

Сеть строго фокусируется на начинках.

Чтобы помочь сети отличить пиццу от суши с начинками, добавьте больше учебных изображений суши с начинками.

Исследуйте суши, неправильно классифицированные как картофель фри

Почему сеть классифицирует суши как картофель фри? Сеть классифицирует 7-е изображение как картофель фри вместо суши. Эти определенные суши являются желтыми, и сетевая сила сопоставляют этот цвет с картофелем фри.

Запустите CAM градиента на этом изображении.

imageNumber = 7;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (sushi score: "+ max(score)+")", "Interpreter", "none")

Сети фокусируются на желтых суши, классифицирующих его как картофель фри.

Чтобы помочь сети в этом конкретном случае, обучите его с большим количеством изображений желтых продуктов, которые не являются картофелем фри.

Большинство и наименьшее как сашими

Продовольственный набор данных включает только 8 наблюдений за сашими.

Подобно классу суши отобразите изображения сашими в порядке большинства как сашими, чтобы меньше всего любить сашими.

chosenClass = "sashimi";
numImgsToShow = 8;

[sortedScores, imgIdx] = findMaxActivatingImages(imdsTest, chosenClass, predictedScores, numImgsToShow);

figure
plotImages(imdsTest, imgIdx, sortedScores, predictedClasses, numImgsToShow)

Этот рисунок показывает, что сеть не уверена в сашими, сопровождаемом определенной другой едой, такова как сашими с растительностью. Однако это также показывает, что существуют некоторые дефекты в данных, где изображения суши являются mislabeled как сашими как основная истина, такая как номер изображения 5.

Визуализируйте сигналы для класса сашими

Сеть смотрит на правильную вещь для сашими? Рассмотрите только изображения, которые сеть правильно классифицирует как сашими, и проверяйте то, на чем это фокусируется во время процесса классификации.

Считайте самый высокий счет подобное сашими изображение и запустите CAM градиента на нем.

imageNumber = 1;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (score: "+ max(score)+")")

Это изображение содержит и сашими и суши. Сеть правильно фокусируется на сашими и игнорирует суши. Однако набор данных содержит другие изображения, которые также имеют смесь суши и сашими, но они помечены как только для суши вместо этого. Из-за этого сложно сети изучать, как правильно классифицировать эти изображения, как показано в "суши меньше всего как суши" случай. Это также объясняет, почему счет предсказания не очень высок в данном случае.

Исследуйте сашими, неправильно классифицированное как гамбургер

Почему сеть классифицирует сашими как гамбургер? Сеть классифицирует 3 из 8 изображений сашими как гамбургер. Набор данных содержит намного больше наблюдений за гамбургером, который может сместить сеть.

Считайте одно из этих изображений и запустите CAM градиента на нем.

imageNumber = 4;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (sashimi score: "+ max(score)+")")

Здесь, перенесенная в морскую водоросль еда путает сеть, когда это напоминает характеристики булочки сути булочки humburger. Это снова - дефект в данных о сашими, которые не варьируются достаточно.

Исследуйте сашими, неправильно классифицированное как греческий салат

Почему сеть классифицирует сашими как греческий салат? Сеть классифицирует 6-е изображение как греческий язык вместо сашими.

Считайте изображение и запустите метод CAM градиента.

imageNumber = 6;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (sashimi score: "+ max(score)+")", "Interpreter", "none")

Здесь, растительность вводит в заблуждение сеть в размышление, что изображения изображают тарелку сашими. Чтобы помочь сети распознать различие между этими двумя классами, добавьте более варьировавшийся набор данных сашими, сопровождаемого листами и овощами.

Большинство и наименьшее как греческий салат

Продовольственный набор данных только включает 5 наблюдений за греческим салатом.

Подобно вышеупомянутым случаям отобразите греческие изображения салата в порядке наиболее вероятного греческого салата к наименее вероятному греческому салату.

chosenClass = "greek_salad";
numImgsToShow = 5;

[sortedScores, imgIdx] = findMaxActivatingImages(imdsTest, chosenClass, predictedScores, numImgsToShow);

figure
plotImages(imdsTest, imgIdx, sortedScores, predictedClasses, numImgsToShow)

Этот рисунок показывает, что сеть не уверена в греческом салате, сопровождаемом определенной другой едой, такова как суть.

Исследуйте сигналы для греческого класса салата

Сеть смотрит на правильную вещь для греческого салата? Рассмотрите только изображения, которые сеть правильно классифицирует как греческий салат, и проверяйте то, на чем это фокусируется во время процесса классификации.

Считайте самое высокое греческое изображение салата счета согласно сети и запустите CAM градиента.

imageNumber = 1;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (score: "+ max(score)+")", "Interpreter", "none")

Сеть игнорирует фактический греческий салат и скорее смотрит на вещи в фоновом режиме вместо этого. Это - индикация, что сеть классифицирует это изображение правильно, но по неправильным причинам.

Считайте второе изображение, которое сеть классифицирует правильно как греческий салат и CAM градиента запуска.

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (score: "+ max(score)+")", "Interpreter", "none")

Даже при том, что греческий салат обычно не имеет яиц, сетевого особого внимания на помидоре и маслинах, который указывает, что это смотрит на правильные вещи для греческого салата.

Исследуйте греческий салат, неправильно классифицированный как гамбургер

Почему сеть классифицирует греческий салат как гамбургер? Сеть классифицирует четвертое изображение как гамбургер. Это вызывается фрагментами сути в салате?

Считайте изображение и запустите CAM градиента.

imageNumber = 4;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label);

figure
alpha = 0.5;
plotGradCAM(img, gradcamMap, alpha);
title(string(label)+" (greek_salad score: "+ max(score)+")", "Interpreter", "none")

Сеть фокусируется на областях с сутью.

Заключение

При исследовании точек данных, которые дают начало большим или маленьким баллам класса, и точки данных, которые сеть классифицирует уверенно, но неправильно, являются простым методом, который может обеспечить полезное понимание, как функционирует обучивший сеть. В случае продовольственного набора данных этот пример подсветил что:

  • Тестовые данные содержат много изображений с неправильными истинными метками, такими как "сашими", которое является на самом деле "суши" вместо этого.

  • Сеть полагает, что "суши" "несколько, кластеризируемые, вещи круглой формы". Однако это должно смочь отличить одинокие суши также.

  • Любые суши или сашими с начинками или необычными цветами путают сеть. Чтобы разрешить эту проблему, данные должны иметь более широкое разнообразие суши и сашими.

  • Любые суши или сашими, сопровождаемое подобными салату компонентами, вероятно, будут перепутаны с классами салата. Поэтому необходимо обучить сеть с большим количеством изображений суши и сашими с небольшим количеством салата на стороне.

  • Учиться, что делает греческий салат, выделяется, сети нужно больше наблюдений за греческими салатами.

Функции помощника

function downloadExampleFoodImagesData(url, dataDir)
% Download the Example Food Image data set, containing 978 images of
% different types of food split into 9 classes.

% Copyright 2019 The MathWorks, Inc.

fileName = "ExampleFoodImageDataset.zip";
fileFullPath = fullfile(dataDir, fileName);

% Download the .zip file into a temporary directory.
if ~exist(fileFullPath, "file")
    fprintf("Downloading MathWorks Example Food Image dataset...\n");
    fprintf("This can take several minutes to download...\n");
    websave(fileFullPath, url);
    fprintf("Download finished...\n");
else
    fprintf("Skipping download, file already exists...\n");
end

% Unzip the file.
%
% Check if the file has already been unzipped by checking for the presence
% of one of the class directories.
exampleFolderFullPath = fullfile(dataDir, "pizza");
if ~exist(exampleFolderFullPath, "dir")
    fprintf("Unzipping file...\n");
    unzip(fileFullPath, dataDir);
    fprintf("Unzipping finished...\n");
else
    fprintf("Skipping unzipping, file already unzipped...\n");
end
fprintf("Done.\n");

end

function [sortedScores, imgIdx] = findMaxActivatingImages(imds, className, predictedScores, numImgsToShow)
% Find the predicted scores of the chosen class on all the images of the chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
[scoresForChosenClass, imgsOfClassIdxs] = findScoresForChosenClass(imds, className, predictedScores);

% Sort the scores in descending order
[sortedScores, idx] = sort(scoresForChosenClass, 'descend');

% Return the indices of only the first few
imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));

end

function [sortedScores, imgIdx] = findMinActivatingImages(imds, className, predictedScores, numImgsToShow)
% Find the predicted scores of the chosen class on all the images of the chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
[scoresForChosenClass, imgsOfClassIdxs] = findScoresForChosenClass(imds, className, predictedScores);

% Sort the scores in ascending order
[sortedScores, idx] = sort(scoresForChosenClass, 'ascend');

% Return the indices of only the first few
imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));

end

function [scoresForChosenClass, imgsOfClassIdxs] = findScoresForChosenClass(imds, className, predictedScores)
% Find the index of className (e.g. "sushi" is the 9th class)
uniqueClasses = unique(imds.Labels);
chosenClassIdx = string(uniqueClasses) == className;

% Find the indices in imageDatastore that are images of label "className"
% (e.g. find all images of class sushi)
imgsOfClassIdxs = find(imds.Labels == className);

% Find the predicted scores of the chosen class on all the images of the
% chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
scoresForChosenClass = predictedScores(imgsOfClassIdxs,chosenClassIdx);
end

function plotImages(imds, imgIdx, sortedScores, predictedClasses,numImgsToShow)

for i=1:numImgsToShow
    score = sortedScores(i);
    sortedImgIdx = imgIdx(i);
    predClass = predictedClasses(sortedImgIdx);
    correctClass = imds.Labels(sortedImgIdx);
    imgPath = imds.Files{sortedImgIdx};
    
    if predClass == correctClass
        color = "\color{green}";
    else
        color = "\color{red}";
    end
    
    subplot(3,ceil(numImgsToShow./3),i)
    imshow(imread(imgPath));
    title("Predicted: " + color + string(predClass) + "\newline\color{black}Score: " + num2str(score) + "\newlineGround truth: " + string(correctClass));
end

end

function [convMap,dScoresdMap] = gradcam(dlnet, dlImg, softmaxName, convLayerName, classfn)
% Computes the Grad-CAM map for a dlnetwork, taking the derivative of the softmax layer score
% for a given class with respect to a convolutional feature map.
[scores,convMap] = predict(dlnet, dlImg, 'Outputs', {softmaxName, convLayerName});
classScore = scores(classfn);
dScoresdMap = dlgradient(classScore,convMap);
end

function gradcamMap = computeGradCAM(dlnet, img, softmaxName, convLayerName, label)
% For automatic differentiation, the input image img must be a dlarray.
dlImg = dlarray(single(img),'SSC');

% Compute the gradCAM map by passing the dlarray image
[convMap, dScoresdMap] = dlfeval(@gradcam, dlnet, dlImg, softmaxName, convLayerName, label);

% Resize the gradient map to the net image size, and scale the scores to the appropriate levels for display.
gradcamMap = sum(convMap .* sum(dScoresdMap, [1 2]), 3);
gradcamMap = extractdata(gradcamMap);
gradcamMap = rescale(gradcamMap);
gradcamMap = imresize(gradcamMap, dlnet.Layers(1).InputSize(1:2), 'Method', 'bicubic');
end

function plotGradCAM(img, gradcamMap, alpha)

subplot(1,2,1)
imshow(img);

h= subplot(1,2,2);
imshow(img)
hold on;
imagesc(gradcamMap,'AlphaData', alpha);

originalSize2 = get(h, 'Position');

colormap jet
colorbar

set(h, 'Position', originalSize2);
hold off;
end

Смотрите также

| | | | |

Связанные примеры

Больше о