В этом примере показано, как обучить сеть семантической сегментации использование глубокого обучения.
Семантическая сеть сегментации классифицирует каждый пиксель в изображении, получая к изображение, которое сегментировано по классам. Приложения для семантической сегментации включают сегментацию дорог для автономного управления автомобилем и сегментацию раковой клетки для медицинского диагностирования. Чтобы узнать больше, смотрите Начало работы с Семантической Сегментацией Используя Глубокое обучение (Computer Vision Toolbox).
Чтобы проиллюстрировать метод обучения, этот пример обучает Deeplab v3 + [1], один тип сверточной нейронной сети (CNN), спроектированной для семантической сегментации изображений. Другие типы сетей для семантической сегментации включают полностью сверточные сети (FCN), SegNet и U-Net. Метод обучения, показанный здесь, может быть применен к тем сетям также.
Этот пример использует набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных является набором изображений, содержащих представления уличного уровня, полученные при управлении. Набор данных обеспечивает метки пиксельного уровня для 32 семантических классов включая автомобиль, пешехода и дорогу.
Этот пример создает Deeplab v3 + сеть с весами, инициализированными от предварительно обученной сети Resnet-18. ResNet-18 является эффективной сетью, которая хорошо подходит для приложений с ограниченными ресурсами обработки. Другие предварительно обученные сети, такие как MobileNet v2 или ResNet-50 могут также использоваться в зависимости от требований к приложению. Для получения дополнительной информации смотрите Предварительно обученные Глубокие нейронные сети.
Чтобы получить предварительно обученный Resnet-18, установите resnet18
. После того, как установка завершена, запустите следующий код, чтобы проверить, что установка правильна.
resnet18();
Кроме того, загрузите предварительно обученную версию DeepLab v3 +. Предварительно обученная модель позволяет вам запускать целый пример, не имея необходимость ожидать обучения завершиться.
pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid.mat'; pretrainedFolder = fullfile(tempdir,'pretrainedNetwork'); pretrainedNetwork = fullfile(pretrainedFolder,'deeplabv3plusResnet18CamVid.mat'); if ~exist(pretrainedNetwork,'file') mkdir(pretrainedFolder); disp('Downloading pretrained network (58 MB)...'); websave(pretrainedNetwork,pretrainedURL); end
CUDA-способный графический процессор NVIDIA™ настоятельно рекомендован для выполнения этого примера. Для использования GPU требуется Parallel Computing Toolbox. Для получения информации о поддерживаемом вычислите возможности, смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).
Загрузите набор данных CamVid со следующих URL.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; outputFolder = fullfile(tempdir,'CamVid'); labelsZip = fullfile(outputFolder,'labels.zip'); imagesZip = fullfile(outputFolder,'images.zip'); if ~exist(labelsZip, 'file') || ~exist(imagesZip,'file') mkdir(outputFolder) disp('Downloading 16 MB CamVid dataset labels...'); websave(labelsZip, labelURL); unzip(labelsZip, fullfile(outputFolder,'labels')); disp('Downloading 557 MB CamVid dataset images...'); websave(imagesZip, imageURL); unzip(imagesZip, fullfile(outputFolder,'images')); end
Примечание: Время загрузки данных зависит от вашего Интернет-соединения. Команды, используемые выше, блокируют MATLAB до завершения загрузки. В качестве альтернативы можно использовать веб-браузер, чтобы сначала загрузить набор данных на локальный диск. Чтобы использовать файл, вы загрузили с сети, измените outputFolder
переменная выше к местоположению загруженного файла.
Используйте imageDatastore
загружать изображения CamVid. imageDatastore
позволяет вам эффективно загрузить большое количество изображений на диске.
imgDir = fullfile(outputFolder,'images','701_StillsRaw_full'); imds = imageDatastore(imgDir);
Отобразите одно из изображений.
I = readimage(imds,559); I = histeq(I); imshow(I)
Используйте pixelLabelDatastore
(Computer Vision Toolbox), чтобы загрузить пиксель CamVid помечает данные изображения. pixelLabelDatastore
инкапсулирует данные о пиксельных метках и метку ID к отображению имени класса.
Мы делаем обучение легче, мы группируем 32 исходных класса в CamVid к 11 классам. Задайте эти классы.
classes = [ "Sky" "Building" "Pole" "Road" "Pavement" "Tree" "SignSymbol" "Fence" "Car" "Pedestrian" "Bicyclist" ];
Чтобы уменьшать 32 класса в 11, несколько классов от исходного набора данных группируются. Например, "Автомобиль" является комбинацией "Автомобиля", "SUVPickupTruck", "Truck_Bus", "Обучаются", и "OtherMoving". Возвратите сгруппированную метку IDs при помощи функции поддержки camvidPixelLabelIDs
, который перечислен в конце этого примера.
labelIDs = camvidPixelLabelIDs();
Используйте идентификаторы классов и меток, чтобы создать pixelLabelDatastore.
labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
Считайте и отобразите одно из помеченных пикселем изображений путем накладывания его сверху изображения.
C = readimage(pxds,559);
cmap = camvidColorMap;
B = labeloverlay(I,C,'ColorMap',cmap);
imshow(B)
pixelLabelColorbar(cmap,classes);
Области без перекрытия цвета не имеют пиксельных меток и не используются во время обучения.
Чтобы видеть распределение меток класса в наборе данных CamVid, используйте countEachLabel
(Computer Vision Toolbox). Эта функция считает количество пикселей меткой класса.
tbl = countEachLabel(pxds)
tbl=11×3 table
Name PixelCount ImagePixelCount
______________ __________ _______________
{'Sky' } 7.6801e+07 4.8315e+08
{'Building' } 1.1737e+08 4.8315e+08
{'Pole' } 4.7987e+06 4.8315e+08
{'Road' } 1.4054e+08 4.8453e+08
{'Pavement' } 3.3614e+07 4.7209e+08
{'Tree' } 5.4259e+07 4.479e+08
{'SignSymbol'} 5.2242e+06 4.6863e+08
{'Fence' } 6.9211e+06 2.516e+08
{'Car' } 2.4437e+07 4.8315e+08
{'Pedestrian'} 3.4029e+06 4.4444e+08
{'Bicyclist' } 2.5912e+06 2.6196e+08
Визуализируйте счетчики пикселей по классам.
frequency = tbl.PixelCount/sum(tbl.PixelCount);
bar(1:numel(classes),frequency)
xticks(1:numel(classes))
xticklabels(tbl.Name)
xtickangle(45)
ylabel('Frequency')
Идеально, все классы имели бы равное количество наблюдений. Однако классы в CamVid являются неустойчивыми, который является распространенной проблемой в автомобильных наборах данных уличных сцен. Такие сцены имеют больше неба, создания и дорожных пикселей, чем пиксели пешехода и велосипедиста, потому что небо, создания и дороги покрывают больше области в изображении. Если не обработанный правильно, эта неустойчивость может быть вредна для процесса обучения, потому что изучение смещается в пользу доминирующих классов. Позже в этом примере, вы будете использовать взвешивание класса, чтобы обработать эту проблему.
Изображения в наборе данных CamVid имеют размер 720 на 960. Размер изображения выбран таким образом, что достаточно большой пакет изображений может уместиться в памяти во время обучения на Титане NVIDIA™ X с 12 Гбайт памяти. Вы, возможно, должны будете изменить размер изображений на меньшие величины, если ваш графический процессор не имеет достаточной памяти, либо уменьшить размер обучающего пакета.
Deeplab v3 + обучен с помощью 60% изображений от набора данных. Остальная часть изображений разделена равномерно на части в 20% и 20% для валидации и тестирования соответственно. Следующий код случайным образом разделяет изображение и данные о пиксельных метках на наборы для обучения, валидации и тестирования.
[imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds);
Соотношение 60/20/20 разделяет результаты на следующие количества изображений для обучения, валидации и тестирования:
numTrainingImages = numel(imdsTrain.Files)
numTrainingImages = 421
numValImages = numel(imdsVal.Files)
numValImages = 140
numTestingImages = numel(imdsTest.Files)
numTestingImages = 140
Используйте deeplabv3plusLayers
функция, чтобы создать DeepLab v3 + сеть на основе ResNet-18. Выбор лучшей сети для вашего приложения требует эмпирического анализа и является другим уровнем настройки гиперпараметра. Например, можно экспериментировать с различными основными сетями, такими как ResNet-50 или MobileNet v2, или можно попробовать другие архитектуры сети семантической сегментации, такие как SegNet, полностью сверточные сети (FCN) или U-Net.
% Specify the network image size. This is typically the same as the traing image sizes. imageSize = [720 960 3]; % Specify the number of classes. numClasses = numel(classes); % Create DeepLab v3+. lgraph = deeplabv3plusLayers(imageSize, numClasses, "resnet18");
Как показано ранее классы в CamVid не сбалансированы. Чтобы улучшить обучение, можно использовать взвешивание класса, чтобы сбалансировать классы. Используйте пиксельные количества метки, вычисленные ранее с countEachLabel
(Computer Vision Toolbox) и вычисляет веса класса медианной частоты.
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount; classWeights = median(imageFreq) ./ imageFreq
classWeights = 11×1
0.3182
0.2082
5.0924
0.1744
0.7103
0.4175
4.5371
1.8386
1.0000
6.6059
⋮
Задайте веса класса с помощью pixelClassificationLayer
(Computer Vision Toolbox).
pxLayer = pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights); lgraph = replaceLayer(lgraph,"classification",pxLayer);
Алгоритм оптимизации, используемый для обучения, является стохастическим градиентным спуском с импульсом (SGDM). Используйте trainingOptions
задавать гиперпараметры, используемые для SGDM.
% Define validation data. dsVal = combine(imdsVal,pxdsVal); % Define training options. options = trainingOptions('sgdm', ... 'LearnRateSchedule','piecewise',... 'LearnRateDropPeriod',10,... 'LearnRateDropFactor',0.3,... 'Momentum',0.9, ... 'InitialLearnRate',1e-3, ... 'L2Regularization',0.005, ... 'ValidationData',dsVal,... 'MaxEpochs',30, ... 'MiniBatchSize',8, ... 'Shuffle','every-epoch', ... 'CheckpointPath', tempdir, ... 'VerboseFrequency',2,... 'Plots','training-progress',... 'ValidationPatience', 4);
Скорость обучения использует кусочное расписание. Темп обучения уменьшается на множитель 0,3 каждые 10 эпох. Это позволяет сети обучаться быстро с более высоким начальным темпом обучения, и в то же время сохраняется способность найти решение, близкое к локальному оптимуму, когда темп обучения понижается.
Сеть тестируется против данных о валидации каждая эпоха путем установки 'ValidationData'
параметр. 'ValidationPatience'
собирается в 4 остановить обучение рано, когда точность валидации сходится. Это препятствует избыточному обучению сети на обучающем наборе данных.
Мини-пакет размером 8 используется для уменьшения использования памяти во время обучения. Вы можете увеличить или уменьшить это значение в зависимости от объема памяти GPU, имеющейся в вашей системе.
Кроме того, 'CheckpointPath'
установлен во временное местоположение. Эта пара "имя-значение" включает сохранение сетевых контрольных точек в конце каждой учебной эпохи. Если обучение прервано из-за системного отказа или отключения электроэнергии, можно возобновить обучение с сохраненной контрольной точки. Убедитесь что местоположение, заданное 'CheckpointPath'
имеет достаточно пробела, чтобы сохранить сетевые контрольные точки. Например, сохраняя 100 Deeplab v3 + контрольные точки требуют ~6 Гбайт дискового пространства, потому что каждая контрольная точка составляет 61 Мбайт.
Увеличение данных используется, чтобы улучшить сетевую точность путем случайного преобразования исходных данных во время обучения. При помощи увеличения данных можно добавить больше разнообразия в обучающие данные, не увеличивая число помеченных обучающих выборок. Чтобы применить то же случайное преобразование, чтобы и отобразить и данные о пиксельных метках используют datastore combine
и transform
. Во-первых, объедините imdsTrain
и pxdsTrain
.
dsTrain = combine(imdsTrain, pxdsTrain);
Затем используйте datastore transform
чтобы применить желаемое увеличение данных, заданное в поддержке, функционируют augmentImageAndLabel
Здесь для увеличения данных используются случайное отражение "слева/справа" и случайный преобразование X/Y на +/-10 пикселей.
xTrans = [-10 10]; yTrans = [-10 10]; dsTrain = transform(dsTrain, @(data)augmentImageAndLabel(data,xTrans,yTrans));
Обратите внимание на то, что увеличение данных не применяется к данным о валидации и тесту. Идеально, тест и данные о валидации должны быть представительными для исходных данных и оставлены немодифицированными для несмещенной оценки.
Запустите обучение с помощью trainNetwork
если doTraining
флаг верен. В противном случае загружает предварительно обученную сеть.
Примечание: обучение было проверено на Титане NVIDIA™ X с 12 Гбайт памяти графического процессора. Если ваш графический процессор имеет меньше памяти, у можно закончиться память во время обучения. Если это происходит, попробуйте установку 'MiniBatchSize'
к 1 в trainingOptions
, или сокращение сетевого входа и изменение размеров обучающих данных. Обучение эта сеть занимает приблизительно 5 часов. В зависимости от вашего оборудования графического процессора это может взять еще дольше.
doTraining = false; if doTraining [net, info] = trainNetwork(dsTrain,lgraph,options); else data = load(pretrainedNetwork); net = data.net; end
Как быстрая проверка работоспособности, запустите обучивший сеть на одном тестовом изображении.
I = readimage(imdsTest,35); C = semanticseg(I, net);
Отобразите результаты.
B = labeloverlay(I,C,'Colormap',cmap,'Transparency',0.4); imshow(B) pixelLabelColorbar(cmap, classes);
Сравните результаты в C
с ожидаемой основной истиной, сохраненной в pxdsTest
. Зеленые и пурпурные области подсвечивают области, где результаты сегментации отличаются от ожидаемой основной истины.
expectedResult = readimage(pxdsTest,35); actual = uint8(C); expected = uint8(expectedResult); imshowpair(actual, expected)
Визуально, результаты семантической сегментации перекрываются хорошо для классов, таких как дорога, небо и здание. Однако меньшие объекты как пешеходы и автомобили не так точны. Сумма перекрытия для класса может быть измерена с помощью метрики пересечения по объединению (IoU), также известной как индекс Jaccard. Используйте jaccard
(Image Processing Toolbox) функция, чтобы измерить IoU.
iou = jaccard(C,expectedResult); table(classes,iou)
ans=11×2 table
classes iou
____________ _______
"Sky" 0.91837
"Building" 0.84479
"Pole" 0.31203
"Road" 0.93698
"Pavement" 0.82838
"Tree" 0.89636
"SignSymbol" 0.57644
"Fence" 0.71046
"Car" 0.66688
"Pedestrian" 0.48417
"Bicyclist" 0.68431
Метрика IoU подтверждает визуальные результаты. Классы "Дорога", "Небо" и "Здания" имеют высокие очки IoU, в то время как такие классы, как "Пешеход" и "Автомобиль" имеют низкие баллы. Другие общие метрики сегментации включают dice
(Image Processing Toolbox) и bfscore
(Image Processing Toolbox) счет соответствия контура.
Измерять точность для нескольких тестовых изображений, runsemanticseg
(Computer Vision Toolbox) на целом наборе тестов. Размер мини-пакета, равный 4, используется, чтобы уменьшать использование памяти при сегментации изображений. Вы можете увеличить или уменьшить это значение в зависимости от объема памяти GPU, имеющейся в вашей системе.
pxdsResults = semanticseg(imdsTest,net, ... 'MiniBatchSize',4, ... 'WriteLocation',tempdir, ... 'Verbose',false);
semanticseg
возвращает результаты для набора тестов как pixelLabelDatastore
объект. Данные о метке фактического пикселя для каждого теста отображают в imdsTest
записан в диск в месте, заданном 'WriteLocation'
параметр. Используйте evaluateSemanticSegmentation
(Computer Vision Toolbox), чтобы измерить метрики семантической сегментации на результатах набора тестов.
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest,'Verbose',false);
evaluateSemanticSegmentation
возвращает различные метрики для набора данных в целом, для отдельных классов, и для каждого тестового изображения. Чтобы видеть метрики уровня набора данных, смотрите metrics.DataSetMetrics
.
metrics.DataSetMetrics
ans=1×5 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.87695 0.85392 0.6302 0.80851 0.65051
Метрики набора данных предоставляют общий обзор производительности сети. Чтобы увидеть влияние каждого класса на общую производительности, смотрите метрики по классам с помощью metrics.ClassMetrics
.
metrics.ClassMetrics
ans=11×3 table
Accuracy IoU MeanBFScore
________ _______ ___________
Sky 0.93112 0.90209 0.8952
Building 0.78453 0.76098 0.58511
Pole 0.71586 0.21477 0.51439
Road 0.93024 0.91465 0.76696
Pavement 0.88466 0.70571 0.70919
Tree 0.87377 0.76323 0.70875
SignSymbol 0.79358 0.39309 0.48302
Fence 0.81507 0.46484 0.48566
Car 0.90956 0.76799 0.69233
Pedestrian 0.87629 0.4366 0.60792
Bicyclist 0.87844 0.60829 0.55089
Несмотря на то, что полная эффективность набора данных довольно высока, метрики класса показывают что недостаточно представленные классы, такие как Pedestrian
, Bicyclist
, и Car
не сегментируются, а также классы, такие как Road
, Sky
, и Building
. Дополнительные данные, которые включают больше выборок недостаточно представленных классов, могут помочь улучшить результаты.
function labelIDs = camvidPixelLabelIDs() % Return the label IDs corresponding to each class. % % The CamVid dataset has 32 classes. Group them into 11 classes following % the original SegNet training methodology [1]. % % The 11 classes are: % "Sky" "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol", % "Fence", "Car", "Pedestrian", and "Bicyclist". % % CamVid pixel label IDs are provided as RGB color values. Group them into % 11 classes and return them as a cell array of M-by-3 matrices. The % original CamVid class names are listed alongside each RGB value. Note % that the Other/Void class are excluded below. labelIDs = { ... % "Sky" [ 128 128 128; ... % "Sky" ] % "Building" [ 000 128 064; ... % "Bridge" 128 000 000; ... % "Building" 064 192 000; ... % "Wall" 064 000 064; ... % "Tunnel" 192 000 128; ... % "Archway" ] % "Pole" [ 192 192 128; ... % "Column_Pole" 000 000 064; ... % "TrafficCone" ] % Road [ 128 064 128; ... % "Road" 128 000 192; ... % "LaneMkgsDriv" 192 000 064; ... % "LaneMkgsNonDriv" ] % "Pavement" [ 000 000 192; ... % "Sidewalk" 064 192 128; ... % "ParkingBlock" 128 128 192; ... % "RoadShoulder" ] % "Tree" [ 128 128 000; ... % "Tree" 192 192 000; ... % "VegetationMisc" ] % "SignSymbol" [ 192 128 128; ... % "SignSymbol" 128 128 064; ... % "Misc_Text" 000 064 064; ... % "TrafficLight" ] % "Fence" [ 064 064 128; ... % "Fence" ] % "Car" [ 064 000 128; ... % "Car" 064 128 192; ... % "SUVPickupTruck" 192 128 192; ... % "Truck_Bus" 192 064 128; ... % "Train" 128 064 064; ... % "OtherMoving" ] % "Pedestrian" [ 064 064 000; ... % "Pedestrian" 192 128 064; ... % "Child" 064 000 192; ... % "CartLuggagePram" 064 128 064; ... % "Animal" ] % "Bicyclist" [ 000 128 192; ... % "Bicyclist" 192 000 192; ... % "MotorcycleScooter" ] }; end
function pixelLabelColorbar(cmap, classNames) % Add a colorbar to the current axis. The colorbar is formatted % to display the class names with the color. colormap(gca,cmap) % Add colorbar to current figure. c = colorbar('peer', gca); % Use class names for tick marks. c.TickLabels = classNames; numClasses = size(cmap,1); % Center tick labels. c.Ticks = 1/(numClasses*2):1/numClasses:1; % Remove tick mark. c.TickLength = 0; end
function cmap = camvidColorMap() % Define the colormap used by CamVid dataset. cmap = [ 128 128 128 % Sky 128 0 0 % Building 192 192 192 % Pole 128 64 128 % Road 60 40 222 % Pavement 128 128 0 % Tree 192 128 128 % SignSymbol 64 64 128 % Fence 64 0 128 % Car 64 64 0 % Pedestrian 0 128 192 % Bicyclist ]; % Normalize between [0 1]. cmap = cmap ./ 255; end
function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds) % Partition CamVid data by randomly selecting 60% of the data for training. The % rest is used for testing. % Set initial random state for example reproducibility. rng(0); numFiles = numel(imds.Files); shuffledIndices = randperm(numFiles); % Use 60% of the images for training. numTrain = round(0.60 * numFiles); trainingIdx = shuffledIndices(1:numTrain); % Use 20% of the images for validation numVal = round(0.20 * numFiles); valIdx = shuffledIndices(numTrain+1:numTrain+numVal); % Use the rest for testing. testIdx = shuffledIndices(numTrain+numVal+1:end); % Create image datastores for training and test. trainingImages = imds.Files(trainingIdx); valImages = imds.Files(valIdx); testImages = imds.Files(testIdx); imdsTrain = imageDatastore(trainingImages); imdsVal = imageDatastore(valImages); imdsTest = imageDatastore(testImages); % Extract class and label IDs info. classes = pxds.ClassNames; labelIDs = camvidPixelLabelIDs(); % Create pixel label datastores for training and test. trainingLabels = pxds.Files(trainingIdx); valLabels = pxds.Files(valIdx); testLabels = pxds.Files(testIdx); pxdsTrain = pixelLabelDatastore(trainingLabels, classes, labelIDs); pxdsVal = pixelLabelDatastore(valLabels, classes, labelIDs); pxdsTest = pixelLabelDatastore(testLabels, classes, labelIDs); end
function data = augmentImageAndLabel(data, xTrans, yTrans) % Augment images and pixel label images using random reflection and % translation. for i = 1:size(data,1) tform = randomAffine2d(... 'XReflection',true,... 'XTranslation', xTrans, ... 'YTranslation', yTrans); % Center the view at the center of image in the output space while % allowing translation to move the output image out of view. rout = affineOutputView(size(data{i,1}), tform, 'BoundsStyle', 'centerOutput'); % Warp the image and pixel labels using the same transform. data{i,1} = imwarp(data{i,1}, tform, 'OutputView', rout); data{i,2} = imwarp(data{i,2}, tform, 'OutputView', rout); end end
[1] Chen, Liang-Chieh et al. “Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation” ECCV (2018).
[2] Brostow, G. J., J. Fauqueur, and R. Cipolla. "Semantic object classes in video: A high-definition ground truth database." Pattern Recognition Letters. Vol. 30, Issue 2, 2009, стр 88-97.
imageDataAugmenter
| trainingOptions
| trainNetwork
| countEachLabel
(Computer Vision Toolbox) | evaluateSemanticSegmentation
(Computer Vision Toolbox) | pixelClassificationLayer
(Computer Vision Toolbox) | pixelLabelDatastore
(Computer Vision Toolbox) | pixelLabelImageDatastore
(Computer Vision Toolbox) | segnetLayers
(Computer Vision Toolbox) | semanticseg
(Computer Vision Toolbox) | labeloverlay
(Image Processing Toolbox)