Этот пример показывает, как обучить сеть семантической сегментации с помощью глубокого обучения.
Семантическая сеть сегментации классифицирует каждый пиксель в изображении, получая к изображение, которое сегментировано по классам. Приложения для семантической сегментации включают сегментацию дорог для автономного управления автомобилем и сегментацию раковых камер для медицинского диагностирования. Дополнительные сведения см. в разделе Начало работы с семантической сегментацией с использованием глубокого обучения (Computer Vision Toolbox).
Чтобы проиллюстрировать процедуру обучения, этот пример обучает Deeplab v3 + [1], один тип сверточной нейронной сети (CNN), предназначенной для семантической сегментации изображений. Другие типы сетей для семантической сегментации включают полностью сверточные сети (FCN), SegNet и U-Net. Процедура обучения, показанная здесь, может также применяться к этим сетям.
Этот пример использует набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных является коллекцией изображений, содержащих представления уличного уровня, полученные во время вождения. Набор данных обеспечивает метки пиксельного уровня для 32 семантических классов, включая автомобиль, пешехода и дорогу.
Этот пример создает сеть Deeplab v3 + с весами, инициализированными предварительно обученной Resnet-18 сетью. ResNet-18 является эффективной сетью, которая хорошо подходит для приложений с ограниченными ресурсами обработки. Другие предварительно обученные сети, такие как MobileNet v2 или ResNet-50, также могут использоваться в зависимости от требований приложения. Для получения дополнительной информации см. Pretrained Deep Neural Networks.
Чтобы получить предварительно обученную Resnet-18, установите resnet18
. После завершения установки запустите следующий код, чтобы убедиться в правильности установки.
resnet18();
Кроме того, скачать предварительно обученную версию DeepLab v3 +. Предварительно обученная модель позволяет вам запустить весь пример, не дожидаясь завершения обучения.
pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid.mat'; pretrainedFolder = fullfile(tempdir,'pretrainedNetwork'); pretrainedNetwork = fullfile(pretrainedFolder,'deeplabv3plusResnet18CamVid.mat'); if ~exist(pretrainedNetwork,'file') mkdir(pretrainedFolder); disp('Downloading pretrained network (58 MB)...'); websave(pretrainedNetwork,pretrainedURL); end
Для выполнения этого примера настоятельно рекомендуется использовать NVIDIA™ графический процессор с поддержкой CUDA. Для использования графический процессор требуется Parallel Computing Toolbox™. Для получения информации о поддерживаемых вычислительных возможностях смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
Загрузите набор данных CamVid со следующих URL-адресов.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; outputFolder = fullfile(tempdir,'CamVid'); labelsZip = fullfile(outputFolder,'labels.zip'); imagesZip = fullfile(outputFolder,'images.zip'); if ~exist(labelsZip, 'file') || ~exist(imagesZip,'file') mkdir(outputFolder) disp('Downloading 16 MB CamVid dataset labels...'); websave(labelsZip, labelURL); unzip(labelsZip, fullfile(outputFolder,'labels')); disp('Downloading 557 MB CamVid dataset images...'); websave(imagesZip, imageURL); unzip(imagesZip, fullfile(outputFolder,'images')); end
Примечание: Время загрузки данных зависит от вашего подключения к Интернету. Команды, используемые выше, блокируют MATLAB до завершения загрузки. Также можно использовать веб-браузер для первой загрузки набора данных на локальный диск. Чтобы использовать загруженный из Интернета файл, измените outputFolder
переменная выше в местоположении загруженного файла.
Использование imageDatastore
для загрузки изображений CamVid. The imageDatastore
позволяет эффективно загружать на диск большой набор изображений.
imgDir = fullfile(outputFolder,'images','701_StillsRaw_full'); imds = imageDatastore(imgDir);
Отобразите одно из изображений.
I = readimage(imds,559); I = histeq(I); imshow(I)
Использование pixelLabelDatastore
(Computer Vision Toolbox) для загрузки данных о пиксельных метках изображение. A pixelLabelDatastore
инкапсулирует данные о пиксельных метках и идентификатор метки в отображение имен классов.
Мы облегчаем обучение, группируем 32 оригинальных класса в CamVid до 11 классов. Задайте эти классы.
classes = [ "Sky" "Building" "Pole" "Road" "Pavement" "Tree" "SignSymbol" "Fence" "Car" "Pedestrian" "Bicyclist" ];
Чтобы уменьшить 32 класса до 11, несколько классов из исходного набора данных сгруппированы вместе. Например, «Car» является комбинацией «Car», «SUVPickupTruck», «Truck_Bus,» «Train» и «OtherMoving». Верните сгруппированные идентификаторы меток с помощью вспомогательной функции camvidPixelLabelIDs
, который перечислен в конце этого примера.
labelIDs = camvidPixelLabelIDs();
Используйте идентификаторы классов и меток, чтобы создать pixelLabelDatastore.
labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
Чтение и отображение одного из маркированных пикселями изображений путем наложения его поверх изображения.
C = readimage(pxds,559);
cmap = camvidColorMap;
B = labeloverlay(I,C,'ColorMap',cmap);
imshow(B)
pixelLabelColorbar(cmap,classes);
Области без наложения цветов не имеют пиксельных меток и не используются во время обучения.
Чтобы увидеть распределение меток классов в наборе данных CamVid, используйте countEachLabel
(Computer Vision Toolbox). Эта функция отсчитывает количество пикселей по меткам классов.
tbl = countEachLabel(pxds)
tbl=11×3 table
Name PixelCount ImagePixelCount
______________ __________ _______________
{'Sky' } 7.6801e+07 4.8315e+08
{'Building' } 1.1737e+08 4.8315e+08
{'Pole' } 4.7987e+06 4.8315e+08
{'Road' } 1.4054e+08 4.8453e+08
{'Pavement' } 3.3614e+07 4.7209e+08
{'Tree' } 5.4259e+07 4.479e+08
{'SignSymbol'} 5.2242e+06 4.6863e+08
{'Fence' } 6.9211e+06 2.516e+08
{'Car' } 2.4437e+07 4.8315e+08
{'Pedestrian'} 3.4029e+06 4.4444e+08
{'Bicyclist' } 2.5912e+06 2.6196e+08
Визуализируйте счетчики пикселей по классам.
frequency = tbl.PixelCount/sum(tbl.PixelCount);
bar(1:numel(classes),frequency)
xticks(1:numel(classes))
xticklabels(tbl.Name)
xtickangle(45)
ylabel('Frequency')
В идеале все классы имели бы равное количество наблюдений. Однако классы в CamVid несбалансированны, что является распространенной проблемой в автомобильных наборах данных уличных сцен. Такие сцены имеют больше пикселей неба, создания и дороги, чем пиксели пешехода и велосипедиста, потому что небо, создания и дороги покрывают больше площади на изображении. Если не обработать правильно, этот дисбаланс может нанести ущерб процессу обучения, потому что обучение предвзято в пользу доминирующих классов. Позже в этом примере для решения этой проблемы будет использоваться взвешивание классов.
Изображения в наборе данных CamVid имеют размер 720 на 960. Размер изображения выбран таким, чтобы достаточно большой пакет изображений мог помещаться в памяти во время обучения на NVIDIA™ Titan X с 12 ГБ памяти. Вы, возможно, должны будете изменить размер изображений на меньшие величины, если ваш графический процессор не имеет достаточной памяти, или уменьшить размер обучающего пакета.
Deeplab v3 + обучается с использованием 60% изображений из набора данных. Остальная часть изображений разделена равномерно на части в 20% и 20% для валидации и тестирования соответственно. Следующий код случайным образом разделяет изображение и данные о пиксельных метках на наборы для обучения, валидации и тестирования.
[imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds);
Разделение 60/20/20 приводит к следующему количеству изображений для обучения, валидации и тестирования:
numTrainingImages = numel(imdsTrain.Files)
numTrainingImages = 421
numValImages = numel(imdsVal.Files)
numValImages = 140
numTestingImages = numel(imdsTest.Files)
numTestingImages = 140
Используйте deeplabv3plusLayers
функция для создания сети DeepLab v3 + на основе ResNet-18. Выбор наилучшей сети для вашего приложения требует эмпирического анализа и является еще одним уровнем настройки гиперпараметра. Например, можно экспериментировать с различными базовыми сетями, такими как ResNet-50 или MobileNet v2, или можно попробовать другие семантические сетевые архитектуры сегментации, такие как SegNet, полностью сверточные сети (FCN) или U-Net.
% Specify the network image size. This is typically the same as the traing image sizes. imageSize = [720 960 3]; % Specify the number of classes. numClasses = numel(classes); % Create DeepLab v3+. lgraph = deeplabv3plusLayers(imageSize, numClasses, "resnet18");
Как показано ранее, классы в CamVid не сбалансированы. Для улучшения обучения можно использовать взвешивание классов для балансировки классов. Используйте счетчики пиксельных меток, вычисленные ранее с countEachLabel
(Computer Vision Toolbox) и вычислите средние веса классов частот.
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount; classWeights = median(imageFreq) ./ imageFreq
classWeights = 11×1
0.3182
0.2082
5.0924
0.1744
0.7103
0.4175
4.5371
1.8386
1.0000
6.6059
⋮
Задайте веса классов с помощью pixelClassificationLayer
(Computer Vision Toolbox).
pxLayer = pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights); lgraph = replaceLayer(lgraph,"classification",pxLayer);
Алгоритм оптимизации, используемый для обучения, является стохастическим градиентным спуском с импульсом (SGDM). Использование trainingOptions
для определения гиперпараметров, используемой для SGDM.
% Define validation data. dsVal = combine(imdsVal,pxdsVal); % Define training options. options = trainingOptions('sgdm', ... 'LearnRateSchedule','piecewise',... 'LearnRateDropPeriod',10,... 'LearnRateDropFactor',0.3,... 'Momentum',0.9, ... 'InitialLearnRate',1e-3, ... 'L2Regularization',0.005, ... 'ValidationData',dsVal,... 'MaxEpochs',30, ... 'MiniBatchSize',8, ... 'Shuffle','every-epoch', ... 'CheckpointPath', tempdir, ... 'VerboseFrequency',2,... 'Plots','training-progress',... 'ValidationPatience', 4);
В скорость обучения используется кусочное расписание. Скорость обучения уменьшается на множитель 0,3 каждые 10 эпох. Это позволяет сети обучаться быстро с более высокой начальной скоростью обучения, и в то же время может найти решение, близкое к локальному оптимуму, когда скорость обучения падает.
Сеть проверяется на соответствие данным валидации каждую эпоху путем установки 'ValidationData'
параметр. The 'ValidationPatience'
устанавливается равным 4, чтобы остановить обучение раньше, когда точность валидации сходится. Это препятствует сверхподбору кривой сети на обучающем наборе данных.
Мини-пакет размером 8 используется для уменьшения использования памяти во время обучения. Вы можете увеличить или уменьшить это значение в зависимости от объема памяти графический процессор, имеющейся в вашей системе.
В сложение, 'CheckpointPath'
устанавливается во временное место. Эта пара "имя-значение" включает сохранение сетевых контрольных точек в конце каждой эпохи обучения. Если обучение прервано из-за отказа системы или отключения степени, можно возобновить обучение с сохраненной контрольной точки. Убедитесь, что местоположение задано 'CheckpointPath'
имеет достаточно пространство для хранения сетевых контрольных точек. Например, для сохранения 100 контрольных точек Deeplab v3 + требуется ~ 6 ГБ дискового пространства, поскольку каждая контрольная точка составляет 61 МБ.
Увеличение количества данных используется для повышения точности сети путем случайного преобразования исходных данных во время обучения. При помощи увеличения данных можно добавить больше разнообразия в обучающие данные, не увеличивая количество маркированных обучающих выборок. Чтобы применить одно и то же случайное преобразование как к изображению, так и к данным о пиксельных метках, используйте datastore combine
и transform
. Во-первых, объедините imdsTrain
и pxdsTrain
.
dsTrain = combine(imdsTrain, pxdsTrain);
Далее используйте datastore transform
применить необходимое увеличение данных, заданное в вспомогательной функции augmentImageAndLabel
. Здесь для увеличения данных используются случайное отражение «слева/справа» и случайный преобразование X/Y на +/-10 пикселей.
xTrans = [-10 10]; yTrans = [-10 10]; dsTrain = transform(dsTrain, @(data)augmentImageAndLabel(data,xTrans,yTrans));
Обратите внимание, что увеличение количества данных не применяется к тестовым и валидационным данным. В идеале данные испытаний и валидации должны быть показательными по сравнению с исходными данными и не должны быть изменены для объективной оценки.
Начните обучение с помощью trainNetwork
если doTraining
флаг равен true. В противном случае загружает предварительно обученную сеть
Примечание: Обучение было проверено на NVIDIA™ Titan X с 12 ГБ памяти графический процессор. Если ваш графический процессор имеет меньше памяти, у вас может закончиться память во время обучения. Если это произойдет, попробуйте установить 'MiniBatchSize'
до 1 в trainingOptions
или уменьшение входа сети и изменение размера обучающих данных. Обучение этой сети занимает около 5 часов. В зависимости от вашего оборудования для графического процессора это может занять еще больше времени.
doTraining = false; if doTraining [net, info] = trainNetwork(dsTrain,lgraph,options); else data = load(pretrainedNetwork); net = data.net; end
В качестве быстрой проверки работоспособности запустите обученную сеть на одном тестовом изображении.
I = readimage(imdsTest,35); C = semanticseg(I, net);
Отображение результатов.
B = labeloverlay(I,C,'Colormap',cmap,'Transparency',0.4); imshow(B) pixelLabelColorbar(cmap, classes);
Сравните результаты в C
с ожидаемой основной истиной, сохраненной в pxdsTest
. Зеленая и пурпурная области подсвечивают области, где результаты сегментации отличаются от ожидаемой основной истины.
expectedResult = readimage(pxdsTest,35); actual = uint8(C); expected = uint8(expectedResult); imshowpair(actual, expected)
Визуально результаты семантической сегментации хорошо перекрываются для таких классов, как дорога , небо и здание. Однако меньшие объекты вроде пешеходов и автомобилей не так точны. Сумма перекрытия для класса может быть измерено с помощью метрики пересечения по соединению (IoU), также известной как индекс Жаккара. Используйте jaccard
(Image Processing Toolbox) функция для измерения IoU.
iou = jaccard(C,expectedResult); table(classes,iou)
ans=11×2 table
classes iou
____________ _______
"Sky" 0.91837
"Building" 0.84479
"Pole" 0.31203
"Road" 0.93698
"Pavement" 0.82838
"Tree" 0.89636
"SignSymbol" 0.57644
"Fence" 0.71046
"Car" 0.66688
"Pedestrian" 0.48417
"Bicyclist" 0.68431
Метрика IoU подтверждает визуальные результаты. Классы "Дорога" , "Небо" и "Здания" имеют высокие счета IoU, в то время как классы, такие как пешеходный и автомобиль, имеют низкие счета. Другие общие метрики сегментации включают dice
(Image Processing Toolbox) и bfscore
(Image Processing Toolbox) счет соответствия контура.
Чтобы измерить точность для нескольких тестовых изображений, запустите semanticseg
(Computer Vision Toolbox) на целом тестовом наборе. Размер мини-пакета, равный 4, используется для уменьшения использования памяти при сегментации изображений. Вы можете увеличить или уменьшить это значение в зависимости от объема памяти графический процессор, имеющейся в вашей системе.
pxdsResults = semanticseg(imdsTest,net, ... 'MiniBatchSize',4, ... 'WriteLocation',tempdir, ... 'Verbose',false);
semanticseg
возвращает результаты для тестового набора как pixelLabelDatastore
объект. Фактические данные о пиксельных метках для каждого тестового изображения в imdsTest
записывается на диск в расположении, заданном 'WriteLocation'
параметр. Использование evaluateSemanticSegmentation
(Computer Vision Toolbox), чтобы измерить метрики семантической сегментации на результатах тестового набора.
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest,'Verbose',false);
evaluateSemanticSegmentation
возвращает различные метрики для набора данных в целом, для отдельных классов и для каждого тестового изображения. Чтобы увидеть метрики уровня набора данных, смотрите metrics.DataSetMetrics
.
metrics.DataSetMetrics
ans=1×5 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU MeanBFScore
______________ ____________ _______ ___________ ___________
0.87695 0.85392 0.6302 0.80851 0.65051
Метрики набора данных обеспечивают высокоуровневый обзор эффективности сети. Чтобы увидеть влияние каждого класса на общую эффективность, смотрите метрики по классам с помощью metrics.ClassMetrics
.
metrics.ClassMetrics
ans=11×3 table
Accuracy IoU MeanBFScore
________ _______ ___________
Sky 0.93112 0.90209 0.8952
Building 0.78453 0.76098 0.58511
Pole 0.71586 0.21477 0.51439
Road 0.93024 0.91465 0.76696
Pavement 0.88466 0.70571 0.70919
Tree 0.87377 0.76323 0.70875
SignSymbol 0.79358 0.39309 0.48302
Fence 0.81507 0.46484 0.48566
Car 0.90956 0.76799 0.69233
Pedestrian 0.87629 0.4366 0.60792
Bicyclist 0.87844 0.60829 0.55089
Несмотря на то, что общая эффективность набора данных довольно высока, метрики классов показывают, что недостаточно представленные классы, такие как Pedestrian
, Bicyclist
, и Car
не сегментированы, также как и классы, такие как Road
, Sky
, и Building
. Дополнительные данные, которые включают больше выборок недостаточно представленных классов, могут помочь улучшить результаты.
function labelIDs = camvidPixelLabelIDs() % Return the label IDs corresponding to each class. % % The CamVid dataset has 32 classes. Group them into 11 classes following % the original SegNet training methodology [1]. % % The 11 classes are: % "Sky" "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol", % "Fence", "Car", "Pedestrian", and "Bicyclist". % % CamVid pixel label IDs are provided as RGB color values. Group them into % 11 classes and return them as a cell array of M-by-3 matrices. The % original CamVid class names are listed alongside each RGB value. Note % that the Other/Void class are excluded below. labelIDs = { ... % "Sky" [ 128 128 128; ... % "Sky" ] % "Building" [ 000 128 064; ... % "Bridge" 128 000 000; ... % "Building" 064 192 000; ... % "Wall" 064 000 064; ... % "Tunnel" 192 000 128; ... % "Archway" ] % "Pole" [ 192 192 128; ... % "Column_Pole" 000 000 064; ... % "TrafficCone" ] % Road [ 128 064 128; ... % "Road" 128 000 192; ... % "LaneMkgsDriv" 192 000 064; ... % "LaneMkgsNonDriv" ] % "Pavement" [ 000 000 192; ... % "Sidewalk" 064 192 128; ... % "ParkingBlock" 128 128 192; ... % "RoadShoulder" ] % "Tree" [ 128 128 000; ... % "Tree" 192 192 000; ... % "VegetationMisc" ] % "SignSymbol" [ 192 128 128; ... % "SignSymbol" 128 128 064; ... % "Misc_Text" 000 064 064; ... % "TrafficLight" ] % "Fence" [ 064 064 128; ... % "Fence" ] % "Car" [ 064 000 128; ... % "Car" 064 128 192; ... % "SUVPickupTruck" 192 128 192; ... % "Truck_Bus" 192 064 128; ... % "Train" 128 064 064; ... % "OtherMoving" ] % "Pedestrian" [ 064 064 000; ... % "Pedestrian" 192 128 064; ... % "Child" 064 000 192; ... % "CartLuggagePram" 064 128 064; ... % "Animal" ] % "Bicyclist" [ 000 128 192; ... % "Bicyclist" 192 000 192; ... % "MotorcycleScooter" ] }; end
function pixelLabelColorbar(cmap, classNames) % Add a colorbar to the current axis. The colorbar is formatted % to display the class names with the color. colormap(gca,cmap) % Add colorbar to current figure. c = colorbar('peer', gca); % Use class names for tick marks. c.TickLabels = classNames; numClasses = size(cmap,1); % Center tick labels. c.Ticks = 1/(numClasses*2):1/numClasses:1; % Remove tick mark. c.TickLength = 0; end
function cmap = camvidColorMap() % Define the colormap used by CamVid dataset. cmap = [ 128 128 128 % Sky 128 0 0 % Building 192 192 192 % Pole 128 64 128 % Road 60 40 222 % Pavement 128 128 0 % Tree 192 128 128 % SignSymbol 64 64 128 % Fence 64 0 128 % Car 64 64 0 % Pedestrian 0 128 192 % Bicyclist ]; % Normalize between [0 1]. cmap = cmap ./ 255; end
function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds) % Partition CamVid data by randomly selecting 60% of the data for training. The % rest is used for testing. % Set initial random state for example reproducibility. rng(0); numFiles = numel(imds.Files); shuffledIndices = randperm(numFiles); % Use 60% of the images for training. numTrain = round(0.60 * numFiles); trainingIdx = shuffledIndices(1:numTrain); % Use 20% of the images for validation numVal = round(0.20 * numFiles); valIdx = shuffledIndices(numTrain+1:numTrain+numVal); % Use the rest for testing. testIdx = shuffledIndices(numTrain+numVal+1:end); % Create image datastores for training and test. trainingImages = imds.Files(trainingIdx); valImages = imds.Files(valIdx); testImages = imds.Files(testIdx); imdsTrain = imageDatastore(trainingImages); imdsVal = imageDatastore(valImages); imdsTest = imageDatastore(testImages); % Extract class and label IDs info. classes = pxds.ClassNames; labelIDs = camvidPixelLabelIDs(); % Create pixel label datastores for training and test. trainingLabels = pxds.Files(trainingIdx); valLabels = pxds.Files(valIdx); testLabels = pxds.Files(testIdx); pxdsTrain = pixelLabelDatastore(trainingLabels, classes, labelIDs); pxdsVal = pixelLabelDatastore(valLabels, classes, labelIDs); pxdsTest = pixelLabelDatastore(testLabels, classes, labelIDs); end
function data = augmentImageAndLabel(data, xTrans, yTrans) % Augment images and pixel label images using random reflection and % translation. for i = 1:size(data,1) tform = randomAffine2d(... 'XReflection',true,... 'XTranslation', xTrans, ... 'YTranslation', yTrans); % Center the view at the center of image in the output space while % allowing translation to move the output image out of view. rout = affineOutputView(size(data{i,1}), tform, 'BoundsStyle', 'centerOutput'); % Warp the image and pixel labels using the same transform. data{i,1} = imwarp(data{i,1}, tform, 'OutputView', rout); data{i,2} = imwarp(data{i,2}, tform, 'OutputView', rout); end end
[1] Chen, Liang-Chieh et al. «Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation (неопр.) (недоступная ссылка)». ECCV (2018).
[2] Brostow, G. J., J. Fauqueur, and R. Cipolla. Semantic object classes in video: A high-definition основная истина database (неопр.) (недоступная ссылка). Распознавание Букв. Том 30, Выпуск 2, 2009, стр. 88-97.
imageDataAugmenter
| trainingOptions
| trainNetwork
| countEachLabel
(Computer Vision Toolbox) | evaluateSemanticSegmentation
(Computer Vision Toolbox) | pixelClassificationLayer
(Computer Vision Toolbox) | pixelLabelDatastore
(Computer Vision Toolbox) | pixelLabelImageDatastore
(Computer Vision Toolbox) | segnetLayers
(Computer Vision Toolbox) | semanticseg
(Computer Vision Toolbox) | labeloverlay
(Набор Image Processing Toolbox)