В этом примере показано, как сгенерировать синтетическое изображение сцены из семантической карты сегментации с помощью pix2pixHD условной генеративной состязательной сети (CGAN).
Pix2pixHD [1] состоит из двух сетей, которые обучаются одновременно, чтобы максимизировать эффективность обеих.
Генератор является нейронной сетью в стиле энкодер-декодер, которая генерирует изображение сцены из семантической карты сегментации. Сеть CGAN обучает генератор генерировать изображение сцены, которое дискриминатор неправильно классифицирует как реальное.
Дискриминатор является полностью сверточной нейронной сетью, которая сравнивает сгенерированное изображение сцены и соответствующее реальное изображение и пытается классифицировать их как поддельные и реальные, соответственно. Сеть CGAN обучает дискриминатор правильно различать сгенерированное и реальное изображение.
Сети генератора и дискриминатора конкурируют друг с другом во время обучения. Обучение сходится, когда ни одна из сетей не может улучшиться дальше.
Этот пример использует Набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных представляет собой набор 701 изображений, содержащих представления уличного уровня, полученные во время вождения. Набор данных обеспечивает пиксельные метки для 32 семантических классов, включая автомобиль, пешехода и дорогу.
Загрузите набор данных CamVid с этих URL-адресов. Время загрузки зависит от вашего подключения к Интернету.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; dataDir = fullfile(tempdir,'CamVid'); downloadCamVidData(dataDir,imageURL,labelURL); imgDir = fullfile(dataDir,"images","701_StillsRaw_full"); labelDir = fullfile(dataDir,'labels');
Создайте imageDatastore
для хранения изображений в наборе данных CamVid.
imds = imageDatastore(imgDir); imageSize = [576 768];
Определите имена классов и идентификаторы меток пикселей 32 классов в наборе данных CamVid с помощью функции helper defineCamVid32ClassesAndPixelLabelIDs
. Получите стандартную карту цветов для набора данных CamVid с помощью функции helper camvid32ColorMap
. Вспомогательные функции присоединены к примеру как вспомогательные файлы.
numClasses = 32; [classes,labelIDs] = defineCamVid32ClassesAndPixelLabelIDs; cmap = camvid32ColorMap;
Создайте pixelLabelDatastore
для хранения изображений меток пикселей.
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
Предварительный просмотр изображения метки пикселя и соответствующего изображения основной истины сцены. Преобразуйте метки из категорийных меток в цвета RGB с помощью label2rgb
, затем отобразите пиксельное изображение метки и основной истины изображение в монтаже.
im = preview(imds); px = preview(pxds); px = label2rgb(px,cmap); montage({px,im})
Разделите данные на обучающие и тестовые наборы с помощью функции helper partitionCamVidForPix2PixHD
. Эта функция присоединена к примеру как вспомогательный файл. Функция helper разделяет данные на 648 обучающих файлов и 32 тестовых файлов.
[imdsTrain,imdsTest,pxdsTrain,pxdsTest] = partitionCamVidForPix2PixHD(imds,pxds,classes,labelIDs);
Используйте combine
функция для объединения изображений меток пикселей и основной истины изображений сцены в один datastore.
dsTrain = combine(pxdsTrain,imdsTrain);
Увеличение обучающих данных при помощи transform
функция с пользовательскими операциями предварительной обработки, заданными функцией helper preprocessCamVidForPix2PixHD
. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.
The preprocessCamVidForPix2PixHD
функция выполняет следующие операции:
Масштабируйте достоверные данные по области значений [-1, 1]. Эта область значений соответствует области значений конечных tanhLayer
(Deep Learning Toolbox) в сети генератора.
Измените размер изображения и меток на выходной размер сети, 576 на 768 пикселей, с помощью бикубической и ближайшей соседней соседней понижающей дискретизации, соответственно.
Преобразуйте одноканальную карту сегментации в 32-канальную однокодированную карту сегментации с горячим кодированием с помощью onehotencode
(Deep Learning Toolbox) функция.
Случайным образом разверните изображение и пиксельные пары меток в горизонтальном направлении.
dsTrain = transform(dsTrain,@(x) preprocessCamVidForPix2PixHD(x,imageSize));
Предварительный просмотр каналов однокодированной закодированной карты сегментации в монтаже. Каждый канал представляет 1-горячую карту, соответствующую пикселям уникального класса.
map = preview(dsTrain); montage(map{1},'Size',[4 8],'Bordersize',5,'BackgroundColor','b')
Задайте сеть генератора pix2pixHD, которая генерирует изображение сцены из глубинной однокодированной карты сегментации. Этот вход имеет ту же высоту и ширину, что и исходная карта сегментации, и то же количество каналов, что и классы.
generatorInputSize = [imageSize numClasses];
Создайте сеть генератора pix2pixHD с помощью pix2pixHDGlobalGenerator
функция.
dlnetGenerator = pix2pixHDGlobalGenerator(generatorInputSize);
Отображение сетевой архитектуры.
analyzeNetwork(dlnetGenerator)
Обратите внимание, что этот пример показывает использование глобального генератора pix2pixHD для генерации изображений размером 576 на 768 пикселей. Чтобы создать локальные сети энхансеров, которые генерируют изображения с более высоким разрешением, таким как 1152 на 1536 пикселей или даже выше, можно использовать addPix2PixHDLocalEnhancer
функция. Локальные сети энхансеров помогают генерировать мелкие детали уровня при очень высоких разрешениях.
Задайте закрашенную фигуру сети дискриминатора GAN, которые классифицируют вход изображение как реальное (1) или поддельное (0). Этот пример использует две сети дискриминаторов в разных входных масштабах, также известных как многомасштабные дискриминаторы шкалы. Первая шкала совпадает с размером изображения, а вторая шкала равна половине размера изображения.
Вход дискриминатора является глубинной конкатенацией однокодированных закодированных карт сегментации и изображения сцены, которое будет классифицировано. Укажите количество каналов, вводимых в дискриминатор, как общее количество маркированных классов и цветовых каналов изображений.
numImageChannels = 3; numChannelsDiscriminator = numClasses + numImageChannels;
Задайте размер входа первого дискриминатора. Создайте закрашенную фигуру GAN с нормализацией образца с помощью patchGANDiscriminator
функция.
discriminatorInputSizeScale1 = [imageSize numChannelsDiscriminator]; dlnetDiscriminatorScale1 = patchGANDiscriminator(discriminatorInputSizeScale1,"NormalizationLayer","instance");
Укажите размер входа второго дискриминатора как половину размера изображения, затем создайте вторую закрашенную фигуру GAN.
discriminatorInputSizeScale2 = [floor(imageSize)./2 numChannelsDiscriminator]; dlnetDiscriminatorScale2 = patchGANDiscriminator(discriminatorInputSizeScale2,"NormalizationLayer","instance");
Визуализация сетей.
analyzeNetwork(dlnetDiscriminatorScale1); analyzeNetwork(dlnetDiscriminatorScale2);
Функция помощника modelGradients
вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери соответствия признаков и потери VGG для генератора. Эта функция определяется в разделе Вспомогательные функции этого примера.
Цель генератора состоит в том, чтобы сгенерировать изображения, которые дискриминатор классифицирует как действительные (1). Потеря генератора состоит из трех потерь.
Состязательные потери вычисляются как квадратное различие между вектором таковых и предсказаниями дискриминатора на сгенерированном изображении. являются предсказаниями дискриминатора на изображении, сгенерированном генератором. Эта потеря реализована с использованием части pix2pixhdAdversarialLoss
вспомогательная функция, заданная в разделе Вспомогательные функции этого примера.
Функция, совпадающая с потерей, штрафует расстояние между действительной и сгенерированной картами функций, полученное как предсказания от сети дискриминатора. - общее количество слоев функций дискриминатора. и являются основные истины изображениями и сгенерированными изображениями, соответственно. Эта потеря реализована с помощью pix2pixhdFeatureMatchingLoss
вспомогательная функция, заданная в разделе Вспомогательные функции этого примера
Восприятие потери наказывает расстояние между реальной и сгенерированной картами функций, полученное как предсказания от сети редукции данных. - общее количество слоев функций. и являются сетевыми прогнозами для основной истины изображений и сгенерированных изображений, соответственно. Эта потеря реализована с помощью pix2pixhdVggLoss
вспомогательная функция, заданная в разделе Вспомогательные функции этого примера. Сеть редукции данных создается в окне «Загрузка сети редукции данных».
Общие потери генератора являются взвешенной суммой всех трех потерь. , , и являются коэффициентами веса для состязательной потери, потери соответствия функций и потери восприятия, соответственно.
Обратите внимание, что состязательные потери и потери соответствия функций для генератора вычисляются для двух различных шкал.
Цель дискриминатора состоит в том, чтобы правильно различать основную истину изображения и сгенерированные изображения. Потеря дискриминатора - это сумма двух компонентов:
Квадратное различие между вектором таковых и предсказаниями дискриминатора на вещественных изображениях
Квадратное различие между нулевым вектором и предсказаниями дискриминатора на сгенерированных изображениях
Потеря дискриминатора реализована с помощью части pix2pixhdAdversarialLoss
вспомогательная функция, заданная в разделе Вспомогательные функции этого примера. Обратите внимание, что состязательные потери для дискриминатора вычисляются для двух различных шкал дискриминатора.
Этот пример модифицирует предварительно обученную VGG-19 глубокую нейронную сеть, чтобы извлечь функции реальных и сгенерированных изображений в различных слоях. Эти многослойные функции используются, чтобы вычислить восприятие потери генератора.
Чтобы получить предварительно обученную VGG-19 сеть, установите vgg19
(Deep Learning Toolbox). Если у вас нет установленных необходимых пакетов поддержки, то программное обеспечение предоставляет ссылку на загрузку.
netVGG = vgg19;
Визуализируйте сетевую архитектуру с помощью приложения Deep Network Designer (Deep Learning Toolbox).
deepNetworkDesigner(netVGG)
Чтобы сделать VGG-19 сеть подходящей для редукции данных, сохраните слои до 'pool5' и удалите все полносвязные слои из сети. Получившаяся сеть является полностью сверточной сетью.
netVGG = layerGraph(netVGG.Layers(1:38));
Создайте новый входной слой изображения без нормализации. Замените оригинальное изображение входа слой новым слоем.
inp = imageInputLayer([imageSize 3],"Normalization","None","Name","Input"); netVGG = replaceLayer(netVGG,"input",inp); netVGG = dlnetwork(netVGG);
Задайте опции для оптимизации Adam. Обучайте на 60 эпох. Задайте одинаковые опции для сетей генератора и дискриминатора.
Задайте равную скорость обучения 0,0002.
Инициализируйте конечный средний градиент и конечный средний градиент-квадратные скорости распада с []
.
Используйте коэффициент градиентного распада 0,5 и квадратный коэффициент градиентного распада 0,999.
Используйте мини-пакет размером 1 для обучения.
numEpochs = 60; learningRate = 0.0002; trailingAvgGenerator = []; trailingAvgSqGenerator = []; trailingAvgDiscriminatorScale1 = []; trailingAvgSqDiscriminatorScale1 = []; trailingAvgDiscriminatorScale2 = []; trailingAvgSqDiscriminatorScale2 = []; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999; miniBatchSize = 1;
Создайте minibatchqueue
(Deep Learning Toolbox) объект, который управляет мини-пакетированием наблюдений в пользовательском цикле обучения. The minibatchqueue
объект также переводит данные в dlarray
(Deep Learning Toolbox) объект, который включает автоматическую дифференциацию в приложениях глубокого обучения.
Задайте формат извлечения данных пакета следующим SSCB
(пространственный, пространственный, канальный, пакетный). Установите DispatchInBackground
Аргумент пары "имя-значение" как логическое значение, возвращаемое canUseGPU
. Если поддерживаемый графический процессор доступен для расчетов, то minibatchqueue
объект обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.
mbqTrain = minibatchqueue(dsTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
По умолчанию пример загружает предварительно обученную версию сети генератора pix2pixHD для набора данных CamVid с помощью функции helper downloadTrainedPix2PixHDNet
. Функция helper присоединена к примеру как вспомогательный файл. Предварительно обученная сеть позволяет запускать весь пример, не дожидаясь завершения обучения.
Чтобы обучить сеть, установите doTraining
переменная в следующем коде, для true
. Обучите модель в пользовательском цикле обучения. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next
(Deep Learning Toolbox) функция.
Оцените градиенты модели с помощью dlfeval
(Deep Learning Toolbox) функцию и modelGradients
вспомогательная функция.
Обновляйте параметры сети с помощью adamupdate
(Deep Learning Toolbox) функция.
Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.
Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения дополнительной информации смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
Обучение занимает около 22 часов на NVIDIA™ Titan RTX и может занять еще больше времени в зависимости от вашего графического процессора оборудования. Если у ваш графический процессор меньше памяти, попробуйте уменьшить размер входных изображений, задав imageSize
переменная как [480 640] в разделе «Предварительная обработка обучающих данных» примера.
doTraining = false; if doTraining fig = figure; lossPlotter = configureTrainingProgressPlotter(fig); iteration = 0; % Loop over epochs for epoch = 1:numEpochs % Reset and shuffle the data reset(mbqTrain); shuffle(mbqTrain); % Loop over each image while hasdata(mbqTrain) iteration = iteration + 1; % Read data from current mini-batch [dlInputSegMap,dlRealImage] = next(mbqTrain); % Evaluate the model gradients and the generator state using % dlfeval and the GANLoss function listed at the end of the % example [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = dlfeval( ... @modelGradients,dlInputSegMap,dlRealImage,dlnetGenerator,dlnetDiscriminatorScale1,dlnetDiscriminatorScale2,netVGG); % Update the generator parameters [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = adamupdate( ... dlnetGenerator,gradParamsG, ... trailingAvgGenerator,trailingAvgSqGenerator,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Update the discriminator scale1 parameters [dlnetDiscriminatorScale1,trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1] = adamupdate( ... dlnetDiscriminatorScale1,gradParamsDScale1, ... trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Update the discriminator scale2 parameters [dlnetDiscriminatorScale2,trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2] = adamupdate( ... dlnetDiscriminatorScale2,gradParamsDScale2, ... trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Plot and display various losses lossPlotter = updateTrainingProgressPlotter(lossPlotter,iteration, ... epoch,numEpochs,lossD,lossGGAN,lossGFM,lossGVGG); end end save('trainedPix2PixHDNet.mat','dlnetGenerator'); else trainedPix2PixHDNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedPix2PixHDv2.zip'; netDir = fullfile(tempdir,'CamVid'); downloadTrainedPix2PixHDNet(trainedPix2PixHDNet_url,netDir); load(fullfile(netDir,'trainedPix2PixHDv2.mat')); end
Эффективность этой обученной Pix2PixHD сети ограничена, потому что количество обучающих изображений CamVid относительно мало. Кроме того, некоторые изображения относятся к последовательности изображений и, следовательно, коррелируют с другими изображениями в набор обучающих данных. Чтобы улучшить эффективность Pix2PixHD сети, обучите сеть с помощью другого набора данных, который имеет большее количество обучающих изображений без корреляции.
Из-за ограничений эта Pix2PixHD сеть генерирует более реалистичные изображения для одних тестовых изображений, чем для других. Чтобы продемонстрировать различие в результатах, сравните сгенерированные изображения для первого и третьего тестового изображения. Угол камеры первого тестового изображения имеет необычную точку расположения, которая обращена более перпендикулярно дороге, чем типовое обучающее изображение. В противоположность этому угол камеры третьего тестового изображения имеет типовую точку обзора, которая обращена вдоль дороги и показывает две полосы с маркерами маршрута. Сеть имеет значительно лучшую эффективность, генерируя реалистичное изображение для третьего тестового изображения, чем для первого тестового изображения.
Получите первую основную истину изображение сцены из тестовых данных. Измените размер изображения с помощью бикубической интерполяции.
idxToTest = 1;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");
Получите соответствующее изображение метки пикселя из тестовых данных. Измените размер изображения метки пикселя с помощью самой близкой соседней интерполяции.
segMap = readimage(pxdsTest,idxToTest);
segMap = imresize(segMap,imageSize,"nearest");
Преобразуйте изображение метки пикселя в многоканальную одногретую карту сегментации при помощи onehotencode
(Deep Learning Toolbox) функция.
segMapOneHot = onehotencode(segMap,3,'single');
Создание dlarray
объекты, которые вводят данные в генератор. Если поддерживаемый графический процессор доступен для расчетов, выполните вывод на графическом процессоре, преобразовав данные в gpuArray
объект.
dlSegMap = dlarray(segMapOneHot,'SSCB'); if canUseGPU dlSegMap = gpuArray(dlSegMap); end
Сгенерируйте изображение сцены из генератора и одногретую карту сегментации с помощью predict
(Deep Learning Toolbox) функция.
dlGeneratedImage = predict(dlnetGenerator,dlSegMap); generatedImage = extractdata(gather(dlGeneratedImage));
Конечный слой сети генератора производит активации в области значений [-1, 1]. Для отображения измените значения активации на область значений [0, 1].
generatedImage = rescale(generatedImage);
Для отображения преобразуйте метки из категориальных меток в цвета RGB с помощью label2rgb
функция.
coloredSegMap = label2rgb(segMap,cmap);
Отобразите изображение метки пикселя RGB, сгенерированное изображение сцены и изображение сцены основной истины в монтаже.
figure montage({coloredSegMap generatedImage gtImage},'Size',[1 3]) title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])
Получите третью основную истину изображение сцены из тестовых данных. Измените размер изображения с помощью бикубической интерполяции.
idxToTest = 3;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");
Чтобы получить третье изображение метки пикселя из тестовых данных и сгенерировать соответствующее изображение сцены, можно использовать функцию helper evaluatePix2PixHD
. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.
The evaluatePix2PixHD
функция выполняет те же операции, что и оценка первого тестового изображения:
Получите изображение метки пикселя из тестовых данных. Измените размер изображения метки пикселя с помощью самой близкой соседней интерполяции.
Преобразуйте изображение метки пикселя в многоканальную одногретую карту сегментации с помощью onehotencode
(Deep Learning Toolbox) функция.
Создайте dlarray
объект для входных данных в генератор. Для вывода графический процессор преобразуйте данные в gpuArray
объект.
Сгенерируйте изображение сцены из генератора и одногретую карту сегментации с помощью predict
(Deep Learning Toolbox) функция.
Переопределите значения активации в области значений [0, 1].
[generatedImage,segMap] = evaluatePix2PixHD(pxdsTest,idxToTest,imageSize,dlnetGenerator);
Для отображения преобразуйте метки из категориальных меток в цвета RGB с помощью label2rgb
функция.
coloredSegMap = label2rgb(segMap,cmap);
Отобразите изображение метки пикселя RGB, сгенерированное изображение сцены и изображение сцены основной истины в монтаже.
figure montage({coloredSegMap generatedImage gtImage},'Size',[1 3]) title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])
Чтобы оценить, насколько хорошо сеть обобщает изображения меток пикселей за пределами набора данных CamVid, сгенерируйте изображения сцены из пользовательских изображений меток пикселей. Этот пример использует изображения меток пикселей, которые были созданы с помощью приложения Image Labeler. Изображения меток пикселей присоединены к примеру в качестве вспомогательных файлов. Основные истины отсутствуют.
Создайте datastore метки пикселя, который читает и обрабатывает изображения метки пикселя в текущем примере директории.
cpxds = pixelLabelDatastore(pwd,classes,labelIDs);
Для каждого изображения метки пикселя в datastore сгенерируйте изображение сцены с помощью функции helper evaluatePix2PixHD
.
for idx = 1:length(cpxds.Files) % Get the pixel label image and generated scene image [generatedImage,segMap] = evaluatePix2PixHD(cpxds,idx,imageSize,dlnetGenerator); % For display, convert the labels from categorical labels to RGB colors coloredSegMap = label2rgb(segMap); % Display the pixel label image and generated scene image in a montage figure montage({coloredSegMap generatedImage}) title(['Custom Pixel Label Image ',num2str(idx),' and Generated Scene Image']) end
The modelGradients
Функция helper вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери соответствия признаков и потери VGG для генератора.
function [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = modelGradients(inputSegMap,realImage,generator,discriminatorScale1,discriminatorScale2,netVGG) % Compute the image generated by the generator given the input semantic % map. generatedImage = forward(generator,inputSegMap); % Define the loss weights lambdaDiscriminator = 1; lambdaGenerator = 1; lambdaFeatureMatching = 5; lambdaVGG = 5; % Concatenate the image to be classified and the semantic map inpDiscriminatorReal = cat(3,inputSegMap,realImage); inpDiscriminatorGenerated = cat(3,inputSegMap,generatedImage); % Compute the adversarial loss for the discriminator and the generator % for first scale. [DLossScale1,GLossScale1,realPredScale1D,fakePredScale1G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale1); % Scale the generated image, the real image, and the input semantic map to % half size resizedRealImage = dlresize(realImage, 'Scale',0.5, 'Method',"linear"); resizedGeneratedImage = dlresize(generatedImage,'Scale',0.5,'Method',"linear"); resizedinputSegMap = dlresize(inputSegMap,'Scale',0.5,'Method',"nearest"); % Concatenate the image to be classified and the semantic map inpDiscriminatorReal = cat(3,resizedinputSegMap,resizedRealImage); inpDiscriminatorGenerated = cat(3,resizedinputSegMap,resizedGeneratedImage); % Compute the adversarial loss for the discriminator and the generator % for second scale. [DLossScale2,GLossScale2,realPredScale2D,fakePredScale2G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale2); % Compute the feature matching loss for first scale. FMLossScale1 = pix2pixHDFeatureMatchingLoss(realPredScale1D,fakePredScale1G); FMLossScale1 = FMLossScale1 * lambdaFeatureMatching; % Compute the feature matching loss for second scale. FMLossScale2 = pix2pixHDFeatureMatchingLoss(realPredScale2D,fakePredScale2G); FMLossScale2 = FMLossScale2 * lambdaFeatureMatching; % Compute the VGG loss VGGLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG); VGGLoss = VGGLoss * lambdaVGG; % Compute the combined generator loss lossGCombined = GLossScale1 + GLossScale2 + FMLossScale1 + FMLossScale2 + VGGLoss; lossGCombined = lossGCombined * lambdaGenerator; % Compute gradients for the generator gradParamsG = dlgradient(lossGCombined,generator.Learnables,'RetainData',true); % Compute the combined discriminator loss lossDCombined = (DLossScale1 + DLossScale2)/2 * lambdaDiscriminator; % Compute gradients for the discriminator scale1 gradParamsDScale1 = dlgradient(lossDCombined,discriminatorScale1.Learnables,'RetainData',true); % Compute gradients for the discriminator scale2 gradParamsDScale2 = dlgradient(lossDCombined,discriminatorScale2.Learnables); % Log the values for displaying later lossD = gather(extractdata(lossDCombined)); lossGGAN = gather(extractdata(GLossScale1 + GLossScale2)); lossGFM = gather(extractdata(FMLossScale1 + FMLossScale2)); lossGVGG = gather(extractdata(VGGLoss)); end
Функция помощника pix2pixHDAdverserialLoss
вычисляет градиенты состязательных потерь для генератора и дискриминатора. Функция также возвращает карты признаков реального изображения и синтетические изображения.
function [DLoss,GLoss,realPredFtrsD,genPredFtrsD] = pix2pixHDAdverserialLoss(inpReal,inpGenerated,discriminator) % Discriminator layer names containing feature maps featureNames = {'act_top','act_mid_1','act_mid_2','act_tail','conv2d_final'}; % Get the feature maps for the real image from the discriminator realPredFtrsD = cell(size(featureNames)); [realPredFtrsD{:}] = forward(discriminator,inpReal,"Outputs",featureNames); % Get the feature maps for the generated image from the discriminator genPredFtrsD = cell(size(featureNames)); [genPredFtrsD{:}] = forward(discriminator,inpGenerated,"Outputs",featureNames); % Get the feature map from the final layer to compute the loss realPredD = realPredFtrsD{end}; genPredD = genPredFtrsD{end}; % Compute the discriminator loss DLoss = (1 - realPredD).^2 + (genPredD).^2; DLoss = mean(DLoss,"all"); % Compute the generator loss GLoss = (1 - genPredD).^2; GLoss = mean(GLoss,"all"); end
Функция помощника pix2pixHDFeatureMatchingLoss
вычисляет потери соответствия функции между реальным изображением и синтетическим изображением, сгенерированным генератором.
function featureMatchingLoss = pix2pixHDFeatureMatchingLoss(realPredFtrs,genPredFtrs) % Number of features numFtrsMaps = numel(realPredFtrs); % Initialize the feature matching loss featureMatchingLoss = 0; for i = 1:numFtrsMaps % Get the feature maps of the real image a = extractdata(realPredFtrs{i}); % Get the feature maps of the synthetic image b = genPredFtrs{i}; % Compute the feature matching loss featureMatchingLoss = featureMatchingLoss + mean(abs(a - b),"all"); end end
Функция помощника pix2pixHDVGGLoss
вычисляет восприятие потерь VGG между реальным изображением и синтетическим изображением, сгенерированным генератором.
function vggLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG) featureWeights = [1.0/32 1.0/16 1.0/8 1.0/4 1.0]; % Initialize the VGG loss vggLoss = 0; % Specify the names of the layers with desired feature maps featureNames = ["relu1_1","relu2_1","relu3_1","relu4_1","relu5_1"]; % Extract the feature maps for the real image activReal = cell(size(featureNames)); [activReal{:}] = forward(netVGG,realImage,"Outputs",featureNames); % Extract the feature maps for the synthetic image activGenerated = cell(size(featureNames)); [activGenerated{:}] = forward(netVGG,generatedImage,"Outputs",featureNames); % Compute the VGG loss for i = 1:numel(featureNames) vggLoss = vggLoss + featureWeights(i)*mean(abs(activReal{i} - activGenerated{i}),"all"); end end
[1] Ван, Тин-Чун, Мин-Ю Лю, Цзюнь-Янь Чжу, Эндрю Тао, Ян Каутц и Брайан Катандзаро. «Синтез изображений в высоком разрешении и семантическая манипуляция с условными GAN». В 2018 году IEEE/CVF Conference on Компьютерное Зрение and Pattern Recognition, 8798-8807, 2018. https://doi.org/10.1109/CVPR.2018.00917.
[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. Semantic Object Classes in Video: A High-Definition Ground Truth Database (неопр.) (недоступная ссылка). Распознавание Букв. Том 30, Выпуск 2, 2009, стр. 88-97.
combine
| imageDatastore
| pixelLabelDatastore
| transform
| trainingOptions
(Deep Learning Toolbox) | trainNetwork
(Deep Learning Toolbox) | vgg19
(Deep Learning Toolbox)