В этом примере показано, как использовать набор данных, чтобы выяснить, что активирует каналы глубокой нейронной сети. Это позволяет понять, как работает нейронная сеть, и диагностировать потенциальные проблемы с обучающими данными набором.
Этот пример охватывает ряд простых методов визуализации, используя передачу GoogLeNet, полученную на наборе данных о пищевых продуктах. Глядя на изображения, которые максимально или минимально активируют классификатор, можно узнать, почему нейронная сеть получает неправильные классификации.
Загрузите изображения как image datastore. Этот небольшой набор данных содержит в общей сложности 978 наблюдений с 9 классами пищи.
Разделите эти данные на наборы для обучения, валидации и тестирования, чтобы подготовиться к передаче обучения с помощью GoogLeNet. Отобразите выбор изображений из набора данных.
rng default dataDir = fullfile(tempdir,"Food Dataset"); url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip"; if ~exist(dataDir,"dir") mkdir(dataDir); end downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset... This can take several minutes to download... Download finished... Unzipping file... Unzipping finished... Done.
imds = imageDatastore(dataDir, ... "IncludeSubfolders",true,"LabelSource","foldernames"); [imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.6,0.2); rnd = randperm(numel(imds.Files),9); for i = 1:numel(rnd) subplot(3,3,i) imshow(imread(imds.Files{rnd(i)})) label = imds.Labels(rnd(i)); title(label,"Interpreter","none") end
Используйте предварительно обученную сеть GoogLeNet и обучите ее снова, чтобы классифицировать 9 видов пищи. Если у вас нет установленного пакета поддержки Deep Learning Toolbox™ Model для GoogLeNet Network, то программное обеспечение предоставляет ссылку загрузки.
Чтобы попробовать другую предварительно обученную сеть, откройте этот пример в MATLAB ® и выберите другую сеть, такую как squeezenet
, сеть, которая даже быстрее, чем googlenet
. Список всех доступных сетей см. в разделе «Предварительно обученные глубокие нейронные сети».
net = googlenet;
Первый элемент Layers
свойство сети является входным слоем изображения. Этот слой требует входа изображений размера 224 224 3, где 3 количество цветовых каналов.
inputSize = net.Layers(1).InputSize;
Сверточные слои сети извлекают изображение, функции последний выучиваемый слой и конечный слой классификации используют для классификации входа изображения. Эти два слоя, 'loss3-classifier'
и 'output'
в GoogLeNet содержат информацию о том, как объединить функции, которые сеть извлекает в вероятности классов, значение потерь и предсказанные метки. Чтобы обучить предварительно обученную сеть классификации новых изображений, замените эти два слоя новыми слоями, адаптированными к новому набору данных.
Извлеките график слоев из обученной сети.
lgraph = layerGraph(net);
В большинстве сетей последний слой с усвояемыми весами является полносвязным слоем. Замените этот полностью соединенный слой новым полностью соединенным слоем с количеством выходов, равным количеству классов в новом наборе данных (9, в этом примере).
numClasses = numel(categories(imdsTrain.Labels)); newfclayer = fullyConnectedLayer(numClasses,... 'Name','new_fc',... 'WeightLearnRateFactor',10,... 'BiasLearnRateFactor',10); lgraph = replaceLayer(lgraph,net.Layers(end-2).Name,newfclayer);
Слой классификации задает выходные классы сети. Замените слой классификации новым слоем без меток классов. trainNetwork
автоматически устанавливает выходные классы слоя во время обучения.
newclasslayer = classificationLayer('Name','new_classoutput'); lgraph = replaceLayer(lgraph,net.Layers(end).Name,newclasslayer);
Сеть требует изображений входа размера 224 на 224 на 3, но у изображений в изображении datastore есть различные размеры. Используйте хранилище данных дополненных изображений, чтобы автоматически изменить размер обучающих изображений. Задайте дополнительные операции увеличения для выполнения на обучающих изображениях: случайным образом разверните обучающие изображения вдоль вертикальной оси, случайным образом перемещайте их до 30 пикселей и масштабируйте до 10% по горизонтали и вертикали. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.
pixelRange = [-30 30]; scaleRange = [0.9 1.1]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange, ... 'RandXScale',scaleRange, ... 'RandYScale',scaleRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ... 'DataAugmentation',imageAugmenter);
Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшего увеличения данных, используйте хранилище datastore с дополненными изображениями, не задавая никаких дополнительных операций предварительной обработки.
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Задайте опции обучения. Задайте InitialLearnRate
к небольшому значению для замедления обучения в перенесенных слоях, которые еще не заморожены. На предыдущем шаге вы увеличили коэффициенты скорости обучения для последнего обучаемого слоя, чтобы ускорить обучение в новых конечных слоях. Эта комбинация настроек скорости обучения приводит к быстрому обучению в новых слоях, более медленному обучению в средних слоях и отсутствию обучения в более ранних, замороженных слоях.
Укажите количество эпох для обучения. При выполнении передачи обучения вам не нужно тренироваться на столько эпох. Эпоха - это полный цикл обучения на целом наборе обучающих данных. Укажите размер мини-пакета и данные валидации. Вычислите точность валидации один раз в эпоху.
miniBatchSize = 10; valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... 'MaxEpochs',4, ... 'InitialLearnRate',3e-4, ... 'Shuffle','every-epoch', ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',valFrequency, ... 'Verbose',false, ... 'Plots','training-progress');
Обучите сеть с помощью обучающих данных. По умолчанию trainNetwork
использует графический процессор, если он доступен. Для этого требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). В противном случае trainNetwork
использует центральный процессор. Можно также задать окружение выполнения с помощью 'ExecutionEnvironment'
Аргумент пары "имя-значение" из trainingOptions
. Поскольку этот набор данных является маленьким, обучение происходит быстро. Если запустить этот пример и обучить сеть самостоятельно, вы получите различные результаты и неправильные классификации, вызванные случайностью, связанной с процессом обучения.
net = trainNetwork(augimdsTrain,lgraph,options);
Классификация тестовых изображений с помощью тонкой настройки сети и вычисление точности классификации.
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest); [predictedClasses,predictedScores] = classify(net,augimdsTest); accuracy = mean(predictedClasses == imdsTest.Labels)
accuracy = 0.8418
Постройте матрицу неточностей предсказаний тестового набора. Это подчеркивает, какие именно классы вызывают большинство проблем для сети.
figure; confusionchart(imdsTest.Labels,predictedClasses,'Normalization',"row-normalized");
Матрица неточностей показывает, что сеть имеет плохую эффективность для некоторых классов, таких как греческий салат, сашими, хот-дог и суши. Эти классы недостаточно представлены в наборе данных, который может повлиять на эффективность сети. Исследуйте один из этих классов, чтобы лучше понять, почему сеть борется.
figure();
histogram(imdsValidation.Labels);
ax = gca();
ax.XAxis.TickLabelInterpreter = "none";
Исследуйте сетевую классификацию для класса суши.
Во-первых, найдите, какие изображения суши наиболее сильно активируют сеть для класса суши. Это отвечает на вопрос «Какие изображения, по мнению сети, наиболее похожи на суши?».
Постройте максимально активирующие изображения, это входные изображения, которые сильно активируют «суши» нейрон полносвязного слоя. Этот рисунок показывает 4 лучшие изображения в нисходящем счете класса.
chosenClass = "sushi";
classIdx = find(net.Layers(end).Classes == chosenClass);
numImgsToShow = 4;
[sortedScores,imgIdx] = findMaxActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);
figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)
Сеть смотрит на правильную вещь для суши? Максимально активирующие для сети изображения класса суши все выглядят похожими друг на друга - множество круглых форм сгруппированы тесно.
Сеть хорошо классифицирует такие виды суши. Однако, чтобы убедиться, что это правда, и лучше понять, почему сеть принимает свои решения, используйте метод визуализации, такой как Grad-CAM. Для получения дополнительной информации об использовании Grad-CAM, смотрите Grad-CAM Раскрывает Почему За Решения Глубокого Обучения.
Считайте первое измененное изображение из дополненного datastore, затем постройте график визуализации Grad-CAM с помощью gradCAM
.
imageNumber = 1; observation = augimdsTest.readByIndex(imgIdx(imageNumber)); img = observation.input{1}; label = predictedClasses(imgIdx(imageNumber)); score = sortedScores(imageNumber); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); sgtitle(string(label)+" (score: "+ max(score)+")")
Карта Grad-CAM подтверждает, что сеть фокусируется на суши в изображении. Однако можно также увидеть, что сеть смотрит на части пластины и таблицы.
Второе изображение имеет кластер суши слева и одинокий суши справа. Чтобы увидеть, на чем ориентирована сеть, прочитайте второе изображение и постройте график Grad-CAM.
imageNumber = 2; observation = augimdsTest.readByIndex(imgIdx(imageNumber)); img = observation.input{1}; label = predictedClasses(imgIdx(imageNumber)); score = sortedScores(imageNumber); gradcamMap = gradCAM(net,img,label); figure plotGradCAM(img,gradcamMap,alpha); sgtitle(string(label)+" (score: "+ max(score)+")")
Сеть классифицирует это изображение как суши, потому что видит группу суши. Однако способен ли он классифицировать один суши самостоятельно? Протестируйте это, посмотрев на фотографию всего одного суши.
img = imread(strcat(tempdir,"Food Dataset/sushi/sushi_18.jpg")); img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true); [label,score] = classify(net,img); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); sgtitle(string(label)+" (score: "+ max(score)+")")
Сеть способна правильно классифицировать эти одинокие суши. Однако GradCAM показывает, что сеть фокусируется на верхней части суши и кластере огурцов, а не на целом кусочке вместе.
Запустите метод визуализации Grad-CAM на одиноком суши, который не содержит никаких сложенных небольших кусочков ингредиентов.
img = imread("crop__sushi34-copy.jpg"); img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true); [label,score] = classify(net,img); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); title(string(label)+" (score: "+ max(score)+")")
В этом случае метод визуализации подчеркивает, почему сеть работает плохо. Он неправильно классифицирует изображение суши как гамбургер.
Чтобы решить этот вопрос, необходимо накормить сеть большим количеством изображений одиноких суши в процессе обучения.
Теперь найдите, какие изображения суши активируют сеть для класса суши меньше всего. Это отвечает на вопрос «Какие изображения, по мнению сети, менее похожи на суши?».
Это полезно, потому что находит изображения, на которых плохо работает сеть, и дает некоторое представление о своем решении.
chosenClass = "sushi";
numImgsToShow = 9;
[sortedScores,imgIdx] = findMinActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);
figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)
Почему сеть классифицирует суши как сасими? Сеть классифицирует 3 из 9 изображений как сашими. Некоторые из этих изображений, например изображения 4 и 9, на самом деле содержат сашими, что означает, что сеть на самом деле не неправильно классифицирует их. Эти изображения неправильно маркированы.
Чтобы увидеть, на чем ориентирована сеть, запустите метод Grad-CAM на одном из этих изображений.
imageNumber = 4; observation = augimdsTest.readByIndex(imgIdx(imageNumber)); img = observation.input{1}; label = predictedClasses(imgIdx(imageNumber)); score = sortedScores(imageNumber); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); title(string(label)+" (sushi score: "+ max(score)+")")
Как и ожидалось, сеть фокусируется на сасими вместо суши.
Почему сеть классифицирует суши как пиццу? Сеть классифицирует четыре изображения как пиццу вместо суши. Рассмотрим изображение 1, это изображение имеет красочное покрытие, которое может запутать сеть.
Чтобы увидеть, на какую часть изображения смотрит сеть, запустите метод Grad-CAM на одном из этих изображений.
imageNumber = 1; observation = augimdsTest.readByIndex(imgIdx(imageNumber)); img = observation.input{1}; label = predictedClasses(imgIdx(imageNumber)); score = sortedScores(imageNumber); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); title(string(label)+" (sushi score: "+ max(score)+")")
Сеть сильно фокусируется на верхних точках. Чтобы помочь сети отличить пиццу от суши с топпингами, добавьте больше обучающих изображений суши с топпингами. Сеть также фокусируется на диске. Это может быть, так как сеть научилась связывать определенные продукты с определенными типами тарелок, что также подсвечивается при взгляде на изображения суши. Чтобы улучшить эффективность сети, обучайте, используя больше примеров еды на разных типах тарелок.
Почему сеть классифицирует суши как гамбургер? Чтобы увидеть, на чем ориентирована сеть, запустите метод Grad-CAM на неправильно классифицированном изображении.
imageNumber = 2; observation = augimdsTest.readByIndex(imgIdx(imageNumber)); img = observation.input{1}; label = predictedClasses(imgIdx(imageNumber)); score = sortedScores(imageNumber); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); title(string(label)+" (sushi score: "+ max(score)+")")
Сеть фокусируется на цветке в изображении. Красочный фиолетовый цветок и коричневый стебель перепутали сеть, идентифицируя это изображение как гамбургер.
Почему сеть классифицирует суши как картофель фри? Сеть классифицирует 3-е изображение как картофель фри вместо суши. Этот специфический суши имеет желтый топпинг, и сеть может связать этот цвет с картофелем фри.
Запустите Grad-CAM на этом изображении.
imageNumber = 3; observation = augimdsTest.readByIndex(imgIdx(imageNumber)); img = observation.input{1}; label = predictedClasses(imgIdx(imageNumber)); score = sortedScores(imageNumber); gradcamMap = gradCAM(net,img,label); figure alpha = 0.5; plotGradCAM(img,gradcamMap,alpha); title(string(label)+" (sushi score: "+ max(score)+")","Interpreter","none")
Сети фокусируются на желтом суши, классифицируя его как картофель фри. Как и с гамбургером, необычный цвет заставил сеть неправильно классифицировать суши.
Чтобы помочь сети в этом конкретном случае, обучите ее большему количеству изображений желтой пищи, которая не является картофелем фри.
Исследование точек данных, которые вызывают большие или маленькие счета класса, и точек данных, которые сеть классифицирует уверенно, но неправильно, является простым методом, который может предоставить полезное представление о том, как функционирует обученная сеть. В случае набора данных о пищевых продуктах в этом примере подчеркивалось, что:
Тестовые данные содержат несколько изображений с неправильными истинными метками, такими как «сашими», который на самом деле является «суши». Данные также содержат неполные метки, такие как изображения, которые содержат как суши, так и сасими.
Сеть считает «суши» «множественными, кластерными, кругловидными вещами». Однако он должен уметь различать и одиноких суши.
Любые суши или сашими с топингами или необычными цветами путают сеть. Чтобы решить эту проблему, данные должны иметь более широкое разнообразие суши и сасими.
Для повышения эффективности сети необходимо видеть больше изображения из недостаточно представленных классов.
function downloadExampleFoodImagesData(url,dataDir) % Download the Example Food Image data set, containing 978 images of % different types of food split into 9 classes. % Copyright 2019 The MathWorks, Inc. fileName = "ExampleFoodImageDataset.zip"; fileFullPath = fullfile(dataDir,fileName); % Download the .zip file into a temporary directory. if ~exist(fileFullPath,"file") fprintf("Downloading MathWorks Example Food Image dataset...\n"); fprintf("This can take several minutes to download...\n"); websave(fileFullPath,url); fprintf("Download finished...\n"); else fprintf("Skipping download, file already exists...\n"); end % Unzip the file. % % Check if the file has already been unzipped by checking for the presence % of one of the class directories. exampleFolderFullPath = fullfile(dataDir,"pizza"); if ~exist(exampleFolderFullPath,"dir") fprintf("Unzipping file...\n"); unzip(fileFullPath,dataDir); fprintf("Unzipping finished...\n"); else fprintf("Skipping unzipping, file already unzipped...\n"); end fprintf("Done.\n"); end function [sortedScores,imgIdx] = findMaxActivatingImages(imds,className,predictedScores,numImgsToShow) % Find the predicted scores of the chosen class on all the images of the chosen class % (e.g. predicted scores for sushi on all the images of sushi) [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores); % Sort the scores in descending order [sortedScores,idx] = sort(scoresForChosenClass,'descend'); % Return the indices of only the first few imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow)); end function [sortedScores,imgIdx] = findMinActivatingImages(imds,className,predictedScores,numImgsToShow) % Find the predicted scores of the chosen class on all the images of the chosen class % (e.g. predicted scores for sushi on all the images of sushi) [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores); % Sort the scores in ascending order [sortedScores,idx] = sort(scoresForChosenClass,'ascend'); % Return the indices of only the first few imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow)); end function [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores) % Find the index of className (e.g. "sushi" is the 9th class) uniqueClasses = unique(imds.Labels); chosenClassIdx = string(uniqueClasses) == className; % Find the indices in imageDatastore that are images of label "className" % (e.g. find all images of class sushi) imgsOfClassIdxs = find(imds.Labels == className); % Find the predicted scores of the chosen class on all the images of the % chosen class % (e.g. predicted scores for sushi on all the images of sushi) scoresForChosenClass = predictedScores(imgsOfClassIdxs,chosenClassIdx); end function plotImages(imds,imgIdx,sortedScores,predictedClasses,numImgsToShow) for i=1:numImgsToShow score = sortedScores(i); sortedImgIdx = imgIdx(i); predClass = predictedClasses(sortedImgIdx); correctClass = imds.Labels(sortedImgIdx); imgPath = imds.Files{sortedImgIdx}; if predClass == correctClass color = "\color{green}"; else color = "\color{red}"; end predClassTitle = strrep(string(predClass),'_',' '); correctClassTitle = strrep(string(correctClass),'_',' '); subplot(3,ceil(numImgsToShow./3),i) imshow(imread(imgPath)); title("Predicted: " + color + predClassTitle + "\newline\color{black}Score: " + num2str(score) + "\newlineGround truth: " + correctClassTitle); end end function plotGradCAM(img,gradcamMap,alpha) subplot(1,2,1) imshow(img); h = subplot(1,2,2); imshow(img) hold on; imagesc(gradcamMap,'AlphaData',alpha); originalSize2 = get(h,'Position'); colormap jet colorbar set(h,'Position',originalSize2); hold off; end
augmentedImageDatastore
| classify
| confusionchart
| dlnetwork
| googlenet
| gradCAM
| imageDatastore
| imageLIME
| occlusionSensitivity