Этот пример показывает, как обучить семантическую сеть сегментации использование глубокого обучения.
Семантическая сеть сегментации классифицирует каждый пиксель на изображение, приводящее к изображению, которое сегментируется классом. Приложения для семантической сегментации включают дорожную сегментацию для автономного управления и сегментацию раковой клетки для медицинского диагноза. Чтобы узнать больше, смотрите Семантические Основы Сегментации (Computer Vision Toolbox).
Чтобы проиллюстрировать метод обучения, этот пример обучает Deeplab v3 + [1], один тип сверточной нейронной сети (CNN), разработанной для семантической сегментации изображений. Другие типы сетей для семантической сегментации включают полностью сверточные сети (FCN), SegNet и U-Net. Метод обучения, показанный здесь, может быть применен к тем сетям также.
Этот пример использует набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных является набором изображений, содержащих представления уличного уровня, полученные при управлении. Набор данных обеспечивает метки пиксельного уровня для 32 семантических классов включая автомобиль, пешехода и дорогу.
Этот пример создает Deeplab v3 + сеть с весами, инициализированными от предварительно обученной сети Resnet-18. ResNet-18 является эффективной сетью, которая хорошо подходит для приложений с ограниченными ресурсами обработки. Другие предварительно обученные сети, такие как MobileNet v2 или ResNet-50 могут также использоваться в зависимости от требований к приложению. Для получения дополнительной информации смотрите Предварительно обученные Глубокие нейронные сети (Deep Learning Toolbox).
Чтобы получить предварительно обученный Resnet-18, установите Модель Deep Learning Toolbox™ для Сети Resnet-18. После того, как установка завершена, запустите следующий код, чтобы проверить, что установка правильна.
resnet18();
Кроме того, загрузите предварительно обученную версию DeepLab v3 +. Предварительно обученная модель позволяет вам запускать целый пример, не имея необходимость ожидать обучения завершиться.
pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid.mat'; pretrainedFolder = fullfile(tempdir,'pretrainedNetwork'); pretrainedNetwork = fullfile(pretrainedFolder,'deeplabv3plusResnet18CamVid.mat'); if ~exist(pretrainedFolder,'dir') mkdir(pretrainedFolder); disp('Downloading pretrained network (58 MB)...'); websave(pretrainedNetwork,pretrainedURL); end
CUDA-способный графический процессор NVIDIA™ с вычисляет возможность 3.0, или выше настоятельно рекомендован для выполнения этого примера. Использование графического процессора требует Parallel Computing Toolbox™.
Загрузите набор данных CamVid со следующих URL.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; outputFolder = fullfile(tempdir,'CamVid'); if ~exist(outputFolder, 'dir') mkdir(outputFolder) labelsZip = fullfile(outputFolder,'labels.zip'); imagesZip = fullfile(outputFolder,'images.zip'); disp('Downloading 16 MB CamVid dataset labels...'); websave(labelsZip, labelURL); unzip(labelsZip, fullfile(outputFolder,'labels')); disp('Downloading 557 MB CamVid dataset images...'); websave(imagesZip, imageURL); unzip(imagesZip, fullfile(outputFolder,'images')); end
Примечание: время Загрузки данных зависит от вашего Интернет-соединения. Команды, используемые выше блока MATLAB до загрузки, завершены. Также можно использовать веб-браузер, чтобы сначала загрузить набор данных на локальный диск. Чтобы использовать файл, вы загрузили с сети, замените переменную outputFolder
выше к местоположению загруженного файла.
Используйте imageDatastore
, чтобы загрузить изображения CamVid. imageDatastore
позволяет вам эффективно загрузить большое количество изображений на диске.
imgDir = fullfile(outputFolder,'images','701_StillsRaw_full'); imds = imageDatastore(imgDir);
Отобразите одно из изображений.
I = readimage(imds,1); I = histeq(I); imshow(I)
Используйте pixelLabelDatastore
, чтобы загрузить пиксельные данные изображения метки CamVid. pixelLabelDatastore
инкапсулирует пиксельные данные о метке и метку ID к отображению имени класса.
Мы делаем обучение легче, мы группируем 32 исходных класса в CamVid к 11 классам. Задайте эти классы.
classes = [ "Sky" "Building" "Pole" "Road" "Pavement" "Tree" "SignSymbol" "Fence" "Car" "Pedestrian" "Bicyclist" ];
Чтобы уменьшать 32 класса в 11, несколько классов от исходного набора данных группируются. Например, "Автомобиль" является комбинацией "Автомобиля", "SUVPickupTruck", "Truck_Bus", "Train" и "OtherMoving". Возвратите сгруппированную метку IDs при помощи функции поддержки camvidPixelLabelIDs
, который перечислен в конце этого примера.
labelIDs = camvidPixelLabelIDs();
Используйте классы и метку IDs, чтобы создать pixelLabelDatastore.
labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
Считайте и отобразите одно из маркированных пикселем изображений путем накладывания его сверху изображения.
C = readimage(pxds,1);
cmap = camvidColorMap;
B = labeloverlay(I,C,'ColorMap',cmap);
imshow(B)
pixelLabelColorbar(cmap,classes);
Области без перекрытия цвета не имеют пиксельных меток и не используются во время обучения.
Чтобы видеть распределение меток класса в наборе данных CamVid, используйте countEachLabel
. Эта функция считает количество пикселей меткой класса.
tbl = countEachLabel(pxds)
tbl=11×3 table
Name PixelCount ImagePixelCount
____________ __________ _______________
'Sky' 7.6801e+07 4.8315e+08
'Building' 1.1737e+08 4.8315e+08
'Pole' 4.7987e+06 4.8315e+08
'Road' 1.4054e+08 4.8453e+08
'Pavement' 3.3614e+07 4.7209e+08
'Tree' 5.4259e+07 4.479e+08
'SignSymbol' 5.2242e+06 4.6863e+08
'Fence' 6.9211e+06 2.516e+08
'Car' 2.4437e+07 4.8315e+08
'Pedestrian' 3.4029e+06 4.4444e+08
'Bicyclist' 2.5912e+06 2.6196e+08
Визуализируйте пиксельные количества классом.
frequency = tbl.PixelCount/sum(tbl.PixelCount);
bar(1:numel(classes),frequency)
xticks(1:numel(classes))
xticklabels(tbl.Name)
xtickangle(45)
ylabel('Frequency')
Идеально, все классы имели бы равное количество наблюдений. Однако классы в CamVid являются неустойчивыми, который является распространенной проблемой в автомобильных наборах данных уличных сцен. Такие сцены имеют больше неба, создания и дорожных пикселей, чем пиксели пешехода и велосипедиста, потому что небо, создания и дороги покрывают больше области в изображении. Если не обработанный правильно, эта неустойчивость может быть вредна для процесса обучения, потому что изучение смещается в пользу доминирующих классов. Позже в этом примере, вы будете использовать взвешивание класса, чтобы обработать эту проблему.
Изображения в наборе данных CamVid 720 960 в размере. Размер изображения выбран таким образом, что достаточно большой пакет изображений может уместиться в памяти во время обучения на Титане NVIDIA™ X с 12 Гбайт памяти. Вы, возможно, должны изменить размер изображений к меньшим размерам, если ваш графический процессор не имеет достаточной памяти или уменьшает учебный пакетный размер.
Deeplab v3 + обучен с помощью 60% изображений от набора данных. Остальная часть изображений разделена равномерно в 20% и 20% для валидации и тестирующий соответственно. Следующий код случайным образом разделяет изображение и пиксельные данные о метке в обучение, валидацию и набор тестов.
[imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds);
60/20/20 разделяют результаты в следующем количестве обучения, валидации и тестируют изображения:
numTrainingImages = numel(imdsTrain.Files)
numTrainingImages = 421
numValImages = numel(imdsVal.Files)
numValImages = 140
numTestingImages = numel(imdsTest.Files)
numTestingImages = 140
Используйте функцию helperDeeplabv3PlusResnet18
, которая присоединена к этому примеру как к вспомогательному файлу, чтобы создать DeepLab v3 + сеть на основе ResNet-18. Выбор лучшей сети для вашего приложения требует эмпирического анализа и является другим уровнем настройки гиперпараметра. Например, можно экспериментировать с различными основными сетями, такими как ResNet-50 или Начало v3, или можно попробовать другую семантическую архитектуру сети сегментации, такую как SegNet, полностью сверточные сети (FCN) или U-Net.
% Specify the network image size. This is typically the same as the traing image sizes. imageSize = [720 960 3]; % Specify the number of classes. numClasses = numel(classes); % Create DeepLab v3+. lgraph = helperDeeplabv3PlusResnet18(imageSize, numClasses);
Как показано ранее классы в CamVid не сбалансированы. Чтобы улучшить обучение, можно использовать взвешивание класса, чтобы сбалансировать классы. Используйте пиксельные количества метки, вычисленные ранее с countEachLabel
, и вычислите средние веса класса частоты.
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount; classWeights = median(imageFreq) ./ imageFreq
classWeights = 11×1
0.3182
0.2082
5.0924
0.1744
0.7103
0.4175
4.5371
1.8386
1.0000
6.6059
⋮
Задайте веса класса с помощью pixelClassificationLayer
.
pxLayer = pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights); lgraph = replaceLayer(lgraph,"classification",pxLayer);
Алгоритм оптимизации, используемый для обучения, является стохастическим спуском градиента с импульсом (SGDM). Используйте trainingOptions
, чтобы задать гиперпараметры, используемые для SGDM.
% Define validation data. pximdsVal = pixelLabelImageDatastore(imdsVal,pxdsVal); % Define training options. options = trainingOptions('sgdm', ... 'LearnRateSchedule','piecewise',... 'LearnRateDropPeriod',10,... 'LearnRateDropFactor',0.3,... 'Momentum',0.9, ... 'InitialLearnRate',1e-3, ... 'L2Regularization',0.005, ... 'ValidationData',pximdsVal,... 'MaxEpochs',30, ... 'MiniBatchSize',8, ... 'Shuffle','every-epoch', ... 'CheckpointPath', tempdir, ... 'VerboseFrequency',2,... 'Plots','training-progress',... 'ValidationPatience', 4); ...
Темп обучения использует кусочное расписание. Темп обучения уменьшается фактором 0,3 каждых 10 эпох. Это позволяет сети учиться быстро с более высоким начальным темпом обучения, в то время как способность найти решение близко к локальному оптимуму однажды темп обучения понижается.
Сеть тестируется против данных о валидации каждая эпоха путем установки параметра 'ValidationData'
. 'ValidationPatience'
собирается в 4 остановить обучение рано, когда точность валидации сходится. Это препятствует тому, чтобы сеть сверхсоответствовала на обучающем наборе данных.
Мини-пакетный размер 8 используется, чтобы уменьшать использование памяти в то время как обучение. Можно увеличить или уменьшить это значение на основе суммы памяти графического процессора, которую вы имеете в своей системе.
Кроме того, 'CheckpointPath'
установлен во временное местоположение. Эта пара "имя-значение" включает сохранение сетевых контрольных точек в конце каждой учебной эпохи. Если обучение прервано из-за системного отказа или отключения электроэнергии, можно возобновить обучение от сохраненной контрольной точки. Убедитесь, что местоположение, заданное 'CheckpointPath'
, имеет достаточно пробела, чтобы сохранить сетевые контрольные точки. Например, сохраняя 100 Deeplab v3 + контрольные точки требуют ~6 Гбайт дискового пространства, потому что каждая контрольная точка составляет 61 Мбайт.
Увеличение данных используется во время обучения предоставить больше примеров сети, потому что это помогает улучшить точность сети. Здесь, случайный слева/справа отражение и случайный перевод X/Y +/-10 пикселей используются для увеличения данных. Используйте imageDataAugmenter
, чтобы задать эти параметры увеличения данных.
augmenter = imageDataAugmenter('RandXReflection',true,... 'RandXTranslation',[-10 10],'RandYTranslation',[-10 10]);
imageDataAugmenter
поддерживает несколько других типов увеличения данных. Выбор среди них требует эмпирического анализа и является другим уровнем настройки гиперпараметра.
Объедините данные тренировки и выборы увеличения данных с помощью pixelLabelImageDatastore
. Пакеты чтений pixelLabelImageDatastore
данных тренировки, применяет увеличение данных и отправляет увеличенные данные в учебный алгоритм.
pximds = pixelLabelImageDatastore(imdsTrain,pxdsTrain, ... 'DataAugmentation',augmenter);
Запустите обучение с помощью trainNetwork
, если флаг doTraining
верен. В противном случае загрузите предварительно обученную сеть.
Примечание: обучение было проверено на Титане NVIDIA™ X с 12 Гбайт памяти графического процессора. Если ваш графический процессор имеет меньше памяти, у можно закончиться память. Если это происходит, попробуйте установку 'MiniBatchSize'
к 1 в trainingOptions
. Обучение эта сеть занимает приблизительно 5 часов. В зависимости от вашего оборудования графического процессора это может взять еще дольше.
doTraining = false; if doTraining [net, info] = trainNetwork(pximds,lgraph,options); else data = load(pretrainedNetwork); net = data.net; end
Как быстрая проверка работоспособности, запустите обучивший сеть на одном тестовом изображении.
I = readimage(imdsTest,35); C = semanticseg(I, net);
Отобразите результаты.
B = labeloverlay(I,C,'Colormap',cmap,'Transparency',0.4); imshow(B) pixelLabelColorbar(cmap, classes);
Сравните результаты в C
с ожидаемой наземной истиной, сохраненной в pxdsTest
. Зеленые и пурпурные области подсвечивают области, где результаты сегментации отличаются от ожидаемой наземной истины.
expectedResult = readimage(pxdsTest,35); actual = uint8(C); expected = uint8(expectedResult); imshowpair(actual, expected)
Визуально, семантические результаты сегментации накладываются хорошо для классов, таких как дорога, небо и создание. Однако меньшие объекты как пешеходы и автомобили не так точны. Сумма перекрытия в классе может быть измерена с помощью метрики пересечения по объединению (IoU), также известной как индекс Jaccard. Используйте функцию jaccard
, чтобы измерить IoU.
iou = jaccard(C,expectedResult); table(classes,iou)
ans=11×2 table
classes iou
____________ _______
"Sky" 0.91837
"Building" 0.84479
"Pole" 0.31203
"Road" 0.93698
"Pavement" 0.82838
"Tree" 0.89636
"SignSymbol" 0.57644
"Fence" 0.71046
"Car" 0.66688
"Pedestrian" 0.48417
"Bicyclist" 0.68431
Метрика IoU подтверждает визуальные результаты. Дорога, небо и классы создания имеют высокие очки IoU, в то время как классы, такие как пешеход и автомобиль имеют низкие баллы. Другие общие метрики сегментации включают dice
и счет соответствия контура bfscore
.
Измерять точность для нескольких тестовых изображений, runsemanticseg
на целом наборе тестов. Мини-пакетный размер 4 используется, чтобы уменьшать использование памяти при сегментации изображений. Можно увеличить или уменьшить это значение на основе суммы памяти графического процессора, которую вы имеете в своей системе.
pxdsResults = semanticseg(imdsTest,net, ... 'MiniBatchSize',4, ... 'WriteLocation',tempdir, ... 'Verbose',false);
semanticseg
возвращает результаты для набора тестов как объект pixelLabelDatastore
. Данные о метке фактического пикселя для каждого тестового изображения в imdsTest
записаны в диск в месте, заданном параметром 'WriteLocation'
. Используйте evaluateSemanticSegmentation
, чтобы измерить семантические метрики сегментации на результатах набора тестов.
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest,'Verbose',false);
evaluateSemanticSegmentation
возвращает различные метрики для целого набора данных для отдельных классов, и для каждого тестового изображения. Чтобы видеть метрики уровня набора данных, осмотрите metrics.DataSetMetrics
.
metrics.DataSetMetrics
ans=1×5 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.87695 0.85392 0.6302 0.80851 0.65051
Метрики набора данных предоставляют общий обзор производительности сети. Чтобы видеть влияние, каждый класс имеет на общей производительности, осмотрите метрики на класс с помощью metrics.ClassMetrics
.
metrics.ClassMetrics
ans=11×3 table
Accuracy IoU MeanBFScore
________ _______ ___________
Sky 0.93111 0.90209 0.8952
Building 0.78453 0.76098 0.58511
Pole 0.71586 0.21477 0.5144
Road 0.93024 0.91465 0.76696
Pavement 0.88466 0.70571 0.70919
Tree 0.87377 0.76323 0.70875
SignSymbol 0.79358 0.39309 0.48302
Fence 0.81507 0.46484 0.48565
Car 0.90956 0.76799 0.69233
Pedestrian 0.87629 0.4366 0.60792
Bicyclist 0.87844 0.60829 0.55089
Несмотря на то, что полная производительность набора данных довольно высока, метрики класса показывают, что недостаточно представленные классы, такие как Pedestrian
, Bicyclist
и Car
не сегментируются, а также классы, такие как Road
, Sky
и Building
. Дополнительные данные, которые включают больше выборок недостаточно представленных классов, могут помочь улучшить результаты.
helperDeeplabv3PlusResnet18.m
присоединен к этому примеру как к вспомогательному файлу.
% lgraph = helperDeeplabv3PlusResnet18(imageSize, numClasses) creates a % DeepLab v3+ layer graph object using a pre-trained ResNet-18 configured % using the following inputs: % % Inputs % ------ % imageSize - size of the network input image specified as a vector % [H W] or [H W C], where H and W are the image height and % width, and C is the number of image channels. % % numClasses - number of classes the network should be configured to % classify. % % The output lgraph is a LayerGraph object.
function labelIDs = camvidPixelLabelIDs() % Return the label IDs corresponding to each class. % % The CamVid dataset has 32 classes. Group them into 11 classes following % the original SegNet training methodology [1]. % % The 11 classes are: % "Sky" "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol", % "Fence", "Car", "Pedestrian", and "Bicyclist". % % CamVid pixel label IDs are provided as RGB color values. Group them into % 11 classes and return them as a cell array of M-by-3 matrices. The % original CamVid class names are listed alongside each RGB value. Note % that the Other/Void class are excluded below. labelIDs = { ... % "Sky" [ 128 128 128; ... % "Sky" ] % "Building" [ 000 128 064; ... % "Bridge" 128 000 000; ... % "Building" 064 192 000; ... % "Wall" 064 000 064; ... % "Tunnel" 192 000 128; ... % "Archway" ] % "Pole" [ 192 192 128; ... % "Column_Pole" 000 000 064; ... % "TrafficCone" ] % Road [ 128 064 128; ... % "Road" 128 000 192; ... % "LaneMkgsDriv" 192 000 064; ... % "LaneMkgsNonDriv" ] % "Pavement" [ 000 000 192; ... % "Sidewalk" 064 192 128; ... % "ParkingBlock" 128 128 192; ... % "RoadShoulder" ] % "Tree" [ 128 128 000; ... % "Tree" 192 192 000; ... % "VegetationMisc" ] % "SignSymbol" [ 192 128 128; ... % "SignSymbol" 128 128 064; ... % "Misc_Text" 000 064 064; ... % "TrafficLight" ] % "Fence" [ 064 064 128; ... % "Fence" ] % "Car" [ 064 000 128; ... % "Car" 064 128 192; ... % "SUVPickupTruck" 192 128 192; ... % "Truck_Bus" 192 064 128; ... % "Train" 128 064 064; ... % "OtherMoving" ] % "Pedestrian" [ 064 064 000; ... % "Pedestrian" 192 128 064; ... % "Child" 064 000 192; ... % "CartLuggagePram" 064 128 064; ... % "Animal" ] % "Bicyclist" [ 000 128 192; ... % "Bicyclist" 192 000 192; ... % "MotorcycleScooter" ] }; end
function pixelLabelColorbar(cmap, classNames) % Add a colorbar to the current axis. The colorbar is formatted % to display the class names with the color. colormap(gca,cmap) % Add colorbar to current figure. c = colorbar('peer', gca); % Use class names for tick marks. c.TickLabels = classNames; numClasses = size(cmap,1); % Center tick labels. c.Ticks = 1/(numClasses*2):1/numClasses:1; % Remove tick mark. c.TickLength = 0; end function cmap = camvidColorMap() % Define the colormap used by CamVid dataset. cmap = [ 128 128 128 % Sky 128 0 0 % Building 192 192 192 % Pole 128 64 128 % Road 60 40 222 % Pavement 128 128 0 % Tree 192 128 128 % SignSymbol 64 64 128 % Fence 64 0 128 % Car 64 64 0 % Pedestrian 0 128 192 % Bicyclist ]; % Normalize between [0 1]. cmap = cmap ./ 255; end
function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds) % Partition CamVid data by randomly selecting 60% of the data for training. The % rest is used for testing. % Set initial random state for example reproducibility. rng(0); numFiles = numel(imds.Files); shuffledIndices = randperm(numFiles); % Use 60% of the images for training. numTrain = round(0.60 * numFiles); trainingIdx = shuffledIndices(1:numTrain); % Use 20% of the images for validation numVal = round(0.20 * numFiles); valIdx = shuffledIndices(numTrain+1:numTrain+numVal); % Use the rest for testing. testIdx = shuffledIndices(numTrain+numVal+1:end); % Create image datastores for training and test. trainingImages = imds.Files(trainingIdx); valImages = imds.Files(valIdx); testImages = imds.Files(testIdx); imdsTrain = imageDatastore(trainingImages); imdsVal = imageDatastore(valImages); imdsTest = imageDatastore(testImages); % Extract class and label IDs info. classes = pxds.ClassNames; labelIDs = camvidPixelLabelIDs(); % Create pixel label datastores for training and test. trainingLabels = pxds.Files(trainingIdx); valLabels = pxds.Files(valIdx); testLabels = pxds.Files(testIdx); pxdsTrain = pixelLabelDatastore(trainingLabels, classes, labelIDs); pxdsVal = pixelLabelDatastore(valLabels, classes, labelIDs); pxdsTest = pixelLabelDatastore(testLabels, classes, labelIDs); end
[1] Чен, Лян-Чие и др. “Декодер энкодера с Отделимой Сверткой Atrous для Семантической Сегментации Изображений”. ECCV (2018).
[2] Brostow, G. J. Ж. Фокер и Р. Сиполла. "Семантические классы объектов в видео: наземная база данных истины высокой четкости". Буквы Распознавания образов. Издание 30, Выпуск 2, 2009, стр 88-97.
countEachLabel
| evaluateSemanticSegmentation
| imageDataAugmenter
| labeloverlay
| pixelClassificationLayer
| pixelLabelDatastore
| pixelLabelImageDatastore
| segnetLayers
| semanticseg
| trainNetwork
| trainingOptions