Сгенерируйте изображение из карты сегментации используя глубокое обучение

В этом примере показано, как сгенерировать синтетическое изображение сцены из семантической карты сегментации с помощью pix2pixHD условной генеративной состязательной сети (CGAN).

Pix2pixHD [1] состоит из двух сетей, которые обучаются одновременно, чтобы максимизировать эффективность обеих.

  1. Генератор является нейронной сетью в стиле энкодер-декодер, которая генерирует изображение сцены из семантической карты сегментации. Сеть CGAN обучает генератор генерировать изображение сцены, которое дискриминатор неправильно классифицирует как реальное.

  2. Дискриминатор является полностью сверточной нейронной сетью, которая сравнивает сгенерированное изображение сцены и соответствующее реальное изображение и пытается классифицировать их как поддельные и реальные, соответственно. Сеть CGAN обучает дискриминатор правильно различать сгенерированное и реальное изображение.

Сети генератора и дискриминатора конкурируют друг с другом во время обучения. Обучение сходится, когда ни одна из сетей не может улучшиться дальше.

Загрузка набора данных CamVid

Этот пример использует Набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных представляет собой набор 701 изображений, содержащих представления уличного уровня, полученные во время вождения. Набор данных обеспечивает пиксельные метки для 32 семантических классов, включая автомобиль, пешехода и дорогу.

Загрузите набор данных CamVid с этих URL-адресов. Время загрузки зависит от вашего подключения к Интернету.

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

dataDir = fullfile(tempdir,'CamVid'); 
downloadCamVidData(dataDir,imageURL,labelURL);
imgDir = fullfile(dataDir,"images","701_StillsRaw_full");
labelDir = fullfile(dataDir,'labels');

Предварительная обработка обучающих данных

Создайте imageDatastore для хранения изображений в наборе данных CamVid.

imds = imageDatastore(imgDir);
imageSize = [576 768];

Определите имена классов и идентификаторы меток пикселей 32 классов в наборе данных CamVid с помощью функции helper defineCamVid32ClassesAndPixelLabelIDs. Получите стандартную карту цветов для набора данных CamVid с помощью функции helper camvid32ColorMap. Вспомогательные функции присоединены к примеру как вспомогательные файлы.

numClasses = 32;
[classes,labelIDs] = defineCamVid32ClassesAndPixelLabelIDs;
cmap = camvid32ColorMap;

Создайте pixelLabelDatastore (Computer Vision Toolbox) для хранения изображений меток пикселей.

pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

Предварительный просмотр изображения метки пикселя и соответствующего изображения основной истины сцены. Преобразуйте метки из категорийных меток в цвета RGB с помощью label2rgb (Image Processing Toolbox), затем отобразите пиксельное изображение метки и наземное изображение истинности в монтаже.

im = preview(imds);
px = preview(pxds);
px = label2rgb(px,cmap);
montage({px,im})

Разделите данные на обучающие и тестовые наборы с помощью функции helper partitionCamVidForPix2PixHD. Эта функция присоединена к примеру как вспомогательный файл. Функция helper разделяет данные на 648 обучающих файлов и 32 тестовых файлов.

[imdsTrain,imdsTest,pxdsTrain,pxdsTest] = partitionCamVidForPix2PixHD(imds,pxds,classes,labelIDs);

Используйте combine функция для объединения изображений меток пикселей и основной истины изображений сцены в один datastore.

dsTrain = combine(pxdsTrain,imdsTrain);

Увеличение обучающих данных при помощи transform функция с пользовательскими операциями предварительной обработки, заданными функцией helper preprocessCamVidForPix2PixHD. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.

The preprocessCamVidForPix2PixHD функция выполняет следующие операции:

  1. Масштабируйте достоверные данные по области значений [-1, 1]. Эта область значений соответствует области значений конечных tanhLayer в сети генератора.

  2. Измените размер изображения и меток на выходной размер сети, 576 на 768 пикселей, с помощью бикубической и ближайшей соседней соседней понижающей дискретизации, соответственно.

  3. Преобразуйте одноканальную карту сегментации в 32-канальную однокодированную карту сегментации с горячим кодированием с помощью onehotencode функция.

  4. Случайным образом разверните изображение и пиксельные пары меток в горизонтальном направлении.

dsTrain = transform(dsTrain,@(x) preprocessCamVidForPix2PixHD(x,imageSize));

Предварительный просмотр каналов однокодированной закодированной карты сегментации в монтаже. Каждый канал представляет 1-горячую карту, соответствующую пикселям уникального класса.

map = preview(dsTrain);
montage(map{1},'Size',[4 8],'Bordersize',5,'BackgroundColor','b')

Создайте сеть генератора

Задайте сеть генератора pix2pixHD, которая генерирует изображение сцены из глубинной однокодированной карты сегментации. Этот вход имеет ту же высоту и ширину, что и исходная карта сегментации, и то же количество каналов, что и классы.

generatorInputSize = [imageSize numClasses];

Создайте сеть генератора pix2pixHD с помощью pix2pixHDGlobalGenerator (Image Processing Toolbox) функция.

dlnetGenerator = pix2pixHDGlobalGenerator(generatorInputSize);

Отображение сетевой архитектуры.

analyzeNetwork(dlnetGenerator)

Обратите внимание, что этот пример показывает использование глобального генератора pix2pixHD для генерации изображений размером 576 на 768 пикселей. Чтобы создать локальные сети энхансеров, которые генерируют изображения с более высоким разрешением, таким как 1152 на 1536 пикселей или даже выше, можно использовать addPix2PixHDLocalEnhancer (Image Processing Toolbox) функция. Локальные сети энхансеров помогают генерировать мелкие детали уровня при очень высоких разрешениях.

Создайте сеть дискриминатора

Задайте закрашенную фигуру сети дискриминатора GAN, которые классифицируют вход изображение как реальное (1) или поддельное (0). Этот пример использует две сети дискриминаторов в разных входных масштабах, также известных как многомасштабные дискриминаторы шкалы. Первая шкала совпадает с размером изображения, а вторая шкала равна половине размера изображения.

Вход дискриминатора является глубинной конкатенацией однокодированных закодированных карт сегментации и изображения сцены, которое будет классифицировано. Укажите количество каналов, вводимых в дискриминатор, как общее количество маркированных классов и цветовых каналов изображений.

numImageChannels = 3;
numChannelsDiscriminator = numClasses + numImageChannels;

Задайте размер входа первого дискриминатора. Создайте закрашенную фигуру GAN с нормализацией образца с помощью patchGANDiscriminator (Image Processing Toolbox) функция.

discriminatorInputSizeScale1 = [imageSize numChannelsDiscriminator];
dlnetDiscriminatorScale1 = patchGANDiscriminator(discriminatorInputSizeScale1,"NormalizationLayer","instance");

Укажите размер входа второго дискриминатора как половину размера изображения, затем создайте вторую закрашенную фигуру GAN.

discriminatorInputSizeScale2 = [floor(imageSize)./2 numChannelsDiscriminator];
dlnetDiscriminatorScale2 = patchGANDiscriminator(discriminatorInputSizeScale2,"NormalizationLayer","instance");

Визуализация сетей.

analyzeNetwork(dlnetDiscriminatorScale1);
analyzeNetwork(dlnetDiscriminatorScale2);

Задайте градиенты модели и функции потерь

Функция помощника modelGradients вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери соответствия признаков и потери VGG для генератора. Эта функция определяется в разделе Вспомогательные функции этого примера.

Потери генератора

Цель генератора состоит в том, чтобы сгенерировать изображения, которые дискриминатор классифицирует как действительные (1). Потеря генератора состоит из трех потерь.

  • Состязательные потери вычисляются как квадратное различие между вектором таковых и предсказаниями дискриминатора на сгенерированном изображении. Yˆgenerated являются предсказаниями дискриминатора на изображении, сгенерированном генератором. Эта потеря реализована с использованием части pix2pixhdAdversarialLoss вспомогательная функция, заданная в разделе Вспомогательные функции этого примера.

lossAdversarialGenerator=(1-Yˆgenerated)2

  • Функция, совпадающая с потерей, штрафует L1 расстояние между действительной и сгенерированной картами функций, полученное как предсказания от сети дискриминатора. T - общее количество слоев функций дискриминатора. Yreal и Yˆgenerated являются основные истины изображениями и сгенерированными изображениями, соответственно. Эта потеря реализована с помощью pix2pixhdFeatureMatchingLoss вспомогательная функция, заданная в разделе Вспомогательные функции этого примера

lossFeatureMatching=i=1T||Yreal-Yˆgenerated||1

  • Восприятие потери наказывает L1 расстояние между реальной и сгенерированной картами функций, полученное как предсказания от сети редукции данных. T - общее количество слоев функций. YVggReal и YˆVggGenerated являются сетевыми прогнозами для основной истины изображений и сгенерированных изображений, соответственно. Эта потеря реализована с помощью pix2pixhdVggLoss вспомогательная функция, заданная в разделе Вспомогательные функции этого примера. Сеть редукции данных создается в окне «Загрузка сети редукции данных».

lossVgg=i=1T||YVggReal-YˆVggGenerated||1

Общие потери генератора являются взвешенной суммой всех трех потерь. λ1, λ2, и λ3 являются коэффициентами веса для состязательной потери, потери соответствия функций и потери восприятия, соответственно.

lossGenerator=λ1*lossAdversarialGenerator+λ2*lossFeatureMatching+λ3*lossPerceptual

Обратите внимание, что состязательные потери и потери соответствия функций для генератора вычисляются для двух различных шкал.

Потеря дискриминатора

Цель дискриминатора состоит в том, чтобы правильно различать основную истину изображения и сгенерированные изображения. Потеря дискриминатора - это сумма двух компонентов:

  • Квадратное различие между вектором таковых и предсказаниями дискриминатора на вещественных изображениях

  • Квадратное различие между нулевым вектором и предсказаниями дискриминатора на сгенерированных изображениях

lossDiscriminator=(1-Yreal)2+(0-Yˆgenerated)2

Потеря дискриминатора реализована с помощью части pix2pixhdAdversarialLoss вспомогательная функция, заданная в разделе Вспомогательные функции этого примера. Обратите внимание, что состязательные потери для дискриминатора вычисляются для двух различных шкал дискриминатора.

Загрузка сети редукции данных

Этот пример модифицирует предварительно обученную VGG-19 глубокую нейронную сеть, чтобы извлечь функции реальных и сгенерированных изображений в различных слоях. Эти многослойные функции используются, чтобы вычислить восприятие потери генератора.

Чтобы получить предварительно обученную VGG-19 сеть, установите vgg19. Если у вас нет установленных необходимых пакетов поддержки, то программное обеспечение предоставляет ссылку на загрузку.

netVGG = vgg19;

Визуализируйте сетевую архитектуру с помощью приложения Deep Network Designer.

deepNetworkDesigner(netVGG)

Чтобы сделать VGG-19 сеть подходящей для редукции данных, сохраните слои до 'pool5' и удалите все полносвязные слои из сети. Получившаяся сеть является полностью сверточной сетью.

netVGG = layerGraph(netVGG.Layers(1:38));

Создайте новый входной слой изображения без нормализации. Замените оригинальное изображение входа слой новым слоем.

inp = imageInputLayer([imageSize 3],"Normalization","None","Name","Input");
netVGG = replaceLayer(netVGG,"input",inp);
netVGG = dlnetwork(netVGG);

Настройка опций обучения

Задайте опции для оптимизации Adam. Обучайте на 60 эпох. Задайте одинаковые опции для сетей генератора и дискриминатора.

  • Задайте равную скорость обучения 0,0002.

  • Инициализируйте конечный средний градиент и конечный средний градиент-квадратные скорости распада с [].

  • Используйте коэффициент градиентного распада 0,5 и квадратный коэффициент градиентного распада 0,999.

  • Используйте мини-пакет размером 1 для обучения.

numEpochs = 60;
learningRate = 0.0002;
trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminatorScale1 = [];
trailingAvgSqDiscriminatorScale1 = [];
trailingAvgDiscriminatorScale2 = [];
trailingAvgSqDiscriminatorScale2 = [];
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
miniBatchSize = 1;

Создайте minibatchqueue объект, который управляет мини-пакетированием наблюдений в пользовательском цикле обучения. The minibatchqueue объект также переводит данные в dlarray объект, который позволяет проводить автоматическую дифференциацию в применениях глубокого обучения.

Задайте формат извлечения данных пакета следующим SSCB (пространственный, пространственный, канальный, пакетный). Установите DispatchInBackground Аргумент пары "имя-значение" как логическое значение, возвращаемое canUseGPU. Если поддерживаемый графический процессор доступен для расчетов, то minibatchqueue объект обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.

mbqTrain = minibatchqueue(dsTrain,"MiniBatchSize",miniBatchSize, ...
   "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);

Обучите сеть

По умолчанию пример загружает предварительно обученную версию сети генератора pix2pixHD для набора данных CamVid с помощью функции helper downloadTrainedPix2PixHDNet. Функция helper присоединена к примеру как вспомогательный файл. Предварительно обученная сеть позволяет запускать весь пример, не дожидаясь завершения обучения.

Чтобы обучить сеть, установите doTraining переменная в следующем коде, для true. Обучите модель в пользовательском цикле обучения. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next функция.

  • Оцените градиенты модели с помощью dlfeval функции и modelGradients вспомогательная функция.

  • Обновляйте параметры сети с помощью adamupdate функция.

  • Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.

Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения дополнительной информации смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

Обучение занимает около 22 часов на NVIDIA™ Titan RTX и может занять еще больше времени в зависимости от вашего графического процессора оборудования. Если у ваш графический процессор меньше памяти, попробуйте уменьшить размер входных изображений, задав imageSize переменная как [480 640] в разделе «Предварительная обработка обучающих данных» примера.

doTraining = false;
if doTraining
    fig = figure;    
    
    lossPlotter = configureTrainingProgressPlotter(fig);
    iteration = 0;

    % Loop over epochs
    for epoch = 1:numEpochs
        
        % Reset and shuffle the data
        reset(mbqTrain);
        shuffle(mbqTrain);
 
        % Loop over each image
        while hasdata(mbqTrain)
            iteration = iteration + 1;
            
            % Read data from current mini-batch
            [dlInputSegMap,dlRealImage] = next(mbqTrain);
            
            % Evaluate the model gradients and the generator state using
            % dlfeval and the GANLoss function listed at the end of the
            % example
            [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = dlfeval( ...
                @modelGradients,dlInputSegMap,dlRealImage,dlnetGenerator,dlnetDiscriminatorScale1,dlnetDiscriminatorScale2,netVGG);
            
            % Update the generator parameters
            [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = adamupdate( ...
                dlnetGenerator,gradParamsG, ...
                trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Update the discriminator scale1 parameters
            [dlnetDiscriminatorScale1,trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1] = adamupdate( ...
                dlnetDiscriminatorScale1,gradParamsDScale1, ...
                trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Update the discriminator scale2 parameters
            [dlnetDiscriminatorScale2,trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2] = adamupdate( ...
                dlnetDiscriminatorScale2,gradParamsDScale2, ...
                trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Plot and display various losses
            lossPlotter = updateTrainingProgressPlotter(lossPlotter,iteration, ...
                epoch,numEpochs,lossD,lossGGAN,lossGFM,lossGVGG);
        end
    end
    save('trainedPix2PixHDNet.mat','dlnetGenerator');
    
else    
    trainedPix2PixHDNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedPix2PixHDv2.zip';
    netDir = fullfile(tempdir,'CamVid');
    downloadTrainedPix2PixHDNet(trainedPix2PixHDNet_url,netDir);
    load(fullfile(netDir,'trainedPix2PixHDv2.mat'));
end

Оценка сгенерированных изображений из тестовых данных

Эффективность этой обученной Pix2PixHD сети ограничена, потому что количество обучающих изображений CamVid относительно мало. Кроме того, некоторые изображения относятся к последовательности изображений и, следовательно, коррелируют с другими изображениями в набор обучающих данных. Чтобы улучшить эффективность Pix2PixHD сети, обучите сеть с помощью другого набора данных, который имеет большее количество обучающих изображений без корреляции.

Из-за ограничений эта Pix2PixHD сеть генерирует более реалистичные изображения для одних тестовых изображений, чем для других. Чтобы продемонстрировать различие в результатах, сравните сгенерированные изображения для первого и третьего тестового изображения. Угол камеры первого тестового изображения имеет необычную точку расположения, которая обращена более перпендикулярно дороге, чем типовое обучающее изображение. В противоположность этому угол камеры третьего тестового изображения имеет типовую точку обзора, которая обращена вдоль дороги и показывает две полосы с маркерами маршрута. Сеть имеет значительно лучшую эффективность, генерируя реалистичное изображение для третьего тестового изображения, чем для первого тестового изображения.

Получите первую основную истину изображение сцены из тестовых данных. Измените размер изображения с помощью бикубической интерполяции.

idxToTest = 1;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");

Получите соответствующее изображение метки пикселя из тестовых данных. Измените размер изображения метки пикселя с помощью самой близкой соседней интерполяции.

segMap = readimage(pxdsTest,idxToTest);
segMap = imresize(segMap,imageSize,"nearest");

Преобразуйте изображение метки пикселя в многоканальную одногретую карту сегментации при помощи onehotencode функция.

segMapOneHot = onehotencode(segMap,3,'single');

Создание dlarray объекты, которые вводят данные в генератор. Если поддерживаемый графический процессор доступен для расчетов, выполните вывод на графическом процессоре, преобразовав данные в gpuArray объект.

dlSegMap = dlarray(segMapOneHot,'SSCB'); 
if canUseGPU
    dlSegMap = gpuArray(dlSegMap);
end

Сгенерируйте изображение сцены из генератора и одногретую карту сегментации с помощью predict функция.

dlGeneratedImage = predict(dlnetGenerator,dlSegMap);
generatedImage = extractdata(gather(dlGeneratedImage));

Конечный слой сети генератора производит активации в области значений [-1, 1]. Для отображения измените значения активации на область значений [0, 1].

generatedImage = rescale(generatedImage);

Для отображения преобразуйте метки из категориальных меток в цвета RGB с помощью label2rgb (Image Processing Toolbox) функция.

coloredSegMap = label2rgb(segMap,cmap);

Отобразите изображение метки пикселя RGB, сгенерированное изображение сцены и изображение сцены основной истины в монтаже.

figure
montage({coloredSegMap generatedImage gtImage},'Size',[1 3])
title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])

Получите третью основную истину изображение сцены из тестовых данных. Измените размер изображения с помощью бикубической интерполяции.

idxToTest = 3;  
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");

Чтобы получить третье изображение метки пикселя из тестовых данных и сгенерировать соответствующее изображение сцены, можно использовать функцию helper evaluatePix2PixHD. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.

The evaluatePix2PixHD функция выполняет те же операции, что и оценка первого тестового изображения:

  • Получите изображение метки пикселя из тестовых данных. Измените размер изображения метки пикселя с помощью самой близкой соседней интерполяции.

  • Преобразуйте изображение метки пикселя в многоканальную одногретую карту сегментации с помощью onehotencode функция.

  • Создайте dlarray объект для входных данных в генератор. Для вывода графический процессор преобразуйте данные в gpuArray объект.

  • Сгенерируйте изображение сцены из генератора и одногретую карту сегментации с помощью predict функция.

  • Переопределите значения активации в области значений [0, 1].

[generatedImage,segMap] = evaluatePix2PixHD(pxdsTest,idxToTest,imageSize,dlnetGenerator);

Для отображения преобразуйте метки из категориальных меток в цвета RGB с помощью label2rgb (Image Processing Toolbox) функция.

coloredSegMap = label2rgb(segMap,cmap);

Отобразите изображение метки пикселя RGB, сгенерированное изображение сцены и изображение сцены основной истины в монтаже.

figure
montage({coloredSegMap generatedImage gtImage},'Size',[1 3])
title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])

Оценка сгенерированных изображений из пользовательских изображений меток пикселей

Чтобы оценить, насколько хорошо сеть обобщает изображения меток пикселей за пределами набора данных CamVid, сгенерируйте изображения сцены из пользовательских изображений меток пикселей. Этот пример использует изображения меток пикселей, которые были созданы с помощью приложения Image Labeler (Computer Vision Toolbox). Изображения меток пикселей присоединены к примеру как вспомогательные файлы. Основные истины отсутствуют.

Создайте datastore метки пикселя, который читает и обрабатывает изображения метки пикселя в текущем примере директории.

cpxds = pixelLabelDatastore(pwd,classes,labelIDs);

Для каждого изображения метки пикселя в datastore сгенерируйте изображение сцены с помощью функции helper evaluatePix2PixHD.

for idx = 1:length(cpxds.Files)

    % Get the pixel label image and generated scene image
    [generatedImage,segMap] = evaluatePix2PixHD(cpxds,idx,imageSize,dlnetGenerator);
    
    % For display, convert the labels from categorical labels to RGB colors
    coloredSegMap = label2rgb(segMap);
    
    % Display the pixel label image and generated scene image in a montage
    figure
    montage({coloredSegMap generatedImage})
    title(['Custom Pixel Label Image ',num2str(idx),' and Generated Scene Image'])

end

Вспомогательные функции

Функция градиентов модели

The modelGradients Функция helper вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери соответствия признаков и потери VGG для генератора.

function [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = modelGradients(inputSegMap,realImage,generator,discriminatorScale1,discriminatorScale2,netVGG)
              
    % Compute the image generated by the generator given the input semantic
    % map.
    generatedImage = forward(generator,inputSegMap);
    
    % Define the loss weights
    lambdaDiscriminator = 1;
    lambdaGenerator = 1;
    lambdaFeatureMatching = 5;
    lambdaVGG = 5;
    
    % Concatenate the image to be classified and the semantic map
    inpDiscriminatorReal = cat(3,inputSegMap,realImage);
    inpDiscriminatorGenerated = cat(3,inputSegMap,generatedImage);
    
    % Compute the adversarial loss for the discriminator and the generator
    % for first scale.
    [DLossScale1,GLossScale1,realPredScale1D,fakePredScale1G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale1);
    
    % Scale the generated image, the real image, and the input semantic map to
    % half size
    resizedRealImage = dlresize(realImage, 'Scale',0.5, 'Method',"linear");
    resizedGeneratedImage = dlresize(generatedImage,'Scale',0.5,'Method',"linear");
    resizedinputSegMap = dlresize(inputSegMap,'Scale',0.5,'Method',"nearest");
    
    % Concatenate the image to be classified and the semantic map
    inpDiscriminatorReal = cat(3,resizedinputSegMap,resizedRealImage);
    inpDiscriminatorGenerated = cat(3,resizedinputSegMap,resizedGeneratedImage);
    
    % Compute the adversarial loss for the discriminator and the generator
    % for second scale.
    [DLossScale2,GLossScale2,realPredScale2D,fakePredScale2G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale2);
    
    % Compute the feature matching loss for first scale.
    FMLossScale1 = pix2pixHDFeatureMatchingLoss(realPredScale1D,fakePredScale1G);
    FMLossScale1 = FMLossScale1 * lambdaFeatureMatching;
    
    % Compute the feature matching loss for second scale.
    FMLossScale2 = pix2pixHDFeatureMatchingLoss(realPredScale2D,fakePredScale2G);
    FMLossScale2 = FMLossScale2 * lambdaFeatureMatching;
    
    % Compute the VGG loss
    VGGLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG);
    VGGLoss = VGGLoss * lambdaVGG;
    
    % Compute the combined generator loss
    lossGCombined = GLossScale1 + GLossScale2 + FMLossScale1 + FMLossScale2 + VGGLoss;
    lossGCombined = lossGCombined * lambdaGenerator;
    
    % Compute gradients for the generator
    gradParamsG = dlgradient(lossGCombined,generator.Learnables,'RetainData',true);
    
    % Compute the combined discriminator loss
    lossDCombined = (DLossScale1 + DLossScale2)/2 * lambdaDiscriminator;
    
    % Compute gradients for the discriminator scale1
    gradParamsDScale1 = dlgradient(lossDCombined,discriminatorScale1.Learnables,'RetainData',true);
    
    % Compute gradients for the discriminator scale2
    gradParamsDScale2 = dlgradient(lossDCombined,discriminatorScale2.Learnables);
    
    % Log the values for displaying later
    lossD = gather(extractdata(lossDCombined));
    lossGGAN = gather(extractdata(GLossScale1 + GLossScale2));
    lossGFM  = gather(extractdata(FMLossScale1 + FMLossScale2));
    lossGVGG = gather(extractdata(VGGLoss));
end

Функция состязательных потерь

Функция помощника pix2pixHDAdverserialLoss вычисляет градиенты состязательных потерь для генератора и дискриминатора. Функция также возвращает карты признаков реального изображения и синтетические изображения.

function [DLoss,GLoss,realPredFtrsD,genPredFtrsD] = pix2pixHDAdverserialLoss(inpReal,inpGenerated,discriminator)

    % Discriminator layer names containing feature maps
    featureNames = {'act_top','act_mid_1','act_mid_2','act_tail','conv2d_final'};
    
    % Get the feature maps for the real image from the discriminator    
    realPredFtrsD = cell(size(featureNames));
    [realPredFtrsD{:}] = forward(discriminator,inpReal,"Outputs",featureNames);
    
    % Get the feature maps for the generated image from the discriminator    
    genPredFtrsD = cell(size(featureNames));
    [genPredFtrsD{:}] = forward(discriminator,inpGenerated,"Outputs",featureNames);
    
    % Get the feature map from the final layer to compute the loss
    realPredD = realPredFtrsD{end};
    genPredD = genPredFtrsD{end};
    
    % Compute the discriminator loss
    DLoss = (1 - realPredD).^2 + (genPredD).^2;
    DLoss = mean(DLoss,"all");
    
    % Compute the generator loss
    GLoss = (1 - genPredD).^2;
    GLoss = mean(GLoss,"all");
end

Функция соответствия признаков потерь

Функция помощника pix2pixHDFeatureMatchingLoss вычисляет потери соответствия функции между реальным изображением и синтетическим изображением, сгенерированным генератором.

function featureMatchingLoss = pix2pixHDFeatureMatchingLoss(realPredFtrs,genPredFtrs)

    % Number of features
    numFtrsMaps = numel(realPredFtrs);
    
    % Initialize the feature matching loss
    featureMatchingLoss = 0;
    
    for i = 1:numFtrsMaps
        % Get the feature maps of the real image
        a = extractdata(realPredFtrs{i});
        % Get the feature maps of the synthetic image
        b = genPredFtrs{i};
        
        % Compute the feature matching loss
        featureMatchingLoss = featureMatchingLoss + mean(abs(a - b),"all");
    end
end

Перцептивная функция потерь VGG

Функция помощника pix2pixHDVGGLoss вычисляет восприятие потерь VGG между реальным изображением и синтетическим изображением, сгенерированным генератором.

function vggLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG)

    featureWeights = [1.0/32 1.0/16 1.0/8 1.0/4 1.0];
    
    % Initialize the VGG loss
    vggLoss = 0;
    
    % Specify the names of the layers with desired feature maps
    featureNames = ["relu1_1","relu2_1","relu3_1","relu4_1","relu5_1"];
    
    % Extract the feature maps for the real image
    activReal = cell(size(featureNames));
    [activReal{:}] = forward(netVGG,realImage,"Outputs",featureNames);
    
    % Extract the feature maps for the synthetic image
    activGenerated = cell(size(featureNames));
    [activGenerated{:}] = forward(netVGG,generatedImage,"Outputs",featureNames);
    
    % Compute the VGG loss
    for i = 1:numel(featureNames)
        vggLoss = vggLoss + featureWeights(i)*mean(abs(activReal{i} - activGenerated{i}),"all");
    end
end

Ссылки

[1] Ван, Тин-Чун, Мин-Ю Лю, Цзюнь-Янь Чжу, Эндрю Тао, Ян Каутц и Брайан Катандзаро. «Синтез изображений в высоком разрешении и семантическая манипуляция с условными GAN». В 2018 году IEEE/CVF Conference on Компьютерное Зрение and Pattern Recognition, 8798-8807, 2018. https://doi.org/10.1109/CVPR.2018.00917.

[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. Semantic Object Classes in Video: A High-Definition Ground Truth Database (неопр.) (недоступная ссылка). Распознавание Букв. Том 30, Выпуск 2, 2009, стр. 88-97.

См. также

| | | | | | | | | (Computer Vision Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте