Сгенерируйте изображение из карты сегментации Используя глубокое обучение

В этом примере показано, как сгенерировать синтетическое изображение сцены из карты семантической сегментации с помощью pix2pixHD условной порождающей соперничающей сети (CGAN).

Pix2pixHD [1] состоит из двух сетей, которые обучены одновременно, чтобы максимизировать эффективность обоих.

  1. Генератор является нейронной сетью стиля декодера энкодера, которая генерирует изображение сцены из карты семантической сегментации. Сеть CGAN обучает генератор генерировать изображение сцены, которое различитель неправильно классифицирует как действительное.

  2. Различитель полностью сверточная нейронная сеть, которая сравнивает сгенерированное изображение сцены и соответствующее действительное изображение и пытается классифицировать их как фальшивку и действительный, соответственно. Сеть CGAN обучает различитель правильно различать сгенерированное и действительное изображение.

Генератор и сети различителя конкурируют друг против друга во время обучения. Обучение сходится, когда никакая сеть не может улучшиться далее.

Загрузите набор данных CamVid

Этот пример использует Набор данных CamVid 2[] из Кембриджского университета для обучения. Этот набор данных является набором 701 изображения, содержащего представления уличного уровня, полученные при управлении. Набор данных обеспечивает пиксельные метки для 32 семантических классов включая автомобиль, пешехода и дорогу.

Загрузите набор данных CamVid с этих URL. Время загрузки зависит от вашего интернет-соединения.

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

dataDir = fullfile(tempdir,'CamVid'); 
downloadCamVidData(dataDir,imageURL,labelURL);
imgDir = fullfile(dataDir,"images","701_StillsRaw_full");
labelDir = fullfile(dataDir,'labels');

Предварительно обработайте обучающие данные

Создайте imageDatastore сохранить изображения в наборе данных CamVid.

imds = imageDatastore(imgDir);
imageSize = [576 768];

Задайте имена классов и пиксельную метку IDs этих 32 классов в наборе данных CamVid с помощью функции помощника defineCamVid32ClassesAndPixelLabelIDs. Получите стандартную палитру для набора данных CamVid с помощью функции помощника camvid32ColorMap. Функции помощника присоединены к примеру как к вспомогательным файлам.

numClasses = 32;
[classes,labelIDs] = defineCamVid32ClassesAndPixelLabelIDs;
cmap = camvid32ColorMap;

Создайте pixelLabelDatastore (Computer Vision Toolbox), чтобы сохранить пиксель помечает изображения.

pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

Предварительно просмотрите пиксельное изображение метки и соответствующее изображение сцены основной истины. Преобразуйте метки от категориальных меток до цветов RGB при помощи label2rgb (Image Processing Toolbox) функция, затем отобразите пиксельное изображение метки и изображение основной истины в монтаже.

im = preview(imds);
px = preview(pxds);
px = label2rgb(px,cmap);
montage({px,im})

Разделите данные в наборы обучающих данных и наборы тестов с помощью функции помощника partitionCamVidForPix2PixHDЭта функция присоединена к примеру как вспомогательный файл. Функция помощника разделяет данные в 648 учебных файлов и 32 тестовых файла.

[imdsTrain,imdsTest,pxdsTrain,pxdsTest] = partitionCamVidForPix2PixHD(imds,pxds,classes,labelIDs);

Используйте combine функционируйте, чтобы объединить пиксельные изображения метки и изображения сцены основной истины в один datastore.

dsTrain = combine(pxdsTrain,imdsTrain);

Увеличьте обучающие данные при помощи transform функция с пользовательскими операциями предварительной обработки, заданными помощником, функционирует preprocessCamVidForPix2PixHD. Эта функция помощника присоединена к примеру как к вспомогательному файлу.

preprocessCamVidForPix2PixHD функция выполняет эти операции:

  1. Масштабируйте достоверные данные к области значений [-1, 1]. Эта область значений совпадает с областью значений итогового tanhLayer в сети генератора.

  2. Измените размер изображения и меток к выходному размеру сети, 576 768 пиксели, с помощью bicubic и самая близкая соседняя субдискретизация, соответственно.

  3. Преобразуйте одну карту сегментации канала в одногорячую закодированную карту сегментации с 32 каналами с помощью onehotencode функция.

  4. Случайным образом зеркальное изображение и пиксель помечают пары в горизонтальном направлении.

dsTrain = transform(dsTrain,@(x) preprocessCamVidForPix2PixHD(x,imageSize));

Предварительно просмотрите каналы одногорячей закодированной карты сегментации в монтаже. Каждый канал представляет одногорячую карту, соответствующую пикселям уникального класса.

map = preview(dsTrain);
montage(map{1},'Size',[4 8],'Bordersize',5,'BackgroundColor','b')

Создайте сеть генератора

Задайте pix2pixHD сеть генератора, которая генерирует изображение сцены из мудрой глубиной одногорячей закодированной карты сегментации. Этот вход имеет ту же высоту и ширину как исходная карта сегментации и то же количество каналов как классы.

generatorInputSize = [imageSize numClasses];

Создайте pix2pixHD сеть генератора использование pix2pixHDGlobalGenerator (Image Processing Toolbox) функция.

dlnetGenerator = pix2pixHDGlobalGenerator(generatorInputSize);

Отобразите сетевую архитектуру.

analyzeNetwork(dlnetGenerator)

Обратите внимание на то, что этот пример показывает использование pix2pixHD глобального генератора для генерации изображений размера 576 768 пиксели. Чтобы создать локальные сети усилителя, которые генерируют изображения в более высоком разрешении такой как 1152 1536 пиксели или еще выше, можно использовать addPix2PixHDLocalEnhancer (Image Processing Toolbox) функция. Локальные сети усилителя помогают сгенерировать прекрасные детали уровня в очень высоких разрешениях.

Создайте сеть различителя

Задайте закрашенную фигуру сети различителя GAN, который классифицирует входное изображение или как действительное (1) или как фальшивка (0). Этот пример использует две сети различителя в различных входных шкалах, также известных как многошкальные различители. Первая шкала одного размера с размером изображения, и вторая шкала является половиной размера размера изображения.

Вход к различителю является мудрой глубиной конкатенацией одногорячих закодированных карт сегментации и изображения сцены, которое будет классифицировано. Задайте количество входа каналов к различителю как общее количество помеченных классов и каналов цвета изображения.

numImageChannels = 3;
numChannelsDiscriminator = numClasses + numImageChannels;

Задайте входной размер первого различителя. Создайте закрашенную фигуру различитель GAN с нормализацией экземпляра с помощью patchGANDiscriminator (Image Processing Toolbox) функция.

discriminatorInputSizeScale1 = [imageSize numChannelsDiscriminator];
dlnetDiscriminatorScale1 = patchGANDiscriminator(discriminatorInputSizeScale1,"NormalizationLayer","instance");

Задайте входной размер второго различителя как половина размера изображения, затем создайте вторую закрашенную фигуру различитель GAN.

discriminatorInputSizeScale2 = [floor(imageSize)./2 numChannelsDiscriminator];
dlnetDiscriminatorScale2 = patchGANDiscriminator(discriminatorInputSizeScale2,"NormalizationLayer","instance");

Визуализируйте сети.

analyzeNetwork(dlnetDiscriminatorScale1);
analyzeNetwork(dlnetDiscriminatorScale2);

Градиенты модели Define и функции потерь

Функция помощника modelGradients вычисляет градиенты и соперничающую потерю для генератора и различителя. Функция также вычисляет потерю соответствия функции и потерю VGG для генератора. Эта функция задана в разделе Supporting Functions этого примера.

Потеря генератора

Цель генератора состоит в том, чтобы сгенерировать изображения, которые различитель классифицирует как действительные (1). Потеря генератора состоит из трех потерь.

  • Соперничающая потеря вычисляется как различие в квадрате между вектором из единиц и предсказаниями различителя на сгенерированном изображении. Yˆgenerated предсказания различителя на изображении, сгенерированном генератором. Эта потеря реализована с помощью части pix2pixhdAdversarialLoss функция, определяемая помощника в разделе Supporting Functions этого примера.

lossAdversarialGenerator=(1-Yˆgenerated)2

  • Потеря соответствия функции штрафует L1 расстояние между действительными и сгенерированными картами функции, полученными как предсказания из сети различителя. T общее количество слоев функции различителя. Yreal и Yˆgenerated изображения основной истины и сгенерированные изображения, соответственно. Эта потеря реализована с помощью pix2pixhdFeatureMatchingLoss функция, определяемая помощника в разделе Supporting Functions этого примера

lossFeatureMatching=i=1T||Yreal-Yˆgenerated||1

  • Перцепционная потеря штрафует L1 расстояние между действительными и сгенерированными картами функции, полученными как предсказания из сети извлечения признаков. T общее количество слоев функции. YVggReal и YˆVggGenerated сетевые предсказания для изображений основной истины и сгенерированных изображений, соответственно. Эта потеря реализована с помощью pix2pixhdVggLoss функция, определяемая помощника в разделе Supporting Functions этого примера. Сеть извлечения признаков создается в Сети Извлечения признаков Загрузки.

lossVgg=i=1T||YVggReal-YˆVggGenerated||1

Полная потеря генератора является взвешенной суммой всех трех потерь. λ1, λ2, и λ3 весовые коэффициенты за соперничающую потерю, потерю соответствия функции и перцепционную потерю, соответственно.

lossGenerator=λ1*lossAdversarialGenerator+λ2*lossFeatureMatching+λ3*lossPerceptual

Обратите внимание на то, что соперничающая потеря и потеря соответствия функции для генератора вычисляются для двух различных шкал.

Потеря различителя

Цель различителя состоит в том, чтобы правильно различать изображения основной истины и сгенерированные изображения. Потеря различителя является суммой двух компонентов:

  • Различие в квадрате между вектором из единиц и предсказаниями различителя на действительных изображениях

  • Различие в квадрате между нулевым вектором и предсказаниями различителя на сгенерированных изображениях

lossDiscriminator=(1-Yreal)2+(0-Yˆgenerated)2

Потеря различителя реализована с помощью части pix2pixhdAdversarialLoss функция, определяемая помощника в разделе Supporting Functions этого примера. Обратите внимание на то, что соперничающая потеря для различителя вычисляется для двух различных шкал различителя.

Загрузите сеть извлечения признаков

Этот пример изменяет предварительно обученную глубокую нейронную сеть VGG-19, чтобы извлечь функции действительных и сгенерированных изображений на различных слоях. Эти многоуровневые функции используются для расчета перцепционная потеря генератора.

Чтобы получить предварительно обученную сеть VGG-19, установите vgg19. Если вам не установили необходимые пакеты поддержки, то программное обеспечение обеспечивает ссылку на загрузку.

netVGG = vgg19;

Визуализируйте сетевую архитектуру с помощью приложения Deep Network Designer.

deepNetworkDesigner(netVGG)

Чтобы сделать сеть VGG-19 подходящей для извлечения признаков, сохраните слои до 'pool5' и удалите все полносвязные слоя от сети. Получившаяся сеть является полностью сверточной сетью.

netVGG = layerGraph(netVGG.Layers(1:38));

Создайте новый входной слой изображений без нормализации. Замените входной слой оригинального изображения на новый слой.

inp = imageInputLayer([imageSize 3],"Normalization","None","Name","Input");
netVGG = replaceLayer(netVGG,"input",inp);
netVGG = dlnetwork(netVGG);

Задайте опции обучения

Задайте опции для оптимизации Адама. Обучайтесь в течение 60 эпох. Задайте идентичные опции для сетей различителя и генератора.

  • Задайте равную скорость обучения 0,0002.

  • Инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с [].

  • Используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.

  • Используйте мини-пакетный размер 1 для обучения.

numEpochs = 60;
learningRate = 0.0002;
trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminatorScale1 = [];
trailingAvgSqDiscriminatorScale1 = [];
trailingAvgDiscriminatorScale2 = [];
trailingAvgSqDiscriminatorScale2 = [];
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
miniBatchSize = 1;

Создайте minibatchqueue объект, который справляется с мини-пакетной обработкой наблюдений в пользовательском учебном цикле. minibatchqueue возразите также бросает данные к dlarray объект, который включает автоматическое дифференцирование в применении глубокого обучения.

Задайте мини-пакетный формат экстракции данных как SSCB (пространственный, пространственный, канал, пакет). Установите DispatchInBackground аргумент пары "имя-значение" как boolean, возвращенный canUseGPU. Если поддерживаемый графический процессор доступен для расчета, то minibatchqueue объект предварительно обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.

mbqTrain = minibatchqueue(dsTrain,"MiniBatchSize",miniBatchSize, ...
   "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);

Обучите сеть

По умолчанию пример загружает предварительно обученную версию pix2pixHD сети генератора для набора данных CamVid при помощи функции помощника downloadTrainedPix2PixHDNet. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.

Чтобы обучить сеть, установите doTraining переменная в следующем коде к true. Обучите модель в пользовательском учебном цикле. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next функция.

  • Оцените градиенты модели с помощью dlfeval функционируйте и modelGradients функция помощника.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

Обучение занимает приблизительно 22 часа на Титане NVIDIA™ RTX и может взять еще дольше в зависимости от вашего оборудования графического процессора. Если ваше устройство графического процессора имеет меньше памяти, попытайтесь уменьшать размер входных изображений путем определения imageSize переменная как [480 640] в разделе Preprocess Training Data примера.

doTraining = false;
if doTraining
    fig = figure;    
    
    lossPlotter = configureTrainingProgressPlotter(fig);
    iteration = 0;

    % Loop over epochs
    for epoch = 1:numEpochs
        
        % Reset and shuffle the data
        reset(mbqTrain);
        shuffle(mbqTrain);
 
        % Loop over each image
        while hasdata(mbqTrain)
            iteration = iteration + 1;
            
            % Read data from current mini-batch
            [dlInputSegMap,dlRealImage] = next(mbqTrain);
            
            % Evaluate the model gradients and the generator state using
            % dlfeval and the GANLoss function listed at the end of the
            % example
            [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = dlfeval( ...
                @modelGradients,dlInputSegMap,dlRealImage,dlnetGenerator,dlnetDiscriminatorScale1,dlnetDiscriminatorScale2,netVGG);
            
            % Update the generator parameters
            [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = adamupdate( ...
                dlnetGenerator,gradParamsG, ...
                trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Update the discriminator scale1 parameters
            [dlnetDiscriminatorScale1,trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1] = adamupdate( ...
                dlnetDiscriminatorScale1,gradParamsDScale1, ...
                trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Update the discriminator scale2 parameters
            [dlnetDiscriminatorScale2,trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2] = adamupdate( ...
                dlnetDiscriminatorScale2,gradParamsDScale2, ...
                trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Plot and display various losses
            lossPlotter = updateTrainingProgressPlotter(lossPlotter,iteration, ...
                epoch,numEpochs,lossD,lossGGAN,lossGFM,lossGVGG);
        end
    end
    save('trainedPix2PixHDNet.mat','dlnetGenerator');
    
else    
    trainedPix2PixHDNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedPix2PixHDv2.zip';
    netDir = fullfile(tempdir,'CamVid');
    downloadTrainedPix2PixHDNet(trainedPix2PixHDNet_url,netDir);
    load(fullfile(netDir,'trainedPix2PixHDv2.mat'));
end

Оцените сгенерированные изображения от тестовых данных

Эффективность этой обученной сети Pix2PixHD ограничивается, потому что количество изображений обучения CamVid относительно мало. Кроме того, некоторые изображения принадлежат последовательности изображений и поэтому коррелируются с другими изображениями в наборе обучающих данных. Чтобы улучшить эффективность сети Pix2PixHD, обучите сеть с помощью различного набора данных, который имеет большее число учебных изображений без корреляции.

Из-за ограничений эта сеть Pix2PixHD генерирует более реалистические изображения для некоторых тестовых изображений, чем для других. Чтобы продемонстрировать различие в результатах, сравните сгенерированные изображения для первого и третьего тестового изображения. Угол камеры первого тестового изображения имеет редкую точку наблюдения, которая стоит перед большим количеством перпендикуляра к дороге, чем типичное учебное изображение. В отличие от этого угол камеры третьего тестового изображения имеет типичную точку наблюдения, которая стоит вдоль дороги и показывает два маршрута с маркерами маршрута. Сеть имеет значительно лучшую эффективность, генерирующую реалистическое изображение для третьего тестового изображения, чем для первого тестового изображения.

Получите первое изображение сцены основной истины от тестовых данных. Измените размер изображения с помощью бикубической интерполяции.

idxToTest = 1;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");

Получите соответствующее пиксельное изображение метки от тестовых данных. Измените размер пиксельного изображения метки с помощью самой близкой соседней интерполяции.

segMap = readimage(pxdsTest,idxToTest);
segMap = imresize(segMap,imageSize,"nearest");

Преобразуйте пиксельное изображение метки в многоканальную одногорячую карту сегментации при помощи onehotencode функция.

segMapOneHot = onehotencode(segMap,3,'single');

Создайте dlarray объекты, который вводит данные к генератору. Если поддерживаемый графический процессор доступен для расчета, то выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.

dlSegMap = dlarray(segMapOneHot,'SSCB'); 
if canUseGPU
    dlSegMap = gpuArray(dlSegMap);
end

Сгенерируйте изображение сцены от генератора и одногорячей карты сегментации с помощью predict функция.

dlGeneratedImage = predict(dlnetGenerator,dlSegMap);
generatedImage = extractdata(gather(dlGeneratedImage));

Последний слой сети генератора производит активации в области значений [-1, 1]. Для отображения перемасштабируйте активации к области значений [0, 1].

generatedImage = rescale(generatedImage);

Для отображения преобразуйте метки от категориальных меток до цветов RGB при помощи label2rgb (Image Processing Toolbox) функция.

coloredSegMap = label2rgb(segMap,cmap);

Отобразите пиксельное изображение метки RGB, сгенерированное изображение сцены и изображение сцены основной истины в монтаже.

figure
montage({coloredSegMap generatedImage gtImage},'Size',[1 3])
title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])

Получите третье изображение сцены основной истины от тестовых данных. Измените размер изображения с помощью бикубической интерполяции.

idxToTest = 3;  
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");

Чтобы получить третий пиксель помечают изображение от тестовых данных и сгенерировать соответствующее изображение сцены, можно использовать функцию помощника evaluatePix2PixHD. Эта функция помощника присоединена к примеру как к вспомогательному файлу.

evaluatePix2PixHD функция выполняет те же операции как оценка первого тестового изображения:

  • Получите пиксельное изображение метки от тестовых данных. Измените размер пиксельного изображения метки с помощью самой близкой соседней интерполяции.

  • Преобразуйте пиксельное изображение метки в многоканальную одногорячую карту сегментации с помощью onehotencode функция.

  • Создайте dlarray возразите против входных данных к генератору. Для вывода графического процессора преобразуйте данные в gpuArray объект.

  • Сгенерируйте изображение сцены от генератора и одногорячей карты сегментации с помощью predict функция.

  • Перемасштабируйте активации к области значений [0, 1].

[generatedImage,segMap] = evaluatePix2PixHD(pxdsTest,idxToTest,imageSize,dlnetGenerator);

Для отображения преобразуйте метки от категориальных меток до цветов RGB при помощи label2rgb (Image Processing Toolbox) функция.

coloredSegMap = label2rgb(segMap,cmap);

Отобразите пиксельное изображение метки RGB, сгенерированное изображение сцены и изображение сцены основной истины в монтаже.

figure
montage({coloredSegMap generatedImage gtImage},'Size',[1 3])
title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])

Оцените сгенерированные изображения от пользовательских пиксельных изображений метки

Чтобы оценить, как хорошо сеть делает вывод к пиксельным изображениям метки вне набора данных CamVid, сгенерируйте изображения сцены от пользовательских пиксельных изображений метки. Этот пример использует пиксельные изображения метки, которые были созданы с помощью Image Labeler (Computer Vision Toolbox) приложение. Пиксельные изображения метки присоединены к примеру как к вспомогательным файлам. Никакие изображения основной истины не доступны.

Создайте пиксельный datastore метки, который читает и обрабатывает пиксельные изображения метки в текущей директории в качестве примера.

cpxds = pixelLabelDatastore(pwd,classes,labelIDs);

Для каждого пиксельного изображения метки в datastore сгенерируйте изображение сцены с помощью функции помощника evaluatePix2PixHD.

for idx = 1:length(cpxds.Files)

    % Get the pixel label image and generated scene image
    [generatedImage,segMap] = evaluatePix2PixHD(cpxds,idx,imageSize,dlnetGenerator);
    
    % For display, convert the labels from categorical labels to RGB colors
    coloredSegMap = label2rgb(segMap);
    
    % Display the pixel label image and generated scene image in a montage
    figure
    montage({coloredSegMap generatedImage})
    title(['Custom Pixel Label Image ',num2str(idx),' and Generated Scene Image'])

end

Вспомогательные Функции

Функция градиентов модели

modelGradients функция помощника вычисляет градиенты и соперничающую потерю для генератора и различителя. Функция также вычисляет потерю соответствия функции и потерю VGG для генератора.

function [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = modelGradients(inputSegMap,realImage,generator,discriminatorScale1,discriminatorScale2,netVGG)
              
    % Compute the image generated by the generator given the input semantic
    % map.
    generatedImage = forward(generator,inputSegMap);
    
    % Define the loss weights
    lambdaDiscriminator = 1;
    lambdaGenerator = 1;
    lambdaFeatureMatching = 5;
    lambdaVGG = 5;
    
    % Concatenate the image to be classified and the semantic map
    inpDiscriminatorReal = cat(3,inputSegMap,realImage);
    inpDiscriminatorGenerated = cat(3,inputSegMap,generatedImage);
    
    % Compute the adversarial loss for the discriminator and the generator
    % for first scale.
    [DLossScale1,GLossScale1,realPredScale1D,fakePredScale1G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale1);
    
    % Scale the generated image, the real image, and the input semantic map to
    % half size
    resizedRealImage = dlresize(realImage, 'Scale',0.5, 'Method',"linear");
    resizedGeneratedImage = dlresize(generatedImage,'Scale',0.5,'Method',"linear");
    resizedinputSegMap = dlresize(inputSegMap,'Scale',0.5,'Method',"nearest");
    
    % Concatenate the image to be classified and the semantic map
    inpDiscriminatorReal = cat(3,resizedinputSegMap,resizedRealImage);
    inpDiscriminatorGenerated = cat(3,resizedinputSegMap,resizedGeneratedImage);
    
    % Compute the adversarial loss for the discriminator and the generator
    % for second scale.
    [DLossScale2,GLossScale2,realPredScale2D,fakePredScale2G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale2);
    
    % Compute the feature matching loss for first scale.
    FMLossScale1 = pix2pixHDFeatureMatchingLoss(realPredScale1D,fakePredScale1G);
    FMLossScale1 = FMLossScale1 * lambdaFeatureMatching;
    
    % Compute the feature matching loss for second scale.
    FMLossScale2 = pix2pixHDFeatureMatchingLoss(realPredScale2D,fakePredScale2G);
    FMLossScale2 = FMLossScale2 * lambdaFeatureMatching;
    
    % Compute the VGG loss
    VGGLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG);
    VGGLoss = VGGLoss * lambdaVGG;
    
    % Compute the combined generator loss
    lossGCombined = GLossScale1 + GLossScale2 + FMLossScale1 + FMLossScale2 + VGGLoss;
    lossGCombined = lossGCombined * lambdaGenerator;
    
    % Compute gradients for the generator
    gradParamsG = dlgradient(lossGCombined,generator.Learnables,'RetainData',true);
    
    % Compute the combined discriminator loss
    lossDCombined = (DLossScale1 + DLossScale2)/2 * lambdaDiscriminator;
    
    % Compute gradients for the discriminator scale1
    gradParamsDScale1 = dlgradient(lossDCombined,discriminatorScale1.Learnables,'RetainData',true);
    
    % Compute gradients for the discriminator scale2
    gradParamsDScale2 = dlgradient(lossDCombined,discriminatorScale2.Learnables);
    
    % Log the values for displaying later
    lossD = gather(extractdata(lossDCombined));
    lossGGAN = gather(extractdata(GLossScale1 + GLossScale2));
    lossGFM  = gather(extractdata(FMLossScale1 + FMLossScale2));
    lossGVGG = gather(extractdata(VGGLoss));
end

Соперничающая функция потерь

Функция помощника pix2pixHDAdverserialLoss вычисляет соперничающие градиенты потерь для генератора и различителя. Функция также возвращает карты функции действительного изображения и синтетических изображений.

function [DLoss,GLoss,realPredFtrsD,genPredFtrsD] = pix2pixHDAdverserialLoss(inpReal,inpGenerated,discriminator)

    % Discriminator layer names containing feature maps
    featureNames = {'act_top','act_mid_1','act_mid_2','act_tail','conv2d_final'};
    
    % Get the feature maps for the real image from the discriminator    
    realPredFtrsD = cell(size(featureNames));
    [realPredFtrsD{:}] = forward(discriminator,inpReal,"Outputs",featureNames);
    
    % Get the feature maps for the generated image from the discriminator    
    genPredFtrsD = cell(size(featureNames));
    [genPredFtrsD{:}] = forward(discriminator,inpGenerated,"Outputs",featureNames);
    
    % Get the feature map from the final layer to compute the loss
    realPredD = realPredFtrsD{end};
    genPredD = genPredFtrsD{end};
    
    % Compute the discriminator loss
    DLoss = (1 - realPredD).^2 + (genPredD).^2;
    DLoss = mean(DLoss,"all");
    
    % Compute the generator loss
    GLoss = (1 - genPredD).^2;
    GLoss = mean(GLoss,"all");
end

Покажите соответствие с функцией потерь

Функция помощника pix2pixHDFeatureMatchingLoss вычисляет потерю соответствия функции между действительным изображением и синтетическим изображением, сгенерированным генератором.

function featureMatchingLoss = pix2pixHDFeatureMatchingLoss(realPredFtrs,genPredFtrs)

    % Number of features
    numFtrsMaps = numel(realPredFtrs);
    
    % Initialize the feature matching loss
    featureMatchingLoss = 0;
    
    for i = 1:numFtrsMaps
        % Get the feature maps of the real image
        a = extractdata(realPredFtrs{i});
        % Get the feature maps of the synthetic image
        b = genPredFtrs{i};
        
        % Compute the feature matching loss
        featureMatchingLoss = featureMatchingLoss + mean(abs(a - b),"all");
    end
end

Перцепционная функция потерь VGG

Функция помощника pix2pixHDVGGLoss вычисляет перцепционную потерю VGG между действительным изображением и синтетическим изображением, сгенерированным генератором.

function vggLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG)

    featureWeights = [1.0/32 1.0/16 1.0/8 1.0/4 1.0];
    
    % Initialize the VGG loss
    vggLoss = 0;
    
    % Specify the names of the layers with desired feature maps
    featureNames = ["relu1_1","relu2_1","relu3_1","relu4_1","relu5_1"];
    
    % Extract the feature maps for the real image
    activReal = cell(size(featureNames));
    [activReal{:}] = forward(netVGG,realImage,"Outputs",featureNames);
    
    % Extract the feature maps for the synthetic image
    activGenerated = cell(size(featureNames));
    [activGenerated{:}] = forward(netVGG,generatedImage,"Outputs",featureNames);
    
    % Compute the VGG loss
    for i = 1:numel(featureNames)
        vggLoss = vggLoss + featureWeights(i)*mean(abs(activReal{i} - activGenerated{i}),"all");
    end
end

Ссылки

[1] Ван, Звон-Chun, Мин-Юй Лю, июнь-Yan Чжу, Эндрю Тао, Ян Коц и Брайан Кэйтанзаро. "Синтез изображений с высоким разрешением и Семантическая Манипуляция с Условным GANs". На 2018 Конференциях IEEE/CVF по Компьютерному зрению и Распознаванию образов, 8798–8807, 2018. https://doi.org/10.1109/CVPR.2018.00917.

[2] Brostow, Габриэль Дж., Жюльен Фокер и Роберто Сиполья. "Семантические Классы объектов в Видео: База данных Основной истины Высокой четкости". Pattern Recognition Letters. Vol. 30, Issue 2, 2009, стр 88-97.

Смотрите также

| | (Computer Vision Toolbox) | | | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте