exponenta event banner

Обучение сети семантической сегментации Deep Learning с использованием данных моделирования 3-D

В этом примере показано, как использовать данные моделирования 3-D для обучения семантической сети сегментации и ее точной настройки на реальные данные с использованием генеративных состязательных сетей (GAN).

В этом примере используются данные моделирования 3-D созданные конструктором сценариев вождения и нереальным механизмом ®. Пример создания таких данных моделирования см. в разделе Визуализация глубины и семантической сегментации с использованием моделирования нереального двигателя (Automated Driving Toolbox). Среда моделирования 3-D генерирует изображения и соответствующие метки пикселей истинности земли. Использование данных моделирования позволяет избежать процесса аннотирования, который является утомительным и требует больших усилий человека. Однако модели смены доменов, обученные только моделированию, плохо работают с реальными наборами данных. Для решения этой проблемы можно использовать адаптацию домена для точной настройки обученной модели для работы с реальным набором данных .

В этом примере используется сеть TouchSegNet [1], которая адаптирует структуру прогнозов сегментации выходных данных, которые выглядят одинаково независимо от области ввода. Сеть GAN основана на модели GAN и состоит из двух сетей, которые обучаются одновременно для максимизации производительности обеих:

  1. Генератор - Сеть, обученная генерировать высококачественные результаты сегментации на основе реальных или смоделированных входных изображений

  2. Дискриминатор - сеть, которая сравнивает и пытается отличить прогнозы сегментации генератора от реальных или моделируемых данных

Для точной настройки модели CamSegNet для реальных данных этот пример использует подмножество Данные CamVid [2] и адаптирует модель для генерации высококачественных прогнозов сегментации данных CamVid.

Загрузить предварительно обученную сеть

Загрузите предварительно обученную сеть. Предварительно обученная модель позволяет выполнять весь пример без необходимости ждать завершения обучения. Если вы хотите обучить сеть, установите doTraining переменная для true.

doTraining = false;
if ~doTraining
    pretrainedURL = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedAdaptSegGANNet.mat';
    pretrainedFolder = fullfile(tempdir,'pretrainedNetwork');
    pretrainedNetwork = fullfile(pretrainedFolder,'trainedAdaptSegGANNet.mat'); 
    if ~exist(pretrainedNetwork,'file')
        mkdir(pretrainedFolder);
        disp('Downloading pretrained network (57 MB)...');
        websave(pretrainedNetwork,pretrainedURL);
    end
    pretrained = load(pretrainedNetwork);
    dlnetGenerator = pretrained.dlnetGenerator;
end    

Загрузить наборы данных

Загрузите моделирование и реальные наборы данных с помощью downloadDataset , определенной в разделе «Вспомогательные функции» данного примера. downloadDataset загружает весь набор данных CamVid и разбивает данные на обучающие и тестовые наборы.

Набор данных моделирования был создан конструктором сценариев управления. Сгенерированные сценарии, состоящие из 553 фотореалистичных изображений с метками, были визуализированы Unreal Engine. Этот набор данных используется для обучения модели.

Реальный набор данных является подмножеством набора данных CamVid из Кембриджского университета. Для адаптации модели к реальным данным, 69 изображений CamVid. Для оценки обучаемой модели используется 368 изображений CamVid.

Время загрузки зависит от вашего подключения к Интернету.

simulationDataURL = 'https://ssd.mathworks.com/supportfiles/vision/data/SimulationDrivingDataset.zip';
realImageDataURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
realLabelDataURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

simulationDataLocation = fullfile(tempdir,'SimulationData');
realDataLocation = fullfile(tempdir,'RealData');
[simulationImagesFolder, simulationLabelsFolder, realImagesFolder, realLabelsFolder, ...
    realTestImagesFolder, realTestLabelsFolder] = ... 
    downloadDataset(simulationDataLocation,simulationDataURL,realDataLocation,realImageDataURL,realLabelDataURL);

Загруженные файлы включают пиксельные метки для реального домена, но обратите внимание, что эти пиксельные метки не используются в процессе обучения. В этом примере метки пикселей реальной области используются только для вычисления среднего значения пересечения через объединение (IoU) для оценки эффективности обученной модели.

Моделирование нагрузки и реальные данные

Использовать imageDatastore для загрузки имитационных и реальных наборов данных для обучения. Используя хранилище данных образа, можно эффективно загружать большую коллекцию образов на диск.

simData = imageDatastore(simulationImagesFolder);
realData = imageDatastore(realImagesFolder);

Просмотр изображений из набора данных моделирования и набора реальных данных.

simImage = preview(simData);
realImage = preview(realData);
montage({simImage,realImage})

Реальные и смоделированные изображения выглядят совершенно иначе. Следовательно, модели, обученные моделированию данных и оцениваемые по реальным данным, работают плохо из-за сдвига домена.

Загрузка помеченных пикселями изображений для данных моделирования и реальных данных

Загрузка данных изображения эмуляционной пиксельной метки с помощью pixelLabelDatastore(Панель инструментов компьютерного зрения). Хранилище данных метки пикселя инкапсулирует данные метки пикселя и идентификатор метки в соответствие имени класса.

В этом примере укажите пять классов, полезных для автоматизированного приложения вождения: дорога, фон, дорожное покрытие, небо и автомобиль.

classes = [
    "Road"
    "Background"
    "Pavement"
    "Sky"
    "Car"
    ];
numClasses = numel(classes);

Набор данных моделирования состоит из восьми классов. Сократите число классов с восьми до пяти, сгруппировав классы здания, дерева, сигнала трафика и света из исходного набора данных в один фоновый класс. Возврат сгруппированных идентификаторов меток с помощью функции помощника simulationPixelLabelIDs. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.

labelIDs = simulationPixelLabelIDs;

Используйте классы и идентификаторы меток для создания хранилища данных меток пикселей данных моделирования.

simLabels = pixelLabelDatastore(simulationLabelsFolder,classes,labelIDs);

Инициализация карты цветов для сегментированных изображений с помощью функции помощника domainAdaptationColorMap, определенный в разделе «Вспомогательные функции».

dmap = domainAdaptationColorMap;

Предварительный просмотр изображения с меткой пикселя путем наложения метки поверх изображения с помощью labeloverlay(Панель инструментов обработки изображений).

simImageLabel = preview(simLabels);
overlayImageSimulation = labeloverlay(simImage,simImageLabel,'ColorMap',dmap);
figure
imshow(overlayImageSimulation)
labelColorbar(dmap,classes);

Переместите моделирование и реальные данные, используемые для обучения, в нуль-центр, чтобы центрировать данные вокруг начала координат, используя transform функции и preprocessData вспомогательная функция, определенная в разделе «Вспомогательные функции».

preprocessedSimData = transform(simData, @(simdata)preprocessData(simdata));
preprocessedRealData = transform(realData, @(realdata)preprocessData(realdata));

Используйте combine для объединения хранилища данных преобразованного изображения и хранилищ данных меток пикселей области моделирования. В процессе обучения не используются пиксельные метки реальных данных.

combinedSimData = combine(preprocessedSimData,simLabels);

Определение генератора SegNet

В этом примере VGG-16 сеть, предварительно подготовленная в ImageNet, изменяется на полностью сверточную сеть. Для увеличения воспринимающих полей добавляют расширенные сверточные слои с шагами 2 и 4. Это делает разрешение карты выходных характеристик одной восьмой от размера входных данных. Для предоставления многомасштабной информации используется пространственное объединение пирамид (ASPP), за которым следует resize2dlayer с коэффициентом повышения дискретизации 8 для изменения размера выходного сигнала до размера входного сигнала.

Используемая в этом примере сеть генератора PingleSegNet показана на следующей диаграмме.

Чтобы получить предварительно обученную сеть VGG-16, установите vgg16. Если пакет поддержки не установлен, программное обеспечение предоставляет ссылку для загрузки.

net = vgg16;

Чтобы сделать сеть VGG-16 подходящей для семантической сегментации, удалите все уровни VGG после 'relu4_3'.

vggLayers = net.Layers(2:24);

Создайте слой ввода изображения размером 1280-на-720-на-3 для генератора.

inputSizeGenerator = [1280 720 3];
inputLayer = imageInputLayer(inputSizeGenerator,'Normalization','None','Name','inputLayer');

Создание полностью сверточных сетевых уровней. Используйте коэффициенты расширения 2 и 4 для увеличения соответствующих полей.

fcnlayers = [
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2],'Name','conv5_1','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_1')
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2] ,'Name','conv5_2','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_2')
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2],'Name','conv5_3','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_3')
    convolution2dLayer([3 3], 480,'DilationFactor',[4 4],'Padding',[4 4 4 4],'Name','conv6_1','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu6_1')
    convolution2dLayer([3 3], 480,'DilationFactor',[4 4],'Padding',[4 4 4 4] ,'Name','conv6_2','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu6_2')
    ];

Объедините слои и создайте график слоев.

layers = [
    inputLayer
    vggLayers
    fcnlayers
    ];
lgraph = layerGraph(layers);

ASPP используется для предоставления многомасштабной информации. Добавьте модуль ASPP к графу уровней с размером фильтра, равным количеству каналов, используя addASPPToNetwork вспомогательная функция, определенная в разделе «Вспомогательные функции».

lgraph  = addASPPToNetwork(lgraph, numClasses);

Обратиться resize2dLayer с коэффициентом повышающей дискретизации 8, чтобы выходной сигнал соответствовал размеру входного сигнала.

upSampleLayer = resize2dLayer('Scale',8,'Method','bilinear','Name','resizeLayer');
lgraphGenerator = addLayers(lgraph,upSampleLayer);
lgraphGenerator = connectLayers(lgraphGenerator,'additionLayer','resizeLayer');

Визуализация генераторной сети на графике.

plot(lgraphGenerator)
title("Generator")

Определение дискриминатора RegedSeg

Дискриминаторная сеть состоит из пяти сверточных слоев с размером ядра 3 и шагом 2, где количество каналов - {64, 128, 256, 512, 1}. За каждым слоем следует слой ReLU с утечкой, параметризованный шкалой 0,2, за исключением последнего слоя. resize2dLayer используется для изменения размера выходного сигнала дискриминатора. Следует отметить, что в этом примере не используется нормализация партии, поскольку дискриминатор обучается совместно с сетью сегментации с использованием небольшого размера партии.

Дискриминаторная сеть RegingSegNet в этом примере показана на следующей диаграмме.

Создание слоя ввода изображения размером 1280-на-720-на-numClasses что принимает во внимание прогнозы сегментации симуляции и реальных доменов.

inputSizeDiscriminator = [1280 720 numClasses];

Создайте полностью сверточные слои и создайте график слоев дискриминатора.

% Factor for number of channels in convolution layer.
numChannelsFactor = 64;

% Scale factor to resize the output of the discriminator.
resizeScale = 64;

% Scalar multiplier for leaky ReLU layers.
leakyReLUScale = 0.2;

% Create the layers of the discriminator.
layers = [
    imageInputLayer(inputSizeDiscriminator,'Normalization','none','Name','inputLayer')
    convolution2dLayer(3,numChannelsFactor,'Stride',2,'Padding',1,'Name','conv1','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu1')
    convolution2dLayer(3,numChannelsFactor*2,'Stride',2,'Padding',1,'Name','conv2','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu2')
    convolution2dLayer(3,numChannelsFactor*4,'Stride',2,'Padding',1,'Name','conv3','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu3')
    convolution2dLayer(3,numChannelsFactor*8,'Stride',2,'Padding',1,'Name','conv4','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu4')
    convolution2dLayer(3,1,'Stride',2,'Padding',1,'Name','classifer','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    resize2dLayer('Scale', resizeScale,'Method','bilinear','Name','resizeLayer');
    ];

% Create the layer graph of the discriminator.
lgraphDiscriminator  = layerGraph(layers);

Визуализация дискриминаторной сети на графике.

plot(lgraphDiscriminator)
title("Discriminator")

Укажите параметры обучения

Укажите эти параметры обучения.

  • Задать общее количество итераций как 5000. Таким образом, вы обучаете сеть в течение примерно 10 эпох.

  • Установка коэффициента обучения для генератора на 2.5e-4.

  • Установка коэффициента обучения для дискриминатора равным 1e-4.

  • Установите коэффициент регуляризации L2 в значение 0.0005.

  • Скорость обучения экспоненциально уменьшается на основе формулы learningrate × [iterationtotaliterations] мощности. Это уменьшение помогает стабилизировать градиенты при более высоких итерациях. Установить мощность на0.9.

  • Установите вес состязательной потери равным 0.001.

  • Инициализировать скорость градиента как [ ]. Это значение используется SGDM для хранения скорости градиентов.

  • Инициализировать скользящее среднее градиентов параметров как [ ]. Это значение используется инициализатором Адама для хранения среднего значения градиентов параметров.

  • Инициализировать скользящее среднее квадратичных градиентов параметров как [ ]. Это значение используется инициализатором Адама для хранения среднего значения квадратичных градиентов параметров.

  • Задайте размер мини-партии как 1.

numIterations = 5000;
learnRateGenBase = 2.5e-4;
learnRateDisBase = 1e-4;
l2Regularization = 0.0005;
power = 0.9;
lamdaAdv = 0.001;
vel= [];
averageGrad = [];
averageSqGrad = [];
miniBatchSize = 1;

Обучение на GPU, если он доступен. Для использования графического процессора требуются параллельные вычислительные Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ®. Чтобы автоматически определить, доступен ли графический процессор, установитеexecutionEnvironment кому "auto". Если у вас нет графического процессора или вы не хотите использовать его для обучения, установите executionEnvironment кому "cpu". Для обеспечения использования графического процессора для обучения, установить executionEnvironment кому "gpu". Сведения о поддерживаемых вычислительных возможностях см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

executionEnvironment = "auto";

Создать minibatchqueue объект из объединенного хранилища данных области моделирования.

mbqTrainingDataSimulation =  minibatchqueue(combinedSimData,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Создать minibatchqueue объект из хранилища данных входного изображения реального домена.

mbqTrainingDataReal = minibatchqueue(preprocessedRealData,"MiniBatchSize",miniBatchSize, ... 
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Модель поезда

Обучение модели с помощью пользовательского цикла обучения. Вспомогательная функция modelGradients, определенные в разделе «Вспомогательные функции» этого примера, вычисляют градиенты и потери для генератора и дискриминатора. Создание графика хода обучения с использованием configureTrainingLossPlotter, прилагается к этому примеру в качестве вспомогательного файла и обновляет ход обучения с помощью updateTrainingPlots. Закольцовывать обучающие данные и обновлять сетевые параметры на каждой итерации.

Для каждой итерации:

  • Прочитайте информацию об изображении и метке из minibatchqueue объект данных моделирования с использованием next функция.

  • Считывание информации об изображении из minibatchqueue объект реальных данных с использованием next функция.

  • Оценка градиентов модели с помощью dlfeval и modelGradients вспомогательная функция, определенная в разделе «Вспомогательные функции». modelGradients возвращает градиенты потерь относительно обучаемых параметров.

  • Обновление параметров сети генератора с помощью sgdmupdate функция.

  • Обновите параметры сети дискриминатора с помощью adamupdate функция.

  • Обновите график хода обучения для каждой итерации и просмотрите различные вычисленные потери.

if doTraining

    % Create the dlnetwork object of the generator.
    dlnetGenerator = dlnetwork(lgraphGenerator);
    
    % Create the dlnetwork object of the discriminator.
    dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
    
    % Create the subplots for the generator and discriminator loss.
    fig = figure;
    [generatorLossPlotter, discriminatorLossPlotter] = configureTrainingLossPlotter(fig);
    
    % Loop through the data for the specified number of iterations.
    for iter = 1:numIterations
       
        % Reset the minibatchqueue of simulation data.
        if ~hasdata(mbqTrainingDataSimulation)
            reset(mbqTrainingDataSimulation);
        end
        
        % Retrieve the next mini-batch of simulation data and labels.
        [dlX,label] = next(mbqTrainingDataSimulation); 
        
        % Reset the minibatchqueue of real data.
        if ~hasdata(mbqTrainingDataReal)
            reset(mbqTrainingDataReal);
        end
        
        % Retrieve the next mini-batch of real data. 
        dlZ = next(mbqTrainingDataReal);  
        
        % Evaluate the model gradients and loss using dlfeval and the modelGradients function.
        [gradientGenerator,gradientDiscriminator, lossSegValue, lossAdvValue, lossDisValue] = ...
            dlfeval(@modelGradients,dlnetGenerator,dlnetDiscriminator,dlX,dlZ,label,lamdaAdv);
        
        % Apply L2 regularization.
        gradientGenerator  = dlupdate(@(g,w) g + l2Regularization*w, gradientGenerator, dlnetGenerator.Learnables);
        
        % Adjust the learning rate.
        learnRateGen = piecewiseLearningRate(iter,learnRateGenBase,numIterations,power);
        learnRateDis = piecewiseLearningRate(iter,learnRateDisBase,numIterations,power);
        
         % Update the generator network learnable parameters using the SGDM optimizer.
        [dlnetGenerator.Learnables, vel] = ... 
            sgdmupdate(dlnetGenerator.Learnables,gradientGenerator,vel,learnRateGen);
               
         % Update the discriminator network learnable parameters using the Adam optimizer.
        [dlnetDiscriminator.Learnables, averageGrad, averageSqGrad] = ...
            adamupdate(dlnetDiscriminator.Learnables,gradientDiscriminator,averageGrad,averageSqGrad,iter,learnRateDis) ;
        
        % Update the training plot with loss values.
        updateTrainingPlots(generatorLossPlotter,discriminatorLossPlotter,iter, ... 
            double(gather(extractdata(lossSegValue + lamdaAdv * lossAdvValue))),double(gather(extractdata(lossDisValue))));

    end
    
    % Save the trained model.
    save('trainedAdaptSegGANNet.mat','dlnetGenerator');
end 

Дискриминатор теперь может идентифицировать, является ли вход от моделирования или реальной области. В свою очередь, генератор теперь может генерировать прогнозы сегментации, которые одинаковы в моделировании и реальных доменах.

Оценка модели на реальных тестовых данных

Оцените производительность обученной сети TrainingSegNet путем вычисления среднего IoU для прогнозов тестовых данных.

Загрузите тестовые данные с помощью imageDatastore.

realTestData = imageDatastore(realTestImagesFolder);

Набор данных CamVid имеет 32 класса. Используйте realpixelLabelIDs вспомогательная функция для уменьшения числа классов до пяти, как для набора данных моделирования. realpixelLabelIDs вспомогательная функция присоединена к этому примеру в качестве вспомогательного файла.

labelIDs = realPixelLabelIDs;

Использовать pixelLabelDatastore (Computer Vision Toolbox) для загрузки изображений метки истинности земли для тестовых данных.

realTestLabels = pixelLabelDatastore(realTestLabelsFolder,classes,labelIDs);

Переместите данные в центр нулевых значений для центрирования данных вокруг исходной точки, как и для учебных данных, с помощью transform функции и preprocessData вспомогательная функция, определенная в разделе «Вспомогательные функции».

preprocessedRealTestData = transform(realTestData, @(realtestdata)preprocessData(realtestdata));

Использовать combine для объединения преобразованного хранилища данных изображения и хранилищ данных меток пикселей реальных тестовых данных.

combinedRealTestData = combine(preprocessedRealTestData,realTestLabels);

Создать minibatchqueue объект из объединенного хранилища данных тестовых данных. Набор "MiniBatchSize" кому 1 для простоты оценки метрик.

mbqimdsTest = minibatchqueue(combinedRealTestData,"MiniBatchSize",1,...
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Чтобы создать массив ячеек матрицы путаницы, используйте вспомогательную функцию predictSegmentationLabelsOnTestSet на minibatchqueue объект тестовых данных. Вспомогательная функция predictSegmentationLabelsOnTestSet перечислены ниже в разделе «Вспомогательные функции».

imageSetConfusionMat = predictSegmentationLabelsOnTestSet(dlnetGenerator,mbqimdsTest);

Использовать evaluateSemanticSegmentation (Computer Vision Toolbox) для измерения показателей семантической сегментации в матрице путаницы тестового набора.

metrics = evaluateSemanticSegmentation(imageSetConfusionMat,classes,'Verbose',false);

Чтобы просмотреть метрики уровня набора данных, проверьте metrics.DataSetMetrics.

metrics.DataSetMetrics
ans=1×4 table
    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU
    ______________    ____________    _______    ___________

       0.86883           0.769        0.64487      0.78026  

Метрики набора данных обеспечивают общий обзор производительности сети. Чтобы увидеть влияние каждого класса на общую производительность, проверьте метрики для каждого класса с помощью metrics.ClassMetrics.

metrics.ClassMetrics
ans=5×2 table
                  Accuracy      IoU  
                  ________    _______

    Road           0.9147     0.81301
    Background    0.93418     0.85518
    Pavement      0.33373     0.27105
    Sky           0.82652     0.81109
    Car           0.83586     0.47399

Производительность набора данных хорошая, но метрики классов показывают, что классы автомобилей и дорожного покрытия плохо сегментированы. Обучение сети с использованием дополнительных данных может дать улучшенные результаты.

Изображение сегмента

Запустите обученную сеть на одном тестовом изображении для проверки сегментированного прогноза вывода.

% Read the image from the test data.
data = readimage(realTestData,350);

% Perform the preprocessing step of zero shift on the image.
processeddata = preprocessData(data);

% Convert the data to dlarray.
processeddata = dlarray(processeddata,'SSCB');

% Predict the output of the network.
[genPrediction, ~] = forward(dlnetGenerator,processeddata);

% Get the label, which is the index with the maximum value in the channel dimension.
[~, labels] = max(genPrediction,[],3);

% Overlay the predicted labels on the image.
segmentedImage = labeloverlay(data,uint8(gather(extractdata(labels))),'Colormap',dmap);

Просмотрите результаты.

figure
imshow(segmentedImage);
labelColorbar(dmap,classes);

Сравните результаты метки с ожидаемой истинностью грунта, сохраненной в realTestLabels. Зеленая и пурпурная области выделяют области, где результаты сегментации отличаются от ожидаемой истинности грунта.

expectedResult = readimage(realTestLabels,350);
actual = uint8(gather(extractdata(labels)));
expected = uint8(expectedResult);
figure
imshowpair(actual,expected)

Визуально результаты семантической сегментации хорошо перекрываются для классов дорог, неба и зданий. Однако результаты плохо накладываются на классы автомобилей и дорожного покрытия.

Вспомогательные функции

Функция градиентов модели

Вспомогательная функция modelGradients вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери сегментации для генератора и потери перекрестной энтропии для дискриминатора. Поскольку не требуется запоминать информацию о состоянии между итерациями как для генерирующих, так и для дискриминаторных сетей, состояния не обновляются.

function [gradientGenerator, gradientDiscriminator, lossSegValue, lossAdvValue, lossDisValue] = modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, label, lamdaAdv)

% Labels for adversarial training.
simulationLabel = 0;
realLabel = 1;

% Extract the predictions of the simulation from the generator.
[genPredictionSimulation, ~] = forward(dlnetGenerator,dlX);

% Compute the generator loss.
lossSegValue = segmentationLoss(genPredictionSimulation,label);

% Extract the predictions of the real data from the generator.
[genPredictionReal, ~] = forward(dlnetGenerator,dlZ);

% Extract the softmax predictions of the real data from the discriminator.
disPredictionReal = forward(dlnetDiscriminator,softmax(genPredictionReal));

% Create a matrix of simulation labels of real prediction size.
Y = simulationLabel * ones(size(disPredictionReal));

% Compute the adversarial loss to make the real distribution close to the simulation label.
lossAdvValue = mse(disPredictionReal,Y)/numel(Y(:));

% Compute the gradients of the generator with regard to loss.
gradientGenerator = dlgradient(lossSegValue + lamdaAdv*lossAdvValue,dlnetGenerator.Learnables);

% Extract the softmax predictions of the simulation from the discriminator.
disPredictionSimulation = forward(dlnetDiscriminator,softmax(genPredictionSimulation));

% Create a matrix of simulation labels of simulation prediction size.
Y = simulationLabel * ones(size(disPredictionSimulation));

% Compute the discriminator loss with regard to simulation class.
lossDisValueSimulation = mse(disPredictionSimulation,Y)/numel(Y(:));
 
% Extract the softmax predictions of the real data from the discriminator.
disPredictionReal = forward(dlnetDiscriminator,softmax(genPredictionReal));

% Create a matrix of real labels of real prediction size.
Y = realLabel * ones(size(disPredictionReal));

% Compute the discriminator loss with regard to real class.
lossDisValueReal = mse(disPredictionReal,Y)/numel(Y(:));

% Compute the total discriminator loss.
lossDisValue = lossDisValueSimulation + lossDisValueReal;

% Compute the gradients of the discriminator with regard to loss.
gradientDiscriminator = dlgradient(lossDisValue,dlnetDiscriminator.Learnables);

end

Функция потери сегментации

Вспомогательная функция segmentationLoss вычисляет потерю сегментации элемента, которая определяется как потеря перекрестной энтропии для генератора с использованием данных моделирования и его соответствующей истинности на земле. Вспомогательная функция вычисляет потери с помощью crossentropy функция.

function loss = segmentationLoss(predict, target)

% Generate the one-hot encodings of the ground truth.
oneHotTarget = onehotencode(categorical(extractdata(target)),4);

% Convert the one-hot encoded data to dlarray.
oneHotTarget = dlarray(oneHotTarget,'SSBC');

% Compute the softmax output of the predictions.
predictSoftmax = softmax(predict);

% Compute the cross-entropy loss.
loss =  crossentropy(predictSoftmax,oneHotTarget,'TargetCategories','exclusive')/(numel(oneHotTarget)/2);
end

Вспомогательная функция downloadDataset загружает наборы данных моделирования и реальных данных из указанных URL-адресов в указанные папки, если они не существуют. Функция возвращает пути моделирования, реальные данные обучения и реальные данные тестирования. Функция загружает весь набор данных CamVid и разбивает данные на обучающие и тестовые наборы с помощью subsetCamVidDatasetFileNames файл мата, присоединенный к примеру как вспомогательный файл.

function [simulationImagesFolder, simulationLabelsFolder, realImagesFolder, realLabelsFolder,...
    realTestImagesFolder, realTestLabelsFolder] = ...
    downloadDataset(simulationDataLocation, simulationDataURL, realDataLocation, realImageDataURL, realLabelDataURL)
    
% Build the training image and label folder location for simulation data.
simulationDataZip = fullfile(simulationDataLocation,'SimulationDrivingDataset.zip');

% Get the simulation data if it does not exist.
if ~exist(simulationDataZip,'file')
    mkdir(simulationDataLocation)
    
    disp('Downloading the simulation data');
    websave(simulationDataZip,simulationDataURL);
    unzip(simulationDataZip,simulationDataLocation);
end
  
simulationImagesFolder = fullfile(simulationDataLocation,'SimulationDrivingDataset','images');
simulationLabelsFolder = fullfile(simulationDataLocation,'SimulationDrivingDataset','labels');

camVidLabelsZip = fullfile(realDataLocation,'CamVidLabels.zip');
camVidImagesZip = fullfile(realDataLocation,'CamVidImages.zip');

if ~exist(camVidLabelsZip,'file') || ~exist(camVidImagesZip,'file')   
    mkdir(realDataLocation)
       
    disp('Downloading 16 MB CamVid dataset labels...'); 
    websave(camVidLabelsZip, realLabelDataURL);
    unzip(camVidLabelsZip, fullfile(realDataLocation,'CamVidLabels'));
    
    disp('Downloading 587 MB CamVid dataset images...');  
    websave(camVidImagesZip, realImageDataURL);       
    unzip(camVidImagesZip, fullfile(realDataLocation,'CamVidImages'));    
end

% Build the training image and label folder location for real data.
realImagesFolder = fullfile(realDataLocation,'train','images');
realLabelsFolder = fullfile(realDataLocation,'train','labels');

% Build the testing image and label folder location for real data.
realTestImagesFolder = fullfile(realDataLocation,'test','images');
realTestLabelsFolder = fullfile(realDataLocation,'test','labels');

% Partition the data into training and test sets if they do not exist.
if ~exist(realImagesFolder,'file') || ~exist(realLabelsFolder,'file') || ...
        ~exist(realTestImagesFolder,'file') || ~exist(realTestLabelsFolder,'file')

    
    mkdir(realImagesFolder);
    mkdir(realLabelsFolder);
    mkdir(realTestImagesFolder);
    mkdir(realTestLabelsFolder);
    
    % Load the mat file that has the names for testing and training.
    partitionNames = load('subsetCamVidDatasetFileNames.mat');
    
    % Extract the test images names.
    imageTestNames = partitionNames.imageTestNames;
    
    % Remove the empty cells. 
    imageTestNames = imageTestNames(~cellfun('isempty',imageTestNames));
    
    % Extract the test labels names.
    labelTestNames = partitionNames.labelTestNames;
    
    % Remove the empty cells.
    labelTestNames = labelTestNames(~cellfun('isempty',labelTestNames));
    
    % Copy the test images to the respective folder.
    for i = 1:size(imageTestNames,1)
        labelSource = fullfile(realDataLocation,'CamVidLabels',labelTestNames(i));
        imageSource = fullfile(realDataLocation,'CamVidImages','701_StillsRaw_full',imageTestNames(i));
        copyfile(imageSource{1}, realTestImagesFolder);
        copyfile(labelSource{1}, realTestLabelsFolder);
    end
    
    % Extract the train images names.
    imageTrainNames = partitionNames.imageTrainNames;
    
    % Remove the empty cells.
    imageTrainNames = imageTrainNames(~cellfun('isempty',imageTrainNames));
    
    % Extract the train labels names.
    labelTrainNames = partitionNames.labelTrainNames;
    
    % Remove the empty cells.
    labelTrainNames = labelTrainNames(~cellfun('isempty',labelTrainNames));
    
    % Copy the train images to the respective folder.
    for i = 1:size(imageTrainNames,1)
        labelSource = fullfile(realDataLocation,'CamVidLabels',labelTrainNames(i));
        imageSource = fullfile(realDataLocation,'CamVidImages','701_StillsRaw_full',imageTrainNames(i));
        copyfile(imageSource{1},realImagesFolder);
        copyfile(labelSource{1},realLabelsFolder);
    end
end
end

Вспомогательная функция addASPPToNetwork создает слои пространственного пула пирамид (ASPP) и добавляет их к графу входного уровня. Функция возвращает граф слоев с подключенными к нему слоями ASPP.

function lgraph  = addASPPToNetwork(lgraph, numClasses)

% Define the ASPP dilation factors.
asppDilationFactors = [6,12];

% Define the ASPP filter sizes.
asppFilterSizes = [3,3];

% Extract the last layer of the layer graph.
lastLayerName = lgraph.Layers(end).Name;

% Define the addition layer.
addLayer = additionLayer(numel(asppDilationFactors),'Name','additionLayer');

% Add the addition layer to the layer graph.
lgraph = addLayers(lgraph,addLayer);

% Create the ASPP layers connected to the addition layer
% and connect the layer graph.
for i = 1: numel(asppDilationFactors)
    asppConvName = "asppConv_" + string(i);
    branchFilterSize = asppFilterSizes(i);
    branchDilationFactor = asppDilationFactors(i);
    asspLayer  = convolution2dLayer(branchFilterSize, numClasses,'DilationFactor', branchDilationFactor,...
        'Padding','same','Name',asppConvName,'WeightsInitializer','narrow-normal','BiasInitializer','zeros');
    lgraph = addLayers(lgraph,asspLayer);
    lgraph = connectLayers(lgraph,lastLayerName,asppConvName);
    lgraph = connectLayers(lgraph,asppConvName,strcat(addLayer.Name,'/',addLayer.InputNames{i}));
end
end

Вспомогательная функция predictSegmentationLabelsOnTestSet вычисляет матрицу путаницы прогнозируемой и нулевой меток истинности с помощью segmentationConfusionMatrix (Панель инструментов компьютерного зрения).

function confusionMatrix =  predictSegmentationLabelsOnTestSet(net, minbatchTestData)   
    
confusionMatrix = {};
i = 1;
while hasdata(minbatchTestData)
    
    % Use next to retrieve a mini-batch from the datastore.
    [dlX, gtlabels] = next(minbatchTestData);
    
    % Predict the output of the network.
    [genPrediction, ~] = forward(net,dlX);
    
    % Get the label, which is the index with maximum value in the channel dimension.
    [~, labels] = max(genPrediction,[],3);
    
    % Get the confusion matrix of each image.
    confusionMatrix{i}  = segmentationConfusionMatrix(double(gather(extractdata(labels))),double(gather(extractdata(gtlabels))));
  
    i = i+1;
end

confusionMatrix = confusionMatrix';
    
end

Вспомогательная функция piecewiseLearningRate вычисляет текущую скорость обучения на основе номера итерации.

function lr = piecewiseLearningRate(i, baseLR, numIterations, power)

fraction = i/numIterations;
factor = (1 - fraction)^power * 1e1;
lr = baseLR * factor;

end

Вспомогательная функция preprocessData выполняет сдвиг центра нуля путем вычитания числа каналов изображения на соответствующее среднее значение.

function data = preprocessData(data)

% Extract respective channels.
rc = data(:,:,1);
gc = data(:,:,2);
bc = data(:,:,3);

% Compute the respective channel means.
r = mean(rc(:));
g = mean(gc(:));
b = mean(bc(:));

% Shift the data by the mean of respective channel.
data = single(data) - single(shiftdim([r g b],-1));  
end

Ссылки

[1] Цай, И-Хсуань, Вэй-Чих Хунг, Самуэль Шултер, Кихюк Сон, Мин-Хсуань Ян и Манмохан Чандракер. «Учимся адаптировать пространство структурированного вывода для семантической сегментации». В 2018 году Конференция IEEE/CVF по компьютерному зрению и распознаванию образов, 7472-81. Солт-Лейк-Сити, UT: IEEE, 2018. https://doi.org/10.1109/CVPR.2018.00780.

[2] Бростоу, Габриэль Дж., Жюльен Фокер и Роберто Чиполла. «Классы семантических объектов в видео: база данных истинности земли высокой четкости». Письма 30, № 2 (январь 2009 года): 88-97. https://doi.org/10.1016/j.patrec.2008.04.005.