Обучите сеть Семантической Сегментации глубокого обучения Используя 3-D данные моделирования

В этом примере показано, как использовать 3-D данные моделирования, чтобы обучить сеть семантической сегментации и подстроить их к реальным данным с помощью порождающих соперничающих сетей (GANs).

Этот пример использует 3-D данные моделирования, сгенерированные Driving Scenario Designer и Нереальным Engine®. Для примера, показывающего, как сгенерировать такие данные моделирования, смотрите Глубину и Визуализацию Семантической Сегментации Используя Нереальную Симуляцию Engine (Automated Driving Toolbox). 3-D среда симуляции генерирует изображения и соответствующие пиксельные метки основной истины. Используя данные моделирования избегает процесса аннотации, который и утомителен и требует большого количества человеческого усилия. Однако доменные модели сдвига, обученные только на данных моделирования, не выполняют хорошо на реальных наборах данных. Чтобы обратиться к этому, можно использовать доменную адаптацию, чтобы подстроить обученную модель, чтобы работать над реальным набором данных.

Этот пример использует AdaptSegNet [1], сеть, которая адаптирует структуру выходных предсказаний сегментации, которые выглядят подобными независимо от входной области. Сеть AdaptSegNet основана на модели GAN и состоит из двух сетей, которые обучены одновременно, чтобы максимизировать эффективность обоих:

  1. Генератор — Сеть, обученная, чтобы сгенерировать высококачественную сегментацию, следует из действительных или симулированных входных изображений

  2. Различитель — Сеть, которая выдерживает сравнение и пытается различать, являются ли предсказания сегментации генератора из действительных или симулированных данных

Чтобы подстроить модель AdaptSegNet для реальных данных, этот пример использует подмножество Данные CamVid 2[] и адаптирует модель, чтобы сгенерировать высококачественные предсказания сегментации на данных CamVid.

Загрузите предварительно обученную сеть

Загрузите предварительно обученную сеть. Предварительно обученная модель позволяет вам запускать целый пример, не имея необходимость ожидать обучения завершиться. Если вы хотите обучить сеть, установите doTraining переменная к true.

doTraining = false;
if ~doTraining
    pretrainedURL = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedAdaptSegGANNet.mat';
    pretrainedFolder = fullfile(tempdir,'pretrainedNetwork');
    pretrainedNetwork = fullfile(pretrainedFolder,'trainedAdaptSegGANNet.mat'); 
    if ~exist(pretrainedNetwork,'file')
        mkdir(pretrainedFolder);
        disp('Downloading pretrained network (57 MB)...');
        websave(pretrainedNetwork,pretrainedURL);
    end
    pretrained = load(pretrainedNetwork);
    dlnetGenerator = pretrained.dlnetGenerator;
end    

Загрузите наборы данных

Загрузите симуляцию и действительные наборы данных при помощи downloadDataset функция, заданная в разделе Supporting Functions этого примера. downloadDataset функционируйте загружает целый набор данных CamVid, и разделите данные в наборы обучающих данных и наборы тестов.

Набор данных моделирования был сгенерирован Driving Scenario Designer. Сгенерированные сценарии, которые состоят из 553 фотореалистических изображений с метками, были представлены Нереальным Engine. Вы используете этот набор данных, чтобы обучить модель.

Действительный набор данных является подмножеством набора данных CamVid из Кембриджского университета. Адаптировать модель к реальным данным, 69 изображениям CamVid. Чтобы оценить обученную модель, вы используете 368 изображений CamVid.

Время загрузки зависит от вашего интернет-соединения.

simulationDataURL = 'https://ssd.mathworks.com/supportfiles/vision/data/SimulationDrivingDataset.zip';
realImageDataURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
realLabelDataURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

simulationDataLocation = fullfile(tempdir,'SimulationData');
realDataLocation = fullfile(tempdir,'RealData');
[simulationImagesFolder, simulationLabelsFolder, realImagesFolder, realLabelsFolder, ...
    realTestImagesFolder, realTestLabelsFolder] = ... 
    downloadDataset(simulationDataLocation,simulationDataURL,realDataLocation,realImageDataURL,realLabelDataURL);

Загруженные файлы включают пиксельные метки для действительной области, но отмечают, что вы не используете эти пиксельные метки в учебном процессе. Этот пример использует действительные доменные пиксельные метки только, чтобы вычислить среднее значение пересечения по объединению (IoU), чтобы оценить эффективность обученной модели.

Загрузите симуляцию и действительные данные

Используйте imageDatastore загружать симуляцию и действительные наборы данных для обучения. При помощи datastore изображений можно эффективно загрузить большое количество изображений на диске.

simData = imageDatastore(simulationImagesFolder);
realData = imageDatastore(realImagesFolder);

Предварительный просмотр отображает от набора данных моделирования и действительного набора данных.

simImage = preview(simData);
realImage = preview(realData);
montage({simImage,realImage})

Действительные и симулированные изображения выглядят очень отличающимися. Следовательно, модели, обученные на симулированных данных и оцененные на действительных данных, выполняют плохо из-за доменного сдвига.

Загрузите помеченные пикселем изображения для данных моделирования и действительных данных

Загрузите пиксельные данные изображения метки симуляции при помощи pixelLabelDatastore (Computer Vision Toolbox). PixelLabelDatastore инкапсулирует данные метки пикселя и идентификатор метки в сопоставление имен классов.

В данном примере задайте пять классов, полезных для автоматизированного ведущего приложения: дорога, фон, тротуар, небо и автомобиль.

classes = [
    "Road"
    "Background"
    "Pavement"
    "Sky"
    "Car"
    ];
numClasses = numel(classes);

Набор данных моделирования имеет восемь классов. Сократите количество классов от восемь до пять путем группировки создания, дерева, сигнала трафика и легких классов от исходного набора данных в один фоновый класс. Возвратите сгруппированную метку IDs при помощи функции помощника simulationPixelLabelIDs. Эта функция помощника присоединена к примеру как к вспомогательному файлу.

labelIDs = simulationPixelLabelIDs;

Используйте классы и метку IDs, чтобы создать пиксельный datastore метки данных моделирования.

simLabels = pixelLabelDatastore(simulationLabelsFolder,classes,labelIDs);

Инициализируйте палитру для сегментированных изображений с помощью функции помощника domainAdaptationColorMap, заданный в разделе Supporting Functions.

dmap = domainAdaptationColorMap;

Предварительно просмотрите помеченное пикселем изображение путем накладывания метки сверху изображения с помощью labeloverlay (Image Processing Toolbox) функция.

simImageLabel = preview(simLabels);
overlayImageSimulation = labeloverlay(simImage,simImageLabel,'ColorMap',dmap);
figure
imshow(overlayImageSimulation)
labelColorbar(dmap,classes);

Переключите симуляцию и действительные данные, используемые для обучения обнулить центр, сосредоточить данные вокруг источника, при помощи transform функционируйте и preprocessData функция помощника, заданная в разделе Supporting Functions.

preprocessedSimData = transform(simData, @(simdata)preprocessData(simdata));
preprocessedRealData = transform(realData, @(realdata)preprocessData(realdata));

Используйте combine функционируйте, чтобы объединить datastore преобразованного изображения и пиксельные хранилища данных метки области симуляции. Учебный процесс не использует пиксельные метки действительных данных.

combinedSimData = combine(preprocessedSimData,simLabels);

Задайте генератор AdaptSegNet

Этот пример изменяет сеть VGG-16, предварительно обученную на ImageNet к полностью сверточной сети. Чтобы увеличить восприимчивые поля, расширил сверточные слои шагами 2, и 4 добавляются. Это делает выходную одну восьмую разрешения карты функции входного размера. Atrous пространственное объединение пирамиды (ASPP) используется, чтобы предоставить многошкальную информацию и сопровождается resize2dlayer с фактором повышающей дискретизации 8, чтобы изменить размер выхода к размеру входа.

Сеть генератора AdaptSegNet, используемая в этом примере, проиллюстрирована в следующей схеме.

Чтобы получить предварительно обученную сеть VGG-16, установите vgg16. Если пакет поддержки не установлен, то программное обеспечение обеспечивает ссылку на загрузку.

net = vgg16;

Чтобы сделать сеть VGG-16 подходящей для семантической сегментации, удалите все слои VGG после 'relu4_3'.

vggLayers = net.Layers(2:24);

Создайте входной слой изображений размера 1280 720 3 для генератора.

inputSizeGenerator = [1280 720 3];
inputLayer = imageInputLayer(inputSizeGenerator,'Normalization','None','Name','inputLayer');

Создайте полностью сверточные слоя сети. Используйте коэффициенты расширения 2 и 4, чтобы увеличить соответствующие поля.

fcnlayers = [
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2],'Name','conv5_1','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_1')
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2] ,'Name','conv5_2','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_2')
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2],'Name','conv5_3','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_3')
    convolution2dLayer([3 3], 480,'DilationFactor',[4 4],'Padding',[4 4 4 4],'Name','conv6_1','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu6_1')
    convolution2dLayer([3 3], 480,'DilationFactor',[4 4],'Padding',[4 4 4 4] ,'Name','conv6_2','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu6_2')
    ];

Объедините слои и создайте график слоев.

layers = [
    inputLayer
    vggLayers
    fcnlayers
    ];
lgraph = layerGraph(layers);

ASPP используется, чтобы предоставить многошкальную информацию. Добавьте модуль ASPP в график слоев с размером фильтра, равным количеству каналов при помощи addASPPToNetwork функция помощника, заданная в разделе Supporting Functions.

lgraph  = addASPPToNetwork(lgraph, numClasses);

Примените resize2dLayer с фактором повышающей дискретизации 8, чтобы заставить выход совпадать с размером входа.

upSampleLayer = resize2dLayer('Scale',8,'Method','bilinear','Name','resizeLayer');
lgraphGenerator = addLayers(lgraph,upSampleLayer);
lgraphGenerator = connectLayers(lgraphGenerator,'additionLayer','resizeLayer');

Визуализируйте сеть генератора в графике.

plot(lgraphGenerator)
title("Generator")

Задайте различитель AdaptSeg

Сеть различителя состоит из пяти сверточных слоев с размером ядра 3 и шагом 2, где количество каналов {64, 128, 256, 512, 1}. Каждый слой сопровождается текучим слоем ReLU, параметрированным шкалой 0,2, за исключением последнего слоя. resize2dLayer используется, чтобы изменить размер выхода различителя. Обратите внимание на то, что этот пример не использует нормализацию партии., когда различитель совместно обучен с сетью сегментации использование небольшого пакетного размера.

Сеть различителя AdaptSegNet в этом примере проиллюстрирована в следующей схеме.

Создайте входной слой изображений размера 1280 720 numClasses это берет в предсказаниях сегментации симуляции и действительных областей.

inputSizeDiscriminator = [1280 720 numClasses];

Создайте полностью сверточные слои и сгенерируйте график слоев различителя.

% Factor for number of channels in convolution layer.
numChannelsFactor = 64;

% Scale factor to resize the output of the discriminator.
resizeScale = 64;

% Scalar multiplier for leaky ReLU layers.
leakyReLUScale = 0.2;

% Create the layers of the discriminator.
layers = [
    imageInputLayer(inputSizeDiscriminator,'Normalization','none','Name','inputLayer')
    convolution2dLayer(3,numChannelsFactor,'Stride',2,'Padding',1,'Name','conv1','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu1')
    convolution2dLayer(3,numChannelsFactor*2,'Stride',2,'Padding',1,'Name','conv2','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu2')
    convolution2dLayer(3,numChannelsFactor*4,'Stride',2,'Padding',1,'Name','conv3','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu3')
    convolution2dLayer(3,numChannelsFactor*8,'Stride',2,'Padding',1,'Name','conv4','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu4')
    convolution2dLayer(3,1,'Stride',2,'Padding',1,'Name','classifer','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    resize2dLayer('Scale', resizeScale,'Method','bilinear','Name','resizeLayer');
    ];

% Create the layer graph of the discriminator.
lgraphDiscriminator  = layerGraph(layers);

Визуализируйте сеть различителя в графике.

plot(lgraphDiscriminator)
title("Discriminator")

Задайте опции обучения

Задайте эти опции обучения.

  • Определите общий номер итераций к 5000. Путем выполнения так, вы обучаете сеть в течение приблизительно 10 эпох.

  • Установите скорость обучения для генератора к 2.5e-4.

  • Установите скорость обучения для различителя к 1e-4.

  • Установитесь коэффициент регуляризации L2 на 0.0005.

  • Скорость обучения экспоненциально уменьшается на основе формулы learningrate× [iterationtotaliterations]power. Это уменьшение помогает стабилизировать градиенты в более высоких итерациях. Установите степень на 0.9.

  • Установите вес соперничающей потери для 0.001.

  • Инициализируйте скорость градиента как [ ]. Это значение используется SGDM, чтобы сохранить скорость градиентов.

  • Инициализируйте скользящее среднее значение градиентов параметра как [ ]. Это значение используется инициализатором Адама, чтобы сохранить среднее значение градиентов параметра.

  • Инициализируйте скользящее среднее значение градиентов параметра в квадрате как [ ]. Это значение используется инициализатором Адама, чтобы сохранить среднее значение градиентов параметра в квадрате.

  • Установите мини-пакетный размер на 1.

numIterations = 5000;
learnRateGenBase = 2.5e-4;
learnRateDisBase = 1e-4;
l2Regularization = 0.0005;
power = 0.9;
lamdaAdv = 0.001;
vel= [];
averageGrad = [];
averageSqGrad = [];
miniBatchSize = 1;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше. Чтобы автоматически обнаружить, если вы имеете графический процессор в наличии, установите executionEnvironment к "auto". Если вы не имеете графического процессора или не хотите использовать один для обучения, устанавливать executionEnvironment к "cpu". Чтобы гарантировать использование графического процессора для обучения, установите executionEnvironment к "gpu".

executionEnvironment = "auto";

Создайте minibatchqueue объект от объединенного datastore области симуляции.

mbqTrainingDataSimulation =  minibatchqueue(combinedSimData,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Создайте minibatchqueue объект от входа отображает datastore действительной области.

mbqTrainingDataReal = minibatchqueue(preprocessedRealData,"MiniBatchSize",miniBatchSize, ... 
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Обучите модель

Обучите модель с помощью пользовательского учебного цикла. Функция помощника modelGradients, заданный в разделе Supporting Functions этого примера, вычислите градиенты и потери для генератора и различителя. Создайте график процесса обучения с помощью configureTrainingLossPlotter, присоединенный к этому примеру как вспомогательный файл, aбез обозначения даты обновите процесс обучения с помощью updateTrainingPlots. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации.

Для каждой итерации:

  • Считайте изображение и пометьте информацию от minibatchqueue объект данных моделирования с помощью next функция.

  • Считайте данные изображения из minibatchqueue объект действительных данных с помощью next функция.

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция помощника, заданная в разделе Supporting Functions. modelGradients возвращает градиенты потери относительно настраиваемых параметров.

  • Обновите параметры сети генератора с помощью sgdmupdate функция.

  • Обновите параметры сети различителя с помощью adamupdate функция.

  • Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.

if doTraining

    % Create the dlnetwork object of the generator.
    dlnetGenerator = dlnetwork(lgraphGenerator);
    
    % Create the dlnetwork object of the discriminator.
    dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
    
    % Create the subplots for the generator and discriminator loss.
    fig = figure;
    [generatorLossPlotter, discriminatorLossPlotter] = configureTrainingLossPlotter(fig);
    
    % Loop through the data for the specified number of iterations.
    for iter = 1:numIterations
       
        % Reset the minibatchqueue of simulation data.
        if ~hasdata(mbqTrainingDataSimulation)
            reset(mbqTrainingDataSimulation);
        end
        
        % Retrieve the next mini-batch of simulation data and labels.
        [dlX,label] = next(mbqTrainingDataSimulation); 
        
        % Reset the minibatchqueue of real data.
        if ~hasdata(mbqTrainingDataReal)
            reset(mbqTrainingDataReal);
        end
        
        % Retrieve the next mini-batch of real data. 
        dlZ = next(mbqTrainingDataReal);  
        
        % Evaluate the model gradients and loss using dlfeval and the modelGradients function.
        [gradientGenerator,gradientDiscriminator, lossSegValue, lossAdvValue, lossDisValue] = ...
            dlfeval(@modelGradients,dlnetGenerator,dlnetDiscriminator,dlX,dlZ,label,lamdaAdv);
        
        % Apply L2 regularization.
        gradientGenerator  = dlupdate(@(g,w) g + l2Regularization*w, gradientGenerator, dlnetGenerator.Learnables);
        
        % Adjust the learning rate.
        learnRateGen = piecewiseLearningRate(iter,learnRateGenBase,numIterations,power);
        learnRateDis = piecewiseLearningRate(iter,learnRateDisBase,numIterations,power);
        
         % Update the generator network learnable parameters using the SGDM optimizer.
        [dlnetGenerator.Learnables, vel] = ... 
            sgdmupdate(dlnetGenerator.Learnables,gradientGenerator,vel,learnRateGen);
               
         % Update the discriminator network learnable parameters using the Adam optimizer.
        [dlnetDiscriminator.Learnables, averageGrad, averageSqGrad] = ...
            adamupdate(dlnetDiscriminator.Learnables,gradientDiscriminator,averageGrad,averageSqGrad,iter,learnRateDis) ;
        
        % Update the training plot with loss values.
        updateTrainingPlots(generatorLossPlotter,discriminatorLossPlotter,iter, ... 
            double(gather(extractdata(lossSegValue + lamdaAdv * lossAdvValue))),double(gather(extractdata(lossDisValue))));

    end
    
    % Save the trained model.
    save('trainedAdaptSegGANNet.mat','dlnetGenerator');
end 

Различитель может теперь идентифицировать, является ли вход от симуляции или действительной области. В свою очередь генератор может теперь сгенерировать предсказания сегментации, которые подобны через симуляцию и действительные области.

Оцените модель на действительных тестовых данных

Оцените эффективность обученной сети AdaptSegNet путем вычисления среднего IoU для предсказаний тестовых данных.

Загрузите тестовые данные с помощью imageDatastore.

realTestData = imageDatastore(realTestImagesFolder);

Набор данных CamVid имеет 32 класса. Используйте realpixelLabelIDs функция помощника, чтобы сократить количество классов к пять, что касается набора данных моделирования. realpixelLabelIDs функция помощника присоединена к этому примеру как к вспомогательному файлу.

labelIDs = realPixelLabelIDs;

Используйте pixelLabelDatastore (Computer Vision Toolbox), чтобы загрузить основную истину помечает изображения для тестовых данных.

realTestLabels = pixelLabelDatastore(realTestLabelsFolder,classes,labelIDs);

Сдвиньте данные, чтобы обнулить центр, чтобы сосредоточить данные вокруг источника, что касается обучающих данных, при помощи transform функционируйте и preprocessData функция помощника, заданная в разделе Supporting Functions.

preprocessedRealTestData = transform(realTestData, @(realtestdata)preprocessData(realtestdata));

Используйте combine чтобы объединить datastore преобразованного изображения и пиксель помечают хранилища данных действительных тестовых данных.

combinedRealTestData = combine(preprocessedRealTestData,realTestLabels);

Создайте minibatchqueue объект от объединенного datastore теста data. Установите "MiniBatchSize" к 1 для простоты оценки метрик.

mbqimdsTest = minibatchqueue(combinedRealTestData,"MiniBatchSize",1,...
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Чтобы сгенерировать массив ячейки матрицы беспорядка, используйте функцию помощника predictSegmentationLabelsOnTestSet на minibatchqueue объект тестовых данных. Функция помощника predictSegmentationLabelsOnTestSet описан ниже в разделе Supporting Functions.

imageSetConfusionMat = predictSegmentationLabelsOnTestSet(dlnetGenerator,mbqimdsTest);

Используйте evaluateSemanticSegmentation (Computer Vision Toolbox), чтобы измерить метрики семантической сегментации на матрице беспорядка набора тестов.

metrics = evaluateSemanticSegmentation(imageSetConfusionMat,classes,'Verbose',false);

Чтобы видеть метрики уровня набора данных, смотрите metrics.DataSetMetrics.

metrics.DataSetMetrics
ans=1×4 table
    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU
    ______________    ____________    _______    ___________

       0.86883           0.769        0.64487      0.78026  

Метрики набора данных предоставляют общий обзор производительности сети. Чтобы увидеть влияние каждого класса на общую производительности, смотрите метрики по классам с помощью metrics.ClassMetrics.

metrics.ClassMetrics
ans=5×2 table
                  Accuracy      IoU  
                  ________    _______

    Road           0.9147     0.81301
    Background    0.93418     0.85518
    Pavement      0.33373     0.27105
    Sky           0.82652     0.81109
    Car           0.83586     0.47399

Эффективность набора данных хороша, но метрики класса показывают, что автомобиль и классы тротуара не сегментируются хорошо. Обучение сеть с помощью дополнительных данных может привести к улучшенным результатам.

Изображение сегмента

Запустите обучивший сеть на одном тестовом изображении, чтобы проверять сегментированное выходное предсказание.

% Read the image from the test data.
data = readimage(realTestData,350);

% Perform the preprocessing step of zero shift on the image.
processeddata = preprocessData(data);

% Convert the data to dlarray.
processeddata = dlarray(processeddata,'SSCB');

% Predict the output of the network.
[genPrediction, ~] = forward(dlnetGenerator,processeddata);

% Get the label, which is the index with the maximum value in the channel dimension.
[~, labels] = max(genPrediction,[],3);

% Overlay the predicted labels on the image.
segmentedImage = labeloverlay(data,uint8(gather(extractdata(labels))),'Colormap',dmap);

Отобразите результаты.

figure
imshow(segmentedImage);
labelColorbar(dmap,classes);

Сравните результаты метки с ожидаемой основной истиной, сохраненной в realTestLabels. Зеленые и пурпурные области подсвечивают области, где результаты сегментации отличаются от ожидаемой основной истины.

expectedResult = readimage(realTestLabels,350);
actual = uint8(gather(extractdata(labels)));
expected = uint8(expectedResult);
figure
imshowpair(actual,expected)

Визуально, результаты семантической сегментации перекрываются хорошо для классов "Дорога", "Небо" и "Здания". Однако результаты не перекрываются хорошо для классов тротуара и автомобиля.

Вспомогательные Функции

Функция градиентов модели

Функция помощника modelGradients вычисляет градиенты и соперничающую потерю для генератора и различителя. Функция также вычисляет потерю сегментации для генератора и потерю перекрестной энтропии для различителя. Когда никакая информация состояния не требуется, чтобы помниться между итерациями и за генератор и за сети различителя, состояния не обновляются.

function [gradientGenerator, gradientDiscriminator, lossSegValue, lossAdvValue, lossDisValue] = modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, label, lamdaAdv)

% Labels for adversarial training.
simulationLabel = 0;
realLabel = 1;

% Extract the predictions of the simulation from the generator.
[genPredictionSimulation, ~] = forward(dlnetGenerator,dlX);

% Compute the generator loss.
lossSegValue = segmentationLoss(genPredictionSimulation,label);

% Extract the predictions of the real data from the generator.
[genPredictionReal, ~] = forward(dlnetGenerator,dlZ);

% Extract the softmax predictions of the real data from the discriminator.
disPredictionReal = forward(dlnetDiscriminator,softmax(genPredictionReal));

% Create a matrix of simulation labels of real prediction size.
Y = simulationLabel * ones(size(disPredictionReal));

% Compute the adversarial loss to make the real distribution close to the simulation label.
lossAdvValue = mse(disPredictionReal,Y)/numel(Y(:));

% Compute the gradients of the generator with regard to loss.
gradientGenerator = dlgradient(lossSegValue + lamdaAdv*lossAdvValue,dlnetGenerator.Learnables);

% Extract the softmax predictions of the simulation from the discriminator.
disPredictionSimulation = forward(dlnetDiscriminator,softmax(genPredictionSimulation));

% Create a matrix of simulation labels of simulation prediction size.
Y = simulationLabel * ones(size(disPredictionSimulation));

% Compute the discriminator loss with regard to simulation class.
lossDisValueSimulation = mse(disPredictionSimulation,Y)/numel(Y(:));
 
% Extract the softmax predictions of the real data from the discriminator.
disPredictionReal = forward(dlnetDiscriminator,softmax(genPredictionReal));

% Create a matrix of real labels of real prediction size.
Y = realLabel * ones(size(disPredictionReal));

% Compute the discriminator loss with regard to real class.
lossDisValueReal = mse(disPredictionReal,Y)/numel(Y(:));

% Compute the total discriminator loss.
lossDisValue = lossDisValueSimulation + lossDisValueReal;

% Compute the gradients of the discriminator with regard to loss.
gradientDiscriminator = dlgradient(lossDisValue,dlnetDiscriminator.Learnables);

end

Функция потерь сегментации

Функция помощника segmentationLoss вычисляет потерю сегментации функции, которая задана как потеря перекрестной энтропии для генератора с помощью данных моделирования и его соответствующей основной истины. Функция помощника вычисляет потерю при помощи crossentropy функция.

function loss = segmentationLoss(predict, target)

% Generate the one-hot encodings of the ground truth.
oneHotTarget = onehotencode(categorical(extractdata(target)),4);

% Convert the one-hot encoded data to dlarray.
oneHotTarget = dlarray(oneHotTarget,'SSBC');

% Compute the softmax output of the predictions.
predictSoftmax = softmax(predict);

% Compute the cross-entropy loss.
loss =  crossentropy(predictSoftmax,oneHotTarget,'TargetCategories','exclusive')/(numel(oneHotTarget)/2);
end

Функция помощника downloadDataset загрузки и симуляция и действительные наборы данных от заданных URL до заданных местоположений папки, если они не существуют. Функция возвращает пути симуляции, действительных обучающих данных и действительных данных о тестировании. Функция загружает целый набор данных CamVid, и разделите данные в наборы обучающих данных и наборы тестов с помощью subsetCamVidDatasetFileNames матовый файл, присоединенный к примеру как вспомогательный файл.

function [simulationImagesFolder, simulationLabelsFolder, realImagesFolder, realLabelsFolder,...
    realTestImagesFolder, realTestLabelsFolder] = ...
    downloadDataset(simulationDataLocation, simulationDataURL, realDataLocation, realImageDataURL, realLabelDataURL)
    
% Build the training image and label folder location for simulation data.
simulationDataZip = fullfile(simulationDataLocation,'SimulationDrivingDataset.zip');

% Get the simulation data if it does not exist.
if ~exist(simulationDataZip,'file')
    mkdir(simulationDataLocation)
    
    disp('Downloading the simulation data');
    websave(simulationDataZip,simulationDataURL);
    unzip(simulationDataZip,simulationDataLocation);
end
  
simulationImagesFolder = fullfile(simulationDataLocation,'SimulationDrivingDataset','images');
simulationLabelsFolder = fullfile(simulationDataLocation,'SimulationDrivingDataset','labels');

camVidLabelsZip = fullfile(realDataLocation,'CamVidLabels.zip');
camVidImagesZip = fullfile(realDataLocation,'CamVidImages.zip');

if ~exist(camVidLabelsZip,'file') || ~exist(camVidImagesZip,'file')   
    mkdir(realDataLocation)
       
    disp('Downloading 16 MB CamVid dataset labels...'); 
    websave(camVidLabelsZip, realLabelDataURL);
    unzip(camVidLabelsZip, fullfile(realDataLocation,'CamVidLabels'));
    
    disp('Downloading 587 MB CamVid dataset images...');  
    websave(camVidImagesZip, realImageDataURL);       
    unzip(camVidImagesZip, fullfile(realDataLocation,'CamVidImages'));    
end

% Build the training image and label folder location for real data.
realImagesFolder = fullfile(realDataLocation,'train','images');
realLabelsFolder = fullfile(realDataLocation,'train','labels');

% Build the testing image and label folder location for real data.
realTestImagesFolder = fullfile(realDataLocation,'test','images');
realTestLabelsFolder = fullfile(realDataLocation,'test','labels');

% Partition the data into training and test sets if they do not exist.
if ~exist(realImagesFolder,'file') || ~exist(realLabelsFolder,'file') || ...
        ~exist(realTestImagesFolder,'file') || ~exist(realTestLabelsFolder,'file')

    
    mkdir(realImagesFolder);
    mkdir(realLabelsFolder);
    mkdir(realTestImagesFolder);
    mkdir(realTestLabelsFolder);
    
    % Load the mat file that has the names for testing and training.
    partitionNames = load('subsetCamVidDatasetFileNames.mat');
    
    % Extract the test images names.
    imageTestNames = partitionNames.imageTestNames;
    
    % Remove the empty cells. 
    imageTestNames = imageTestNames(~cellfun('isempty',imageTestNames));
    
    % Extract the test labels names.
    labelTestNames = partitionNames.labelTestNames;
    
    % Remove the empty cells.
    labelTestNames = labelTestNames(~cellfun('isempty',labelTestNames));
    
    % Copy the test images to the respective folder.
    for i = 1:size(imageTestNames,1)
        labelSource = fullfile(realDataLocation,'CamVidLabels',labelTestNames(i));
        imageSource = fullfile(realDataLocation,'CamVidImages','701_StillsRaw_full',imageTestNames(i));
        copyfile(imageSource{1}, realTestImagesFolder);
        copyfile(labelSource{1}, realTestLabelsFolder);
    end
    
    % Extract the train images names.
    imageTrainNames = partitionNames.imageTrainNames;
    
    % Remove the empty cells.
    imageTrainNames = imageTrainNames(~cellfun('isempty',imageTrainNames));
    
    % Extract the train labels names.
    labelTrainNames = partitionNames.labelTrainNames;
    
    % Remove the empty cells.
    labelTrainNames = labelTrainNames(~cellfun('isempty',labelTrainNames));
    
    % Copy the train images to the respective folder.
    for i = 1:size(imageTrainNames,1)
        labelSource = fullfile(realDataLocation,'CamVidLabels',labelTrainNames(i));
        imageSource = fullfile(realDataLocation,'CamVidImages','701_StillsRaw_full',imageTrainNames(i));
        copyfile(imageSource{1},realImagesFolder);
        copyfile(labelSource{1},realLabelsFolder);
    end
end
end

Функция помощника addASPPToNetwork создает слои atrous пространственного объединения пирамиды (ASPP) и добавляет их во входной график слоев. Функция возвращает график слоев со слоями ASPP, соединенными с ним.

function lgraph  = addASPPToNetwork(lgraph, numClasses)

% Define the ASPP dilation factors.
asppDilationFactors = [6,12];

% Define the ASPP filter sizes.
asppFilterSizes = [3,3];

% Extract the last layer of the layer graph.
lastLayerName = lgraph.Layers(end).Name;

% Define the addition layer.
addLayer = additionLayer(numel(asppDilationFactors),'Name','additionLayer');

% Add the addition layer to the layer graph.
lgraph = addLayers(lgraph,addLayer);

% Create the ASPP layers connected to the addition layer
% and connect the layer graph.
for i = 1: numel(asppDilationFactors)
    asppConvName = "asppConv_" + string(i);
    branchFilterSize = asppFilterSizes(i);
    branchDilationFactor = asppDilationFactors(i);
    asspLayer  = convolution2dLayer(branchFilterSize, numClasses,'DilationFactor', branchDilationFactor,...
        'Padding','same','Name',asppConvName,'WeightsInitializer','narrow-normal','BiasInitializer','zeros');
    lgraph = addLayers(lgraph,asspLayer);
    lgraph = connectLayers(lgraph,lastLayerName,asppConvName);
    lgraph = connectLayers(lgraph,asppConvName,strcat(addLayer.Name,'/',addLayer.InputNames{i}));
end
end

Функция помощника predictSegmentationLabelsOnTestSet вычисляет матрицу беспорядка предсказанного и меток основной истины с помощью segmentationConfusionMatrix (Computer Vision Toolbox) функция.

function confusionMatrix =  predictSegmentationLabelsOnTestSet(net, minbatchTestData)   
    
confusionMatrix = {};
i = 1;
while hasdata(minbatchTestData)
    
    % Use next to retrieve a mini-batch from the datastore.
    [dlX, gtlabels] = next(minbatchTestData);
    
    % Predict the output of the network.
    [genPrediction, ~] = forward(net,dlX);
    
    % Get the label, which is the index with maximum value in the channel dimension.
    [~, labels] = max(genPrediction,[],3);
    
    % Get the confusion matrix of each image.
    confusionMatrix{i}  = segmentationConfusionMatrix(double(gather(extractdata(labels))),double(gather(extractdata(gtlabels))));
  
    i = i+1;
end

confusionMatrix = confusionMatrix';
    
end

Функция помощника piecewiseLearningRate вычисляет текущую скорость обучения на основе номера итерации.

function lr = piecewiseLearningRate(i, baseLR, numIterations, power)

fraction = i/numIterations;
factor = (1 - fraction)^power * 1e1;
lr = baseLR * factor;

end

Функция помощника preprocessData выполняет нулевой центральный сдвиг путем вычитания количества каналов изображений соответствующим средним значением.

function data = preprocessData(data)

% Extract respective channels.
rc = data(:,:,1);
gc = data(:,:,2);
bc = data(:,:,3);

% Compute the respective channel means.
r = mean(rc(:));
g = mean(gc(:));
b = mean(bc(:));

% Shift the data by the mean of respective channel.
data = single(data) - single(shiftdim([r g b],-1));  
end

Ссылки

[1] Tsai, И-Хсуэн, Вэй-Чи Хун, Сэмюэль Шултер, Киюк Зон, Мин-Хсуань Ян и Манмохан Чандрэкер. “Учась Адаптировать Структурированный Выходной Пробел к Семантической Сегментации”. На 2018 Конференциях IEEE/CVF по Компьютерному зрению и Распознаванию образов, 7472–81. Солт-Лейк-Сити, UT: IEEE, 2018. https://doi.org/10.1109/CVPR.2018.00780.

[2] Brostow, Габриэль Дж., Жюльен Фокер и Роберто Сиполья. “Семантические Классы объектов в Видео: База данных Основной истины Высокой четкости”. Буквы Распознавания образов 30, № 2 (январь 2009): 88–97. https://doi.org/10.1016/j.patrec.2008.04.005.