Обучите и разверните полностью Сверточные сети для семантической сегментации

Этот пример показывает, как обучить и развернуть полностью сверточную семантическую сеть сегментации на графическом процессоре NVIDIA® при помощи GPU Coder ™.

Семантическая сеть сегментации классифицирует каждый пиксель на изображение, приводящее к изображению, которое сегментируется классом. Приложения для семантической сегментации включают дорожную сегментацию для автономного управления и сегментацию раковой клетки для медицинского диагноза. Чтобы узнать больше, смотрите Семантические Основы Сегментации (Computer Vision Toolbox).

Чтобы проиллюстрировать метод обучения, этот пример обучает FCN-8s [1], один тип сверточной нейронной сети (CNN), разработанной для семантической сегментации изображений. Другие сети типов для семантической сегментации включают полностью сверточные сети SegNet и U-Net. Метод обучения, показанный здесь, может быть применен к тем сетям также.

Этот пример использует набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных является набором изображений, содержащих представления уличного уровня, полученные при управлении. Набор данных обеспечивает метки пиксельного уровня для 32 семантических классов включая автомобиль, пешехода и дорогу.

Предпосылки

  • CUDA® включил NVIDIA, графический процессор с вычисляет возможность 3.2 или выше.

  • NVIDIA инструментарий CUDA и драйвер.

  • Библиотека NVIDIA cuDNN (v7 и выше).

  • Deep Learning Toolbox™, чтобы использовать Сетевой объект DAG.

  • Parallel Computing Toolbox™

  • Image Processing Toolbox™ для чтения и отображения изображений.

  • Computer Vision Toolbox™ для функции labeloverlay используется в этом примере.

  • GPU Coder для генерации кода CUDA.

  • Интерфейс GPU Coder для Библиотек Глубокого обучения поддерживает пакет. Чтобы установить этот пакет поддержки, используйте Add-On Explorer.

  • Модель Deep Learning Toolbox для пакета Сетевой поддержки VGG-16. Чтобы установить этот пакет поддержки, см. Модель Deep Learning Toolbox™ для Сети VGG-16.

  • Переменные окружения для компиляторов и библиотек. Для получения информации о поддерживаемых версиях компиляторов и библиотек, смотрите Сторонние продукты (GPU Coder). Для подготовки переменных окружения смотрите Подготовку Необходимых как условие продуктов (GPU Coder).

Проверьте среду графического процессора

Используйте coder.checkGpuInstall, функционируют и проверяют, что компиляторы и библиотеки, необходимые для выполнения этого примера, настраиваются правильно.

coder.checkGpuInstall('gpu','codegen','cudnn','quiet');

Настройка

Этот пример создает полностью сверточную семантическую сеть сегментации с весами, инициализированными от сети VGG-16. Функция vgg16 проверяет на существование Модели Deep Learning Toolbox для пакета Сетевой поддержки VGG-16 и возвращает предварительно обученную модель VGG-16.

vgg16();

Кроме того, загрузите предварительно обученную версию FCN. Эта предварительно обученная модель позволяет вам запускать целый пример, не имея необходимость ожидать обучения завершиться. doTraining отмечают средства управления, использует ли пример обучивший сеть из примера или предварительно обученной сети FCN для генерации кода.

doTraining = false;
if ~doTraining
    pretrainedURL = 'https://www.mathworks.com/supportfiles/gpucoder/cnn_models/fcn/FCN8sCamVid.mat';
    disp('Downloading pretrained FCN (448 MB)...');
    websave('FCN8sCamVid.mat',pretrainedURL);
end
Downloading pretrained FCN (448 MB)...

Загрузите набор данных CamVid

Загрузите набор данных CamVid со следующих URL.

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

outputFolder = fullfile(pwd,'CamVid');

if ~exist(outputFolder, 'dir')
   
    mkdir(outputFolder)
    labelsZip = fullfile(outputFolder,'labels.zip');
    imagesZip = fullfile(outputFolder,'images.zip');   
    
    disp('Downloading 16 MB CamVid dataset labels...'); 
    websave(labelsZip, labelURL);
    unzip(labelsZip, fullfile(outputFolder,'labels'));
    
    disp('Downloading 557 MB CamVid dataset images...');  
    websave(imagesZip, imageURL);       
    unzip(imagesZip, fullfile(outputFolder,'images'));    
end

Примечание: время загрузки данных зависит от вашего Интернет-соединения. Выполнение в качестве примера не продолжает, пока операция загрузки не завершена. Также можно использовать веб-браузер, чтобы сначала загрузить набор данных на локальный диск. Затем используйте outputFolder переменную, чтобы указать на местоположение загруженного файла.

Загрузите изображения CamVid

Используйте imageDatastore, чтобы загрузить изображения CamVid. imageDatastore позволяет вам эффективно загрузить большое количество изображений на диске.

imgDir = fullfile(outputFolder,'images','701_StillsRaw_full');
imds = imageDatastore(imgDir);

Отобразите одно из изображений.

I = readimage(imds,25);
I = histeq(I);
imshow(I)

Загрузите CamVid маркированные пикселем изображения

Используйте pixelLabelDatastore, чтобы загрузить пиксельные данные изображения метки CamVid. pixelLabelDatastore инкапсулирует пиксельные данные о метке и метку ID к отображению имени класса.

Выполняя процедуру, используемую в исходной газете SegNet [3], сгруппируйте 32 исходных класса в CamVid к 11 классам. Задайте эти классы.

classes = [
    "Sky"
    "Building"
    "Pole"
    "Road"
    "Pavement"
    "Tree"
    "SignSymbol"
    "Fence"
    "Car"
    "Pedestrian"
    "Bicyclist"
    ];

Чтобы уменьшать 32 класса в 11, несколько классов от исходного набора данных группируются. Например, "Автомобиль" является комбинацией "Автомобиля", "SUVPickupTruck", "Truck_Bus", "Train" и "OtherMoving". Возвратите сгруппированную метку IDs при помощи camvidPixelLabelIDs поддерживающий функции.

labelIDs = camvidPixelLabelIDs();

Используйте классы и метку IDs, чтобы создать pixelLabelDatastore.

labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

Считайте и отобразите одно из маркированных пикселем изображений путем накладывания его сверху изображения.

C = readimage(pxds,25);

cmap = camvidColorMap;

B = labeloverlay(I,C,'ColorMap',cmap);
imshow(B)
pixelLabelColorbar(cmap,classes);

Области без перекрытия цвета не имеют пиксельных меток и не используются во время обучения.

Анализируйте статистику набора данных

Чтобы видеть распределение меток класса в наборе данных CamVid, используйте countEachLabel. Эта функция считает количество пикселей меткой класса.

tbl = countEachLabel(pxds)
tbl=11×3 table
        Name        PixelCount    ImagePixelCount
    ____________    __________    _______________

    'Sky'           7.6801e+07      4.8315e+08   
    'Building'      1.1737e+08      4.8315e+08   
    'Pole'          4.7987e+06      4.8315e+08   
    'Road'          1.4054e+08      4.8453e+08   
    'Pavement'      3.3614e+07      4.7209e+08   
    'Tree'          5.4259e+07       4.479e+08   
    'SignSymbol'    5.2242e+06      4.6863e+08   
    'Fence'         6.9211e+06       2.516e+08   
    'Car'           2.4437e+07      4.8315e+08   
    'Pedestrian'    3.4029e+06      4.4444e+08   
    'Bicyclist'     2.5912e+06      2.6196e+08   

Визуализируйте пиксельные количества классом.

frequency = tbl.PixelCount/sum(tbl.PixelCount);

bar(1:numel(classes),frequency)
xticks(1:numel(classes)) 
xticklabels(tbl.Name)
xtickangle(45)
ylabel('Frequency')

Идеально, все классы имели бы равное количество наблюдений. Однако классы в CamVid являются неустойчивыми, который является распространенной проблемой в автомобильных наборах данных уличных сцен. Такие сцены имеют больше неба, создания и дорожных пикселей, чем пиксели пешехода и велосипедиста, потому что небо, создания и дороги покрывают больше области в изображении. Если не обработанный правильно, эта неустойчивость может быть вредна для процесса обучения, потому что изучение смещается в пользу доминирующих классов. Позже в этом примере, вы будете использовать взвешивание класса, чтобы обработать эту проблему.

Измените размер данных CamVid

Изображения в наборе данных CamVid 720 960. Чтобы уменьшать учебное время и использование памяти, измените размер изображений и пиксельных изображений метки к 360 480. Эти операции выполняются при помощи resizeCamVidImages и resizeCamVidPixelLabels поддерживающий функций.

imageFolder = fullfile(outputFolder,'imagesResized',filesep);
imds = resizeCamVidImages(imds,imageFolder);

labelFolder = fullfile(outputFolder,'labelsResized',filesep);
pxds = resizeCamVidPixelLabels(pxds,labelFolder);

Подготовьте наборы обучающих данных и наборы тестов

SegNet обучен с помощью 60% изображений от набора данных. Остальная часть изображений используется для тестирования. Следующий код случайным образом разделяет изображение и пиксельные данные о метке в набор обучающих данных и набор тестов.

[imdsTrain,imdsTest,pxdsTrain,pxdsTest] = partitionCamVidData(imds,pxds);

60/40 разделяют результаты в следующем количестве обучения и тестируют изображения:

numTrainingImages = numel(imdsTrain.Files)
numTrainingImages = 421
numTestingImages = numel(imdsTest.Files)
numTestingImages = 280

Создайте сеть

Используйте fcnLayers, чтобы создать инициализированное использование полностью сверточных сетевых слоев веса VGG-16. fcnLayers автоматически выполняет трансформации сетей, должен был передать веса от VGG-16 и добавляет дополнительные слои, требуемые для семантической сегментации. Вывод fcnLayers является объектом LayerGraph, представляющим FCN. Объект LayerGraph инкапсулирует сетевые слои и связи между слоями.

imageSize = [360 480];
numClasses = numel(classes);
lgraph = fcnLayers(imageSize,numClasses);

Размер изображения выбран на основе размера изображений в наборе данных. Количество классов выбрано на основе классов в CamVid.

Сбалансируйте классы Используя взвешивание класса

Классы в CamVid не сбалансированы. Чтобы улучшить обучение, можно использовать пиксельные количества метки, вычисленные ранее с countEachLabel, и вычислить средние веса класса частоты [3].

imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq;

Задайте веса класса с помощью pixelClassificationLayer.

pxLayer = pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights)
pxLayer = 
  PixelClassificationLayer with properties:

            Name: 'labels'
         Classes: [11×1 categorical]
    ClassWeights: [11×1 double]
      OutputSize: 'auto'

   Hyperparameters
    LossFunction: 'crossentropyex'

Обновите сеть SegNet с новым pixelClassificationLayer путем удаления текущего pixelClassificationLayer и добавления нового слоя. Текущий pixelClassificationLayer называют 'pixelLabels'. Удалите его с помощью removeLayers, добавьте новый с помощью addLayers и соедините новый слой с остальной частью сети с помощью connectLayers.

lgraph = removeLayers(lgraph,'pixelLabels');
lgraph = addLayers(lgraph, pxLayer);
lgraph = connectLayers(lgraph,'softmax','labels');

Выберите Training Options

Алгоритмом оптимизации, используемым для обучения, является Адам (выведенный от адаптивной оценки момента). Используйте trainingOptions, чтобы задать гиперпараметры, используемые для Адама.

options = trainingOptions('adam', ...
    'InitialLearnRate',1e-3, ...
    'MaxEpochs',100, ...  
    'MiniBatchSize',4, ...
    'Shuffle','every-epoch', ...
    'CheckpointPath', tempdir, ...
    'VerboseFrequency',2);

'MiniBatchSize' 4 используется, чтобы уменьшать использование памяти в то время как обучение. Можно увеличить или уменьшить это значение на основе суммы памяти графического процессора, которую вы имеете в своей системе.

Кроме того, 'CheckpointPath' установлен во временное местоположение. Эта пара "имя-значение" включает сохранение сетевых контрольных точек в конце каждой учебной эпохи. Если обучение прервано из-за системного отказа или отключения электроэнергии, можно возобновить обучение от сохраненной контрольной точки. Убедитесь, что местоположение, заданное 'CheckpointPath', имеет достаточно пробела, чтобы сохранить сетевые контрольные точки.

Увеличение данных

Увеличение данных используется во время обучения предоставить больше примеров сети, потому что это помогает улучшить точность сети. Здесь, случайный слева/справа отражение и случайный перевод X/Y +/-10 пикселей используются для увеличения данных. Используйте imageDataAugmenter, чтобы задать эти параметры увеличения данных.

augmenter = imageDataAugmenter('RandXReflection',true,...
    'RandXTranslation',[-10 10],'RandYTranslation',[-10 10]);

imageDataAugmenter поддерживает несколько других типов увеличения данных. Выбор среди них требует эмпирического анализа и является другим уровнем настройки гиперпараметра.

Запустите обучение

Объедините данные тренировки и выборы увеличения данных с помощью pixelLabelImageDatastore. Пакеты чтений pixelLabelImageDatastore данных тренировки, применяет увеличение данных и отправляет увеличенные данные в учебный алгоритм.

pximds = pixelLabelImageDatastore(imdsTrain,pxdsTrain, ...
    'DataAugmentation',augmenter);

Если флаг doTraining верен, запустите обучение при помощи trainNetwork.

Примечание: обучение было проверено на Титане NVIDIA™ Xp с 12 Гбайт памяти графического процессора. Если ваш графический процессор имеет меньше памяти, у можно закончиться память. Если это происходит, попытайтесь понизить свойство MiniBatchSize в trainingOptions к 1. Обучение эта сеть занимает приблизительно 5 часов. В зависимости от вашего оборудования графического процессора это может взять еще дольше.

if doTraining    
    [net, info] = trainNetwork(pximds,lgraph,options);
    save('FCN8sCamVid.mat','net');
end

Сохраните сетевой объект DAG в файл с именем MAT FCN8sCamVid.mat. Этот файл MAT используется во время генерации кода.

Выполните генерацию кода MEX

Функция fcn_predict.m берет вход изображений и запускает прогноз на изображении с помощью нейронной сети для глубокого обучения, сохраненной в файле FCN8sCamVid.mat. Функция загружает сетевой объект от FCN8sCamVid.mat в персистентную переменную mynet. На последующих вызовах функции постоянный объект снова используется для прогноза.

type('fcn_predict.m')
% Copyright 2019 The MathWorks, Inc.

function out = fcn_predict(in)
%#codegen

% A persistent object mynet is used to load the DAG network object.
% At the first call to this function, the persistent object is constructed and
% setup. When the function is called subsequent times, the same object is reused 
% to call predict on inputs, thus avoiding reconstructing and reloading the
% network object.

persistent mynet;

if isempty(mynet)
    mynet = coder.loadDeepLearningNetwork('FCN8sCamVid.mat');
end

% pass in input
out = predict(mynet,in);

Сгенерируйте объект GPU Configuration для цели MEX, устанавливающей выходной язык на C++. Используйте функцию coder.DeepLearningConfig, чтобы создать cuDNN объект настройки глубокого обучения и присвоить ее свойству DeepLearningConfig объекта настройки графического процессора кода. Запустите команду codegen, задающую входной размер [360, 480, 3]. Это соответствует входному слою FCN.

cfg = coder.gpuConfig('mex');
cfg.TargetLang = 'C++';
cfg.DeepLearningConfig = coder.DeepLearningConfig('cudnn');
codegen -config cfg fcn_predict -args {ones(360,480,3,'uint8')} -report
Code generation successful: View report

Запустите сгенерированный MEX

Загрузите и отобразите входное изображение.

im = imread('testImage.png');
imshow(im);

Вызовите fcn_predict на входном изображении.

predict_scores = fcn_predict_mex(im);

Переменная predict_scores является 3 размерными матрицами, имеющими 11 каналов, соответствующих мудрой пикселем музыке прогноза к каждому классу. Вычислите канал с максимальным счетом прогноза, чтобы получить мудрые пикселем метки.

[~,argmax] = max(predict_scores,[],3);

Наложите сегментированные метки по входу, отображают и отображают сегментированную область.

classes = [
    "Sky"
    "Building"
    "Pole"
    "Road"
    "Pavement"
    "Tree"
    "SignSymbol"
    "Fence"
    "Car"
    "Pedestrian"
    "Bicyclist"
    ];

cmap = camvidColorMap();
SegmentedImage = labeloverlay(im,argmax,'ColorMap',cmap);
figure
imshow(SegmentedImage);
pixelLabelColorbar(cmap,classes);

Очистка

Очистите статический сетевой объект, загруженный в памяти.

clear mex;

Ссылки

[1] Долго, J., Э. Шелхэмер и Т. Даррелл. "Полностью Сверточные Сети для Семантической Сегментации". Продолжения Конференции по IEEE по Компьютерному зрению и Распознаванию образов, 2015, стр 3431–3440.

[2] Brostow, G. J. Ж. Фокер и Р. Сиполла. "Семантические классы объектов в видео: наземная база данных истины высокой четкости". Буквы Распознавания образов. Издание 30, Выпуск 2, 2009, стр 88-97.

[3] Badrinarayanan, V., А. Кендалл и Р. Сиполла. "SegNet: Глубокая Сверточная Архитектура Декодера Энкодера для Сегментации Изображений". arXiv предварительно распечатывают arXiv:1511.00561, 2015.

Для просмотра документации необходимо авторизоваться на сайте