В этом примере показано, как сгенерировать синтетическое изображение сцены из карты семантической сегментации с помощью pix2pixHD условной порождающей соперничающей сети (CGAN).
Pix2pixHD [1] состоит из двух сетей, которые обучены одновременно, чтобы максимизировать эффективность обоих.
Генератор является нейронной сетью стиля декодера энкодера, которая генерирует изображение сцены из карты семантической сегментации. Сеть CGAN обучает генератор генерировать изображение сцены, которое различитель неправильно классифицирует как действительное.
Различитель полностью сверточная нейронная сеть, которая сравнивает сгенерированное изображение сцены и соответствующее действительное изображение и пытается классифицировать их как фальшивку и действительный, соответственно. Сеть CGAN обучает различитель правильно различать сгенерированное и действительное изображение.
Генератор и сети различителя конкурируют друг против друга во время обучения. Обучение сходится, когда никакая сеть не может улучшиться далее.
Этот пример использует Набор данных CamVid 2[] из Кембриджского университета для обучения. Этот набор данных является набором 701 изображения, содержащего представления уличного уровня, полученные при управлении. Набор данных обеспечивает пиксельные метки для 32 семантических классов включая автомобиль, пешехода и дорогу.
Загрузите набор данных CamVid с этих URL. Время загрузки зависит от вашего интернет-соединения.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; dataDir = fullfile(tempdir,'CamVid'); downloadCamVidData(dataDir,imageURL,labelURL); imgDir = fullfile(dataDir,"images","701_StillsRaw_full"); labelDir = fullfile(dataDir,'labels');
Создайте imageDatastore
сохранить изображения в наборе данных CamVid.
imds = imageDatastore(imgDir); imageSize = [576 768];
Задайте имена классов и пиксельную метку IDs этих 32 классов в наборе данных CamVid с помощью функции помощника defineCamVid32ClassesAndPixelLabelIDs
. Получите карту стандартного цвета для набора данных CamVid с помощью функции помощника camvid32ColorMap
. Функции помощника присоединены к примеру как к вспомогательным файлам.
numClasses = 32; [classes,labelIDs] = defineCamVid32ClassesAndPixelLabelIDs; cmap = camvid32ColorMap;
Создайте pixelLabelDatastore
чтобы сохранить пиксель помечают изображения.
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
Предварительно просмотрите пиксельное изображение метки и соответствующее изображение сцены основной истины. Преобразуйте метки от категориальных меток до цветов RGB при помощи label2rgb
функция, затем отобразите пиксельное изображение метки и изображение основной истины в монтаже.
im = preview(imds); px = preview(pxds); px = label2rgb(px,cmap); montage({px,im})
Разделите данные в наборы обучающих данных и наборы тестов с помощью функции помощника partitionCamVidForPix2PixHD
Эта функция присоединена к примеру как вспомогательный файл. Функция помощника разделяет данные в 648 учебных файлов и 32 тестовых файла.
[imdsTrain,imdsTest,pxdsTrain,pxdsTest] = partitionCamVidForPix2PixHD(imds,pxds,classes,labelIDs);
Используйте combine
функционируйте, чтобы объединить пиксельные изображения метки и изображения сцены основной истины в один datastore.
dsTrain = combine(pxdsTrain,imdsTrain);
Увеличьте обучающие данные при помощи transform
функция с пользовательскими операциями предварительной обработки, заданными помощником, функционирует preprocessCamVidForPix2PixHD
. Эта функция помощника присоединена к примеру как к вспомогательному файлу.
preprocessCamVidForPix2PixHD
функция выполняет эти операции:
Масштабируйте достоверные данные к области значений [-1, 1]. Эта область значений совпадает с областью значений итогового tanhLayer
(Deep Learning Toolbox) в сети генератора.
Измените размер изображения и меток к выходному размеру сети, 576 768 пиксели, с помощью bicubic и самая близкая соседняя субдискретизация, соответственно.
Преобразуйте одну карту сегментации канала в одногорячую закодированную карту сегментации с 32 каналами с помощью onehotencode
(Deep Learning Toolbox) функция.
Случайным образом зеркальное изображение и пиксель помечают пары в горизонтальном направлении.
dsTrain = transform(dsTrain,@(x) preprocessCamVidForPix2PixHD(x,imageSize));
Предварительно просмотрите каналы одногорячей закодированной карты сегментации в монтаже. Каждый канал представляет одногорячую карту, соответствующую пикселям уникального класса.
map = preview(dsTrain); montage(map{1},'Size',[4 8],'Bordersize',5,'BackgroundColor','b')
Задайте pix2pixHD сеть генератора, которая генерирует изображение сцены из мудрой глубиной одногорячей закодированной карты сегментации. Этот вход имеет ту же высоту и ширину как исходная карта сегментации и то же количество каналов как классы.
generatorInputSize = [imageSize numClasses];
Создайте pix2pixHD сеть генератора использование pix2pixHDGlobalGenerator
функция.
dlnetGenerator = pix2pixHDGlobalGenerator(generatorInputSize);
Отобразите сетевую архитектуру.
analyzeNetwork(dlnetGenerator)
Обратите внимание на то, что этот пример показывает использование pix2pixHD глобального генератора для генерации изображений размера 576 768 пиксели. Чтобы создать локальные сети усилителя, которые генерируют изображения в более высоком разрешении такой как 1152 1536 пиксели или еще выше, можно использовать addPix2PixHDLocalEnhancer
функция. Локальные сети усилителя помогают сгенерировать прекрасные детали уровня в очень высоких разрешениях.
Задайте закрашенную фигуру сети различителя GAN, который классифицирует входное изображение или как действительное (1) или как фальшивка (0). Этот пример использует две сети различителя в различных входных шкалах, также известных как многошкальные различители шкалы. Первая шкала одного размера с размером изображения, и вторая шкала является половиной размера размера изображения.
Вход к различителю является мудрой глубиной конкатенацией одногорячих закодированных карт сегментации и изображения сцены, которое будет классифицировано. Задайте количество входа каналов к различителю как общее количество помеченных классов и каналов цвета изображения.
numImageChannels = 3; numChannelsDiscriminator = numClasses + numImageChannels;
Задайте входной размер первого различителя. Создайте закрашенную фигуру различитель GAN с нормализацией экземпляра с помощью patchGANDiscriminator
функция.
discriminatorInputSizeScale1 = [imageSize numChannelsDiscriminator]; dlnetDiscriminatorScale1 = patchGANDiscriminator(discriminatorInputSizeScale1,"NormalizationLayer","instance");
Задайте входной размер второго различителя как половина размера изображения, затем создайте вторую закрашенную фигуру различитель GAN.
discriminatorInputSizeScale2 = [floor(imageSize)./2 numChannelsDiscriminator]; dlnetDiscriminatorScale2 = patchGANDiscriminator(discriminatorInputSizeScale2,"NormalizationLayer","instance");
Визуализируйте сети.
analyzeNetwork(dlnetDiscriminatorScale1); analyzeNetwork(dlnetDiscriminatorScale2);
Функция помощника modelGradients
вычисляет градиенты и соперничающую потерю для генератора и различителя. Функция также вычисляет потерю соответствия функции и потерю VGG для генератора. Эта функция задана в разделе Supporting Functions этого примера.
Цель генератора состоит в том, чтобы сгенерировать изображения, которые различитель классифицирует как действительные (1). Потеря генератора состоит из трех потерь.
Соперничающая потеря вычисляется как различие в квадрате между вектором из единиц и предсказаниями различителя на сгенерированном изображении. предсказания различителя на изображении, сгенерированном генератором. Эта потеря реализована с помощью части pix2pixhdAdversarialLoss
функция, определяемая помощника в разделе Supporting Functions этого примера.
Потеря соответствия функции штрафует расстояние между действительными и сгенерированными картами функции, полученными как предсказания из сети различителя. общее количество слоев функции различителя. и изображения основной истины и сгенерированные изображения, соответственно. Эта потеря реализована с помощью pix2pixhdFeatureMatchingLoss
функция, определяемая помощника в разделе Supporting Functions этого примера
Перцепционная потеря штрафует расстояние между действительными и сгенерированными картами функции, полученными как предсказания из сети извлечения признаков. общее количество слоев функции. и сетевые предсказания для изображений основной истины и сгенерированных изображений, соответственно. Эта потеря реализована с помощью pix2pixhdVggLoss
функция, определяемая помощника в разделе Supporting Functions этого примера. Сеть извлечения признаков создается в Сети Извлечения признаков Загрузки.
Полная потеря генератора является взвешенной суммой всех трех потерь. , , и весовые коэффициенты за соперничающую потерю, потерю соответствия функции и перцепционную потерю, соответственно.
Обратите внимание на то, что соперничающая потеря и потеря соответствия функции для генератора вычисляются для двух различных шкал.
Цель различителя состоит в том, чтобы правильно различать изображения основной истины и сгенерированные изображения. Потеря различителя является суммой двух компонентов:
Различие в квадрате между вектором из единиц и предсказаниями различителя на действительных изображениях
Различие в квадрате между нулевым вектором и предсказаниями различителя на сгенерированных изображениях
Потеря различителя реализована с помощью части pix2pixhdAdversarialLoss
функция, определяемая помощника в разделе Supporting Functions этого примера. Обратите внимание на то, что соперничающая потеря для различителя вычисляется для двух различных шкал различителя.
Этот пример изменяет предварительно обученную глубокую нейронную сеть VGG-19, чтобы извлечь функции действительных и сгенерированных изображений на различных слоях. Эти многоуровневые функции используются для расчета перцепционная потеря генератора.
Чтобы получить предварительно обученную сеть VGG-19, установите vgg19
(Deep Learning Toolbox). Если у вас нет необходимых пакетов поддержки установленными, то программное обеспечение обеспечивает ссылку на загрузку.
netVGG = vgg19;
Визуализируйте сетевую архитектуру с помощью Deep Network Designer (Deep Learning Toolbox) приложение.
deepNetworkDesigner(netVGG)
Чтобы сделать сеть VGG-19 подходящей для извлечения признаков, сохраните слои до 'pool5' и удалите все полносвязные слоя от сети. Получившаяся сеть является полностью сверточной сетью.
netVGG = layerGraph(netVGG.Layers(1:38));
Создайте новый входной слой изображений без нормализации. Замените входной слой оригинального изображения на новый слой.
inp = imageInputLayer([imageSize 3],"Normalization","None","Name","Input"); netVGG = replaceLayer(netVGG,"input",inp); netVGG = dlnetwork(netVGG);
Задайте опции для оптимизации Адама. Обучайтесь в течение 60 эпох. Задайте идентичные опции для сетей различителя и генератора.
Задайте равную скорость обучения 0,0002.
Инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с []
.
Используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.
Используйте мини-пакетный размер 1 для обучения.
numEpochs = 60; learningRate = 0.0002; trailingAvgGenerator = []; trailingAvgSqGenerator = []; trailingAvgDiscriminatorScale1 = []; trailingAvgSqDiscriminatorScale1 = []; trailingAvgDiscriminatorScale2 = []; trailingAvgSqDiscriminatorScale2 = []; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999; miniBatchSize = 1;
Создайте minibatchqueue
Объект (Deep Learning Toolbox), который справляется с мини-пакетной обработкой наблюдений в пользовательском учебном цикле. minibatchqueue
возразите также бросает данные к dlarray
Объект (Deep Learning Toolbox), который включает автоматическое дифференцирование в применении глубокого обучения.
Задайте мини-пакетный формат экстракции данных как SSCB
(пространственный, пространственный, канал, пакет). Установите DispatchInBackground
аргумент пары "имя-значение" как булевская переменная, возвращенная canUseGPU
. Если поддерживаемый графический процессор доступен для расчета, то minibatchqueue
объект предварительно обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.
mbqTrain = minibatchqueue(dsTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
По умолчанию пример загружает предварительно обученную версию pix2pixHD сети генератора для набора данных CamVid при помощи функции помощника downloadTrainedPix2PixHDNet
. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.
Чтобы обучить сеть, установите doTraining
переменная в следующем коде к true
. Обучите модель в пользовательском учебном цикле. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next
(Deep Learning Toolbox) функция.
Оцените градиенты модели с помощью dlfeval
(Deep Learning Toolbox) функция и modelGradients
функция помощника.
Обновите сетевые параметры с помощью adamupdate
(Deep Learning Toolbox) функция.
Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.
Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).
Обучение занимает приблизительно 22 часа на Титане NVIDIA™ RTX и может взять еще дольше в зависимости от вашего оборудования графического процессора. Если ваше устройство графического процессора имеет меньше памяти, попытайтесь уменьшать размер входных изображений путем определения imageSize
переменная как [480 640] в разделе Preprocess Training Data примера.
doTraining = false; if doTraining fig = figure; lossPlotter = configureTrainingProgressPlotter(fig); iteration = 0; % Loop over epochs for epoch = 1:numEpochs % Reset and shuffle the data reset(mbqTrain); shuffle(mbqTrain); % Loop over each image while hasdata(mbqTrain) iteration = iteration + 1; % Read data from current mini-batch [dlInputSegMap,dlRealImage] = next(mbqTrain); % Evaluate the model gradients and the generator state using % dlfeval and the GANLoss function listed at the end of the % example [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = dlfeval( ... @modelGradients,dlInputSegMap,dlRealImage,dlnetGenerator,dlnetDiscriminatorScale1,dlnetDiscriminatorScale2,netVGG); % Update the generator parameters [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = adamupdate( ... dlnetGenerator,gradParamsG, ... trailingAvgGenerator,trailingAvgSqGenerator,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Update the discriminator scale1 parameters [dlnetDiscriminatorScale1,trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1] = adamupdate( ... dlnetDiscriminatorScale1,gradParamsDScale1, ... trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Update the discriminator scale2 parameters [dlnetDiscriminatorScale2,trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2] = adamupdate( ... dlnetDiscriminatorScale2,gradParamsDScale2, ... trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2,iteration, ... learningRate,gradientDecayFactor,squaredGradientDecayFactor); % Plot and display various losses lossPlotter = updateTrainingProgressPlotter(lossPlotter,iteration, ... epoch,numEpochs,lossD,lossGGAN,lossGFM,lossGVGG); end end save('trainedPix2PixHDNet.mat','dlnetGenerator'); else trainedPix2PixHDNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedPix2PixHDv2.zip'; netDir = fullfile(tempdir,'CamVid'); downloadTrainedPix2PixHDNet(trainedPix2PixHDNet_url,netDir); load(fullfile(netDir,'trainedPix2PixHDv2.mat')); end
Эффективность этой обученной сети Pix2PixHD ограничивается, потому что количество изображений обучения CamVid относительно мало. Кроме того, некоторые изображения принадлежат последовательности изображений и поэтому коррелируются с другими изображениями в наборе обучающих данных. Чтобы улучшить эффективность сети Pix2PixHD, обучите сеть с помощью различного набора данных, который имеет большее число учебных изображений без корреляции.
Из-за ограничений эта сеть Pix2PixHD генерирует более реалистические изображения для некоторых тестовых изображений, чем для других. Чтобы продемонстрировать различие в результатах, сравните сгенерированные изображения для первого и третьего тестового изображения. Угол камеры первого тестового изображения имеет редкую точку наблюдения, которая стоит перед большим количеством перпендикуляра к дороге, чем типичное учебное изображение. В отличие от этого угол камеры третьего тестового изображения имеет типичную точку наблюдения, которая стоит вдоль дороги и показывает два маршрута с маркерами маршрута. Сеть имеет значительно лучшую эффективность, генерирующую реалистическое изображение для третьего тестового изображения, чем для первого тестового изображения.
Получите первое изображение сцены основной истины от тестовых данных. Измените размер изображения с помощью бикубической интерполяции.
idxToTest = 1;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");
Получите соответствующее пиксельное изображение метки от тестовых данных. Измените размер пиксельного изображения метки с помощью самой близкой соседней интерполяции.
segMap = readimage(pxdsTest,idxToTest);
segMap = imresize(segMap,imageSize,"nearest");
Преобразуйте пиксельное изображение метки в многоканальную одногорячую карту сегментации при помощи onehotencode
(Deep Learning Toolbox) функция.
segMapOneHot = onehotencode(segMap,3,'single');
Создайте dlarray
объекты, который вводит данные к генератору. Если поддерживаемый графический процессор доступен для расчета, то выполните вывод на графическом процессоре путем преобразования данных в gpuArray
объект.
dlSegMap = dlarray(segMapOneHot,'SSCB'); if canUseGPU dlSegMap = gpuArray(dlSegMap); end
Сгенерируйте изображение сцены от генератора и одногорячей карты сегментации с помощью predict
(Deep Learning Toolbox) функция.
dlGeneratedImage = predict(dlnetGenerator,dlSegMap); generatedImage = extractdata(gather(dlGeneratedImage));
Последний слой сети генератора производит активации в области значений [-1, 1]. Для отображения перемасштабируйте активации к области значений [0, 1].
generatedImage = rescale(generatedImage);
Для отображения преобразуйте метки от категориальных меток до цветов RGB при помощи label2rgb
функция.
coloredSegMap = label2rgb(segMap,cmap);
Отобразите пиксельное изображение метки RGB, сгенерированное изображение сцены и изображение сцены основной истины в монтаже.
figure montage({coloredSegMap generatedImage gtImage},'Size',[1 3]) title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])
Получите третье изображение сцены основной истины от тестовых данных. Измените размер изображения с помощью бикубической интерполяции.
idxToTest = 3;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");
Чтобы получить третий пиксель помечают изображение от тестовых данных и сгенерировать соответствующее изображение сцены, можно использовать функцию помощника evaluatePix2PixHD
. Эта функция помощника присоединена к примеру как к вспомогательному файлу.
evaluatePix2PixHD
функция выполняет те же операции как оценка первого тестового изображения:
Получите пиксельное изображение метки от тестовых данных. Измените размер пиксельного изображения метки с помощью самой близкой соседней интерполяции.
Преобразуйте пиксельное изображение метки в многоканальную одногорячую карту сегментации с помощью onehotencode
(Deep Learning Toolbox) функция.
Создайте dlarray
возразите против входных данных к генератору. Для вывода графического процессора преобразуйте данные в gpuArray
объект.
Сгенерируйте изображение сцены от генератора и одногорячей карты сегментации с помощью predict
(Deep Learning Toolbox) функция.
Перемасштабируйте активации к области значений [0, 1].
[generatedImage,segMap] = evaluatePix2PixHD(pxdsTest,idxToTest,imageSize,dlnetGenerator);
Для отображения преобразуйте метки от категориальных меток до цветов RGB при помощи label2rgb
функция.
coloredSegMap = label2rgb(segMap,cmap);
Отобразите пиксельное изображение метки RGB, сгенерированное изображение сцены и изображение сцены основной истины в монтаже.
figure montage({coloredSegMap generatedImage gtImage},'Size',[1 3]) title(['Test Pixel Label Image ',num2str(idxToTest),' with Generated and Ground Truth Scene Images'])
Чтобы оценить, как хорошо сеть делает вывод к пиксельным изображениям метки вне набора данных CamVid, сгенерируйте изображения сцены от пользовательских пиксельных изображений метки. Этот пример использует пиксельные изображения метки, которые были созданы с помощью приложения Image Labeler. Пиксельные изображения метки присоединены к примеру как к вспомогательным файлам. Никакие изображения основной истины не доступны.
Создайте пиксельный datastore метки, который читает и обрабатывает пиксельные изображения метки в текущей директории в качестве примера.
cpxds = pixelLabelDatastore(pwd,classes,labelIDs);
Для каждого пиксельного изображения метки в datastore сгенерируйте изображение сцены с помощью функции помощника evaluatePix2PixHD
.
for idx = 1:length(cpxds.Files) % Get the pixel label image and generated scene image [generatedImage,segMap] = evaluatePix2PixHD(cpxds,idx,imageSize,dlnetGenerator); % For display, convert the labels from categorical labels to RGB colors coloredSegMap = label2rgb(segMap); % Display the pixel label image and generated scene image in a montage figure montage({coloredSegMap generatedImage}) title(['Custom Pixel Label Image ',num2str(idx),' and Generated Scene Image']) end
modelGradients
функция помощника вычисляет градиенты и соперничающую потерю для генератора и различителя. Функция также вычисляет потерю соответствия функции и потерю VGG для генератора.
function [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = modelGradients(inputSegMap,realImage,generator,discriminatorScale1,discriminatorScale2,netVGG) % Compute the image generated by the generator given the input semantic % map. generatedImage = forward(generator,inputSegMap); % Define the loss weights lambdaDiscriminator = 1; lambdaGenerator = 1; lambdaFeatureMatching = 5; lambdaVGG = 5; % Concatenate the image to be classified and the semantic map inpDiscriminatorReal = cat(3,inputSegMap,realImage); inpDiscriminatorGenerated = cat(3,inputSegMap,generatedImage); % Compute the adversarial loss for the discriminator and the generator % for first scale. [DLossScale1,GLossScale1,realPredScale1D,fakePredScale1G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale1); % Scale the generated image, the real image, and the input semantic map to % half size resizedRealImage = dlresize(realImage, 'Scale',0.5, 'Method',"linear"); resizedGeneratedImage = dlresize(generatedImage,'Scale',0.5,'Method',"linear"); resizedinputSegMap = dlresize(inputSegMap,'Scale',0.5,'Method',"nearest"); % Concatenate the image to be classified and the semantic map inpDiscriminatorReal = cat(3,resizedinputSegMap,resizedRealImage); inpDiscriminatorGenerated = cat(3,resizedinputSegMap,resizedGeneratedImage); % Compute the adversarial loss for the discriminator and the generator % for second scale. [DLossScale2,GLossScale2,realPredScale2D,fakePredScale2G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale2); % Compute the feature matching loss for first scale. FMLossScale1 = pix2pixHDFeatureMatchingLoss(realPredScale1D,fakePredScale1G); FMLossScale1 = FMLossScale1 * lambdaFeatureMatching; % Compute the feature matching loss for second scale. FMLossScale2 = pix2pixHDFeatureMatchingLoss(realPredScale2D,fakePredScale2G); FMLossScale2 = FMLossScale2 * lambdaFeatureMatching; % Compute the VGG loss VGGLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG); VGGLoss = VGGLoss * lambdaVGG; % Compute the combined generator loss lossGCombined = GLossScale1 + GLossScale2 + FMLossScale1 + FMLossScale2 + VGGLoss; lossGCombined = lossGCombined * lambdaGenerator; % Compute gradients for the generator gradParamsG = dlgradient(lossGCombined,generator.Learnables,'RetainData',true); % Compute the combined discriminator loss lossDCombined = (DLossScale1 + DLossScale2)/2 * lambdaDiscriminator; % Compute gradients for the discriminator scale1 gradParamsDScale1 = dlgradient(lossDCombined,discriminatorScale1.Learnables,'RetainData',true); % Compute gradients for the discriminator scale2 gradParamsDScale2 = dlgradient(lossDCombined,discriminatorScale2.Learnables); % Log the values for displaying later lossD = gather(extractdata(lossDCombined)); lossGGAN = gather(extractdata(GLossScale1 + GLossScale2)); lossGFM = gather(extractdata(FMLossScale1 + FMLossScale2)); lossGVGG = gather(extractdata(VGGLoss)); end
Функция помощника pix2pixHDAdverserialLoss
вычисляет соперничающие градиенты потерь для генератора и различителя. Функция также возвращает карты функции действительного изображения и синтетических изображений.
function [DLoss,GLoss,realPredFtrsD,genPredFtrsD] = pix2pixHDAdverserialLoss(inpReal,inpGenerated,discriminator) % Discriminator layer names containing feature maps featureNames = {'act_top','act_mid_1','act_mid_2','act_tail','conv2d_final'}; % Get the feature maps for the real image from the discriminator realPredFtrsD = cell(size(featureNames)); [realPredFtrsD{:}] = forward(discriminator,inpReal,"Outputs",featureNames); % Get the feature maps for the generated image from the discriminator genPredFtrsD = cell(size(featureNames)); [genPredFtrsD{:}] = forward(discriminator,inpGenerated,"Outputs",featureNames); % Get the feature map from the final layer to compute the loss realPredD = realPredFtrsD{end}; genPredD = genPredFtrsD{end}; % Compute the discriminator loss DLoss = (1 - realPredD).^2 + (genPredD).^2; DLoss = mean(DLoss,"all"); % Compute the generator loss GLoss = (1 - genPredD).^2; GLoss = mean(GLoss,"all"); end
Функция помощника pix2pixHDFeatureMatchingLoss
вычисляет потерю соответствия функции между действительным изображением и синтетическим изображением, сгенерированным генератором.
function featureMatchingLoss = pix2pixHDFeatureMatchingLoss(realPredFtrs,genPredFtrs) % Number of features numFtrsMaps = numel(realPredFtrs); % Initialize the feature matching loss featureMatchingLoss = 0; for i = 1:numFtrsMaps % Get the feature maps of the real image a = extractdata(realPredFtrs{i}); % Get the feature maps of the synthetic image b = genPredFtrs{i}; % Compute the feature matching loss featureMatchingLoss = featureMatchingLoss + mean(abs(a - b),"all"); end end
Функция помощника pix2pixHDVGGLoss
вычисляет перцепционную потерю VGG между действительным изображением и синтетическим изображением, сгенерированным генератором.
function vggLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG) featureWeights = [1.0/32 1.0/16 1.0/8 1.0/4 1.0]; % Initialize the VGG loss vggLoss = 0; % Specify the names of the layers with desired feature maps featureNames = ["relu1_1","relu2_1","relu3_1","relu4_1","relu5_1"]; % Extract the feature maps for the real image activReal = cell(size(featureNames)); [activReal{:}] = forward(netVGG,realImage,"Outputs",featureNames); % Extract the feature maps for the synthetic image activGenerated = cell(size(featureNames)); [activGenerated{:}] = forward(netVGG,generatedImage,"Outputs",featureNames); % Compute the VGG loss for i = 1:numel(featureNames) vggLoss = vggLoss + featureWeights(i)*mean(abs(activReal{i} - activGenerated{i}),"all"); end end
[1] Ван, Звон-Chun, Мин-Юй Лю, июнь-Yan Чжу, Эндрю Тао, Ян Коц и Брайан Кэйтанзаро. "Синтез изображений с высоким разрешением и Семантическая Манипуляция с Условным GANs". На 2018 Конференциях IEEE/CVF по Компьютерному зрению и Распознаванию образов, 8798–8807, 2018. https://doi.org/10.1109/CVPR.2018.00917.
[2] Brostow, Габриэль Дж., Жюльен Фокер и Роберто Сиполья. "Семантические Классы объектов в Видео: База данных Основной истины Высокой четкости". Pattern Recognition Letters. Vol. 30, Issue 2, 2009, стр 88-97.
combine
| imageDatastore
| pixelLabelDatastore
| transform
| trainingOptions
(Deep Learning Toolbox) | trainNetwork
(Deep Learning Toolbox) | vgg19
(Deep Learning Toolbox)