Семантическая сегментация с использованием расширенных сверток

Обучите сеть семантической сегментации с помощью расширенных сверток.

Семантическая сеть сегментации классифицирует каждый пиксель в изображении, получая к изображение, которое сегментировано по классам. Приложения для семантической сегментации включают сегментацию дорог для автономного управления автомобилем и сегментацию раковых камер для медицинского диагностирования. Дополнительные сведения см. в разделе Начало работы с семантической сегментацией с использованием глубокого обучения (Computer Vision Toolbox).

Семантические сети сегментации, такие как DeepLab [1], широко используют расширенные свертки (также известные как атронные свертки), потому что они могут увеличить восприимчивое поле слоя (площадь входа, который могут видеть слои), не увеличивая количество параметров или расчетов.

Загрузка обучающих данных

Пример использует простой набор данных из изображений треугольника 32 на 32 в целях рисунка. Набор данных включает в себя сопровождающие пиксельные достоверные данные. Загрузите обучающие данные с помощью imageDatastore и a pixelLabelDatastore.

dataFolder = fullfile(toolboxdir('vision'),'visiondata','triangleImages');
imageFolderTrain = fullfile(dataFolder,'trainingImages');
labelFolderTrain = fullfile(dataFolder,'trainingLabels');

Создайте imageDatastore для изображений.

imdsTrain = imageDatastore(imageFolderTrain);

Создайте pixelLabelDatastore для меток основной истины пикселей.

classNames = ["triangle" "background"];
labels = [255 0];
pxdsTrain = pixelLabelDatastore(labelFolderTrain,classNames,labels)
pxdsTrain = 
  PixelLabelDatastore with properties:

                       Files: {200x1 cell}
                  ClassNames: {2x1 cell}
                    ReadSize: 1
                     ReadFcn: @readDatastoreImage
    AlternateFileSystemRoots: {}

Создайте сеть семантической сегментации

Этот пример использует простую сеть семантической сегментации, основанную на расширенных сверточках.

Создайте источник данных для обучающих данных и получите количество пикселей для каждой метки.

pximdsTrain = pixelLabelImageDatastore(imdsTrain,pxdsTrain);
tbl = countEachLabel(pxdsTrain)
tbl=2×3 table
         Name         PixelCount    ImagePixelCount
    ______________    __________    _______________

    {'triangle'  }         10326       2.048e+05   
    {'background'}    1.9447e+05       2.048e+05   

Большинство меток пикселей предназначены для фона. Этот классовый дисбаланс смещает процесс обучения в пользу доминирующего класса. Чтобы исправить это, используйте взвешивание классов для балансировки классов. Для вычисления весов классов можно использовать несколько методов. Одним из распространенных методов является обратное взвешивание частот, где веса классов являются обратными частотам классов. Этот метод увеличивает вес, придаваемый недостаточно представленным классам. Вычислите веса классов, используя обратное взвешивание частот.

numberPixels = sum(tbl.PixelCount);
frequency = tbl.PixelCount / numberPixels;
classWeights = 1 ./ frequency;

Создайте сеть для классификации пикселей с помощью входного слоя изображения с размером входа, соответствующим размеру входных изображений. Затем задайте три блока свертки, нормализации партии . и слоев ReLU. Для каждого сверточного слоя задайте 32 фильтра 3 на 3 с увеличением коэффициентов расширения и заполните входы так, чтобы они были такими же размерами, как выходы путем установки 'Padding' опция для 'same'. Чтобы классифицировать пиксели, включите сверточный слой со свертками K 1 на 1, где K - количество классов, далее слой softmax и pixelClassificationLayer с обратными весами классов.

inputSize = [32 32 1];
filterSize = 3;
numFilters = 32;
numClasses = numel(classNames);

layers = [
    imageInputLayer(inputSize)
    
    convolution2dLayer(filterSize,numFilters,'DilationFactor',1,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(filterSize,numFilters,'DilationFactor',2,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(filterSize,numFilters,'DilationFactor',4,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(1,numClasses)
    softmaxLayer
    pixelClassificationLayer('Classes',classNames,'ClassWeights',classWeights)];

Обучите сеть

Задайте опции обучения.

options = trainingOptions('sgdm', ...
    'MaxEpochs', 100, ...
    'MiniBatchSize', 64, ... 
    'InitialLearnRate', 1e-3);

Обучите сеть с помощью trainNetwork.

net = trainNetwork(pximdsTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:00 |       91.62% |       1.6825 |          0.0010 |
|      17 |          50 |       00:00:08 |       88.56% |       0.2393 |          0.0010 |
|      34 |         100 |       00:00:15 |       92.08% |       0.1672 |          0.0010 |
|      50 |         150 |       00:00:23 |       93.17% |       0.1472 |          0.0010 |
|      67 |         200 |       00:00:31 |       94.15% |       0.1313 |          0.0010 |
|      84 |         250 |       00:00:38 |       94.47% |       0.1167 |          0.0010 |
|     100 |         300 |       00:00:46 |       95.04% |       0.1100 |          0.0010 |
|========================================================================================|

Тестирование сети

Загрузите тестовые данные. Создайте imageDatastore для изображений. Создайте pixelLabelDatastore для меток основной истины пикселей.

imageFolderTest = fullfile(dataFolder,'testImages');
imdsTest = imageDatastore(imageFolderTest);
labelFolderTest = fullfile(dataFolder,'testLabels');
pxdsTest = pixelLabelDatastore(labelFolderTest,classNames,labels);

Делайте предсказания с помощью тестовых данных и обученной сети.

pxdsPred = semanticseg(imdsTest,net,'MiniBatchSize',32,'WriteLocation',tempdir);
Running semantic segmentation network
-------------------------------------
* Processed 100 images.

Оцените точность предсказания с помощью evaluateSemanticSegmentation.

metrics = evaluateSemanticSegmentation(pxdsPred,pxdsTest);
Evaluating semantic segmentation results
----------------------------------------
* Selected metrics: global accuracy, class accuracy, IoU, weighted IoU, BF score.
* Processed 100 images.
* Finalizing... Done.
* Data set metrics:

    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.95237          0.97352       0.72081      0.92889        0.46416  

Для получения дополнительной информации об оценке семантических сетей сегментации смотрите evaluateSemanticSegmentation (Computer Vision Toolbox).

Сегментация нового изображения

Чтение и отображение тестового изображения triangleTest.jpg.

imgTest = imread('triangleTest.jpg');
figure
imshow(imgTest)

Figure contains an axes. The axes contains an object of type image.

Сегментируйте тестовое изображение с помощью semanticseg и отобразите результаты с помощью labeloverlay.

C = semanticseg(imgTest,net);
B = labeloverlay(imgTest,C);
figure
imshow(B)

Figure contains an axes. The axes contains an object of type image.

См. также

| | | (Computer Vision Toolbox) | (Computer Vision Toolbox) | (Computer Vision Toolbox) | (Computer Vision Toolbox) | (Computer Vision Toolbox) | (Computer Vision Toolbox) | (Набор Image Processing Toolbox)

Похожие темы