В этом примере показано, как обучить сеть семантической сегментации с помощью глубокого обучения.
Семантическая сеть сегментации классифицирует каждый пиксель в изображении, в результате чего изображение сегментируется по классу. Приложения для семантической сегментации включают сегментацию дорог для автономного вождения и сегментацию раковых клеток для медицинской диагностики. Дополнительные сведения см. в разделе Начало работы с семантической сегментацией с помощью глубокого обучения (панель инструментов компьютерного зрения).
Чтобы проиллюстрировать процедуру обучения, в этом примере обучается Deeplab v3 + [1], один из типов сверточной нейронной сети (CNN), предназначенной для сегментации семантического изображения. Другие типы сетей для семантической сегментации включают полностью сверточные сети (FCN), SegNet и U-Net. Показанная здесь процедура обучения также может быть применена к этим сетям.
В этом примере для обучения используется набор данных CamVid [2] Кембриджского университета. Этот набор данных представляет собой коллекцию изображений, содержащих виды на уровне улиц, полученные во время вождения. Набор данных содержит пиксельные метки для 32 семантических классов, включая автомобильный, пешеходный и дорожный.
В этом примере создается сеть Deeplab v3 + с весами, инициализированными из предварительно обученной сети Resnet-18. ResNet-18 является эффективной сетью, которая хорошо подходит для приложений с ограниченными ресурсами обработки. Другие предварительно обученные сети, такие как MobileNet v2 или ResNet-50 могут также использоваться в зависимости от основных эксплуатационных характеристик. Дополнительные сведения см. в разделе Предварительно обученные нейронные сети.
Для получения предварительно подготовленного Resnet-18 установите resnet18. После завершения установки выполните следующий код для проверки правильности установки.
resnet18();
Кроме того, загрузите предварительно подготовленную версию DeepLab v3 +. Предварительно обученная модель позволяет выполнять весь пример без необходимости ждать завершения обучения.
pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid.mat'; pretrainedFolder = fullfile(tempdir,'pretrainedNetwork'); pretrainedNetwork = fullfile(pretrainedFolder,'deeplabv3plusResnet18CamVid.mat'); if ~exist(pretrainedNetwork,'file') mkdir(pretrainedFolder); disp('Downloading pretrained network (58 MB)...'); websave(pretrainedNetwork,pretrainedURL); end
Для выполнения этого примера настоятельно рекомендуется использовать графический процессор NVIDIA™ с поддержкой CUDA. Для использования графического процессора требуется Toolbox™ параллельных вычислений. Сведения о поддерживаемых вычислительных возможностях см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).
Загрузите набор данных CamVid со следующих URL-адресов.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; outputFolder = fullfile(tempdir,'CamVid'); labelsZip = fullfile(outputFolder,'labels.zip'); imagesZip = fullfile(outputFolder,'images.zip'); if ~exist(labelsZip, 'file') || ~exist(imagesZip,'file') mkdir(outputFolder) disp('Downloading 16 MB CamVid dataset labels...'); websave(labelsZip, labelURL); unzip(labelsZip, fullfile(outputFolder,'labels')); disp('Downloading 557 MB CamVid dataset images...'); websave(imagesZip, imageURL); unzip(imagesZip, fullfile(outputFolder,'images')); end
Примечание.Время загрузки данных зависит от вашего подключения к Интернету. Команды, использованные выше, блокируют MATLAB до завершения загрузки. Кроме того, с помощью веб-браузера можно сначала загрузить набор данных на локальный диск. Чтобы использовать файл, загруженный из Интернета, измените outputFolder переменная, указанная выше, в соответствии с расположением загруженного файла.
Использовать imageDatastore для загрузки изображений CamVid. imageDatastore позволяет эффективно загружать большую коллекцию образов на диск.
imgDir = fullfile(outputFolder,'images','701_StillsRaw_full'); imds = imageDatastore(imgDir);
Отображение одного из изображений.
I = readimage(imds,559); I = histeq(I); imshow(I)

Использовать pixelLabelDatastore (Computer Vision Toolbox) для загрузки данных изображения пиксельной метки CamVid. A pixelLabelDatastore инкапсулирует данные пиксельной метки и идентификатор метки в соответствие имени класса.
Мы облегчаем обучение, мы группируем 32 оригинальных класса в CamVid до 11 классов. Укажите эти классы.
classes = [
"Sky"
"Building"
"Pole"
"Road"
"Pavement"
"Tree"
"SignSymbol"
"Fence"
"Car"
"Pedestrian"
"Bicyclist"
];Чтобы свести 32 класса к 11, несколько классов из исходного набора данных группируются вместе. Например, «Вагон» - это сочетание «Вагон», «СУВПиккупГрузовик», «Truck_Bus,» «Поезд» и «OtherMoving». Возврат сгруппированных идентификаторов меток с помощью функции поддержки camvidPixelLabelIDs, который приведен в конце этого примера.
labelIDs = camvidPixelLabelIDs();
Используйте классы и идентификаторы меток для создания pixelLabelDatastore.
labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);Считывание и отображение одного из помеченных пикселями изображений путем наложения его поверх изображения.
C = readimage(pxds,559);
cmap = camvidColorMap;
B = labeloverlay(I,C,'ColorMap',cmap);
imshow(B)
pixelLabelColorbar(cmap,classes);
Области без наложения цветов не имеют пиксельных меток и не используются во время обучения.
Чтобы просмотреть распределение меток классов в наборе данных CamVid, используйте countEachLabel(Панель инструментов компьютерного зрения). Эта функция подсчитывает количество пикселей по метке класса.
tbl = countEachLabel(pxds)
tbl=11×3 table
Name PixelCount ImagePixelCount
______________ __________ _______________
{'Sky' } 7.6801e+07 4.8315e+08
{'Building' } 1.1737e+08 4.8315e+08
{'Pole' } 4.7987e+06 4.8315e+08
{'Road' } 1.4054e+08 4.8453e+08
{'Pavement' } 3.3614e+07 4.7209e+08
{'Tree' } 5.4259e+07 4.479e+08
{'SignSymbol'} 5.2242e+06 4.6863e+08
{'Fence' } 6.9211e+06 2.516e+08
{'Car' } 2.4437e+07 4.8315e+08
{'Pedestrian'} 3.4029e+06 4.4444e+08
{'Bicyclist' } 2.5912e+06 2.6196e+08
Визуализация количества пикселей по классам.
frequency = tbl.PixelCount/sum(tbl.PixelCount);
bar(1:numel(classes),frequency)
xticks(1:numel(classes))
xticklabels(tbl.Name)
xtickangle(45)
ylabel('Frequency')
В идеале все классы имели бы равное количество наблюдений. Однако классы в CamVid несбалансированы, что является распространенной проблемой в автомобильных наборах данных уличных сцен. Такие сцены имеют больше пикселей неба, здания и дороги, чем пиксели пешеходов и велосипедистов, потому что небо, здания и дороги покрывают большую площадь изображения. В случае неправильной обработки этот дисбаланс может нанести ущерб процессу обучения, поскольку обучение смещено в пользу доминирующих классов. Далее в этом примере для решения этой проблемы будет использоваться взвешивание класса.
Изображения в наборе данных CamVid имеют размер 720 на 960. Размер изображения выбирается таким образом, чтобы достаточно большая партия изображений могла поместиться в память во время обучения на NVIDIA™ Titan X с 12 ГБ памяти. Возможно, потребуется изменить размер изображений до меньшего размера, если графический процессор не имеет достаточного объема памяти или уменьшен размер учебного пакета.
Deeplab v3 + обучается с использованием 60% изображений из набора данных. Остальные изображения равномерно разделены на 20% и 20% для проверки и тестирования соответственно. Следующий код случайным образом разбивает данные изображения и метки пикселя на обучающий, проверочный и тестовый набор.
[imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds);
Разделение 60/20/20 приводит к следующему количеству изображений обучения, проверки и тестирования:
numTrainingImages = numel(imdsTrain.Files)
numTrainingImages = 421
numValImages = numel(imdsVal.Files)
numValImages = 140
numTestingImages = numel(imdsTest.Files)
numTestingImages = 140
Используйте deeplabv3plusLayers для создания сети DeepLab v3 + на основе ResNet-18. Выбор лучшей сети для вашего приложения требует эмпирического анализа и является другим уровнем настройки гиперпараметров. Например, можно поэкспериментировать с различными базовыми сетями, такими как ResNet-50 или StartNet v2, или можно попробовать другие архитектуры сети семантической сегментации, такие как SegNet, полностью сверточные сети (FCN) или U-Net.
% Specify the network image size. This is typically the same as the traing image sizes. imageSize = [720 960 3]; % Specify the number of classes. numClasses = numel(classes); % Create DeepLab v3+. lgraph = deeplabv3plusLayers(imageSize, numClasses, "resnet18");
Как показано ранее, классы в CamVid не сбалансированы. Для улучшения обучения можно использовать взвешивание классов, чтобы сбалансировать классы. Использовать количество меток пикселей, вычисленное ранее с помощью countEachLabel (Computer Vision Toolbox) и вычислите средние веса класса частот.
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount; classWeights = median(imageFreq) ./ imageFreq
classWeights = 11×1
0.3182
0.2082
5.0924
0.1744
0.7103
0.4175
4.5371
1.8386
1.0000
6.6059
⋮
Укажите веса классов с помощью pixelClassificationLayer(Панель инструментов компьютерного зрения).
pxLayer = pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights); lgraph = replaceLayer(lgraph,"classification",pxLayer);
Алгоритм оптимизации, используемый для обучения, - стохастический градиентный спуск с импульсом (SGDM). Использовать trainingOptions для указания гиперпараметров, используемых для SGDM.
% Define validation data. dsVal = combine(imdsVal,pxdsVal); % Define training options. options = trainingOptions('sgdm', ... 'LearnRateSchedule','piecewise',... 'LearnRateDropPeriod',10,... 'LearnRateDropFactor',0.3,... 'Momentum',0.9, ... 'InitialLearnRate',1e-3, ... 'L2Regularization',0.005, ... 'ValidationData',dsVal,... 'MaxEpochs',30, ... 'MiniBatchSize',8, ... 'Shuffle','every-epoch', ... 'CheckpointPath', tempdir, ... 'VerboseFrequency',2,... 'Plots','training-progress',... 'ValidationPatience', 4);
Коэффициент обучения использует кусочный график. Уровень обучения снижается в 0,3 раза каждые 10 эпох. Это позволяет сети быстро учиться с более высокой начальной скоростью обучения, а также находить решение, близкое к локальному оптимуму, как только скорость обучения падает.
Сеть тестируется по данным проверки каждый период путем установки 'ValidationData' параметр. 'ValidationPatience' имеет значение 4, чтобы остановить обучение на ранней стадии, когда точность проверки сходится. Это предотвращает переопределение сети в наборе данных обучения.
Мини-пакет размером 8 используется для уменьшения использования памяти во время обучения. Это значение можно увеличить или уменьшить в зависимости от объема памяти графического процессора в системе.
Кроме того, 'CheckpointPath' имеет временное расположение. Эта пара имя-значение позволяет сохранять контрольные точки сети в конце каждого периода обучения. Если обучение прервано из-за сбоя системы или отключения питания, вы можете возобновить обучение с сохраненной контрольной точки. Убедитесь, что расположение указано 'CheckpointPath' имеет достаточно места для хранения контрольных точек сети. Например, для сохранения 100 контрольных точек Deeplab v3 + требуется ~ 6 ГБ дискового пространства, поскольку размер каждой контрольной точки составляет 61 МБ.
Увеличение данных используется для повышения точности сети путем случайного преобразования исходных данных во время обучения. С помощью увеличения данных можно добавлять большее разнообразие к обучающим данным без увеличения количества маркированных обучающих образцов. Чтобы применить одно и то же случайное преобразование к данным изображения и метки пикселя, используйте хранилище данных combine и transform. Во-первых, комбинировать imdsTrain и pxdsTrain.
dsTrain = combine(imdsTrain, pxdsTrain);
Далее используйте хранилище данных transform для применения требуемого увеличения данных, определенного в поддерживающей функции augmentImageAndLabel. Здесь для увеличения данных используют случайное левое/правое отражение и случайное X/Y-преобразование +/- 10 пикселей.
xTrans = [-10 10]; yTrans = [-10 10]; dsTrain = transform(dsTrain, @(data)augmentImageAndLabel(data,xTrans,yTrans));
Обратите внимание, что увеличение данных не применяется к данным тестирования и проверки. В идеале, данные тестирования и проверки должны быть репрезентативными для исходных данных и оставлены неизмененными для объективной оценки.
Начать обучение с использованием trainNetwork если doTraining флаг имеет значение true. В противном случае загрузите предварительно подготовленную сеть.
Примечание: Обучение было проверено на NVIDIA™ Titan X с 12 ГБ памяти GPU. Если в графическом процессоре меньше памяти, во время обучения может не хватить памяти. В этом случае попробуйте установить 'MiniBatchSize' до 1 в trainingOptionsили уменьшение сетевого ввода и изменение размера обучающих данных. Обучение этой сети занимает около 5 часов. В зависимости от оборудования графического процессора это может занять еще больше времени.
doTraining = false; if doTraining [net, info] = trainNetwork(dsTrain,lgraph,options); else data = load(pretrainedNetwork); net = data.net; end
В качестве быстрой проверки работоспособности запустите обученную сеть на одном тестовом образе.
I = readimage(imdsTest,35); C = semanticseg(I, net);
Просмотрите результаты.
B = labeloverlay(I,C,'Colormap',cmap,'Transparency',0.4); imshow(B) pixelLabelColorbar(cmap, classes);

Сравнение результатов в C с ожидаемой истинностью основания, сохраненной в pxdsTest. Зеленая и пурпурная области выделяют области, где результаты сегментации отличаются от ожидаемой истинности грунта.
expectedResult = readimage(pxdsTest,35); actual = uint8(C); expected = uint8(expectedResult); imshowpair(actual, expected)

Визуально результаты семантической сегментации хорошо перекрываются для таких классов, как дорога, небо и строительство. Однако более мелкие объекты вроде пешеходов и автомобилей не так точны. Величина перекрытия на класс может быть измерена с помощью метрики пересечение-объединение (IoU), также известной как индекс Jaccard. Используйте jaccard(Панель инструментов обработки изображений) для измерения IoU.
iou = jaccard(C,expectedResult); table(classes,iou)
ans=11×2 table
classes iou
____________ _______
"Sky" 0.91837
"Building" 0.84479
"Pole" 0.31203
"Road" 0.93698
"Pavement" 0.82838
"Tree" 0.89636
"SignSymbol" 0.57644
"Fence" 0.71046
"Car" 0.66688
"Pedestrian" 0.48417
"Bicyclist" 0.68431
Метрика IoU подтверждает визуальные результаты. Классы дорог, неба и зданий имеют высокие оценки IoU, в то время как такие классы, как пешеходы и автомобили, имеют низкие оценки. Другие общие метрики сегментации включают dice(Панель инструментов обработки изображений) и bfscore(Панель инструментов обработки изображений) оценка соответствия контуров.
Чтобы измерить точность для нескольких тестовых изображений, выполните командуsemanticseg (Computer Vision Toolbox) для всего набора тестов. Мини-пакет размером 4 используется для уменьшения использования памяти при сегментировании изображений. Это значение можно увеличить или уменьшить в зависимости от объема памяти графического процессора в системе.
pxdsResults = semanticseg(imdsTest,net, ... 'MiniBatchSize',4, ... 'WriteLocation',tempdir, ... 'Verbose',false);
semanticseg возвращает результаты для тестового набора в виде pixelLabelDatastore объект. Фактические данные метки пикселя для каждого тестового изображения в imdsTest записывается на диск в расположении, указанном 'WriteLocation' параметр. Использовать evaluateSemanticSegmentation (Computer Vision Toolbox) для измерения показателей семантической сегментации в результатах тестового набора.
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest,'Verbose',false);evaluateSemanticSegmentation возвращает различные метрики для всего набора данных, для отдельных классов и для каждого тестового образа. Чтобы просмотреть метрики уровня набора данных, проверьте metrics.DataSetMetrics .
metrics.DataSetMetrics
ans=1×5 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.87695 0.85392 0.6302 0.80851 0.65051
Метрики набора данных обеспечивают общий обзор производительности сети. Чтобы увидеть влияние каждого класса на общую производительность, проверьте метрики для каждого класса с помощью metrics.ClassMetrics.
metrics.ClassMetrics
ans=11×3 table
Accuracy IoU MeanBFScore
________ _______ ___________
Sky 0.93112 0.90209 0.8952
Building 0.78453 0.76098 0.58511
Pole 0.71586 0.21477 0.51439
Road 0.93024 0.91465 0.76696
Pavement 0.88466 0.70571 0.70919
Tree 0.87377 0.76323 0.70875
SignSymbol 0.79358 0.39309 0.48302
Fence 0.81507 0.46484 0.48566
Car 0.90956 0.76799 0.69233
Pedestrian 0.87629 0.4366 0.60792
Bicyclist 0.87844 0.60829 0.55089
Хотя общая производительность набора данных довольно высока, метрики классов показывают, что недопредставленные классы, такие как Pedestrian, Bicyclist, и Car не сегментированы, а также классы, такие как Road, Sky, и Building. Дополнительные данные, включающие больше образцов недопредставленных классов, могут помочь улучшить результаты.
function labelIDs = camvidPixelLabelIDs() % Return the label IDs corresponding to each class. % % The CamVid dataset has 32 classes. Group them into 11 classes following % the original SegNet training methodology [1]. % % The 11 classes are: % "Sky" "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol", % "Fence", "Car", "Pedestrian", and "Bicyclist". % % CamVid pixel label IDs are provided as RGB color values. Group them into % 11 classes and return them as a cell array of M-by-3 matrices. The % original CamVid class names are listed alongside each RGB value. Note % that the Other/Void class are excluded below. labelIDs = { ... % "Sky" [ 128 128 128; ... % "Sky" ] % "Building" [ 000 128 064; ... % "Bridge" 128 000 000; ... % "Building" 064 192 000; ... % "Wall" 064 000 064; ... % "Tunnel" 192 000 128; ... % "Archway" ] % "Pole" [ 192 192 128; ... % "Column_Pole" 000 000 064; ... % "TrafficCone" ] % Road [ 128 064 128; ... % "Road" 128 000 192; ... % "LaneMkgsDriv" 192 000 064; ... % "LaneMkgsNonDriv" ] % "Pavement" [ 000 000 192; ... % "Sidewalk" 064 192 128; ... % "ParkingBlock" 128 128 192; ... % "RoadShoulder" ] % "Tree" [ 128 128 000; ... % "Tree" 192 192 000; ... % "VegetationMisc" ] % "SignSymbol" [ 192 128 128; ... % "SignSymbol" 128 128 064; ... % "Misc_Text" 000 064 064; ... % "TrafficLight" ] % "Fence" [ 064 064 128; ... % "Fence" ] % "Car" [ 064 000 128; ... % "Car" 064 128 192; ... % "SUVPickupTruck" 192 128 192; ... % "Truck_Bus" 192 064 128; ... % "Train" 128 064 064; ... % "OtherMoving" ] % "Pedestrian" [ 064 064 000; ... % "Pedestrian" 192 128 064; ... % "Child" 064 000 192; ... % "CartLuggagePram" 064 128 064; ... % "Animal" ] % "Bicyclist" [ 000 128 192; ... % "Bicyclist" 192 000 192; ... % "MotorcycleScooter" ] }; end
function pixelLabelColorbar(cmap, classNames) % Add a colorbar to the current axis. The colorbar is formatted % to display the class names with the color. colormap(gca,cmap) % Add colorbar to current figure. c = colorbar('peer', gca); % Use class names for tick marks. c.TickLabels = classNames; numClasses = size(cmap,1); % Center tick labels. c.Ticks = 1/(numClasses*2):1/numClasses:1; % Remove tick mark. c.TickLength = 0; end
function cmap = camvidColorMap() % Define the colormap used by CamVid dataset. cmap = [ 128 128 128 % Sky 128 0 0 % Building 192 192 192 % Pole 128 64 128 % Road 60 40 222 % Pavement 128 128 0 % Tree 192 128 128 % SignSymbol 64 64 128 % Fence 64 0 128 % Car 64 64 0 % Pedestrian 0 128 192 % Bicyclist ]; % Normalize between [0 1]. cmap = cmap ./ 255; end
function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds) % Partition CamVid data by randomly selecting 60% of the data for training. The % rest is used for testing. % Set initial random state for example reproducibility. rng(0); numFiles = numel(imds.Files); shuffledIndices = randperm(numFiles); % Use 60% of the images for training. numTrain = round(0.60 * numFiles); trainingIdx = shuffledIndices(1:numTrain); % Use 20% of the images for validation numVal = round(0.20 * numFiles); valIdx = shuffledIndices(numTrain+1:numTrain+numVal); % Use the rest for testing. testIdx = shuffledIndices(numTrain+numVal+1:end); % Create image datastores for training and test. trainingImages = imds.Files(trainingIdx); valImages = imds.Files(valIdx); testImages = imds.Files(testIdx); imdsTrain = imageDatastore(trainingImages); imdsVal = imageDatastore(valImages); imdsTest = imageDatastore(testImages); % Extract class and label IDs info. classes = pxds.ClassNames; labelIDs = camvidPixelLabelIDs(); % Create pixel label datastores for training and test. trainingLabels = pxds.Files(trainingIdx); valLabels = pxds.Files(valIdx); testLabels = pxds.Files(testIdx); pxdsTrain = pixelLabelDatastore(trainingLabels, classes, labelIDs); pxdsVal = pixelLabelDatastore(valLabels, classes, labelIDs); pxdsTest = pixelLabelDatastore(testLabels, classes, labelIDs); end
function data = augmentImageAndLabel(data, xTrans, yTrans) % Augment images and pixel label images using random reflection and % translation. for i = 1:size(data,1) tform = randomAffine2d(... 'XReflection',true,... 'XTranslation', xTrans, ... 'YTranslation', yTrans); % Center the view at the center of image in the output space while % allowing translation to move the output image out of view. rout = affineOutputView(size(data{i,1}), tform, 'BoundsStyle', 'centerOutput'); % Warp the image and pixel labels using the same transform. data{i,1} = imwarp(data{i,1}, tform, 'OutputView', rout); data{i,2} = imwarp(data{i,2}, tform, 'OutputView', rout); end end
[1] Чен, Лян-Чие и др. «Кодер-декодер с Atrous Separable сверткой для сегментации семантического изображения». ECCV (2018).
[2] Бростоу, G. J., J. Fauqueur и R. Cipolla. «Классы семантических объектов в видео: база данных истинности земли высокой четкости». Буквы распознавания образов. Том 30, выпуск 2, 2009, стр. 88-97.
imageDataAugmenter | trainingOptions | trainNetwork | countEachLabel (Панель инструментов компьютерного зрения) | evaluateSemanticSegmentation (Панель инструментов компьютерного зрения) | pixelClassificationLayer (Панель инструментов компьютерного зрения) | pixelLabelDatastore (Панель инструментов компьютерного зрения) | pixelLabelImageDatastore (Панель инструментов компьютерного зрения) | segnetLayers (Панель инструментов компьютерного зрения) | semanticseg (Панель инструментов компьютерного зрения) | labeloverlay(Панель инструментов обработки изображений)