Обучите сеть семантической сегментации глубокого обучения с помощью 3-D Данных моделирования

В этом примере показано, как использовать 3-D данные моделирования для обучения семантической сети сегментации и точной настройки ее к реальным данным с помощью генеративных состязательных сетей (GANs).

Этот пример использует данные 3-D симуляции, сгенерированные Driving Scenario Designer и Unreal Engine ®. Для примера, показывающего, как сгенерировать такие данные моделирования, смотрите Depth and Семантическая Сегментация Visualization Using Unreal Engine Simulation (Automated Driving Toolbox). Среда симуляции 3-D генерирует изображения и соответствующие метки наземного пикселя основной истины. Использование данных моделирования избегает процесса аннотации, который является одновременно утомительным и требует большого количества человеческих усилий. Однако модели сдвига в области, обученные только на данных моделирования, плохо работают на реальных наборах данных. Для решения этой проблемы можно использовать адаптацию области, чтобы подстроить обученную модель для работы с реальным набором данных .

Этот пример использует AdaptSegNet [1], сеть, которая адаптирует структуру выходных предсказаний сегментации, которые выглядят одинаково независимо от входной области. Сеть AdaptSegNet основана на модели GAN и состоит из двух сетей, которые обучаются одновременно, чтобы максимизировать эффективность обеих:

  1. Генератор - Сеть обучена генерировать высококачественные результаты сегментации из реальных или моделируемых входных изображений

  2. Дискриминатор - Сеть, которая сравнивает и пытается различить, являются ли предсказания сегментации генератора реальными или моделируемыми данными

Чтобы подстроить модель AdaptSegNet для данных реального мира, этот пример использует подмножество Данные CamVid [2] и адаптирует модель, чтобы сгенерировать высококачественные предсказания сегментации на данных CamVid.

Загрузка предварительно обученной сети

Загрузите предварительно обученную сеть. Предварительно обученная модель позволяет вам запустить весь пример, не дожидаясь завершения обучения. Если вы хотите обучить сеть, установите doTraining переменная в true.

doTraining = false;
if ~doTraining
    pretrainedURL = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedAdaptSegGANNet.mat';
    pretrainedFolder = fullfile(tempdir,'pretrainedNetwork');
    pretrainedNetwork = fullfile(pretrainedFolder,'trainedAdaptSegGANNet.mat'); 
    if ~exist(pretrainedNetwork,'file')
        mkdir(pretrainedFolder);
        disp('Downloading pretrained network (57 MB)...');
        websave(pretrainedNetwork,pretrainedURL);
    end
    pretrained = load(pretrainedNetwork);
    dlnetGenerator = pretrained.dlnetGenerator;
end    

Загрузка наборов данных

Загрузите симуляцию и реальные наборы данных с помощью downloadDataset function, заданная в разделе Support Functions этого примера. The downloadDataset функция загружает весь набор данных CamVid и разбивает данные на обучающие и тестовые наборы.

Набор данных моделирования был сгенерирован Driving Scenario Designer. Сгенерированные сценарии, которые состоят из 553 фотореалистичных изображений с метками, были визуализированы Unreal Engine. Вы используете этот набор данных для обучения модели.

Реальный набор данных является подмножеством набора данных CamVid из Кембриджского университета. Чтобы адаптировать модель к реальным данным, 69 изображений CamVid. Чтобы оценить обученную модель, вы используете 368 изображений CamVid.

Время загрузки зависит от вашего подключения к Интернету.

simulationDataURL = 'https://ssd.mathworks.com/supportfiles/vision/data/SimulationDrivingDataset.zip';
realImageDataURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
realLabelDataURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

simulationDataLocation = fullfile(tempdir,'SimulationData');
realDataLocation = fullfile(tempdir,'RealData');
[simulationImagesFolder, simulationLabelsFolder, realImagesFolder, realLabelsFolder, ...
    realTestImagesFolder, realTestLabelsFolder] = ... 
    downloadDataset(simulationDataLocation,simulationDataURL,realDataLocation,realImageDataURL,realLabelDataURL);

Загруженные файлы включают пиксельные метки для реальной области, но обратите внимание, что вы не используете эти пиксельные метки в процессе обучения. Этот пример использует пиксельные метки действительной области только для вычисления среднего значения пересечения по объединению (IoU), чтобы оценить эффективность обученной модели.

Загрузка симуляции и реальных данных

Использование imageDatastore загрузить симуляцию и реальные наборы данных для обучения. При помощи image datastore можно эффективно загрузить на диск большой набор изображений.

simData = imageDatastore(simulationImagesFolder);
realData = imageDatastore(realImagesFolder);

Предварительный просмотр изображений из набора данных моделирования и реального набора данных.

simImage = preview(simData);
realImage = preview(realData);
montage({simImage,realImage})

Реальные и моделируемые изображения выглядят очень по-другому. Следовательно, модели, обученные на моделируемых данных и оцененные на реальных данных, плохо работают из-за сдвига области.

Загрузка маркированных пикселями изображений для данных моделирования и реальных данных

Загрузите пиксельные данные о метке изображения симуляции при помощи pixelLabelDatastore (Computer Vision Toolbox). Хранилище datastore метки пикселя инкапсулирует данные о пиксельных метках и идентификатор метки в отображение имен классов.

В данном примере задайте пять классов, полезных для приложения для беспилотного вождения: дорога, фон, тротуар, небо и автомобиль.

classes = [
    "Road"
    "Background"
    "Pavement"
    "Sky"
    "Car"
    ];
numClasses = numel(classes);

Набор данных моделирования имеет восемь классов. Уменьшите количество классов с восьми до пяти путем группировки классов создания, дерева, сигнала трафика и света из исходного набора данных в один фоновый класс. Верните сгруппированные идентификаторы меток с помощью функции helper simulationPixelLabelIDs. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.

labelIDs = simulationPixelLabelIDs;

Используйте идентификаторы классов и меток, чтобы создать хранилище datastore о метках пикселей данных моделирования.

simLabels = pixelLabelDatastore(simulationLabelsFolder,classes,labelIDs);

Инициализируйте палитру для сегментированных изображений с помощью функции helper domainAdaptationColorMap, заданный в разделе Вспомогательные функции.

dmap = domainAdaptationColorMap;

Предварительный просмотр изображения с меткой пикселя путем наложения метки поверх изображения с помощью labeloverlay (Image Processing Toolbox) функция.

simImageLabel = preview(simLabels);
overlayImageSimulation = labeloverlay(simImage,simImageLabel,'ColorMap',dmap);
figure
imshow(overlayImageSimulation)
labelColorbar(dmap,classes);

Сдвиньте симуляцию и реальные данные, используемые для обучения, в нулевой центр, чтобы центрировать данные вокруг источника, используя transform функции и preprocessData вспомогательная функция, заданная в разделе Вспомогательные функции.

preprocessedSimData = transform(simData, @(simdata)preprocessData(simdata));
preprocessedRealData = transform(realData, @(realdata)preprocessData(realdata));

Используйте combine функция для объединения преобразованного datastore изображений и хранилищ данных меток пикселей области симуляции. Процесс обучения не использует пиксельные метки реальных данных.

combinedSimData = combine(preprocessedSimData,simLabels);

Задайте генератор AdaptSegNet

Этот пример изменяет VGG-16 сеть, предварительно обученную в ImageNet, на полностью сверточную сеть. Для увеличения рецептивных полей добавляют расширенные сверточные слои с полосами 2 и 4. Это делает выход карты функций одной восьмой от размера входа. Atrous пространственное пирамидальное объединение (ASPP) используется для предоставления многомасштабной информации и сопровождается resize2dlayer с коэффициентом повышающей дискретизации 8, чтобы изменить размер выхода на размер входа.

Сеть генератора AdaptSegNet, используемая в этом примере, проиллюстрирована на следующей схеме.

Чтобы получить предварительно обученную VGG-16 сеть, установите vgg16. Если пакет поддержки не установлен, то программное обеспечение предоставляет ссылку для загрузки.

net = vgg16;

Чтобы сделать VGG-16 сеть подходящей для семантической сегментации, удалите все слои VGG после 'relu4_3'.

vggLayers = net.Layers(2:24);

Создайте входной слой изображения размера 1280 на 720 на 3 для генератора.

inputSizeGenerator = [1280 720 3];
inputLayer = imageInputLayer(inputSizeGenerator,'Normalization','None','Name','inputLayer');

Создайте полностью сверточные слои сети. Используйте коэффициенты расширения 2 и 4, чтобы увеличить соответствующие поля.

fcnlayers = [
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2],'Name','conv5_1','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_1')
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2] ,'Name','conv5_2','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_2')
    convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2],'Name','conv5_3','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu5_3')
    convolution2dLayer([3 3], 480,'DilationFactor',[4 4],'Padding',[4 4 4 4],'Name','conv6_1','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu6_1')
    convolution2dLayer([3 3], 480,'DilationFactor',[4 4],'Padding',[4 4 4 4] ,'Name','conv6_2','WeightsInitializer','narrow-normal','BiasInitializer','zeros')
    reluLayer('Name','relu6_2')
    ];

Объедините слои и создайте график слоев.

layers = [
    inputLayer
    vggLayers
    fcnlayers
    ];
lgraph = layerGraph(layers);

ASPP используется для предоставления многомасштабной информации. Добавьте модуль ASPP к графику слоев с размером фильтра, равным количеству каналов, при помощи addASPPToNetwork вспомогательная функция, заданная в разделе Вспомогательные функции.

lgraph  = addASPPToNetwork(lgraph, numClasses);

Применить resize2dLayer с коэффициентом повышающей дискретизации 8, чтобы выход совпадал с размером входа.

upSampleLayer = resize2dLayer('Scale',8,'Method','bilinear','Name','resizeLayer');
lgraphGenerator = addLayers(lgraph,upSampleLayer);
lgraphGenerator = connectLayers(lgraphGenerator,'additionLayer','resizeLayer');

Визуализируйте сеть генератора на графике.

plot(lgraphGenerator)
title("Generator")

Задайте дискриминатор AdaptSeg

Сеть дискриминатора состоит из пяти сверточных слоев с размером ядра 3 и полосой 2, где количество каналов {64, 128, 256, 512, 1}. Каждый слой сопровождается утечкой слоя ReLU, параметризованной масштабом 0,2, за исключением последнего слоя. resize2dLayer используется для изменения размера выхода дискриминатора. Обратите внимание, что в этом примере не используется нормализация партии ., поскольку дискриминатор совместно обучается с сетью сегментации с использованием небольшого размера пакета.

Сеть дискриминатора AdaptSegNet в этом примере проиллюстрирована на следующей схеме.

Создайте слой входа изображений размера 1280 на 720 бай- numClasses что принимает в сегментации предсказания симуляции и вещественных областей.

inputSizeDiscriminator = [1280 720 numClasses];

Создайте полностью сверточные слои и сгенерируйте график слоев дискриминатора.

% Factor for number of channels in convolution layer.
numChannelsFactor = 64;

% Scale factor to resize the output of the discriminator.
resizeScale = 64;

% Scalar multiplier for leaky ReLU layers.
leakyReLUScale = 0.2;

% Create the layers of the discriminator.
layers = [
    imageInputLayer(inputSizeDiscriminator,'Normalization','none','Name','inputLayer')
    convolution2dLayer(3,numChannelsFactor,'Stride',2,'Padding',1,'Name','conv1','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu1')
    convolution2dLayer(3,numChannelsFactor*2,'Stride',2,'Padding',1,'Name','conv2','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu2')
    convolution2dLayer(3,numChannelsFactor*4,'Stride',2,'Padding',1,'Name','conv3','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu3')
    convolution2dLayer(3,numChannelsFactor*8,'Stride',2,'Padding',1,'Name','conv4','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    leakyReluLayer(leakyReLUScale,'Name','lrelu4')
    convolution2dLayer(3,1,'Stride',2,'Padding',1,'Name','classifer','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    resize2dLayer('Scale', resizeScale,'Method','bilinear','Name','resizeLayer');
    ];

% Create the layer graph of the discriminator.
lgraphDiscriminator  = layerGraph(layers);

Визуализируйте сеть дискриминатора на графике.

plot(lgraphDiscriminator)
title("Discriminator")

Настройка опций обучения

Задайте эти опции обучения.

  • Установите общее количество итераций равным 5000. Тем самым вы обучаете сеть около 10 эпох.

  • Установите скорость обучения для генератора равной 2.5e-4.

  • Установите скорость обучения для дискриминатора равной 1e-4.

  • Установите коэффициент регуляризации L2 равным 0.0005.

  • Экспоненциально скорость обучения уменьшается на основе формулы learningrate× [iterationtotaliterations]power. Это уменьшение помогает стабилизировать градиенты при более высоких итерациях. Установите степень 0.9.

  • Установите вес состязательных потерь равным 0.001.

  • Инициализируйте скорость градиента следующим [ ]. Это значение используется SGDM, чтобы сохранить скорость градиентов.

  • Инициализируйте скользящее среднее значение для градиентов параметра следующим [ ]. Это значение используется инициализатором Адама, чтобы сохранить среднее значение градиентов параметра.

  • Инициализируйте движущееся среднее значение градиентов квадратов параметров следующим [ ]. Это значение используется инициализатором Адама, чтобы сохранить среднее значение градиентов квадратов параметра.

  • Установите размер мини-пакета равным 1.

numIterations = 5000;
learnRateGenBase = 2.5e-4;
learnRateDisBase = 1e-4;
l2Regularization = 0.0005;
power = 0.9;
lamdaAdv = 0.001;
vel= [];
averageGrad = [];
averageSqGrad = [];
miniBatchSize = 1;

Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Чтобы автоматически обнаружить, доступен ли вам графический процессор, установите executionEnvironment на "auto". Если у вас нет графический процессор или вы не хотите использовать его для обучения, установите executionEnvironment на "cpu". Чтобы гарантировать использование графический процессор для обучения, установите executionEnvironment на "gpu". Для получения информации о поддерживаемых вычислительных возможностях смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

executionEnvironment = "auto";

Создайте minibatchqueue объект из объединенного datastore области симуляции.

mbqTrainingDataSimulation =  minibatchqueue(combinedSimData,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Создайте minibatchqueue объект из входа входных изображений действительной области.

mbqTrainingDataReal = minibatchqueue(preprocessedRealData,"MiniBatchSize",miniBatchSize, ... 
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Обучите модель

Обучите модель с помощью пользовательского цикла обучения. Функция помощника modelGradients, заданный в разделе Вспомогательные функции этого примера, вычислите градиенты и потери для генератора и дискриминатора. Создайте график процесса обучения с помощью configureTrainingLossPlotter, приложенный к этому примеру как вспомогательный файл, и обновление процесса обучения с помощью updateTrainingPlots. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации.

Для каждой итерации:

  • Чтение изображения и информации о метке из minibatchqueue объект данных моделирования с помощью next функция.

  • Считайте информацию об изображении из minibatchqueue объект реальных данных с помощью next функция.

  • Оцените градиенты модели с помощью dlfeval и modelGradients вспомогательная функция, заданная в разделе Вспомогательные функции. modelGradients возвращает градиенты потерь относительно настраиваемых параметров.

  • Обновите параметры сети генератора с помощью sgdmupdate функция.

  • Обновите параметры сети дискриминатора, используя adamupdate функция.

  • Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.

if doTraining

    % Create the dlnetwork object of the generator.
    dlnetGenerator = dlnetwork(lgraphGenerator);
    
    % Create the dlnetwork object of the discriminator.
    dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
    
    % Create the subplots for the generator and discriminator loss.
    fig = figure;
    [generatorLossPlotter, discriminatorLossPlotter] = configureTrainingLossPlotter(fig);
    
    % Loop through the data for the specified number of iterations.
    for iter = 1:numIterations
       
        % Reset the minibatchqueue of simulation data.
        if ~hasdata(mbqTrainingDataSimulation)
            reset(mbqTrainingDataSimulation);
        end
        
        % Retrieve the next mini-batch of simulation data and labels.
        [dlX,label] = next(mbqTrainingDataSimulation); 
        
        % Reset the minibatchqueue of real data.
        if ~hasdata(mbqTrainingDataReal)
            reset(mbqTrainingDataReal);
        end
        
        % Retrieve the next mini-batch of real data. 
        dlZ = next(mbqTrainingDataReal);  
        
        % Evaluate the model gradients and loss using dlfeval and the modelGradients function.
        [gradientGenerator,gradientDiscriminator, lossSegValue, lossAdvValue, lossDisValue] = ...
            dlfeval(@modelGradients,dlnetGenerator,dlnetDiscriminator,dlX,dlZ,label,lamdaAdv);
        
        % Apply L2 regularization.
        gradientGenerator  = dlupdate(@(g,w) g + l2Regularization*w, gradientGenerator, dlnetGenerator.Learnables);
        
        % Adjust the learning rate.
        learnRateGen = piecewiseLearningRate(iter,learnRateGenBase,numIterations,power);
        learnRateDis = piecewiseLearningRate(iter,learnRateDisBase,numIterations,power);
        
         % Update the generator network learnable parameters using the SGDM optimizer.
        [dlnetGenerator.Learnables, vel] = ... 
            sgdmupdate(dlnetGenerator.Learnables,gradientGenerator,vel,learnRateGen);
               
         % Update the discriminator network learnable parameters using the Adam optimizer.
        [dlnetDiscriminator.Learnables, averageGrad, averageSqGrad] = ...
            adamupdate(dlnetDiscriminator.Learnables,gradientDiscriminator,averageGrad,averageSqGrad,iter,learnRateDis) ;
        
        % Update the training plot with loss values.
        updateTrainingPlots(generatorLossPlotter,discriminatorLossPlotter,iter, ... 
            double(gather(extractdata(lossSegValue + lamdaAdv * lossAdvValue))),double(gather(extractdata(lossDisValue))));

    end
    
    % Save the trained model.
    save('trainedAdaptSegGANNet.mat','dlnetGenerator');
end 

Теперь дискриминатор может идентифицировать, является ли вход из симуляции или действительной области. В свою очередь, генератор теперь может генерировать предсказания сегментации, которые аналогичны в симуляции и реальных областях.

Оцените модель на реальных тестовых данных

Оцените эффективность обученной сети AdaptSegNet, вычислив среднее значение IoU для предсказаний тестовых данных.

Загрузите тестовые данные с помощью imageDatastore.

realTestData = imageDatastore(realTestImagesFolder);

Набор данных CamVid имеет 32 класса. Используйте realpixelLabelIDs вспомогательная функция для уменьшения количества классов до пяти, как для набора данных моделирования. The realpixelLabelIDs Функция helper присоединена к этому примеру как вспомогательный файл.

labelIDs = realPixelLabelIDs;

Использование pixelLabelDatastore (Computer Vision Toolbox), чтобы загрузить изображения наземной метки истинности для тестовых данных.

realTestLabels = pixelLabelDatastore(realTestLabelsFolder,classes,labelIDs);

Переместите данные в нулевой центр, чтобы центрировать данные вокруг источника, как и для обучающих данных, при помощи transform функции и preprocessData вспомогательная функция, заданная в разделе Вспомогательные функции.

preprocessedRealTestData = transform(realTestData, @(realtestdata)preprocessData(realtestdata));

Использование combine чтобы объединить преобразованный datastore изображений и хранилища данных меток пикселей реальных тестовых данных.

combinedRealTestData = combine(preprocessedRealTestData,realTestLabels);

Создайте minibatchqueue объект из объединённого datastore тестовых данных .Задайте "MiniBatchSize" на 1 для простоты оценки метрик.

mbqimdsTest = minibatchqueue(combinedRealTestData,"MiniBatchSize",1,...
    "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);

Чтобы сгенерировать массив матричных ячеек неточностей, используйте функцию helper predictSegmentationLabelsOnTestSet на minibatchqueue объект тестовых данных. Функция помощника predictSegmentationLabelsOnTestSet приведен ниже в разделе Вспомогательные функции.

imageSetConfusionMat = predictSegmentationLabelsOnTestSet(dlnetGenerator,mbqimdsTest);

Использование evaluateSemanticSegmentation (Computer Vision Toolbox), чтобы измерить метрики семантической сегментации в матрице неточностей тестового набора.

metrics = evaluateSemanticSegmentation(imageSetConfusionMat,classes,'Verbose',false);

Чтобы увидеть метрики уровня набора данных, смотрите metrics.DataSetMetrics.

metrics.DataSetMetrics
ans=1×4 table
    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU
    ______________    ____________    _______    ___________

       0.86883           0.769        0.64487      0.78026  

Метрики набора данных обеспечивают высокоуровневый обзор эффективности сети. Чтобы увидеть влияние каждого класса на общую эффективность, смотрите метрики по классам с помощью metrics.ClassMetrics.

metrics.ClassMetrics
ans=5×2 table
                  Accuracy      IoU  
                  ________    _______

    Road           0.9147     0.81301
    Background    0.93418     0.85518
    Pavement      0.33373     0.27105
    Sky           0.82652     0.81109
    Car           0.83586     0.47399

Эффективность набора данных хороша, но метрики классов показывают, что классы автомобиля и тротуара не сегментированы хорошо. Обучение сети с помощью дополнительных данных может привести к улучшенным результатам.

Изображение сегмента

Запустите обученную сеть на одном тестовом изображении, чтобы проверить сегментированное выходное предсказание.

% Read the image from the test data.
data = readimage(realTestData,350);

% Perform the preprocessing step of zero shift on the image.
processeddata = preprocessData(data);

% Convert the data to dlarray.
processeddata = dlarray(processeddata,'SSCB');

% Predict the output of the network.
[genPrediction, ~] = forward(dlnetGenerator,processeddata);

% Get the label, which is the index with the maximum value in the channel dimension.
[~, labels] = max(genPrediction,[],3);

% Overlay the predicted labels on the image.
segmentedImage = labeloverlay(data,uint8(gather(extractdata(labels))),'Colormap',dmap);

Отображение результатов.

figure
imshow(segmentedImage);
labelColorbar(dmap,classes);

Сравните результаты метки с ожидаемыми основной истины, хранящимися в realTestLabels. Зеленая и пурпурная области подсвечивают области, где результаты сегментации отличаются от ожидаемой основной истины.

expectedResult = readimage(realTestLabels,350);
actual = uint8(gather(extractdata(labels)));
expected = uint8(expectedResult);
figure
imshowpair(actual,expected)

Визуально результаты семантической сегментации хорошо перекрываются для классов "Дорога" , "Небо" и "Здания". Однако результаты плохо перекрываются для классов автомобиля и тротуара.

Вспомогательные функции

Функция градиентов модели

Функция помощника modelGradients вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери сегментации для генератора и потери перекрестной энтропии для дискриминатора. Поскольку никакая информация о состоянии не требуется запоминать между итерациями для сетей генератора и дискриминатора, состояния не обновляются.

function [gradientGenerator, gradientDiscriminator, lossSegValue, lossAdvValue, lossDisValue] = modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, label, lamdaAdv)

% Labels for adversarial training.
simulationLabel = 0;
realLabel = 1;

% Extract the predictions of the simulation from the generator.
[genPredictionSimulation, ~] = forward(dlnetGenerator,dlX);

% Compute the generator loss.
lossSegValue = segmentationLoss(genPredictionSimulation,label);

% Extract the predictions of the real data from the generator.
[genPredictionReal, ~] = forward(dlnetGenerator,dlZ);

% Extract the softmax predictions of the real data from the discriminator.
disPredictionReal = forward(dlnetDiscriminator,softmax(genPredictionReal));

% Create a matrix of simulation labels of real prediction size.
Y = simulationLabel * ones(size(disPredictionReal));

% Compute the adversarial loss to make the real distribution close to the simulation label.
lossAdvValue = mse(disPredictionReal,Y)/numel(Y(:));

% Compute the gradients of the generator with regard to loss.
gradientGenerator = dlgradient(lossSegValue + lamdaAdv*lossAdvValue,dlnetGenerator.Learnables);

% Extract the softmax predictions of the simulation from the discriminator.
disPredictionSimulation = forward(dlnetDiscriminator,softmax(genPredictionSimulation));

% Create a matrix of simulation labels of simulation prediction size.
Y = simulationLabel * ones(size(disPredictionSimulation));

% Compute the discriminator loss with regard to simulation class.
lossDisValueSimulation = mse(disPredictionSimulation,Y)/numel(Y(:));
 
% Extract the softmax predictions of the real data from the discriminator.
disPredictionReal = forward(dlnetDiscriminator,softmax(genPredictionReal));

% Create a matrix of real labels of real prediction size.
Y = realLabel * ones(size(disPredictionReal));

% Compute the discriminator loss with regard to real class.
lossDisValueReal = mse(disPredictionReal,Y)/numel(Y(:));

% Compute the total discriminator loss.
lossDisValue = lossDisValueSimulation + lossDisValueReal;

% Compute the gradients of the discriminator with regard to loss.
gradientDiscriminator = dlgradient(lossDisValue,dlnetDiscriminator.Learnables);

end

Функция потерь сегментации

Функция помощника segmentationLoss вычисляет потери сегментации функции, которые заданы как потери перекрестной энтропии для генератора, используя данные моделирования и его соответствующую основную истину. Функция helper вычисляет потери при помощи crossentropy функция.

function loss = segmentationLoss(predict, target)

% Generate the one-hot encodings of the ground truth.
oneHotTarget = onehotencode(categorical(extractdata(target)),4);

% Convert the one-hot encoded data to dlarray.
oneHotTarget = dlarray(oneHotTarget,'SSBC');

% Compute the softmax output of the predictions.
predictSoftmax = softmax(predict);

% Compute the cross-entropy loss.
loss =  crossentropy(predictSoftmax,oneHotTarget,'TargetCategories','exclusive')/(numel(oneHotTarget)/2);
end

Функция помощника downloadDataset загружает как симуляцию, так и реальные наборы данных с указанных URL-адресов в указанные папки, если они не существуют. Функция возвращает пути симуляции, реальных обучающих данных и данных о проверке. Функция загружает весь набор данных CamVid и разбивает данные на обучающие и тестовые наборы с помощью subsetCamVidDatasetFileNames файл mat, присоединенный к примеру как вспомогательный файл.

function [simulationImagesFolder, simulationLabelsFolder, realImagesFolder, realLabelsFolder,...
    realTestImagesFolder, realTestLabelsFolder] = ...
    downloadDataset(simulationDataLocation, simulationDataURL, realDataLocation, realImageDataURL, realLabelDataURL)
    
% Build the training image and label folder location for simulation data.
simulationDataZip = fullfile(simulationDataLocation,'SimulationDrivingDataset.zip');

% Get the simulation data if it does not exist.
if ~exist(simulationDataZip,'file')
    mkdir(simulationDataLocation)
    
    disp('Downloading the simulation data');
    websave(simulationDataZip,simulationDataURL);
    unzip(simulationDataZip,simulationDataLocation);
end
  
simulationImagesFolder = fullfile(simulationDataLocation,'SimulationDrivingDataset','images');
simulationLabelsFolder = fullfile(simulationDataLocation,'SimulationDrivingDataset','labels');

camVidLabelsZip = fullfile(realDataLocation,'CamVidLabels.zip');
camVidImagesZip = fullfile(realDataLocation,'CamVidImages.zip');

if ~exist(camVidLabelsZip,'file') || ~exist(camVidImagesZip,'file')   
    mkdir(realDataLocation)
       
    disp('Downloading 16 MB CamVid dataset labels...'); 
    websave(camVidLabelsZip, realLabelDataURL);
    unzip(camVidLabelsZip, fullfile(realDataLocation,'CamVidLabels'));
    
    disp('Downloading 587 MB CamVid dataset images...');  
    websave(camVidImagesZip, realImageDataURL);       
    unzip(camVidImagesZip, fullfile(realDataLocation,'CamVidImages'));    
end

% Build the training image and label folder location for real data.
realImagesFolder = fullfile(realDataLocation,'train','images');
realLabelsFolder = fullfile(realDataLocation,'train','labels');

% Build the testing image and label folder location for real data.
realTestImagesFolder = fullfile(realDataLocation,'test','images');
realTestLabelsFolder = fullfile(realDataLocation,'test','labels');

% Partition the data into training and test sets if they do not exist.
if ~exist(realImagesFolder,'file') || ~exist(realLabelsFolder,'file') || ...
        ~exist(realTestImagesFolder,'file') || ~exist(realTestLabelsFolder,'file')

    
    mkdir(realImagesFolder);
    mkdir(realLabelsFolder);
    mkdir(realTestImagesFolder);
    mkdir(realTestLabelsFolder);
    
    % Load the mat file that has the names for testing and training.
    partitionNames = load('subsetCamVidDatasetFileNames.mat');
    
    % Extract the test images names.
    imageTestNames = partitionNames.imageTestNames;
    
    % Remove the empty cells. 
    imageTestNames = imageTestNames(~cellfun('isempty',imageTestNames));
    
    % Extract the test labels names.
    labelTestNames = partitionNames.labelTestNames;
    
    % Remove the empty cells.
    labelTestNames = labelTestNames(~cellfun('isempty',labelTestNames));
    
    % Copy the test images to the respective folder.
    for i = 1:size(imageTestNames,1)
        labelSource = fullfile(realDataLocation,'CamVidLabels',labelTestNames(i));
        imageSource = fullfile(realDataLocation,'CamVidImages','701_StillsRaw_full',imageTestNames(i));
        copyfile(imageSource{1}, realTestImagesFolder);
        copyfile(labelSource{1}, realTestLabelsFolder);
    end
    
    % Extract the train images names.
    imageTrainNames = partitionNames.imageTrainNames;
    
    % Remove the empty cells.
    imageTrainNames = imageTrainNames(~cellfun('isempty',imageTrainNames));
    
    % Extract the train labels names.
    labelTrainNames = partitionNames.labelTrainNames;
    
    % Remove the empty cells.
    labelTrainNames = labelTrainNames(~cellfun('isempty',labelTrainNames));
    
    % Copy the train images to the respective folder.
    for i = 1:size(imageTrainNames,1)
        labelSource = fullfile(realDataLocation,'CamVidLabels',labelTrainNames(i));
        imageSource = fullfile(realDataLocation,'CamVidImages','701_StillsRaw_full',imageTrainNames(i));
        copyfile(imageSource{1},realImagesFolder);
        copyfile(labelSource{1},realLabelsFolder);
    end
end
end

Функция помощника addASPPToNetwork создает слои atrous spatial pyramid uling (ASPP) и добавляет их к входу графика слоев. Функция возвращает график слоев с соединенными с ним слоями ASPP.

function lgraph  = addASPPToNetwork(lgraph, numClasses)

% Define the ASPP dilation factors.
asppDilationFactors = [6,12];

% Define the ASPP filter sizes.
asppFilterSizes = [3,3];

% Extract the last layer of the layer graph.
lastLayerName = lgraph.Layers(end).Name;

% Define the addition layer.
addLayer = additionLayer(numel(asppDilationFactors),'Name','additionLayer');

% Add the addition layer to the layer graph.
lgraph = addLayers(lgraph,addLayer);

% Create the ASPP layers connected to the addition layer
% and connect the layer graph.
for i = 1: numel(asppDilationFactors)
    asppConvName = "asppConv_" + string(i);
    branchFilterSize = asppFilterSizes(i);
    branchDilationFactor = asppDilationFactors(i);
    asspLayer  = convolution2dLayer(branchFilterSize, numClasses,'DilationFactor', branchDilationFactor,...
        'Padding','same','Name',asppConvName,'WeightsInitializer','narrow-normal','BiasInitializer','zeros');
    lgraph = addLayers(lgraph,asspLayer);
    lgraph = connectLayers(lgraph,lastLayerName,asppConvName);
    lgraph = connectLayers(lgraph,asppConvName,strcat(addLayer.Name,'/',addLayer.InputNames{i}));
end
end

Функция помощника predictSegmentationLabelsOnTestSet вычисляет матрицу неточностей предсказанных и основную истину меток с помощью segmentationConfusionMatrix (Computer Vision Toolbox) функция.

function confusionMatrix =  predictSegmentationLabelsOnTestSet(net, minbatchTestData)   
    
confusionMatrix = {};
i = 1;
while hasdata(minbatchTestData)
    
    % Use next to retrieve a mini-batch from the datastore.
    [dlX, gtlabels] = next(minbatchTestData);
    
    % Predict the output of the network.
    [genPrediction, ~] = forward(net,dlX);
    
    % Get the label, which is the index with maximum value in the channel dimension.
    [~, labels] = max(genPrediction,[],3);
    
    % Get the confusion matrix of each image.
    confusionMatrix{i}  = segmentationConfusionMatrix(double(gather(extractdata(labels))),double(gather(extractdata(gtlabels))));
  
    i = i+1;
end

confusionMatrix = confusionMatrix';
    
end

Функция помощника piecewiseLearningRate вычисляет текущую скорость обучения на основе числа итераций.

function lr = piecewiseLearningRate(i, baseLR, numIterations, power)

fraction = i/numIterations;
factor = (1 - fraction)^power * 1e1;
lr = baseLR * factor;

end

Функция помощника preprocessData выполняет сдвиг нулевого центра путем вычитания количества каналов изображения соответствующим средним значением.

function data = preprocessData(data)

% Extract respective channels.
rc = data(:,:,1);
gc = data(:,:,2);
bc = data(:,:,3);

% Compute the respective channel means.
r = mean(rc(:));
g = mean(gc(:));
b = mean(bc(:));

% Shift the data by the mean of respective channel.
data = single(data) - single(shiftdim([r g b],-1));  
end

Ссылки

[1] Цай, И-Хсуань, Вэй-Чих Хунг, Самуэль Шультер, Кихюк Сон, Мин-Хсуань Ян и Манмохан Чандракер. «Обучение адаптации структурированного выходного пространства для семантической сегментации». В 2018 году IEEE/CVF Conference on Компьютерное Зрение and Pattern Recognition, 7472-81. Солт-Лейк-Сити, UT: IEEE, 2018. https://doi.org/10.1109/CVPR.2018.00780.

[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. Semantic Object Classes in Video: A High-Definition Ground Truth Database (неопр.) (недоступная ссылка). Pattern Recognition Letters 30, № 2 (январь 2009): 88-97. https://doi.org/10.1016/j.patrec.2008.04.005.