В этом примере показано, как использовать набор данных, чтобы выяснить, что активирует каналы глубокой нейронной сети. Это позволяет понять, как работает нейронная сеть, и диагностировать потенциальные проблемы с помощью обучающего набора данных.
В этом примере рассматривается ряд простых методов визуализации с использованием информации о переносе GoogLeNet в наборе данных о продуктах питания. Просматривая изображения, которые максимально или минимально активируют классификатор, можно обнаружить, почему нейронная сеть получает классификации неправильно.
Загрузите изображения как хранилище данных изображений. Этот небольшой набор данных содержит в общей сложности 978 наблюдений с 9 классами пищи.
Разбейте эти данные на обучающие, проверочные и тестовые наборы для подготовки к обучению с использованием GoogLeNet. Отображение выбранных изображений из набора данных.
rng default dataDir = fullfile(tempdir,"Food Dataset"); url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip"; if ~exist(dataDir,"dir") mkdir(dataDir); end downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset... This can take several minutes to download... Download finished... Unzipping file... Unzipping finished... Done.
imds = imageDatastore(dataDir, ... "IncludeSubfolders",true,"LabelSource","foldernames"); [imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.6,0.2); rnd = randperm(numel(imds.Files),9); for i = 1:numel(rnd) subplot(3,3,i) imshow(imread(imds.Files{rnd(i)})) label = imds.Labels(rnd(i)); title(label,"Interpreter","none") end

Используйте предварительно подготовленную сеть GoogLeNet и обучите ее снова, чтобы классифицировать 9 типов продуктов питания. Если у вас нет установленного пакета поддержки Deep Learning Toolbox™ Model для сети GoogLeNet, то программное обеспечение предоставляет ссылку для загрузки.
Чтобы попробовать другую предварительно подготовленную сеть, откройте этот пример в MATLAB ® и выберите другую сеть, например squeezenet, сеть, которая даже быстрее, чем googlenet. Список всех доступных сетей см. в разделе Предварительно обученные глубокие нейронные сети.
net = googlenet;
Первый элемент Layers свойство сети - уровень ввода изображения. Этот слой требует входных изображений размера 224 на 224 на 3, где 3 количество цветных каналов.
inputSize = net.Layers(1).InputSize;
Сверточные уровни сетевого извлеченного изображения отличаются тем, что последний обучаемый уровень и конечный уровень классификации используют для классификации входного изображения. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержат информацию о том, как объединить функции, извлекаемые сетью, в вероятности классов, значение потерь и прогнозируемые метки. Чтобы обучить предварительно подготовленную сеть классифицировать новые изображения, замените эти два слоя новыми слоями, адаптированными к новому набору данных.
Извлеките график уровня из обученной сети.
lgraph = layerGraph(net);
В большинстве сетей последний уровень с обучаемыми весами является полностью подключенным. Замените этот полностью подключенный уровень новым полностью подключенным уровнем с количеством выходов, равным количеству классов в новом наборе данных (9, в данном примере).
numClasses = numel(categories(imdsTrain.Labels)); newfclayer = fullyConnectedLayer(numClasses,... 'Name','new_fc',... 'WeightLearnRateFactor',10,... 'BiasLearnRateFactor',10); lgraph = replaceLayer(lgraph,net.Layers(end-2).Name,newfclayer);
Уровень классификации определяет выходные классы сети. Замените классификационный слой новым без меток классов. trainNetwork автоматически устанавливает выходные классы слоя во время обучения.
newclasslayer = classificationLayer('Name','new_classoutput'); lgraph = replaceLayer(lgraph,net.Layers(end).Name,newclasslayer);
Сеть требует входных изображений размера 224 на 224 на 3, но у изображений в хранилище данных изображения есть различные размеры. Используйте хранилище данных дополненного изображения для автоматического изменения размеров обучающих изображений. Укажите дополнительные операции увеличения, выполняемые на обучающих изображениях: случайное разворот обучающих изображений вдоль вертикальной оси, случайное перемещение до 30 пикселей и масштабирование до 10% по горизонтали и вертикали. Увеличение объема данных помогает предотвратить переоборудование сети и запоминание точных деталей обучающих изображений.
pixelRange = [-30 30]; scaleRange = [0.9 1.1]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange, ... 'RandXScale',scaleRange, ... 'RandYScale',scaleRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ... 'DataAugmentation',imageAugmenter);
Чтобы автоматически изменять размер изображений проверки без дальнейшего увеличения данных, используйте хранилище данных дополненного изображения без указания дополнительных операций предварительной обработки.
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Укажите параметры обучения. Набор InitialLearnRate до небольшого значения, чтобы замедлить обучение в перенесенных слоях, которые еще не заморожены. На предыдущем шаге вы увеличили коэффициенты скорости обучения для последнего обучаемого уровня, чтобы ускорить обучение на новых конечных уровнях. Это сочетание настроек скорости обучения приводит к быстрому обучению в новых слоях, более медленному обучению в средних слоях и отсутствию обучения в более ранних замороженных слоях.
Укажите количество эпох для обучения. При выполнении трансферного обучения не нужно тренироваться на столько же эпох. Эпоха - это полный цикл обучения по всему набору данных обучения. Укажите размер мини-партии и данные проверки. Вычислите точность проверки один раз в эпоху.
miniBatchSize = 10; valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... 'MaxEpochs',4, ... 'InitialLearnRate',3e-4, ... 'Shuffle','every-epoch', ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',valFrequency, ... 'Verbose',false, ... 'Plots','training-progress');
Обучение сети с использованием данных обучения. По умолчанию trainNetwork использует графический процессор, если он доступен. Для этого требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox). В противном случае trainNetwork использует ЦП. Можно также указать среду выполнения с помощью 'ExecutionEnvironment' аргумент пары имя-значение trainingOptions. Поскольку этот набор данных невелик, обучение идет быстро. Если запустить этот пример и обучить сеть самостоятельно, будут получены различные результаты и неправильные классификации, вызванные случайностью, связанной с процессом обучения.
net = trainNetwork(augimdsTrain,lgraph,options);

Классифицируйте тестовые изображения с помощью отлаженной сети и рассчитайте точность классификации.
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest); [predictedClasses,predictedScores] = classify(net,augimdsTest); accuracy = mean(predictedClasses == imdsTest.Labels)
accuracy = 0.8418
Постройте график матрицы путаницы прогнозов тестового набора. Это подчеркивает, какие конкретные классы вызывают большинство проблем в сети.
figure; confusionchart(imdsTest.Labels,predictedClasses,'Normalization',"row-normalized");

Таблица путаницы показывает, что сеть имеет плохую производительность для некоторых классов, таких как греческий салат, сашими, хот-дог и суши. Эти классы недопредставлены в наборе данных, что может повлиять на производительность сети. Исследуйте один из этих классов, чтобы лучше понять, почему сеть борется.
figure();
histogram(imdsValidation.Labels);
ax = gca();
ax.XAxis.TickLabelInterpreter = "none";
Изучите классификацию сети для класса суши.
Во-первых, найдите, какие изображения суши наиболее сильно активируют сеть для класса суши. Это отвечает на вопрос «Какие изображения, по мнению сети, больше всего напоминают суши?».
Постройте график максимально активирующих изображений, это входные изображения, которые сильно активируют «суши» нейрон полностью связанного слоя. На этом рисунке показаны 4 лучших изображения в убывающем классе.
chosenClass = "sushi";
classIdx = find(net.Layers(end).Classes == chosenClass);
numImgsToShow = 4;
[sortedScores,imgIdx] = findMaxActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);
figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)
Сеть смотрит на правильную вещь для суши? Максимально активирующие изображения класса суши для сети все выглядят похожими друг на друга - множество круглых форм, тесно сгруппированных вместе.
Сеть успешно классифицирует эти виды суши. Однако, чтобы убедиться, что это правда, и лучше понять, почему сеть принимает свои решения, используйте технику визуализации, такую как Grad-CAM. Дополнительные сведения об использовании Grad-CAM см. в разделе Описание решений Grad-CAM, лежащих в основе глубокого обучения.
Прочитайте первое изображение с измененным размером из хранилища данных дополненного изображения, а затем постройте график визуализации Grad-CAM с помощью gradCAM.
imageNumber = 1;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};
label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);
gradcamMap = gradCAM(net,img,label);
figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")
Карта Grad-CAM подтверждает, что сеть фокусируется на суши на изображении. Однако также видно, что сеть смотрит на части пластины и таблицы.
Второе изображение имеет скопление суши слева и одинокие суши справа. Чтобы увидеть, на чем фокусируется сеть, прочитайте второе изображение и постройте график Grad-CAM.
imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};
label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);
gradcamMap = gradCAM(net,img,label);
figure
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")
Сеть классифицирует это изображение как суши, потому что она видит группу суши. Однако способен ли он классифицировать одно суши самостоятельно? Проверьте это, посмотрев на фотографию только одного суши.
img = imread(strcat(tempdir,"Food Dataset/sushi/sushi_18.jpg")); img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true); [label,score] = classify(net,img); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); sgtitle(string(label)+" (score: "+ max(score)+")")

Сеть способна правильно классифицировать эти одинокие суши. Тем не менее, GradCAM показывает, что сеть фокусируется на верхней части суши и скоплении огурца, а не на всей части вместе.
Запустите метод визуализации Grad-CAM на одиноких суши, которые не содержат никаких сложенных небольших кусочков ингредиентов.
img = imread("crop__sushi34-copy.jpg"); img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true); [label,score] = classify(net,img); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); title(string(label)+" (score: "+ max(score)+")")

В этом случае метод визуализации подчеркивает, почему сеть работает плохо. Неправильно классифицирует изображение суши как гамбургер.
Чтобы решить эту проблему, необходимо накормить сеть большим количеством изображений одиноких суши в процессе обучения.
Теперь найдите, какие изображения суши активировать сеть для класса суши меньше всего. Это отвечает на вопрос «Какие изображения в сети считают менее суши-подобными?».
Это полезно, потому что он находит образы, на которых сеть работает плохо, и дает некоторое представление о своем решении.
chosenClass = "sushi";
numImgsToShow = 9;
[sortedScores,imgIdx] = findMinActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);
figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)
Почему сеть классифицирует суши как сашими? Сеть классифицирует 3 из 9 изображений как сашими. Некоторые из этих изображений, например изображения 4 и 9, фактически содержат сашими, что означает, что сеть на самом деле не неправильно классифицирует их. Эти изображения имеют неправильную маркировку.
Чтобы увидеть, на чем фокусируется сеть, запустите технику Grad-CAM на одном из этих изображений.
imageNumber = 4;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};
label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);
gradcamMap = gradCAM(net,img,label);
figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")
Как и ожидалось, сеть ориентируется на сашими вместо суши.
Почему сеть классифицирует суши как пиццу? Сеть классифицирует четыре изображения как пиццу вместо суши. Рассмотрим изображение 1, это изображение имеет цветную вершину, которая может запутать сеть.
Чтобы увидеть, какую часть изображения просматривает сеть, выполните метод Grad-CAM на одном из этих изображений.
imageNumber = 1;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};
label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);
gradcamMap = gradCAM(net,img,label);
figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")
Сеть сильно фокусируется на топингах. Чтобы помочь сети отличить пиццу от суши с топингами, добавьте больше обучающих изображений суши с топингами. Сеть также фокусируется на плите. Это может быть так, как сеть научилась связывать определенные продукты с определенными типами тарелок, как также подчеркивается при взгляде на изображения суши. Чтобы улучшить работу сети, тренируйтесь, используя больше примеров пищи на различных типах тарелок.
Почему сеть классифицирует суши как гамбургер? Чтобы увидеть, на чем фокусируется сеть, выполните метод Grad-CAM на неправильно классифицированном изображении.
imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};
label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);
gradcamMap = gradCAM(net,img,label);
figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")
Сеть фокусируется на цветке на изображении. Красочный фиолетовый цветок и коричневый стебель спутали сеть в идентификации этого изображения как гамбургера.
Почему сеть классифицирует суши как картофель фри? Сеть классифицирует 3-е изображение как картофель фри вместо суши. Этот конкретный суши имеет желтый верх, и сеть может связать этот цвет с картофелем фри.
Запустите программу Grad-CAM для этого образа.
imageNumber = 3;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};
label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);
gradcamMap = gradCAM(net,img,label);
figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")","Interpreter","none")
Сети фокусируются на желтых суши, классифицирующих его как картофель фри. Как и в случае с гамбургером, необычный цвет вызвал неправильную классификацию суши в сети.
Чтобы помочь сети в этом конкретном случае, тренируйте ее с большим количеством изображений желтых продуктов, которые не являются картофелем фри.
Исследование точек данных, которые дают большие или малые оценки классов, и точек данных, которые сеть классифицирует уверенно, но неправильно, является простым методом, который может дать полезное представление о том, как работает обученная сеть. В случае набора данных о продуктах питания этот пример показал, что:
Данные теста содержат несколько изображений с неправильными истинными метками, например «сашими», который на самом деле является «суши». Данные также содержат неполные метки, такие как изображения, которые содержат как суши, так и сашими.
Сеть считает «суши» «множественными, кластеризованными, круглыми вещами». Однако он должен уметь различать и одинокие суши.
Любые суши или сашими с топпингами или необычными цветами путают сеть. Для решения этой проблемы данные должны иметь более широкое разнообразие суши и сашими.
Для повышения производительности сети необходимо видеть больше образов из недопредставленных классов.
function downloadExampleFoodImagesData(url,dataDir) % Download the Example Food Image data set, containing 978 images of % different types of food split into 9 classes. % Copyright 2019 The MathWorks, Inc. fileName = "ExampleFoodImageDataset.zip"; fileFullPath = fullfile(dataDir,fileName); % Download the .zip file into a temporary directory. if ~exist(fileFullPath,"file") fprintf("Downloading MathWorks Example Food Image dataset...\n"); fprintf("This can take several minutes to download...\n"); websave(fileFullPath,url); fprintf("Download finished...\n"); else fprintf("Skipping download, file already exists...\n"); end % Unzip the file. % % Check if the file has already been unzipped by checking for the presence % of one of the class directories. exampleFolderFullPath = fullfile(dataDir,"pizza"); if ~exist(exampleFolderFullPath,"dir") fprintf("Unzipping file...\n"); unzip(fileFullPath,dataDir); fprintf("Unzipping finished...\n"); else fprintf("Skipping unzipping, file already unzipped...\n"); end fprintf("Done.\n"); end function [sortedScores,imgIdx] = findMaxActivatingImages(imds,className,predictedScores,numImgsToShow) % Find the predicted scores of the chosen class on all the images of the chosen class % (e.g. predicted scores for sushi on all the images of sushi) [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores); % Sort the scores in descending order [sortedScores,idx] = sort(scoresForChosenClass,'descend'); % Return the indices of only the first few imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow)); end function [sortedScores,imgIdx] = findMinActivatingImages(imds,className,predictedScores,numImgsToShow) % Find the predicted scores of the chosen class on all the images of the chosen class % (e.g. predicted scores for sushi on all the images of sushi) [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores); % Sort the scores in ascending order [sortedScores,idx] = sort(scoresForChosenClass,'ascend'); % Return the indices of only the first few imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow)); end function [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores) % Find the index of className (e.g. "sushi" is the 9th class) uniqueClasses = unique(imds.Labels); chosenClassIdx = string(uniqueClasses) == className; % Find the indices in imageDatastore that are images of label "className" % (e.g. find all images of class sushi) imgsOfClassIdxs = find(imds.Labels == className); % Find the predicted scores of the chosen class on all the images of the % chosen class % (e.g. predicted scores for sushi on all the images of sushi) scoresForChosenClass = predictedScores(imgsOfClassIdxs,chosenClassIdx); end function plotImages(imds,imgIdx,sortedScores,predictedClasses,numImgsToShow) for i=1:numImgsToShow score = sortedScores(i); sortedImgIdx = imgIdx(i); predClass = predictedClasses(sortedImgIdx); correctClass = imds.Labels(sortedImgIdx); imgPath = imds.Files{sortedImgIdx}; if predClass == correctClass color = "\color{green}"; else color = "\color{red}"; end predClassTitle = strrep(string(predClass),'_',' '); correctClassTitle = strrep(string(correctClass),'_',' '); subplot(3,ceil(numImgsToShow./3),i) imshow(imread(imgPath)); title("Predicted: " + color + predClassTitle + "\newline\color{black}Score: " + num2str(score) + "\newlineGround truth: " + correctClassTitle); end end function plotGradCAM(img,gradcamMap,alpha) subplot(1,2,1) imshow(img); h = subplot(1,2,2); imshow(img) hold on; imagesc(gradcamMap,'AlphaData',alpha); originalSize2 = get(h,'Position'); colormap jet colorbar set(h,'Position',originalSize2); hold off; end
augmentedImageDatastore | classify | confusionchart | dlnetwork | googlenet | gradCAM | imageDatastore | imageLIME | occlusionSensitivity