Исследуйте сеть семантической сегментации с помощью Grad-CAM

В этом примере показано, как исследовать предсказания семантической сети сегментации с помощью Grad-CAM.

Семантическая сеть сегментации классифицирует каждый пиксель в изображении, получая к изображение, которое сегментировано по классам. Можно использовать Grad-CAM, метод визуализации глубокого обучения, чтобы увидеть, какие области изображения важны для решения классификации пикселей.

Загрузка набора данных

Этот пример использует набор данных CamVid [1] из Кембриджского университета для обучения. Этот набор данных представляет собой набор изображений, содержащих представления уличного уровня, полученные во время вождения. Набор данных обеспечивает пиксельные метки уровня для 32 семантических классов, включая автомобиль, пешеходный и дорожный.

Загрузка набора данных CamVid

Загрузите набор данных CamVid.

rng('default')

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

outputFolder = fullfile(tempdir,'CamVid'); 
labelsZip = fullfile(outputFolder,'labels.zip');
imagesZip = fullfile(outputFolder,'images.zip');

if ~exist(labelsZip, 'file') || ~exist(imagesZip,'file')   
    mkdir(outputFolder)
       
    disp('Downloading 16 MB CamVid data set labels...'); 
    websave(labelsZip, labelURL);
    unzip(labelsZip, fullfile(outputFolder,'labels'));
    
    disp('Downloading 557 MB CamVid data set images...');  
    websave(imagesZip, imageURL);       
    unzip(imagesZip, fullfile(outputFolder,'images'));    
end
Downloading 16 MB CamVid data set labels...
Downloading 557 MB CamVid data set images...

Загрузка изображений CamVid

Использование imageDatastore для загрузки изображений CamVid. The imageDatastore позволяет эффективно загружать на диск большой набор изображений.

imgDir = fullfile(outputFolder,'images','701_StillsRaw_full');
imds = imageDatastore(imgDir);

Набор данных содержит 32 класса. Чтобы упростить обучение, уменьшите количество классов до 11 путем группировки нескольких классов из исходного набора данных вместе. Например, создайте "Car«класс, который объединяет» CarSUVPickupTruckTruck_BusTrain, «и» OtherMoving"классы из исходного набора данных. Верните сгруппированные идентификаторы меток с помощью вспомогательной функции camvidPixelLabelIDs, который перечислен в конце этого примера.

classes = [
    "Sky"
    "Building"
    "Pole"
    "Road"
    "Pavement"
    "Tree"
    "SignSymbol"
    "Fence"
    "Car"
    "Pedestrian"
    "Bicyclist"
    ];

labelIDs = camvidPixelLabelIDs;

Используйте идентификаторы классов и меток, чтобы создать pixelLabelDatastore.

labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

Загрузка предварительно обученной сети семантической сегментации

Загрузка предварительно обученной сети семантической сегментации. Предварительно обученная модель позволяет вам запустить весь пример, не дожидаясь завершения обучения. Этот пример загружает обученную сеть Deeplab v3 + с весами, инициализированными из предварительно обученной Resnet-18 сети. Для получения дополнительной информации о создании и обучении сети семантической сегментации, смотрите Семантическая сегментация с использованием глубокого обучения.

pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid.mat';
pretrainedFolder = fullfile(tempdir,'pretrainedNetwork');
pretrainedNetwork = fullfile(pretrainedFolder,'deeplabv3plusResnet18CamVid.mat');

if ~exist(pretrainedNetwork,'file')
    mkdir(pretrainedFolder);
    disp('Downloading pretrained network (58 MB)...');
    websave(pretrainedNetwork,pretrainedURL);
end
Downloading pretrained network (58 MB)...
pretrainedNet = load(pretrainedNetwork); 
net = pretrainedNet.net;

Тестирование сети

Обученная сеть семантической сегментации предсказывает метку каждого пикселя в изображении. Можно протестировать сеть, предсказав пиксельные метки изображения.

Загрузите тестовое изображение.

figure
img = readimage(imds,615);
imshow(img,'InitialMagnification',35)

Используйте semanticseg функция для предсказания пиксельных меток изображения с помощью обученной семантической сети сегментации.

predLabels = semanticseg(img,net);

Отображение результатов.

cmap = camvidColorMap;
segImg = labeloverlay(img,predLabels,'Colormap',cmap,'Transparency',0.4);
figure
imshow(segImg,'InitialMagnification',40)

pixelLabelColorbar(cmap,classes)

Видно, что сеть довольно точно помечает части изображения. Сеть действительно неправильно классифицирует некоторые участки, например, дорогу налево от перекрестка, которая частично неправильно классифицируется как дорожное покрытие.

Исследуйте сетевые предсказания

Глубокие сети сложны, поэтому понять, как сеть определяет конкретное предсказание, трудно. Можно использовать Grad-CAM, чтобы увидеть, какие области тестового изображения использует семантическая сеть сегментации, чтобы сделать ее классификации пикселей.

Grad-CAM вычисляет градиент дифференцируемого выхода, такой как счет класса, относительно сверточных функций в выбранном слое. Grad-CAM обычно используется для задач классификации изображений [2]; однако он также может быть распространен на семантические задачи сегментации [3].

В семантических задачах сегментации слой softmax сети выводит счет для каждого класса для каждого пикселя в оригинальном изображении. Это контрастирует со стандартными задачами классификации изображений, где слой softmax выводит счет для каждого класса для всего изображения. Карта Grad-CAM для класса c является

Mc=ReLU(kαckAk) где αck=1/Ni,jdycdAi,jk

N - количество пикселей, Ak - функция карта интереса, и yc соответствует скалярному счету класса. Для простой задачи классификации изображений, yc - счет softmax для интересующего класса. Для семантической сегментации можно получитьyc путем уменьшения пиксельных счетов классов для интересующего класса до скаляра. Для примера суммируйте пространственные размерности слоя softmax: yc=(i,j)Pyi,jc, где P - пиксели в выход слое семантической сети сегментации [3]. В этом примере выхода слой является слоем softmax перед слоем классификации пикселей. Карта Mc выделяет области, которые влияют на решение для класса c. Более высокие значения указывают области изображения, которые важны для решения классификации пикселей.

Чтобы использовать Grad-CAM, необходимо выбрать слой функции, из которого будет извлечена карта функций, и слой сокращения, из которого будут извлечены выходные активации. Использование analyzeNetwork для поиска слоев, используемых в Grad-CAM.

analyzeNetwork(net)

Задайте слой функции. Обычно это слой ReLU, который принимает выход сверточного слоя в конце сети.

featureLayer = 'dec_relu4';

Задайте слой редукции. The gradCAM функция суммирует пространственные размерности восстановительного слоя для заданных классов, чтобы получить скалярное значение. Это скалярное значение затем дифференцируется относительно каждой функции в слое функций. Для семантических задач сегментации восстановительный слой обычно является слоем softmax.

reductionLayer = 'softmax-out';

Вычислите карту Град-КЭМ для классов дороги и дорожного покрытия.

classes = ["Road" "Pavement"];

gradCAMMap = gradCAM(net,img,classes, ...
    'ReductionLayer',reductionLayer, ...
    'FeatureLayer',featureLayer);

Сравните карту Grad-CAM для этих двух классов с картой семантической сегментации.

predLabels = semanticseg(img,net);
segMap = labeloverlay(img,predLabels,'Colormap',cmap,'Transparency',0.4);

figure;
subplot(2,2,1)
imshow(img)
title('Test Image')
subplot(2,2,2)
imshow(segMap)
title('Semantic Segmentation')
subplot(2,2,3)
imshow(img)
hold on
imagesc(gradCAMMap(:,:,1),'AlphaData',0.5)
title('Grad-CAM: ' + classes(1))
colormap jet
subplot(2,2,4)
imshow(img)
hold on
imagesc(gradCAMMap(:,:,2),'AlphaData',0.5)
title('Grad-CAM: ' + classes(2))
colormap jet

Карты Град-CAM и карта семантической сегментации показывают аналогичную подсветку. Ни одна из карт не различает дорогу слева от перекрестка, которую семантическая карта сегментации помечает как дорожное покрытие. Карта Grad-CAM для класса дорожного покрытия показывает, что ребро дорожного покрытия важнее центра для решения о классификации сети. Возможно, сеть неправильно классифицирует дорогу слева от перекрестка из-за плохой видимости ребра дорожного покрытия.

Исследуйте промежуточные слои

Карта Grad-CAM напоминает карту семантической сегментации, когда вы используете слой около конца сети для расчетов. Можно также использовать Grad-CAM для исследования промежуточных слоев в обученной сети. Более ранние слои имеют небольшой размер рецептивного поля и изучают небольшие низкоуровневые функции по сравнению со слоями в конце сети.

Вычислите карту Grad-CAM для слоев, которые последовательно глубже в сети.

layers = ["res5b_relu","catAspp","dec_relu1"];
numLayers = length(layers);

The res5b_relu слой находится вблизи середины сети, в то время как dec_relu1 находится около конца сети.

Исследуйте решения о классификации сетей для классов автомобилей, дорог и дорожного покрытия. Для каждого слоя и класса вычислите карту Grad-CAM.

classes = ["Car" "Road" "Pavement"];
numClasses = length(classes);

gradCAMMaps = [];
for i = 1:numLayers
    gradCAMMaps(:,:,:,i) = gradCAM(net,img,classes, ...
        'ReductionLayer',reductionLayer, ...
        'FeatureLayer',layers(i));
end

Отображение карт Grad-CAM для каждого слоя и каждого класса. Строки представляют карту для каждого слоя с слоями, упорядоченными от ранних в сети до слоев в конце сети.

figure;
idx = 1;
for i=1:numLayers
    for j=1:numClasses
        subplot(numLayers,numClasses,idx)
        imshow(img)
        hold on
        imagesc(gradCAMMaps(:,:,j,i),'AlphaData',0.5)
        title(sprintf("%s (%s)",classes(j),layers(i)), ...
            "Interpreter","none")
        colormap jet
        idx = idx + 1;
    end
end

Более поздние слои создают карты, очень похожие на карту сегментации. Однако более ранние слои в сети дают более абстрактные результаты и обычно более обеспокоены функциями более низкого уровня, такие как ребра, с меньшей осведомленностью о семантических классах. Например, в картах для более ранних слоев можно увидеть, что и для автомобильного, и для дорожного классов подсвечивается светофор. Это предполагает, что более ранние слои фокусируются на областях изображения, которые связаны с классом, но не обязательно принадлежат ему. Например, светофор, вероятно, появится рядом с дорогой, поэтому сеть может использовать эту информацию, чтобы предсказать, какие пиксели являются дорогами. Можно также увидеть, что для класса дорожного покрытия более ранние слои сильно фокусируются на ребре, что предполагает, что эта функция важна для сети при обнаружении пикселей, входящих в класс дорожного покрытия.

Ссылки

[1] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. Semantic Object Classes in Video: A High-Definition Ground Truth Database (неопр.) (недоступная ссылка). Pattern Recognition Letters 30, № 2 (январь 2009): 88-97. https://doi.org/10.1016/j.patrec.2008.04.005.

[2] Selvaraju, R. R., M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra. «Grad-CAM: визуальные объяснения из глубоких сетей через локализацию на основе градиентов». В IEEE International Conference on Компьютерное Зрение (ICCV), 2017, pp. 618-626. Доступно в Grad-CAM на веб-сайте Компьютерного зрения Foundation Open Access.

[3] Виноградова, Кира, Александр Дибров и Джин Майерс. «К интерпретации семантической сегментации через Отображение активации класса с учетом градиента (абстракция студента)». Материалы Конференции AAAI по искусственному интеллекту 34, № 10 (3 апреля 2020 года): 13943-44. https://doi.org/10.1609/aaai.v34i10.7244.

Вспомогательные функции

function labelIDs = camvidPixelLabelIDs()
% Return the label IDs corresponding to each class.
%
% The CamVid data set has 32 classes. Group them into 11 classes following
% the original SegNet training methodology [1].
%
% The 11 classes are:
%   "Sky", "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol",
%   "Fence", "Car", "Pedestrian",  and "Bicyclist".
%
% CamVid pixel label IDs are provided as RGB color values. Group them into
% 11 classes and return them as a cell array of M-by-3 matrices. The
% original CamVid class names are listed alongside each RGB value. Note
% that the Other/Void class are excluded below.
labelIDs = { ...
    
    % "Sky"
    [
    128 128 128; ... % "Sky"
    ]
    
    % "Building" 
    [
    000 128 064; ... % "Bridge"
    128 000 000; ... % "Building"
    064 192 000; ... % "Wall"
    064 000 064; ... % "Tunnel"
    192 000 128; ... % "Archway"
    ]
    
    % "Pole"
    [
    192 192 128; ... % "Column_Pole"
    000 000 064; ... % "TrafficCone"
    ]
    
    % Road
    [
    128 064 128; ... % "Road"
    128 000 192; ... % "LaneMkgsDriv"
    192 000 064; ... % "LaneMkgsNonDriv"
    ]
    
    % "Pavement"
    [
    000 000 192; ... % "Sidewalk" 
    064 192 128; ... % "ParkingBlock"
    128 128 192; ... % "RoadShoulder"
    ]
        
    % "Tree"
    [
    128 128 000; ... % "Tree"
    192 192 000; ... % "VegetationMisc"
    ]
    
    % "SignSymbol"
    [
    192 128 128; ... % "SignSymbol"
    128 128 064; ... % "Misc_Text"
    000 064 064; ... % "TrafficLight"
    ]
    
    % "Fence"
    [
    064 064 128; ... % "Fence"
    ]
    
    % "Car"
    [
    064 000 128; ... % "Car"
    064 128 192; ... % "SUVPickupTruck"
    192 128 192; ... % "Truck_Bus"
    192 064 128; ... % "Train"
    128 064 064; ... % "OtherMoving"
    ]
    
    % "Pedestrian"
    [
    064 064 000; ... % "Pedestrian"
    192 128 064; ... % "Child"
    064 000 192; ... % "CartLuggagePram"
    064 128 064; ... % "Animal"
    ]
    
    % "Bicyclist"
    [
    000 128 192; ... % "Bicyclist"
    192 000 192; ... % "MotorcycleScooter"
    ]
    
    };
end
function pixelLabelColorbar(cmap, classNames)
% Add a colorbar to the current axis. The colorbar is formatted
% to display the class names with the color.

colormap(gca,cmap)

% Add a colorbar to the current figure.
c = colorbar('peer', gca);

% Use class names for tick marks.
c.TickLabels = classNames;
numClasses = size(cmap,1);

% Center tick labels.
c.Ticks = 1/(numClasses*2):1/numClasses:1;

% Remove tick marks.
c.TickLength = 0;
end

function cmap = camvidColorMap
% Define the colormap used by the CamVid data set.

cmap = [
    128 128 128   % Sky
    128 0 0       % Building
    192 192 192   % Pole
    128 64 128    % Road
    60 40 222     % Pavement
    128 128 0     % Tree
    192 128 128   % SignSymbol
    64 64 128     % Fence
    64 0 128      % Car
    64 64 0       % Pedestrian
    0 128 192     % Bicyclist
    ];

% Normalize between [0 1].
cmap = cmap ./ 255;
end