Исследуйте сеть Семантической Сегментации Используя CAM градиента

В этом примере показано, как исследовать предсказания сети семантической сегментации использование CAM градиента.

Семантическая сеть сегментации классифицирует каждый пиксель в изображении, получая к изображение, которое сегментировано по классам. Можно использовать CAM градиента, метод визуализации глубокого обучения, чтобы видеть, какие области изображения важны для решения классификации пикселей.

Загрузите набор данных

Этот пример использует набор данных CamVid [1] из Кембриджского университета для обучения. Этот набор данных является набором изображений, содержащих представления уличного уровня, полученные при управлении. Набор данных обеспечивает метки пиксельного уровня для 32 семантических классов, включая автомобиль, пешехода и дорогу.

Загрузите набор данных CamVid

Загрузите набор данных CamVid.

rng('default')

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

outputFolder = fullfile(tempdir,'CamVid'); 
labelsZip = fullfile(outputFolder,'labels.zip');
imagesZip = fullfile(outputFolder,'images.zip');

if ~exist(labelsZip, 'file') || ~exist(imagesZip,'file')   
    mkdir(outputFolder)
       
    disp('Downloading 16 MB CamVid data set labels...'); 
    websave(labelsZip, labelURL);
    unzip(labelsZip, fullfile(outputFolder,'labels'));
    
    disp('Downloading 557 MB CamVid data set images...');  
    websave(imagesZip, imageURL);       
    unzip(imagesZip, fullfile(outputFolder,'images'));    
end
Downloading 16 MB CamVid data set labels...
Downloading 557 MB CamVid data set images...

Загрузите изображения CamVid

Используйте imageDatastore загружать изображения CamVid. imageDatastore позволяет вам эффективно загрузить большое количество изображений на диске.

imgDir = fullfile(outputFolder,'images','701_StillsRaw_full');
imds = imageDatastore(imgDir);

Набор данных содержит 32 класса. Чтобы сделать обучение легче, сократите количество классов к 11 путем группировки нескольких классов от исходного набора данных вместе. Например, создайте "Car"класс, который комбинирует "Car", "SUVPickupTruck", "Truck_Bus", "Train", и "OtherMoving"классы от исходного набора данных. Возвратите сгруппированную метку IDs при помощи функции поддержки camvidPixelLabelIDs, который перечислен в конце этого примера.

classes = [
    "Sky"
    "Building"
    "Pole"
    "Road"
    "Pavement"
    "Tree"
    "SignSymbol"
    "Fence"
    "Car"
    "Pedestrian"
    "Bicyclist"
    ];

labelIDs = camvidPixelLabelIDs;

Используйте классы и метку IDs, чтобы создать pixelLabelDatastore.

labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

Загрузите предварительно обученную сеть Семантической Сегментации

Загрузите предварительно обученную сеть семантической сегментации. Предварительно обученная модель позволяет вам запускать целый пример, не имея необходимость ожидать обучения завершиться. Этот пример загружает обученный Deeplab v3 + сеть с весами, инициализированными от предварительно обученной сети ResNet-18. Чтобы получить предварительно обученный ResNet-18, установите resnet18 (Deep Learning Toolbox). Для получения дополнительной информации о создании и обучении сети семантической сегментации, смотрите, что Семантическая Сегментация Использует Глубокое обучение.

pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid.mat';
pretrainedFolder = fullfile(tempdir,'pretrainedNetwork');
pretrainedNetwork = fullfile(pretrainedFolder,'deeplabv3plusResnet18CamVid.mat');

if ~exist(pretrainedNetwork,'file')
    mkdir(pretrainedFolder);
    disp('Downloading pretrained network (58 MB)...');
    websave(pretrainedNetwork,pretrainedURL);
end
Downloading pretrained network (58 MB)...
pretrainedNet = load(pretrainedNetwork); 
net = pretrainedNet.net;

Тестирование сети

Обученная сеть семантической сегментации предсказывает метку каждого пикселя в изображении. Можно протестировать сеть путем предсказания пиксельных меток изображения.

Загрузите тестовое изображение.

figure
img = readimage(imds,615);
imshow(img,'InitialMagnification',35)

Используйте semanticseg функция, чтобы предсказать пиксельные метки изображения при помощи обученной сети семантической сегментации.

predLabels = semanticseg(img,net);

Отобразите результаты.

cmap = camvidColorMap;
segImg = labeloverlay(img,predLabels,'Colormap',cmap,'Transparency',0.4);
figure
imshow(segImg,'InitialMagnification',40)

pixelLabelColorbar(cmap,classes)

Вы видите, что сеть помечает части изображения справедливо точно. Сеть действительно неправильно классифицирует некоторые области, например, дорогу слева от пересечения, которое частично неправильно классифицируется как тротуар.

Исследуйте сетевые предсказания

Глубокие сети являются комплексными, так понимают, как сеть решает, что конкретное предсказание затрудняет. Можно использовать CAM градиента, чтобы видеть, какие области тестового изображения сеть семантической сегментации использует, чтобы сделать ее классификации пикселей.

CAM градиента вычисляет градиент дифференцируемого выхода, такого как счет класса, относительно сверточных функций в выбранном слое. CAM градиента обычно используется для задач классификации изображений [2]; однако, это может также быть расширено к проблемам семантической сегментации [3].

В задачах семантической сегментации, softmax слое сетевых выходных параметров счет к каждому классу для каждого пикселя в оригинальном изображении. Это контрастирует со стандартными проблемами классификации изображений, где softmax слой выводит счет к каждому классу для целого изображения. Карта CAM градиента для класса c

Mc=ReLU(kαckAk) где αck=1/Ni,jdycdAi,jk

N количество пикселей, Ak карта функции интереса, и yc соответствует скалярному счету класса. Для простой проблемы классификации изображений, yc счет softmax к классу интереса. Для семантической сегментации можно получитьyc путем сокращения мудрой пикселем музыки класса к классу интереса для скаляра. Например, суммируйте по пространственным размерностям softmax слоя: yc=(i,j)Pyi,jc, где P пиксели в выходном слое сети семантической сегментации [3]. В этом примере выходной слой является softmax слоем перед слоем классификации пикселей. Карта Mc области подсветок, которые влияют на решение для класса c. Более высокие значения указывают на области изображения, которые важны для решения классификации пикселей.

Чтобы использовать CAM градиента, необходимо выбрать слой функции, чтобы извлечь карту функции из и слой сокращения, чтобы извлечь выходные активации из. Используйте analyzeNetwork найти, что слои используют с CAM градиента.

analyzeNetwork(net)

Задайте слой функции. Обычно это - слой ReLU, который берет выход сверточного слоя в конце сети.

featureLayer = 'dec_relu4';

Задайте слой сокращения. gradCAM функционируйте суммирует пространственные размерности слоя сокращения, для заданных классов, чтобы произвести скалярное значение. Это скалярное значение затем дифференцируется относительно каждой функции в слое функции. Для проблем семантической сегментации слой сокращения обычно является softmax слоем.

reductionLayer = 'softmax-out';

Вычислите карту CAM градиента для классов дороги и тротуара.

classes = ["Road" "Pavement"];

gradCAMMap = gradCAM(net,img,classes, ...
    'ReductionLayer',reductionLayer, ...
    'FeatureLayer',featureLayer);

Сравните карту CAM градиента для этих двух классов к карте семантической сегментации.

predLabels = semanticseg(img,net);
segMap = labeloverlay(img,predLabels,'Colormap',cmap,'Transparency',0.4);

figure;
subplot(2,2,1)
imshow(img)
title('Test Image')
subplot(2,2,2)
imshow(segMap)
title('Semantic Segmentation')
subplot(2,2,3)
imshow(img)
hold on
imagesc(gradCAMMap(:,:,1),'AlphaData',0.5)
title('Grad-CAM: ' + classes(1))
colormap jet
subplot(2,2,4)
imshow(img)
hold on
imagesc(gradCAMMap(:,:,2),'AlphaData',0.5)
title('Grad-CAM: ' + classes(2))
colormap jet

Карты CAM градиента и карта семантической сегментации показывают подобное выделение. Ни одна из карт не отличает дорогу слева от пересечения, которое карта семантической сегментации помечает как тротуар. Карта CAM градиента для класса тротуара показывает, что ребро тротуара более важно, чем центр решения классификации о сети. Сеть возможно неправильно классифицирует дорогу слева от пересечения из-за плохой видимости ребра тротуара.

Исследуйте промежуточные слои

Карта CAM градиента напоминает карту семантической сегментации, когда вы используете слой около конца сети для расчета. Можно также использовать CAM градиента, чтобы исследовать промежуточные слои в обучившем сеть. Более ранние слои имеют небольшой восприимчивый размер поля и изучают маленькие, низкоуровневые функции по сравнению со слоями в конце сети.

Вычислите карту CAM градиента для слоев, которые последовательно более глубоки в сети.

layers = ["res5b_relu","catAspp","dec_relu1"];
numLayers = length(layers);

res5b_relu слой около середины сети, тогда как dec_relu1 около конца сети.

Исследуйте сетевые решения классификации для автомобиля, дороги и классов тротуара. Для каждого слоя и класса, вычислите карту CAM градиента.

classes = ["Car" "Road" "Pavement"];
numClasses = length(classes);

gradCAMMaps = [];
for i = 1:numLayers
    gradCAMMaps(:,:,:,i) = gradCAM(net,img,classes, ...
        'ReductionLayer',reductionLayer, ...
        'FeatureLayer',layers(i));
end

Отобразите карты CAM градиента для каждого слоя и каждого класса. Строки представляют карту для каждого слоя со слоями, упорядоченными от тех рано в сети тем в конце сети.

figure;
idx = 1;
for i=1:numLayers
    for j=1:numClasses
        subplot(numLayers,numClasses,idx)
        imshow(img)
        hold on
        imagesc(gradCAMMaps(:,:,j,i),'AlphaData',0.5)
        title(sprintf("%s (%s)",classes(j),layers(i)), ...
            "Interpreter","none")
        colormap jet
        idx = idx + 1;
    end
end

Более поздние слои производят карты, очень похожие на карту сегментации. Однако слои ранее в сети приводят к более абстрактным результатам и обычно более касаются более низких функций уровня как ребра с меньшей осведомленностью о семантических классах. Например, в картах для более ранних слоев, вы видите, что и для автомобиля и для дорожных классов, светофор подсвечен. Это предлагает, чтобы более ранние слои фокусировались на областях изображения, которые связаны с классом, но не обязательно принадлежат ему. Например, светофор, вероятно, появится близко к дороге, таким образом, сетевая сила использовать эту информацию, чтобы предсказать, какие пиксели являются дорогами. Можно также видеть, что для класса тротуара, более ранние слои высоко фокусируются на ребре, предполагая, что эта функция важна для сети при обнаружении, какие пиксели находятся в классе тротуара.

Ссылки

[1] Brostow, Габриэль Дж., Жюльен Фокер и Роберто Сиполья. “Семантические Классы объектов в Видео: База данных Основной истины Высокой четкости”. Буквы Распознавания образов 30, № 2 (январь 2009): 88–97. https://doi.org/10.1016/j.patrec.2008.04.005.

[2] Selvaraju, R. R. М. Когсвелл, A. Десять кубометров, Р. Ведэнтэм, Д. Пэрих и Д. Бэтра. "CAM градиента: Визуальные Объяснения от Глубоких Сетей через Основанную на градиенте Локализацию". На Международной конференции IEEE по вопросам Компьютерного зрения (ICCV), 2017, стр 618–626. Доступный в Grad-CAM на веб-сайте Открытого доступа Основы Компьютерного зрения.

[3] Виноградова, Кира, Александр Дибров и Джин Майерс. “К Поддающейся толкованию Семантической Сегментации через Взвешенное Градиентом Отображение Активации Класса (Студенческий Краткий обзор)”. Продолжения Конференции AAAI по Искусственному интеллекту 34, № 10 (3 апреля 2020): 13943–44. https://doi.org/10.1609/aaai.v34i10.7244.

Вспомогательные Функции

function labelIDs = camvidPixelLabelIDs()
% Return the label IDs corresponding to each class.
%
% The CamVid data set has 32 classes. Group them into 11 classes following
% the original SegNet training methodology [1].
%
% The 11 classes are:
%   "Sky", "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol",
%   "Fence", "Car", "Pedestrian",  and "Bicyclist".
%
% CamVid pixel label IDs are provided as RGB color values. Group them into
% 11 classes and return them as a cell array of M-by-3 matrices. The
% original CamVid class names are listed alongside each RGB value. Note
% that the Other/Void class are excluded below.
labelIDs = { ...
    
    % "Sky"
    [
    128 128 128; ... % "Sky"
    ]
    
    % "Building" 
    [
    000 128 064; ... % "Bridge"
    128 000 000; ... % "Building"
    064 192 000; ... % "Wall"
    064 000 064; ... % "Tunnel"
    192 000 128; ... % "Archway"
    ]
    
    % "Pole"
    [
    192 192 128; ... % "Column_Pole"
    000 000 064; ... % "TrafficCone"
    ]
    
    % Road
    [
    128 064 128; ... % "Road"
    128 000 192; ... % "LaneMkgsDriv"
    192 000 064; ... % "LaneMkgsNonDriv"
    ]
    
    % "Pavement"
    [
    000 000 192; ... % "Sidewalk" 
    064 192 128; ... % "ParkingBlock"
    128 128 192; ... % "RoadShoulder"
    ]
        
    % "Tree"
    [
    128 128 000; ... % "Tree"
    192 192 000; ... % "VegetationMisc"
    ]
    
    % "SignSymbol"
    [
    192 128 128; ... % "SignSymbol"
    128 128 064; ... % "Misc_Text"
    000 064 064; ... % "TrafficLight"
    ]
    
    % "Fence"
    [
    064 064 128; ... % "Fence"
    ]
    
    % "Car"
    [
    064 000 128; ... % "Car"
    064 128 192; ... % "SUVPickupTruck"
    192 128 192; ... % "Truck_Bus"
    192 064 128; ... % "Train"
    128 064 064; ... % "OtherMoving"
    ]
    
    % "Pedestrian"
    [
    064 064 000; ... % "Pedestrian"
    192 128 064; ... % "Child"
    064 000 192; ... % "CartLuggagePram"
    064 128 064; ... % "Animal"
    ]
    
    % "Bicyclist"
    [
    000 128 192; ... % "Bicyclist"
    192 000 192; ... % "MotorcycleScooter"
    ]
    
    };
end
function pixelLabelColorbar(cmap, classNames)
% Add a colorbar to the current axis. The colorbar is formatted
% to display the class names with the color.

colormap(gca,cmap)

% Add a colorbar to the current figure.
c = colorbar('peer', gca);

% Use class names for tick marks.
c.TickLabels = classNames;
numClasses = size(cmap,1);

% Center tick labels.
c.Ticks = 1/(numClasses*2):1/numClasses:1;

% Remove tick marks.
c.TickLength = 0;
end

function cmap = camvidColorMap
% Define the colormap used by the CamVid data set.

cmap = [
    128 128 128   % Sky
    128 0 0       % Building
    192 192 192   % Pole
    128 64 128    % Road
    60 40 222     % Pavement
    128 128 0     % Tree
    192 128 128   % SignSymbol
    64 64 128     % Fence
    64 0 128      % Car
    64 64 0       % Pedestrian
    0 128 192     % Bicyclist
    ];

% Normalize between [0 1].
cmap = cmap ./ 255;
end