В этом примере показано, как генерировать синтетическое изображение сцены из семантической карты сегментации с использованием pix2pixHD условной генеративной состязательной сети (CGAN).
Pix2pixHD [1] состоит из двух сетей, которые обучаются одновременно, чтобы максимально повысить производительность обеих сетей.
Генератор представляет собой нейронную сеть в стиле кодер-декодер, которая генерирует изображение сцены из семантической карты сегментации. Сеть CGAN обучает генератор генерировать изображение сцены, которое дискриминатор неправильно классифицирует как действительное.
Дискриминатор представляет собой полностью сверточную нейронную сеть, которая сравнивает сгенерированное изображение сцены и соответствующее реальное изображение и пытается классифицировать их как фальшивые и реальные соответственно. Сеть CGAN обучает дискриминатор правильно различать сгенерированное и реальное изображение.
Генераторные и дискриминаторные сети конкурируют друг с другом во время обучения. Обучение сходится, когда ни одна из сетей не может улучшиться.
В этом примере для Набор данных CamVid2обучения используется [] Кембриджского университета. Этот набор данных представляет собой набор из 701 изображений, содержащих виды на уровне улиц, полученные во время вождения. Набор данных предоставляет пиксельные метки для 32 семантических классов, включая автомобильный, пешеходный и дорожный.
Загрузите набор данных CamVid с этих URL-адресов. Время загрузки зависит от вашего подключения к Интернету.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; dataDir = fullfile(tempdir,'CamVid'); downloadCamVidData(dataDir,imageURL,labelURL); imgDir = fullfile(dataDir,"images","701_StillsRaw_full"); labelDir = fullfile(dataDir,'labels');
Создание imageDatastore для сохранения изображений в наборе данных CamVid.
imds = imageDatastore(imgDir); imageSize = [576 768];
Определите имена классов и идентификаторы меток пикселей 32 классов в наборе данных CamVid с помощью вспомогательной функции defineCamVid32ClassesAndPixelLabelIDs. Получение стандартной карты цветов для набора данных CamVid с помощью вспомогательной функции camvid32ColorMap. Вспомогательные функции прикрепляются к примеру как вспомогательные файлы.
numClasses = 32; [classes,labelIDs] = defineCamVid32ClassesAndPixelLabelIDs; cmap = camvid32ColorMap;
Создать pixelLabelDatastore для хранения изображений меток пикселей.
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
Предварительный просмотр изображения метки пикселя и соответствующего изображения сцены истинности земли. Преобразование меток из категориальных меток в цвета RGB с помощью label2rgb затем отображает изображение метки пикселя и изображение истинности земли в монтаже.
im = preview(imds);
px = preview(pxds);
px = label2rgb(px,cmap);
montage({px,im})
Разбиение данных на обучающие и тестовые наборы с помощью функции помощника partitionCamVidForPix2PixHD. Эта функция присоединена к примеру как вспомогательный файл. Вспомогательная функция разбивает данные на 648 учебных файлов и 32 тестовых файла.
[imdsTrain,imdsTest,pxdsTrain,pxdsTest] = partitionCamVidForPix2PixHD(imds,pxds,classes,labelIDs);
Используйте combine функция для объединения изображений меток пикселей и фоновых изображений сцены истинности в одном хранилище данных.
dsTrain = combine(pxdsTrain,imdsTrain);
Дополните учебные данные с помощью transform функция с пользовательскими операциями предварительной обработки, заданными функцией помощника preprocessCamVidForPix2PixHD. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.
preprocessCamVidForPix2PixHD функция выполняет следующие операции:
Масштабируйте истинные данные земли до диапазона [-1, 1]. Этот диапазон соответствует диапазону конечного tanhLayer (Deep Learning Toolbox) в сети генератора.
Измените размер изображения и меток до выходного размера сети, 576 на 768 пикселей, используя двойную и ближайшую соседнюю понижающую дискретизацию, соответственно.
Преобразование одноканальной карты сегментации в 32-канальную одноканальную кодированную карту сегментации с использованием onehotencode (Deep Learning Toolbox).
Случайное отражение пар изображений и меток пикселей в горизонтальном направлении.
dsTrain = transform(dsTrain,@(x) preprocessCamVidForPix2PixHD(x,imageSize));
Предварительный просмотр каналов однокодированной карты сегментации в монтаже. Каждый канал представляет одноконтурную карту, соответствующую пикселям уникального класса.
map = preview(dsTrain);
montage(map{1},'Size',[4 8],'Bordersize',5,'BackgroundColor','b')
Определите сеть генератора pix2pixHD, которая генерирует изображение сцены из карты сегментации с одноступенчатой кодировкой. Этот вход имеет ту же высоту и ширину, что и исходная карта сегментации, и то же количество каналов, что и классы.
generatorInputSize = [imageSize numClasses];
Создайте сеть генератора pix2pixHD с помощью pix2pixHDGlobalGenerator функция.
dlnetGenerator = pix2pixHDGlobalGenerator(generatorInputSize);
Отображение сетевой архитектуры.
analyzeNetwork(dlnetGenerator)

Следует отметить, что в этом примере показано использование глобального генератора pix2pixHD для генерации изображений размером 576 на 768 пикселей. Чтобы создать локальные сети усилителей, которые генерируют изображения с более высоким разрешением, например 1152 на 1536 пикселей или даже выше, можно использовать addPix2PixHDLocalEnhancer функция. Локальные сети усилителей помогают генерировать тонкие детали уровня при очень высоких разрешениях.
Определите сети дискриминаторов GAN с исправлениями, которые классифицируют входное изображение как вещественное (1) или фальшивое (0). В этом примере используются две дискриминаторные сети с различными входными масштабами, также известные как многомасштабные дискриминаторы. Первый масштаб совпадает с размером изображения, а второй масштаб равен половине размера изображения.
Входным сигналом для дискриминатора является глубинная конкатенация карт одноступенчатой кодированной сегментации и изображения сцены, подлежащего классификации. Укажите количество каналов, вводимых в дискриминатор, как общее количество помеченных классов и цветовых каналов изображения.
numImageChannels = 3; numChannelsDiscriminator = numClasses + numImageChannels;
Укажите входной размер первого дискриминатора. Создайте дискриминатор GAN исправления с нормализацией экземпляра с помощью patchGANDiscriminator функция.
discriminatorInputSizeScale1 = [imageSize numChannelsDiscriminator]; dlnetDiscriminatorScale1 = patchGANDiscriminator(discriminatorInputSizeScale1,"NormalizationLayer","instance");
Укажите входной размер второго дискриминатора как половину размера изображения, а затем создайте второй дискриминатор GAN патча.
discriminatorInputSizeScale2 = [floor(imageSize)./2 numChannelsDiscriminator]; dlnetDiscriminatorScale2 = patchGANDiscriminator(discriminatorInputSizeScale2,"NormalizationLayer","instance");
Визуализация сетей.
analyzeNetwork(dlnetDiscriminatorScale1); analyzeNetwork(dlnetDiscriminatorScale2);
Вспомогательная функция modelGradients вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери согласования функций и потери VGG для генератора. Эта функция определена в разделе «Вспомогательные функции» данного примера.
Целью генератора является формирование изображений, которые дискриминатор классифицирует как вещественные (1). Потери генератора состоят из трех потерь.
Состязательная потеря вычисляется как квадрат разности между вектором единиц и предсказаниями дискриминатора на сгенерированном изображении. являются предсказаниями дискриминатора на изображении, сгенерированном генератором. Эта потеря реализуется с использованием части pix2pixhdAdversarialLoss вспомогательная функция, определенная в разделе «Вспомогательные функции» данного примера.
) 2
Потеря согласования признаков штрафует за расстояние между реальными и сгенерированными картами признаков, полученными в виде прогнозов из сети дискриминаторов. - общее число уровней признаков дискриминатора. и - это изображения истины и сгенерированные изображения соответственно. Эта потеря реализуется с помощью pix2pixhdFeatureMatchingLoss вспомогательная функция, определенная в разделе «Вспомогательные функции» данного примера
Потеря восприятия штрафует за расстояние между реальными и сгенерированными картами признаков, полученными в виде прогнозов из сети извлечения признаков. - общее количество слоев элементов. и являются сетевыми прогнозами для фоновых изображений истины и сгенерированных изображений, соответственно. Эта потеря реализуется с помощью pix2pixhdVggLoss вспомогательная функция, определенная в разделе «Вспомогательные функции» данного примера. Сеть извлечения элементов создается в окне Загрузить сеть извлечения элементов.
Общая потеря генератора представляет собой взвешенную сумму всех трех потерь. , и являются весовыми факторами для состязательной потери, потери совпадения признаков и потери восприятия, соответственно.
lossPerceptual
Обратите внимание, что потери состязательности и потери соответствия элементов для генератора вычисляются для двух различных шкал.
Целью дискриминатора является правильное различие между фоновыми образами истинности и сгенерированными изображениями. Потеря дискриминатора представляет собой сумму двух компонентов:
Квадрат разности между вектором единиц и предсказаниями дискриминатора на реальных изображениях
Квадрат разности между вектором нулей и предсказаниями дискриминатора на сгенерированных изображениях
0-Yˆgenerated) 2
Потеря дискриминатора реализуется с использованием части pix2pixhdAdversarialLoss вспомогательная функция, определенная в разделе «Вспомогательные функции» данного примера. Следует отметить, что состязательные потери для дискриминатора вычисляются для двух различных шкал дискриминатора.
В этом примере модифицируется заранее подготовленная VGG-19 глубокая нейронная сеть для извлечения особенностей реальных и сгенерированных изображений на различных уровнях. Эти многослойные элементы используются для вычисления потери восприятия генератора.
Чтобы получить предварительно обученную сеть VGG-19, установите vgg19 (инструментарий глубокого обучения). Если необходимые пакеты поддержки не установлены, программа предоставляет ссылку для загрузки.
netVGG = vgg19;
Визуализация сетевой архитектуры с помощью приложения Deep Network Designer (Deep Learning Toolbox).
deepNetworkDesigner(netVGG)
Чтобы сделать сеть VGG-19 пригодной для извлечения элементов, поддерживайте уровень до 'pool5' и удалите все полностью подключенные слои из сети. Результирующая сеть представляет собой полностью сверточную сеть.
netVGG = layerGraph(netVGG.Layers(1:38));
Создание нового слоя ввода изображения без нормализации. Замените исходный слой ввода изображения новым слоем.
inp = imageInputLayer([imageSize 3],"Normalization","None","Name","Input"); netVGG = replaceLayer(netVGG,"input",inp); netVGG = dlnetwork(netVGG);
Укажите параметры оптимизации Adam. Поезд на 60 эпох. Укажите идентичные опции для сетей генератора и дискриминатора.
Укажите равную скорость обучения 0,0002.
Инициализируйте среднюю градиентную скорость и среднюю градиентно-квадратную скорость затухания с помощью [].
Используйте градиентный коэффициент затухания 0,5 и квадрат градиентного коэффициента затухания 0,999.
Для обучения используйте мини-пакет размером 1.
numEpochs = 60; learningRate = 0.0002; trailingAvgGenerator = []; trailingAvgSqGenerator = []; trailingAvgDiscriminatorScale1 = []; trailingAvgSqDiscriminatorScale1 = []; trailingAvgDiscriminatorScale2 = []; trailingAvgSqDiscriminatorScale2 = []; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999; miniBatchSize = 1;
Создать minibatchqueue Объект (Deep Learning Toolbox), который управляет мини-пакетами наблюдений в пользовательском цикле обучения. minibatchqueue объект также передает данные в dlarray Объект (Deep Learning Toolbox), обеспечивающий автоматическую дифференциацию в приложениях для глубокого обучения.
Укажите формат извлечения данных мини-партии как SSCB (пространственный, пространственный, канальный, пакетный). Установите DispatchInBackground аргумент пары имя-значение как логическое значение, возвращаемое canUseGPU. Если поддерживаемый графический процессор доступен для вычисления, то minibatchqueue объект предварительно обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.
mbqTrain = minibatchqueue(dsTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
По умолчанию в примере загружается предварительно подготовленная версия сети генератора pix2pixHD для набора данных CamVid с помощью вспомогательной функции. downloadTrainedPix2PixHDNet. Вспомогательная функция прикрепляется к примеру как вспомогательный файл. Предварительно обученная сеть позволяет выполнять весь пример без ожидания завершения обучения.
Для обучения сети установите doTraining переменная в следующем коде true. Обучение модели в индивидуальном цикле обучения. Для каждой итерации:
Считывание данных для текущего мини-пакета с помощью next (Deep Learning Toolbox).
Оцените градиенты модели с помощью dlfeval (Deep Learning Toolbox) и modelGradients функция помощника.
Обновление параметров сети с помощью adamupdate (Deep Learning Toolbox).
Обновите график хода обучения для каждой итерации и просмотрите различные вычисленные потери.
Обучение на GPU, если он доступен. Для использования графического процессора требуются параллельные вычислительные Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ®. Дополнительные сведения см. в разделе Поддержка графического процессора по выпуску (Панель инструментов параллельных вычислений).
Обучение занимает около 22 часов на NVIDIA™ Titan RTX и может занять еще больше времени в зависимости от оборудования графического процессора. Если на устройстве графического процессора меньше памяти, попробуйте уменьшить размер входных изображений, указав imageSize переменная [480 640] в разделе Preprocess Training Data примера.
doTraining = false; if doTraining fig = figure; lossPlotter = configureTrainingProgressPlotter(fig); iteration = 0; % Loop over epochs for epoch = 1:numEpochs % Reset and shuffle the data reset(mbqTrain); shuffle(mbqTrain); % Loop over each image while hasdata(mbqTrain) iteration = iteration + 1; % Read data from current mini-batch [dlInputSegMap,dlRealImage] = next(mbqTrain); % Evaluate the model gradients and the generator state using % dlfeval and the GANLoss function listed at the end of the % example [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = dlfeval( ... @modelGradients,dlInputSegMap,dlRealImage,dlnetGenerator,dlnetDiscriminatorScale1,dlnetDiscriminatorScale2,netVGG); % Update the generator parameters [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = adamupdate( ... dlnetGenerator,gradParamsG, ... trailingAvgGenerator,trailingAvgSqGenerator,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Update the discriminator scale1 parameters [dlnetDiscriminatorScale1,trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1] = adamupdate( ... dlnetDiscriminatorScale1,gradParamsDScale1, ... trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Update the discriminator scale2 parameters [dlnetDiscriminatorScale2,trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2] = adamupdate( ... dlnetDiscriminatorScale2,gradParamsDScale2, ... trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Plot and display various losses lossPlotter = updateTrainingProgressPlotter(lossPlotter,iteration, ... epoch,numEpochs,lossD,lossGGAN,lossGFM,lossGVGG); end end save('trainedPix2PixHDNet.mat','dlnetGenerator'); else trainedPix2PixHDNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedPix2PixHDv2.zip'; netDir = fullfile(tempdir,'CamVid'); downloadTrainedPix2PixHDNet(trainedPix2PixHDNet_url,netDir); load(fullfile(netDir,'trainedPix2PixHDv2.mat')); end
Производительность этой обучаемой сети Pix2PixHD ограничена, поскольку количество обучающих изображений CamVid относительно невелико. Кроме того, некоторые изображения принадлежат последовательности изображений и поэтому коррелируются с другими изображениями в обучающем наборе. Чтобы повысить эффективность сети Pix2PixHD, обучайте сеть, используя другой набор данных, который имеет большее количество обучающих изображений без корреляции.
Из-за ограничений эта Pix2PixHD сеть создает более реалистичные изображения для некоторых тестовых изображений, чем для других. Чтобы продемонстрировать разницу в результатах, сравните созданные изображения для первого и третьего тестовых изображений. Угол камеры первого тестового изображения имеет необычную точку обзора, которая обращена более перпендикулярно к дороге, чем обычное тренировочное изображение. Напротив, угол камеры третьего тестового изображения имеет типичную точку обзора, которая обращена вдоль дороги и показывает две полосы с маркерами полос. Сеть имеет значительно более высокую производительность, генерируя реалистичное изображение для третьего тестового изображения, чем для первого тестового изображения.
Получение первого изображения сцены истинности земли из тестовых данных. Измените размер изображения с помощью бикубической интерполяции.
idxToTest = 1;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");Получение соответствующего изображения метки пикселя из тестовых данных. Изменение размера изображения метки пикселя с помощью интерполяции ближайшего соседа.
segMap = readimage(pxdsTest,idxToTest);
segMap = imresize(segMap,imageSize,"nearest");Преобразование изображения метки пикселя в многоканальную одноканальную карту сегментации с помощью onehotencode (Deep Learning Toolbox).
segMapOneHot = onehotencode(segMap,3,'single');Создать dlarray объекты, которые вводят данные в генератор. Если поддерживаемый графический процессор доступен для вычисления, выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.
dlSegMap = dlarray(segMapOneHot,'SSCB'); if canUseGPU dlSegMap = gpuArray(dlSegMap); end
Создание изображения сцены из генератора и карты разовой сегментации с помощью predict (Deep Learning Toolbox).
dlGeneratedImage = predict(dlnetGenerator,dlSegMap); generatedImage = extractdata(gather(dlGeneratedImage));
Конечный слой генераторной сети создает активизации в диапазоне [-1, 1]. Для отображения измените масштаб активизаций на диапазон [0, 1].
generatedImage = rescale(generatedImage);
Для отображения преобразуйте метки из категориальных меток в цвета RGB с помощью label2rgb функция.
coloredSegMap = label2rgb(segMap,cmap);
Отображение изображения метки пикселя RGB, сформированного изображения сцены и изображения сцены истинности земли в монтаже.
figure
montage({coloredSegMap generatedImage gtImage},'Size',[1 3])
title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])
Получение третьего изображения сцены истинности земли из тестовых данных. Измените размер изображения с помощью бикубической интерполяции.
idxToTest = 3;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");Чтобы получить изображение метки третьего пикселя из тестовых данных и создать соответствующее изображение сцены, можно использовать функцию помощника evaluatePix2PixHD. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.
evaluatePix2PixHD функция выполняет те же операции, что и оценка первого тестового изображения:
Получение изображения пиксельной метки из тестовых данных. Изменение размера изображения метки пикселя с помощью интерполяции ближайшего соседа.
Преобразование изображения метки пикселя в многоканальную одноканальную карту сегментации с использованием onehotencode (Deep Learning Toolbox).
Создать dlarray объект для ввода данных в генератор. Для вывода графического процессора преобразуйте данные в gpuArray объект.
Создание изображения сцены из генератора и карты разовой сегментации с помощью predict (Deep Learning Toolbox).
Выполните масштабирование активизаций до диапазона [0, 1].
[generatedImage,segMap] = evaluatePix2PixHD(pxdsTest,idxToTest,imageSize,dlnetGenerator);
Для отображения преобразуйте метки из категориальных меток в цвета RGB с помощью label2rgb функция.
coloredSegMap = label2rgb(segMap,cmap);
Отображение изображения метки пикселя RGB, сформированного изображения сцены и изображения сцены истинности земли в монтаже.
figure
montage({coloredSegMap generatedImage gtImage},'Size',[1 3])
title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])
Чтобы оценить, насколько хорошо сеть обобщает изображения меток пикселей за пределами набора данных CamVid, создайте изображения сцен из пользовательских изображений меток пикселей. В этом примере используются изображения меток пикселей, созданные с помощью приложения Image Labeler. Изображения меток пикселей прикрепляются к примеру как вспомогательные файлы. Изображения истины на земле не доступны.
Создайте хранилище данных меток пикселей, которое считывает и обрабатывает изображения меток пикселей в текущей папке примера.
cpxds = pixelLabelDatastore(pwd,classes,labelIDs);
Для каждого пиксельного изображения метки в хранилище данных создайте изображение сцены с помощью вспомогательной функции evaluatePix2PixHD.
for idx = 1:length(cpxds.Files) % Get the pixel label image and generated scene image [generatedImage,segMap] = evaluatePix2PixHD(cpxds,idx,imageSize,dlnetGenerator); % For display, convert the labels from categorical labels to RGB colors coloredSegMap = label2rgb(segMap); % Display the pixel label image and generated scene image in a montage figure montage({coloredSegMap generatedImage}) title(['Custom Pixel Label Image ',num2str(idx),' and Generated Scene Image']) end


modelGradients вспомогательная функция вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери согласования функций и потери VGG для генератора.
function [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = modelGradients(inputSegMap,realImage,generator,discriminatorScale1,discriminatorScale2,netVGG) % Compute the image generated by the generator given the input semantic % map. generatedImage = forward(generator,inputSegMap); % Define the loss weights lambdaDiscriminator = 1; lambdaGenerator = 1; lambdaFeatureMatching = 5; lambdaVGG = 5; % Concatenate the image to be classified and the semantic map inpDiscriminatorReal = cat(3,inputSegMap,realImage); inpDiscriminatorGenerated = cat(3,inputSegMap,generatedImage); % Compute the adversarial loss for the discriminator and the generator % for first scale. [DLossScale1,GLossScale1,realPredScale1D,fakePredScale1G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale1); % Scale the generated image, the real image, and the input semantic map to % half size resizedRealImage = dlresize(realImage, 'Scale',0.5, 'Method',"linear"); resizedGeneratedImage = dlresize(generatedImage,'Scale',0.5,'Method',"linear"); resizedinputSegMap = dlresize(inputSegMap,'Scale',0.5,'Method',"nearest"); % Concatenate the image to be classified and the semantic map inpDiscriminatorReal = cat(3,resizedinputSegMap,resizedRealImage); inpDiscriminatorGenerated = cat(3,resizedinputSegMap,resizedGeneratedImage); % Compute the adversarial loss for the discriminator and the generator % for second scale. [DLossScale2,GLossScale2,realPredScale2D,fakePredScale2G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale2); % Compute the feature matching loss for first scale. FMLossScale1 = pix2pixHDFeatureMatchingLoss(realPredScale1D,fakePredScale1G); FMLossScale1 = FMLossScale1 * lambdaFeatureMatching; % Compute the feature matching loss for second scale. FMLossScale2 = pix2pixHDFeatureMatchingLoss(realPredScale2D,fakePredScale2G); FMLossScale2 = FMLossScale2 * lambdaFeatureMatching; % Compute the VGG loss VGGLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG); VGGLoss = VGGLoss * lambdaVGG; % Compute the combined generator loss lossGCombined = GLossScale1 + GLossScale2 + FMLossScale1 + FMLossScale2 + VGGLoss; lossGCombined = lossGCombined * lambdaGenerator; % Compute gradients for the generator gradParamsG = dlgradient(lossGCombined,generator.Learnables,'RetainData',true); % Compute the combined discriminator loss lossDCombined = (DLossScale1 + DLossScale2)/2 * lambdaDiscriminator; % Compute gradients for the discriminator scale1 gradParamsDScale1 = dlgradient(lossDCombined,discriminatorScale1.Learnables,'RetainData',true); % Compute gradients for the discriminator scale2 gradParamsDScale2 = dlgradient(lossDCombined,discriminatorScale2.Learnables); % Log the values for displaying later lossD = gather(extractdata(lossDCombined)); lossGGAN = gather(extractdata(GLossScale1 + GLossScale2)); lossGFM = gather(extractdata(FMLossScale1 + FMLossScale2)); lossGVGG = gather(extractdata(VGGLoss)); end
Вспомогательная функция pix2pixHDAdverserialLoss вычисляет градиенты состязательных потерь для генератора и дискриминатора. Функция также возвращает карты характеристик реального изображения и синтетических изображений.
function [DLoss,GLoss,realPredFtrsD,genPredFtrsD] = pix2pixHDAdverserialLoss(inpReal,inpGenerated,discriminator) % Discriminator layer names containing feature maps featureNames = {'act_top','act_mid_1','act_mid_2','act_tail','conv2d_final'}; % Get the feature maps for the real image from the discriminator realPredFtrsD = cell(size(featureNames)); [realPredFtrsD{:}] = forward(discriminator,inpReal,"Outputs",featureNames); % Get the feature maps for the generated image from the discriminator genPredFtrsD = cell(size(featureNames)); [genPredFtrsD{:}] = forward(discriminator,inpGenerated,"Outputs",featureNames); % Get the feature map from the final layer to compute the loss realPredD = realPredFtrsD{end}; genPredD = genPredFtrsD{end}; % Compute the discriminator loss DLoss = (1 - realPredD).^2 + (genPredD).^2; DLoss = mean(DLoss,"all"); % Compute the generator loss GLoss = (1 - genPredD).^2; GLoss = mean(GLoss,"all"); end
Вспомогательная функция pix2pixHDFeatureMatchingLoss вычисляет потери согласования признаков между реальным изображением и синтетическим изображением, генерируемым генератором.
function featureMatchingLoss = pix2pixHDFeatureMatchingLoss(realPredFtrs,genPredFtrs) % Number of features numFtrsMaps = numel(realPredFtrs); % Initialize the feature matching loss featureMatchingLoss = 0; for i = 1:numFtrsMaps % Get the feature maps of the real image a = extractdata(realPredFtrs{i}); % Get the feature maps of the synthetic image b = genPredFtrs{i}; % Compute the feature matching loss featureMatchingLoss = featureMatchingLoss + mean(abs(a - b),"all"); end end
Вспомогательная функция pix2pixHDVGGLoss вычисляет перцептивные потери VGG между реальным изображением и синтетическим изображением, генерируемым генератором.
function vggLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG) featureWeights = [1.0/32 1.0/16 1.0/8 1.0/4 1.0]; % Initialize the VGG loss vggLoss = 0; % Specify the names of the layers with desired feature maps featureNames = ["relu1_1","relu2_1","relu3_1","relu4_1","relu5_1"]; % Extract the feature maps for the real image activReal = cell(size(featureNames)); [activReal{:}] = forward(netVGG,realImage,"Outputs",featureNames); % Extract the feature maps for the synthetic image activGenerated = cell(size(featureNames)); [activGenerated{:}] = forward(netVGG,generatedImage,"Outputs",featureNames); % Compute the VGG loss for i = 1:numel(featureNames) vggLoss = vggLoss + featureWeights(i)*mean(abs(activReal{i} - activGenerated{i}),"all"); end end
[1] Ван, Тин-Чунь, Мин-Ю Лю, Цзюнь-Янь Чжу, Эндрю Тао, Ян Каутц и Брайан Катандзаро. «Синтез изображений высокого разрешения и семантическая манипуляция с условными GAN». В 2018 году Конференция IEEE/CVF по компьютерному зрению и распознаванию образов, 8798-8807, 2018. https://doi.org/10.1109/CVPR.2018.00917.
[2] Бростоу, Габриэль Дж., Жюльен Фокер и Роберто Чиполла. «Классы семантических объектов в видео: база данных истинности земли высокой четкости». Буквы распознавания образов. Том 30, выпуск 2, 2009, стр. 88-97.
combine | imageDatastore | pixelLabelDatastore | transform | trainingOptions (инструментарий для глубокого обучения) | trainNetwork (инструментарий для глубокого обучения) | vgg19 (инструментарий для глубокого обучения)