В этом примере показано, как выполнять преобразование домена между изображениями, полученными в дневное время, и в сумерках с использованием неподконтрольной сети преобразования изображения в изображение (UNIT).
Перевод домена - задача переноса стилей и характеристик из одного домена изображения в другой. Этот метод может быть расширен для других операций обучения изображению-изображению, таких как улучшение изображения, окрашивание изображения, генерация дефектов и медицинский анализ изображения.
UNIT [1] - тип генеративной состязательной сети (GAN), состоящий из одной генераторной сети и двух дискриминаторных сетей, которые обучаются одновременно для максимизации общей производительности. Дополнительные сведения о UNIT см. в разделе Начало работы с GAN для преобразования изображения в изображение (панель инструментов обработки изображения).
В этом примере для Набор данных CamVid[2]обучения используется Кембриджского университета. Этот набор данных представляет собой набор из 701 изображений, содержащих виды на уровне улиц, полученные во время вождения.
Загрузите набор данных CamVid с этого URL-адреса. Время загрузки зависит от вашего подключения к Интернету.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; dataDir = fullfile(tempdir,'CamVid'); downloadCamVidImageData(dataDir,imageURL); imgDir = fullfile(dataDir,"images","701_StillsRaw_full");
Набор данных изображения CamVid включает в себя 497 изображений, полученных в дневное время, и 124 изображения, полученных в сумерках. Производительность обучаемой сети UNIT ограничена, поскольку количество обучающих изображений CamVid относительно невелико, что ограничивает производительность обучаемой сети. Кроме того, некоторые изображения принадлежат последовательности изображений и поэтому коррелируются с другими изображениями в наборе данных. Чтобы минимизировать влияние этих ограничений, в этом примере данные вручную разбиваются на наборы обучающих и тестовых данных таким образом, чтобы максимизировать изменчивость обучающих данных.
Получите имена файлов дневных и сумерковых изображений для обучения и тестирования, загрузив файл camvidDayDuskDatasetFileNames.mat. Наборы обучающих данных состоят из 263 дневных изображений и 107 сумерков. Наборы тестовых данных состоят из 234 дневных изображений и 17 сумерков.
load('camvidDayDuskDatasetFileNames.mat');Создать imageDatastore объекты, управляющие изображениями дня и сумерков для обучения и тестирования.
imdsDayTrain = imageDatastore(fullfile(imgDir,trainDayNames)); imdsDuskTrain = imageDatastore(fullfile(imgDir,trainDuskNames)); imdsDayTest = imageDatastore(fullfile(imgDir,testDayNames)); imdsDuskTest = imageDatastore(fullfile(imgDir,testDuskNames));
Предварительный просмотр обучающего изображения из наборов данных обучения в дневные и сумерки.
day = preview(imdsDayTrain);
dusk = preview(imdsDuskTrain);
montage({day,dusk})
Укажите размер входного изображения для исходного и целевого изображений.
inputSize = [256,256,3];
Увеличение и предварительная обработка данных обучения с помощью transform функция с пользовательскими операциями предварительной обработки, заданными функцией помощника augmentDataForDayToDusk. Эта функция присоединена к примеру как вспомогательный файл.
augmentDataForDayToDusk функция выполняет следующие операции:
Измените размер изображения до указанного входного размера с помощью бикубической интерполяции.
Случайное разворот изображения в горизонтальном направлении.
Масштабируйте изображение до диапазона [-1, 1]. Этот диапазон соответствует диапазону конечного tanhLayer используется в генераторе.
imdsDayTrain = transform(imdsDayTrain, @(x)augmentDataForDayToDusk(x,inputSize)); imdsDuskTrain = transform(imdsDuskTrain, @(x)augmentDataForDayToDusk(x,inputSize));
Создание сети генератора UNIT с помощью unitGenerator(Панель инструментов обработки изображений). Каждая из секций источника и целевого кодера генератора состоит из двух блоков понижающей дискретизации и пяти остаточных блоков. Секции кодера совместно используют два из пяти остаточных блоков. Аналогично, каждая секция исходного и целевого декодера генератора состоит из двух блоков понижающей дискретизации и пяти остаточных блоков, и секции декодера совместно используют два из пяти остаточных блоков.
gen = unitGenerator(inputSize,'NumResidualBlocks',5,'NumSharedBlocks',2);
Визуализация сети генератора.
analyzeNetwork(gen)
Существует две дискриминаторные сети, по одной для каждого из доменов изображения (день и сумерки). Создайте дискриминаторы для исходного и целевого доменов с помощью patchGANDiscriminator(Панель инструментов обработки изображений).
discDay = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ... "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none"); discDusk = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ... "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");
Визуализация дискриминаторных сетей.
analyzeNetwork(discDay); analyzeNetwork(discDusk);
modelGradientsDisc и modelGradientGen вспомогательные функции вычисляют градиенты и потери для дискриминаторов и генератора соответственно. Эти функции определены в разделе «Вспомогательные функции» данного примера.
Задача каждого дискриминатора состоит в том, чтобы правильно различать реальные изображения (1) и преобразованные изображения (0) для изображений в своей области. Каждый дискриминатор имеет одну функцию потерь.
Целью генератора является формирование преобразованных изображений, которые дискриминаторы классифицируют как реальные. Потери генератора представляют собой взвешенную сумму пяти типов потерь: потери самовосстановления, потери согласованности цикла, скрытые потери KL, потери KL цикла и потери противника.
Укажите весовые коэффициенты для различных потерь.
lossWeights.selfReconLossWeight = 10; lossWeights.hiddenKLLossWeight = 0.01; lossWeights.cycleConsisLossWeight = 10; lossWeights.cycleHiddenKLLossWeight = 0.01; lossWeights.advLossWeight = 1; lossWeights.discLossWeight = 0.5;
Укажите параметры оптимизации Adam. Тренировать сеть в течение 35 эпох. Укажите идентичные опции для сетей генератора и дискриминатора.
Укажите равную скорость обучения 0,0001.
Инициализируйте среднюю градиентную скорость и среднюю градиентно-квадратную скорость затухания с помощью [].
Используйте градиентный коэффициент затухания 0,5 и квадрат градиентного коэффициента затухания 0,999.
Используйте регуляризацию распада веса с коэффициентом 0,0001.
Для обучения используйте мини-пакет размером 1.
learnRate = 0.0001; gradDecay = 0.5; sqGradDecay = 0.999; weightDecay = 0.0001; genAvgGradient = []; genAvgGradientSq = []; discDayAvgGradient = []; discDayAvgGradientSq = []; discDuskAvgGradient = []; discDuskAvgGradientSq = []; miniBatchSize = 1; numEpochs = 35;
Создать minibatchqueue объект, который управляет мини-пакетами наблюдений в пользовательском учебном цикле. minibatchqueue объект также передает данные в dlarray объект, обеспечивающий автоматическую дифференциацию в приложениях глубокого обучения.
Укажите формат извлечения данных мини-партии как SSCB (пространственный, пространственный, канальный, пакетный). Установите DispatchInBackground аргумент name-value как логическое значение, возвращенное canUseGPU. Если поддерживаемый графический процессор доступен для вычисления, то minibatchqueue объект предварительно обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.
mbqDayTrain = minibatchqueue(imdsDayTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU); mbqDuskTrain = minibatchqueue(imdsDuskTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
По умолчанию в примере загружается предварительно подготовленная версия генератора UNIT для набора данных CamVid с помощью вспомогательной функции. downloadTrainedDayDuskGeneratorNet. Вспомогательная функция прикрепляется к примеру как вспомогательный файл. Предварительно обученная сеть позволяет выполнять весь пример без ожидания завершения обучения.
Для обучения сети установите doTraining переменная в следующем коде true. Обучение модели в индивидуальном цикле обучения. Для каждой итерации:
Считывание данных для текущего мини-пакета с помощью next функция.
Оцените градиенты модели с помощью dlfeval функции и modelGradientsDisc и modelGradientGen вспомогательные функции.
Обновление параметров сети с помощью adamupdate функция.
Отображение входных и преобразованных изображений для исходного и целевого доменов после каждой эпохи.
Обучение на GPU, если он доступен. Для использования графического процессора требуются параллельные вычислительные Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ®. Дополнительные сведения см. в разделе Поддержка графического процессора по выпуску (Панель инструментов параллельных вычислений). Обучение занимает около 88 часов на устройстве NVIDIA Titan RTX.
doTraining = false; if doTraining % Create a figure to show the results figure("Units","Normalized"); for iPlot = 1:4 ax(iPlot) = subplot(2,2,iPlot); end iteration = 0; % Loop over epochs for epoch = 1:numEpochs % Shuffle data every epoch reset(mbqDayTrain); shuffle(mbqDayTrain); reset(mbqDuskTrain); shuffle(mbqDuskTrain); % Run the loop until all the images in the mini-batch queue mbqDayTrain are processed while hasdata(mbqDayTrain) iteration = iteration + 1; % Read data from the day domain imDay = next(mbqDayTrain); % Read data from the dusk domain if hasdata(mbqDuskTrain) == 0 reset(mbqDuskTrain); shuffle(mbqDuskTrain); end imDusk = next(mbqDuskTrain); % Calculate discriminator gradients and losses [discDayGrads,discDuskGrads,discDayLoss,disDuskLoss] = dlfeval(@modelGradientDisc, ... gen,discDay,discDusk,imDay,imDusk,lossWeights.discLossWeight); % Apply weight decay regularization on day discriminator gradients discDayGrads = dlupdate(@(g,w) g+weightDecay*w,discDayGrads,discDay.Learnables); % Update parameters of day discriminator [discDay,discDayAvgGradient,discDayAvgGradientSq] = adamupdate(discDay,discDayGrads, ... discDayAvgGradient,discDayAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); % Apply weight decay regularization on dusk discriminator gradients discDuskGrads = dlupdate(@(g,w) g+weightDecay*w,discDuskGrads,discDusk.Learnables); % Update parameters of dusk discriminator [discDusk,discDuskAvgGradient,discDuskAvgGradientSq] = adamupdate(discDusk,discDuskGrads, ... discDuskAvgGradient,discDuskAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); % Calculate generator gradient and loss [genGrad,genLoss,images] = dlfeval(@modelGradientGen,gen,discDay,discDusk,imDay,imDusk,lossWeights); % Apply weight decay regularization on generator gradients genGrad = dlupdate(@(g,w) g+weightDecay*w,genGrad,gen.Learnables); % Update parameters of generator [gen,genAvgGradient,genAvgGradientSq] = adamupdate(gen,genGrad,genAvgGradient, ... genAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); end % Display the results updateTrainingPlotDayToDusk(ax,images{:}); end % Save the trained network modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss")); save(strcat("trainedDayDuskUNITGeneratorNet-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'gen'); else net_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedDayDuskUNITGeneratorNet.zip'; downloadTrainedDayDuskGeneratorNet(net_url,dataDir); load(fullfile(dataDir,'trainedDayDuskUNITGeneratorNet.mat')); end
Преобразование «источник-цель» использует генератор UNIT для генерации изображения в целевом (сумерках) домене из изображения в исходном (дневном) домене.
Считывание изображения из хранилища данных дневных тестовых изображений.
idxToTest = 1; dayTestImage = readimage(imdsDayTest,idxToTest);
Преобразование изображения в тип данных single и нормализовать изображение к диапазону [-1, 1].
dayTestImage = im2single(dayTestImage); dayTestImage = (dayTestImage-0.5)/0.5;
Создать dlarray объект, который вводит данные в генератор. Если поддерживаемый графический процессор доступен для вычисления, выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.
dlDayImage = dlarray(dayTestImage,'SSCB'); if canUseGPU dlDayImage = gpuArray(dlDayImage); end
Перевести изображение дня ввода в область сумерков с помощью unitPredict(Панель инструментов обработки изображений).
dlDayToDuskImage = unitPredict(gen,dlDayImage); dayToDuskImage = extractdata(gather(dlDayToDuskImage));
Конечный слой генераторной сети создает активизации в диапазоне [-1, 1]. Для отображения измените масштаб активизаций на диапазон [0, 1]. Кроме того, перед отображением измените масштаб входного дневного изображения.
dayToDuskImage = rescale(dayToDuskImage); dayTestImage = rescale(dayTestImage);
Отображение изображения дня ввода и его переведенной версии в сумерках в монтаже.
figure
montage({dayTestImage dayToDuskImage})
title(['Day Test Image ',num2str(idxToTest),' with Translated Dusk Image'])
При преобразовании целевого изображения в исходное используется генератор UNIT для генерации изображения в исходном (дневном) домене из изображения в целевом (в сумерках) домене.
Считывание изображения из хранилища данных тестовых изображений в сумерках.
idxToTest = 1; duskTestImage = readimage(imdsDuskTest,idxToTest);
Преобразование изображения в тип данных single и нормализовать изображение к диапазону [-1, 1].
duskTestImage = im2single(duskTestImage); duskTestImage = (duskTestImage-0.5)/0.5;
Создать dlarray объект, который вводит данные в генератор. Если поддерживаемый графический процессор доступен для вычисления, выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.
dlDuskImage = dlarray(duskTestImage,'SSCB'); if canUseGPU dlDuskImage = gpuArray(dlDuskImage); end
Перевести входное изображение в сумерках в дневной домен с помощью unitPredict(Панель инструментов обработки изображений).
dlDuskToDayImage = unitPredict(gen,dlDuskImage,"OutputType","TargetToSource"); duskToDayImage = extractdata(gather(dlDuskToDayImage));
Для отображения измените масштаб активизаций на диапазон [0, 1]. Также перед отображением измените масштаб входного изображения в сумерках.
duskToDayImage = rescale(duskToDayImage); duskTestImage = rescale(duskTestImage);
Отображение входного изображения в сумерках и его переведенной дневной версии в монтаже.
montage({duskTestImage duskToDayImage})
title(['Test Dusk Image ',num2str(idxToTest),' with Translated Day Image'])
modelGradientDisc вспомогательная функция вычисляет градиенты и потери для двух дискриминаторов.
function [discAGrads,discBGrads,discALoss,discBLoss] = modelGradientDisc(gen, ... discA,discB,ImageA,ImageB,discLossWeight) [~,fakeA,fakeB,~] = forward(gen,ImageA,ImageB); % Calculate loss of the discriminator for X_A outA = forward(discA,ImageA); outfA = forward(discA,fakeA); discALoss = discLossWeight*computeDiscLoss(outA,outfA); % Update parameters of the discriminator for X discAGrads = dlgradient(discALoss,discA.Learnables); % Calculate loss of the discriminator for X_B outB = forward(discB,ImageB); outfB = forward(discB,fakeB); discBLoss = discLossWeight*computeDiscLoss(outB,outfB); % Update parameters of the discriminator for Y discBGrads = dlgradient(discBLoss,discB.Learnables); % Convert the data type from dlarray to single discALoss = extractdata(discALoss); discBLoss = extractdata(discBLoss); end
modelGradientGen вспомогательная функция вычисляет градиенты и потери для генератора.
function [genGrad,genLoss,images] = modelGradientGen(gen,discA,discB,ImageA,ImageB,lossWeights) [ImageAA,ImageBA,ImageAB,ImageBB] = forward(gen,ImageA,ImageB); hidden = forward(gen,ImageA,ImageB,'Outputs','encoderSharedBlock'); [~,ImageABA,ImageBAB,~] = forward(gen,ImageBA,ImageAB); cycle_hidden = forward(gen,ImageBA,ImageAB,'Outputs','encoderSharedBlock'); % Calculate different losses selfReconLoss = computeReconLoss(ImageA,ImageAA) + computeReconLoss(ImageB,ImageBB); hiddenKLLoss = computeKLLoss(hidden); cycleReconLoss = computeReconLoss(ImageA,ImageABA) + computeReconLoss(ImageB,ImageBAB); cycleHiddenKLLoss = computeKLLoss(cycle_hidden); outA = forward(discA,ImageBA); outB = forward(discB,ImageAB); advLoss = computeAdvLoss(outA) + computeAdvLoss(outB); % Calculate the total loss of generator as a weighted sum of five % losses genTotalLoss = ... selfReconLoss*lossWeights.selfReconLossWeight + ... hiddenKLLoss*lossWeights.hiddenKLLossWeight + ... cycleReconLoss*lossWeights.cycleConsisLossWeight + ... cycleHiddenKLLoss*lossWeights.cycleHiddenKLLossWeight + ... advLoss*lossWeights.advLossWeight; % Update the parameters of generator genGrad = dlgradient(genTotalLoss,gen.Learnables); % Convert the data type from dlarray to single genLoss = extractdata(genTotalLoss); images = {ImageA,ImageAB,ImageB,ImageBA}; end
computeDiscLoss вспомогательная функция вычисляет потерю дискриминатора. Каждая потеря дискриминатора представляет собой сумму двух компонентов:
Квадрат разности между вектором единиц и предсказаниями дискриминатора на реальных изображениях,
Квадрат разности между вектором нулей и предсказаниями дискриминатора на сгенерированных изображениях,
0-Yˆtranslated) 2
function discLoss = computeDiscLoss(Yreal,Ytranslated) discLoss = mean(((1-Yreal).^2),"all") + ... mean(((0-Ytranslated).^2),"all"); end
computeAdvLoss вспомогательная функция вычисляет состязательные потери для генератора. Состязательная потеря - это квадрат разности между вектором единиц и предсказаниями дискриминатора на преобразованном изображении.
) 2
function advLoss = computeAdvLoss(Ytranslated) advLoss = mean(((Ytranslated-1).^2),"all"); end
computeReconLoss вспомогательная функция вычисляет потери самовосстановления и потери последовательности циклов для генератора. Потеря самовосстановления - это расстояние между входными изображениями и их самовосстановленными версиями. Потеря согласованности циклов - это расстояние между входными изображениями и их восстановленными версиями.
) ‖ 1
) ‖ 1
function reconLoss = computeReconLoss(Yreal,Yrecon) reconLoss = mean(abs(Yreal-Yrecon),"all"); end
computeKLLoss вспомогательная функция вычисляет скрытые потери KL и циклические потери KL для генератора. Скрытая потеря KL - это квадрат разности между вектором нулей и 'encoderSharedBlock"активация для потока самовосстановления. Скрытая в цикле потеря KL - это квадрат разности между вектором нулей и 'encoderSharedBlock"активация для потока реконструкции цикла.
) 2
) 2
function klLoss = computeKLLoss(hidden) klLoss = mean(abs(hidden.^2),"all"); end
[1] Лю, Мин-Ю, Томас Бреуэль и Ян Каутц, «Неподконтрольные сети перевода изображений на изображения». In Advances in Neural Information Processing Systems, 2017. https://arxiv.org/abs/1703.00848.
[2] Бростоу, Габриэль Дж., Жюльен Фокер и Роберто Чиполла. «Классы семантических объектов в видео: база данных истинности земли высокой четкости». Буквы распознавания образов. Том 30, выпуск 2, 2009, стр. 88-97.
adamupdate | dlarray | dlfeval | minibatchqueue | transform | patchGANDiscriminator (Панель инструментов обработки изображений) | unitGenerator(Панель инструментов обработки изображений) | unitPredict(Панель инструментов обработки изображений)