Семантическая Сегментация Используя глубокое обучение

В этом примере показано, как обучить сеть семантической сегментации использование глубокого обучения.

Семантическая сеть сегментации классифицирует каждый пиксель в изображении, получая к изображение, которое сегментировано по классам. Приложения для семантической сегментации включают сегментацию дорог для автономного управления автомобилем и сегментацию раковой клетки для медицинского диагностирования. Чтобы узнать больше, смотрите Начало работы с Семантической Сегментацией Используя Глубокое обучение.

Чтобы проиллюстрировать метод обучения, этот пример обучает Deeplab v3 + [1], один тип сверточной нейронной сети (CNN), спроектированной для семантической сегментации изображений. Другие типы сетей для семантической сегментации включают полностью сверточные сети (FCN), SegNet и U-Net. Метод обучения, показанный здесь, может быть применен к тем сетям также.

Этот пример использует набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных является набором изображений, содержащих представления уличного уровня, полученные при управлении. Набор данных обеспечивает метки пиксельного уровня для 32 семантических классов включая автомобиль, пешехода и дорогу.

Настройка

Этот пример создает Deeplab v3 + сеть с весами, инициализированными от предварительно обученной сети Resnet-18. ResNet-18 является эффективной сетью, которая хорошо подходит для приложений с ограниченными ресурсами обработки. Другие предварительно обученные сети, такие как MobileNet v2 или ResNet-50 могут также использоваться в зависимости от требований к приложению. Для получения дополнительной информации смотрите Предварительно обученные Глубокие нейронные сети (Deep Learning Toolbox).

Чтобы получить предварительно обученный Resnet-18, установите Модель Deep Learning Toolbox™ для Сети Resnet-18. После того, как установка завершена, запустите следующий код, чтобы проверить, что установка правильна.

resnet18();

Кроме того, загрузите предварительно обученную версию DeepLab v3 +. Предварительно обученная модель позволяет вам запускать целый пример, не имея необходимость ожидать обучения завершиться.

pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/deeplabv3plusResnet18CamVid.mat';
pretrainedFolder = fullfile(tempdir,'pretrainedNetwork');
pretrainedNetwork = fullfile(pretrainedFolder,'deeplabv3plusResnet18CamVid.mat'); 
if ~exist(pretrainedNetwork,'file')
    mkdir(pretrainedFolder);
    disp('Downloading pretrained network (58 MB)...');
    websave(pretrainedNetwork,pretrainedURL);
end

Для выполнения этого примера настоятельно рекомендуется использовать графический процессор NVIDIA с поддержкой CUDA и вычислительными возможностями 3.0 или выше. Для использования GPU требуется Parallel Computing Toolbox.

Загрузите набор данных CamVid

Загрузите набор данных CamVid со следующих URL.

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

outputFolder = fullfile(tempdir,'CamVid'); 
labelsZip = fullfile(outputFolder,'labels.zip');
imagesZip = fullfile(outputFolder,'images.zip');

if ~exist(labelsZip, 'file') || ~exist(imagesZip,'file')   
    mkdir(outputFolder)
       
    disp('Downloading 16 MB CamVid dataset labels...'); 
    websave(labelsZip, labelURL);
    unzip(labelsZip, fullfile(outputFolder,'labels'));
    
    disp('Downloading 557 MB CamVid dataset images...');  
    websave(imagesZip, imageURL);       
    unzip(imagesZip, fullfile(outputFolder,'images'));    
end

Примечание: Время загрузки данных зависит от вашего Интернет-соединения. Команды, используемые выше, блокируют MATLAB до завершения загрузки. В качестве альтернативы можно использовать веб-браузер, чтобы сначала загрузить набор данных на локальный диск. Чтобы использовать файл, вы загрузили с сети, измените outputFolder переменная выше к местоположению загруженного файла.

Загрузите изображения CamVid

Используйте imageDatastore загружать изображения CamVid. imageDatastore позволяет вам эффективно загрузить большое количество изображений на диске.

imgDir = fullfile(outputFolder,'images','701_StillsRaw_full');
imds = imageDatastore(imgDir);

Отобразите одно из изображений.

I = readimage(imds,1);
I = histeq(I);
imshow(I)

Загрузка Изображений с Пиксельной Маркировкой CamVid

Используйте pixelLabelDatastore чтобы загрузить пиксель CamVid помечают данные изображения. pixelLabelDatastore инкапсулирует данные о пиксельных метках и метку ID к отображению имени класса.

Мы делаем обучение легче, мы группируем 32 исходных класса в CamVid к 11 классам. Задайте эти классы.

classes = [
    "Sky"
    "Building"
    "Pole"
    "Road"
    "Pavement"
    "Tree"
    "SignSymbol"
    "Fence"
    "Car"
    "Pedestrian"
    "Bicyclist"
    ];

Чтобы уменьшать 32 класса в 11, несколько классов от исходного набора данных группируются. Например, "Автомобиль" является комбинацией "Автомобиля", "SUVPickupTruck", "Truck_Bus", "Обучаются", и "OtherMoving". Возвратите сгруппированную метку IDs при помощи функции поддержки camvidPixelLabelIDs, который перечислен в конце этого примера.

labelIDs = camvidPixelLabelIDs();

Используйте идентификаторы классов и меток, чтобы создать pixelLabelDatastore.

labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

Считайте и отобразите одно из помеченных пикселем изображений путем накладывания его сверху изображения.

C = readimage(pxds,1);
cmap = camvidColorMap;
B = labeloverlay(I,C,'ColorMap',cmap);
imshow(B)
pixelLabelColorbar(cmap,classes);

Области без перекрытия цвета не имеют пиксельных меток и не используются во время обучения.

Анализируйте статистику набора данных

Чтобы видеть распределение меток класса в наборе данных CamVid, используйте countEachLabel. Эта функция считает количество пикселей меткой класса.

tbl = countEachLabel(pxds)
tbl=11×3 table
         Name         PixelCount    ImagePixelCount
    ______________    __________    _______________

    {'Sky'       }    7.6801e+07      4.8315e+08   
    {'Building'  }    1.1737e+08      4.8315e+08   
    {'Pole'      }    4.7987e+06      4.8315e+08   
    {'Road'      }    1.4054e+08      4.8453e+08   
    {'Pavement'  }    3.3614e+07      4.7209e+08   
    {'Tree'      }    5.4259e+07       4.479e+08   
    {'SignSymbol'}    5.2242e+06      4.6863e+08   
    {'Fence'     }    6.9211e+06       2.516e+08   
    {'Car'       }    2.4437e+07      4.8315e+08   
    {'Pedestrian'}    3.4029e+06      4.4444e+08   
    {'Bicyclist' }    2.5912e+06      2.6196e+08   

Визуализируйте счетчики пикселей по классам.

frequency = tbl.PixelCount/sum(tbl.PixelCount);

bar(1:numel(classes),frequency)
xticks(1:numel(classes)) 
xticklabels(tbl.Name)
xtickangle(45)
ylabel('Frequency')

Идеально, все классы имели бы равное количество наблюдений. Однако классы в CamVid являются неустойчивыми, который является распространенной проблемой в автомобильных наборах данных уличных сцен. Такие сцены имеют больше неба, создания и дорожных пикселей, чем пиксели пешехода и велосипедиста, потому что небо, создания и дороги покрывают больше области в изображении. Если не обработанный правильно, эта неустойчивость может быть вредна для процесса обучения, потому что изучение смещается в пользу доминирующих классов. Позже в этом примере, вы будете использовать взвешивание класса, чтобы обработать эту проблему.

Изображения в наборе данных CamVid имеют размер 720 на 960. Размер изображения выбран таким образом, что достаточно большой пакет изображений может уместиться в памяти во время обучения на Титане NVIDIA™ X с 12 Гбайт памяти. Вы, возможно, должны будете изменить размер изображений на меньшие величины, если ваш графический процессор не имеет достаточной памяти, либо уменьшить размер обучающего пакета.

Подготовка учебных, проверочных и тестовых наборов

Deeplab v3 + обучен с помощью 60% изображений от набора данных. Остальная часть изображений разделена равномерно на части в 20% и 20% для валидации и тестирования соответственно. Следующий код случайным образом разделяет изображение и данные о пиксельных метках на наборы для обучения, валидации и тестирования.

[imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds);

Соотношение 60/20/20 разделяет результаты на следующие количества изображений для обучения, валидации и тестирования:

numTrainingImages = numel(imdsTrain.Files)
numTrainingImages = 421
numValImages = numel(imdsVal.Files)
numValImages = 140
numTestingImages = numel(imdsTest.Files)
numTestingImages = 140

Создайте сеть

Используйте deeplabv3plusLayers функция, чтобы создать DeepLab v3 + сеть на основе ResNet-18. Выбор лучшей сети для вашего приложения требует эмпирического анализа и является другим уровнем настройки гиперпараметра. Например, можно экспериментировать с различными основными сетями, такими как ResNet-50 или MobileNet v2, или можно попробовать другие архитектуры сети семантической сегментации, такие как SegNet, полностью сверточные сети (FCN) или U-Net.

% Specify the network image size. This is typically the same as the traing image sizes.
imageSize = [720 960 3];

% Specify the number of classes.
numClasses = numel(classes);

% Create DeepLab v3+.
lgraph = deeplabv3plusLayers(imageSize, numClasses, "resnet18");

Сбалансируйте классы Используя взвешивание класса

Как показано ранее классы в CamVid не сбалансированы. Чтобы улучшить обучение, можно использовать взвешивание класса, чтобы сбалансировать классы. Используйте пиксельные количества метки, вычисленные ранее с countEachLabel и вычислите веса класса медианной частоты.

imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq
classWeights = 11×1

    0.3182
    0.2082
    5.0924
    0.1744
    0.7103
    0.4175
    4.5371
    1.8386
    1.0000
    6.6059
      ⋮

Задайте веса класса с помощью pixelClassificationLayer.

pxLayer = pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights);
lgraph = replaceLayer(lgraph,"classification",pxLayer);

Выберите Training Options

Алгоритм оптимизации, используемый в обучении, является стохастическим градиентным спуском с импульсом (SGDM). Используйте trainingOptions задавать гиперпараметры, используемые в SGDM.

% Define validation data.
pximdsVal = pixelLabelImageDatastore(imdsVal,pxdsVal);

% Define training options. 
options = trainingOptions('sgdm', ...
    'LearnRateSchedule','piecewise',...
    'LearnRateDropPeriod',10,...
    'LearnRateDropFactor',0.3,...
    'Momentum',0.9, ...
    'InitialLearnRate',1e-3, ...
    'L2Regularization',0.005, ...
    'ValidationData',pximdsVal,...
    'MaxEpochs',30, ...  
    'MiniBatchSize',8, ...
    'Shuffle','every-epoch', ...
    'CheckpointPath', tempdir, ...
    'VerboseFrequency',2,...
    'Plots','training-progress',...
    'ValidationPatience', 4);

Скорость обучения использует кусочное расписание. Темп обучения уменьшается на множитель 0,3 каждые 10 эпох. Это позволяет сети обучаться быстро с более высоким начальным темпом обучения, и в то же время сохраняется способность найти решение, близкое к локальному оптимуму, когда темп обучения понижается.

Сеть тестируется против данных о валидации каждая эпоха путем установки 'ValidationData' параметр. 'ValidationPatience' собирается в 4 остановить обучение рано, когда точность валидации сходится. Это препятствует избыточному обучению сети на обучающем наборе данных.

Мини-пакет размером 8 используется для уменьшения использования памяти во время обучения. Вы можете увеличить или уменьшить это значение в зависимости от объема памяти GPU, имеющейся в вашей системе.

Кроме того, 'CheckpointPath' установлен во временное местоположение. Эта пара "имя-значение" включает сохранение сетевых контрольных точек в конце каждой учебной эпохи. Если обучение прервано из-за системного отказа или отключения электроэнергии, можно возобновить обучение с сохраненной контрольной точки. Убедитесь что местоположение, заданное 'CheckpointPath' имеет достаточно пробела, чтобы сохранить сетевые контрольные точки. Например, сохраняя 100 Deeplab v3 + контрольные точки требуют ~6 Гбайт дискового пространства, потому что каждая контрольная точка составляет 61 Мбайт.

Увеличение данных

Увеличение количества данных используется во время обучения, чтобы предоставить сети больше примеров, потому что это помогает улучшить точность сети. Здесь для увеличения данных используются случайное отражение "слева/справа" и случайный преобразование X/Y на +/-10 пикселей. Используйте imageDataAugmenter задавать эти параметры увеличения данных.

augmenter = imageDataAugmenter('RandXReflection',true,...
    'RandXTranslation',[-10 10],'RandYTranslation',[-10 10]);

imageDataAugmenter поддержки несколько других типов увеличения данных. Выбор среди них требует эмпирического анализа и является другим уровнем настройки гиперпараметра.

Запустите обучение

Объедините обучающие данные и выборки для увеличения данных с помощью pixelLabelImageDatastore. pixelLabelImageDatastore пакеты чтений обучающих данных, применяет увеличение данных и отправляет увеличенные данные в алгоритм настройки.

pximds = pixelLabelImageDatastore(imdsTrain,pxdsTrain, ...
    'DataAugmentation',augmenter);

Запустите обучение с помощью trainNetwork если doTraining флаг верен. В противном случае загружает предварительно обученную сеть.

Примечание: обучение было проверено на Титане NVIDIA™ X с 12 Гбайт памяти графического процессора. Если ваш графический процессор имеет меньше памяти, у можно закончиться память во время обучения. Если это происходит, попробуйте установку 'MiniBatchSize' к 1 в trainingOptions, или сокращение сетевого входа и изменение размеров обучающих данных с помощью 'OutputSize' параметр pixelLabelImageDatastore. Обучение эта сеть занимает приблизительно 5 часов. В зависимости от вашего оборудования графического процессора это может взять еще дольше.

doTraining = false;
if doTraining    
    [net, info] = trainNetwork(pximds,lgraph,options);
else
    data = load(pretrainedNetwork); 
    net = data.net;
end

Тестовая сеть на одном изображении

Как быстрая проверка работоспособности, запустите обучивший сеть на одном тестовом изображении.

I = readimage(imdsTest,35);
C = semanticseg(I, net);

Отобразите результаты.

B = labeloverlay(I,C,'Colormap',cmap,'Transparency',0.4);
imshow(B)
pixelLabelColorbar(cmap, classes);

Сравните результаты в C с ожидаемой основной истиной, сохраненной в pxdsTest. Зеленые и пурпурные области подсвечивают области, где результаты сегментации отличаются от ожидаемой основной истины.

expectedResult = readimage(pxdsTest,35);
actual = uint8(C);
expected = uint8(expectedResult);
imshowpair(actual, expected)

Визуально, результаты семантической сегментации перекрываются хорошо для классов, таких как дорога, небо и здание. Однако меньшие объекты как пешеходы и автомобили не так точны. Сумма перекрытия для класса может быть измерена с помощью метрики пересечения по объединению (IoU), также известной как индекс Jaccard. Используйте jaccard функционируйте, чтобы измерить IoU.

iou = jaccard(C,expectedResult);
table(classes,iou)
ans=11×2 table
      classes         iou  
    ____________    _______

    "Sky"           0.91837
    "Building"      0.84479
    "Pole"          0.31203
    "Road"          0.93698
    "Pavement"      0.82838
    "Tree"          0.89636
    "SignSymbol"    0.57644
    "Fence"         0.71046
    "Car"           0.66688
    "Pedestrian"    0.48417
    "Bicyclist"     0.68431

Метрика IoU подтверждает визуальные результаты. Классы "Дорога", "Небо" и "Здания" имеют высокие очки IoU, в то время как такие классы, как "Пешеход" и "Автомобиль" имеют низкие баллы. Другие общие метрики сегментации включают dice и bfscore очертите соответствие со счетом.

Оцените обученную сеть

Измерять точность для нескольких тестовых изображений, runsemanticseg на целом наборе тестов. Размер мини-пакета, равный 4, используется, чтобы уменьшать использование памяти при сегментации изображений. Вы можете увеличить или уменьшить это значение в зависимости от объема памяти GPU, имеющейся в вашей системе.

pxdsResults = semanticseg(imdsTest,net, ...
    'MiniBatchSize',4, ...
    'WriteLocation',tempdir, ...
    'Verbose',false);

semanticseg возвращает результаты для набора тестов как pixelLabelDatastore объект. Данные о метке фактического пикселя для каждого теста отображают в imdsTest записан в диск в месте, заданном 'WriteLocation' параметр. Используйте evaluateSemanticSegmentation измерять метрики семантической сегментации на результатах набора тестов.

metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest,'Verbose',false);

evaluateSemanticSegmentation возвращает различные метрики для набора данных в целом, для отдельных классов, и для каждого тестового изображения. Чтобы видеть метрики уровня набора данных, смотрите metrics.DataSetMetrics .

metrics.DataSetMetrics
ans=1×5 table
    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.87695          0.85392       0.6302       0.80851        0.65051  

Метрики набора данных предоставляют общий обзор производительности сети. Чтобы увидеть влияние каждого класса на общую производительности, смотрите метрики по классам с помощью metrics.ClassMetrics.

metrics.ClassMetrics
ans=11×3 table
                  Accuracy      IoU      MeanBFScore
                  ________    _______    ___________

    Sky           0.93112     0.90209       0.8952  
    Building      0.78453     0.76098      0.58511  
    Pole          0.71586     0.21477      0.51439  
    Road          0.93024     0.91465      0.76696  
    Pavement      0.88466     0.70571      0.70919  
    Tree          0.87377     0.76323      0.70875  
    SignSymbol    0.79358     0.39309      0.48302  
    Fence         0.81507     0.46484      0.48564  
    Car           0.90956     0.76799      0.69233  
    Pedestrian    0.87629      0.4366      0.60792  
    Bicyclist     0.87844     0.60829      0.55089  

Несмотря на то, что полная производительность набора данных довольно высока, метрики класса показывают что недостаточно представленные классы, такие как Pedestrian, Bicyclist, и Car не сегментируются, а также классы, такие как Road, Sky, и Building. Дополнительные данные, которые включают больше выборок недостаточно представленных классов, могут помочь улучшить результаты.

Вспомогательные Функции

function labelIDs = camvidPixelLabelIDs()
% Return the label IDs corresponding to each class.
%
% The CamVid dataset has 32 classes. Group them into 11 classes following
% the original SegNet training methodology [1].
%
% The 11 classes are:
%   "Sky" "Building", "Pole", "Road", "Pavement", "Tree", "SignSymbol",
%   "Fence", "Car", "Pedestrian",  and "Bicyclist".
%
% CamVid pixel label IDs are provided as RGB color values. Group them into
% 11 classes and return them as a cell array of M-by-3 matrices. The
% original CamVid class names are listed alongside each RGB value. Note
% that the Other/Void class are excluded below.
labelIDs = { ...
    
    % "Sky"
    [
    128 128 128; ... % "Sky"
    ]
    
    % "Building" 
    [
    000 128 064; ... % "Bridge"
    128 000 000; ... % "Building"
    064 192 000; ... % "Wall"
    064 000 064; ... % "Tunnel"
    192 000 128; ... % "Archway"
    ]
    
    % "Pole"
    [
    192 192 128; ... % "Column_Pole"
    000 000 064; ... % "TrafficCone"
    ]
    
    % Road
    [
    128 064 128; ... % "Road"
    128 000 192; ... % "LaneMkgsDriv"
    192 000 064; ... % "LaneMkgsNonDriv"
    ]
    
    % "Pavement"
    [
    000 000 192; ... % "Sidewalk" 
    064 192 128; ... % "ParkingBlock"
    128 128 192; ... % "RoadShoulder"
    ]
        
    % "Tree"
    [
    128 128 000; ... % "Tree"
    192 192 000; ... % "VegetationMisc"
    ]
    
    % "SignSymbol"
    [
    192 128 128; ... % "SignSymbol"
    128 128 064; ... % "Misc_Text"
    000 064 064; ... % "TrafficLight"
    ]
    
    % "Fence"
    [
    064 064 128; ... % "Fence"
    ]
    
    % "Car"
    [
    064 000 128; ... % "Car"
    064 128 192; ... % "SUVPickupTruck"
    192 128 192; ... % "Truck_Bus"
    192 064 128; ... % "Train"
    128 064 064; ... % "OtherMoving"
    ]
    
    % "Pedestrian"
    [
    064 064 000; ... % "Pedestrian"
    192 128 064; ... % "Child"
    064 000 192; ... % "CartLuggagePram"
    064 128 064; ... % "Animal"
    ]
    
    % "Bicyclist"
    [
    000 128 192; ... % "Bicyclist"
    192 000 192; ... % "MotorcycleScooter"
    ]
    
    };
end
function pixelLabelColorbar(cmap, classNames)
% Add a colorbar to the current axis. The colorbar is formatted
% to display the class names with the color.

colormap(gca,cmap)

% Add colorbar to current figure.
c = colorbar('peer', gca);

% Use class names for tick marks.
c.TickLabels = classNames;
numClasses = size(cmap,1);

% Center tick labels.
c.Ticks = 1/(numClasses*2):1/numClasses:1;

% Remove tick mark.
c.TickLength = 0;
end
function cmap = camvidColorMap()
% Define the colormap used by CamVid dataset.

cmap = [
    128 128 128   % Sky
    128 0 0       % Building
    192 192 192   % Pole
    128 64 128    % Road
    60 40 222     % Pavement
    128 128 0     % Tree
    192 128 128   % SignSymbol
    64 64 128     % Fence
    64 0 128      % Car
    64 64 0       % Pedestrian
    0 128 192     % Bicyclist
    ];

% Normalize between [0 1].
cmap = cmap ./ 255;
end
function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionCamVidData(imds,pxds)
% Partition CamVid data by randomly selecting 60% of the data for training. The
% rest is used for testing.
    
% Set initial random state for example reproducibility.
rng(0); 
numFiles = numel(imds.Files);
shuffledIndices = randperm(numFiles);

% Use 60% of the images for training.
numTrain = round(0.60 * numFiles);
trainingIdx = shuffledIndices(1:numTrain);

% Use 20% of the images for validation
numVal = round(0.20 * numFiles);
valIdx = shuffledIndices(numTrain+1:numTrain+numVal);

% Use the rest for testing.
testIdx = shuffledIndices(numTrain+numVal+1:end);

% Create image datastores for training and test.
trainingImages = imds.Files(trainingIdx);
valImages = imds.Files(valIdx);
testImages = imds.Files(testIdx);

imdsTrain = imageDatastore(trainingImages);
imdsVal = imageDatastore(valImages);
imdsTest = imageDatastore(testImages);

% Extract class and label IDs info.
classes = pxds.ClassNames;
labelIDs = camvidPixelLabelIDs();

% Create pixel label datastores for training and test.
trainingLabels = pxds.Files(trainingIdx);
valLabels = pxds.Files(valIdx);
testLabels = pxds.Files(testIdx);

pxdsTrain = pixelLabelDatastore(trainingLabels, classes, labelIDs);
pxdsVal = pixelLabelDatastore(valLabels, classes, labelIDs);
pxdsTest = pixelLabelDatastore(testLabels, classes, labelIDs);
end

Ссылки

[1] Chen, Liang-Chieh et al. “Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation” ECCV (2018).

[2] Brostow, G. J., J. Fauqueur, and R. Cipolla. "Semantic object classes in video: A high-definition ground truth database." Pattern Recognition Letters. Vol. 30, Issue 2, 2009, стр 88-97.

Смотрите также

| | | | | | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте