В этом примере показано, как выполнить доменный перевод между изображениями, полученными во время дневного времени и условий сумрака с помощью безнадзорной сети перевода от изображения к изображению (МОДУЛЬ).
Доменный перевод является задачей передачи стилей и характеристик от одной области изображения до другого. Этот метод может быть расширен к другим операциям изучения от изображения к изображению, таким как повышение качества изображения, колоризация изображений, дефектная генерация и медицинский анализ изображения.
МОДУЛЬ [1] является типом порождающей соперничающей сети (GAN), которая состоит из одной сети генератора и двух сетей различителя, которые вы обучаете одновременно, чтобы максимизировать общую производительность. Для получения дополнительной информации о МОДУЛЕ, смотрите Начало работы с GANs для Перевода От изображения к изображению (Image Processing Toolbox).
Этот пример использует Набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных является набором 701 изображения, содержащего представления уличного уровня, полученные при управлении.
Загрузите набор данных CamVid с этого URL. Время загрузки зависит от вашего интернет-соединения.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; dataDir = fullfile(tempdir,'CamVid'); downloadCamVidImageData(dataDir,imageURL); imgDir = fullfile(dataDir,"images","701_StillsRaw_full");
Набор данных изображения CamVid включает 497 изображений, полученных в дневное время и 124 изображения, полученные в сумраке. Эффективность обученной сети UNIT ограничивается, потому что количество изображений обучения CamVid относительно мало, который ограничивает эффективность обучившего сеть. Далее, некоторые изображения принадлежат последовательности изображений и поэтому коррелируются с другими изображениями в наборе данных. Чтобы минимизировать удар этих ограничений, этот пример вручную делит данные в обучение и наборы тестовых данных способом, который максимизирует изменчивость обучающих данных.
Получите имена файлов дня и изображений сумрака для обучения и тестирования путем загрузки файла 'camvidDayDuskDatasetFileNames.mat'. Обучающие наборы данных состоят из 263-дневных изображений и 107 изображений сумрака. Наборы тестовых данных состоят из 234-дневных изображений и 17 изображений сумрака.
load('camvidDayDuskDatasetFileNames.mat');Создайте imageDatastore объекты, которые управляют днем и изображениями сумрака для обучения и тестирования.
imdsDayTrain = imageDatastore(fullfile(imgDir,trainDayNames)); imdsDuskTrain = imageDatastore(fullfile(imgDir,trainDuskNames)); imdsDayTest = imageDatastore(fullfile(imgDir,testDayNames)); imdsDuskTest = imageDatastore(fullfile(imgDir,testDuskNames));
Предварительно просмотрите учебное изображение с обучающих наборов данных сумрака и дня.
day = preview(imdsDayTrain);
dusk = preview(imdsDuskTrain);
montage({day,dusk})
Задайте входной размер изображений для входных и выходных изображений.
inputSize = [256,256,3];
Увеличьте и предварительно обработайте обучающие данные при помощи transform функция с пользовательскими операциями предварительной обработки, заданными помощником, функционирует augmentDataForDayToDuskЭта функция присоединена к примеру как вспомогательный файл.
augmentDataForDayToDusk функция выполняет эти операции:
Измените размер изображения к заданному входному размеру с помощью бикубической интерполяции.
Случайным образом инвертируйте изображение в горизонтальном направлении.
Масштабируйте изображение к области значений [-1, 1]. Эта область значений совпадает с областью значений итогового tanhLayer используемый в генераторе.
imdsDayTrain = transform(imdsDayTrain, @(x)augmentDataForDayToDusk(x,inputSize)); imdsDuskTrain = transform(imdsDuskTrain, @(x)augmentDataForDayToDusk(x,inputSize));
Создайте МОДУЛЬНУЮ сеть генератора использование unitGenerator (Image Processing Toolbox) функция. Входные и выходные разделы энкодера генератора каждый состоит из двух блоков субдискретизации и пяти остаточных блоков. Разделы энкодера совместно используют два из пяти остаточных блоков. Аналогично, входные и выходные разделы декодера генератора, каждый состоит из двух блоков субдискретизации и пяти остаточных блоков и разделов декодера, совместно используют два из пяти остаточных блоков.
gen = unitGenerator(inputSize,'NumResidualBlocks',5,'NumSharedBlocks',2);
Визуализируйте сеть генератора.
analyzeNetwork(gen)
Существует две сети различителя, один для каждой из областей изображения (день и сумрак). Создайте различители для входных и выходных областей с помощью patchGANDiscriminator (Image Processing Toolbox) функция.
discDay = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ... "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none"); discDusk = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ... "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");
Визуализируйте сети различителя.
analyzeNetwork(discDay); analyzeNetwork(discDusk);
modelGradientsDisc и modelGradientGen функции помощника вычисляют градиенты и потери для различителей и генератора, соответственно. Эти функции заданы в разделе Supporting Functions этого примера.
Цель каждого различителя состоит в том, чтобы правильно различать действительные изображения (1) и переведенные изображения (0) для изображений в его области. Каждый различитель имеет одну функцию потерь.
Цель генератора состоит в том, чтобы сгенерировать переведенные изображения, которые различители классифицируют как действительные. Потеря генератора является взвешенной суммой пяти типов потерь: потеря самореконструкции, потеря непротиворечивости цикла, скрытая потеря KL, цикл скрытая потеря KL и adverserial потеря.
Задайте весовые коэффициенты за различные потери.
lossWeights.selfReconLossWeight = 10; lossWeights.hiddenKLLossWeight = 0.01; lossWeights.cycleConsisLossWeight = 10; lossWeights.cycleHiddenKLLossWeight = 0.01; lossWeights.advLossWeight = 1; lossWeights.discLossWeight = 0.5;
Задайте опции для оптимизации Адама. Обучите сеть в течение 35 эпох. Задайте идентичные опции для сетей различителя и генератора.
Задайте равную скорость обучения 0,0001.
Инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с [].
Используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.
Используйте регуляризацию затухания веса с фактором 0,0001.
Используйте мини-пакетный размер 1 для обучения.
learnRate = 0.0001; gradDecay = 0.5; sqGradDecay = 0.999; weightDecay = 0.0001; genAvgGradient = []; genAvgGradientSq = []; discDayAvgGradient = []; discDayAvgGradientSq = []; discDuskAvgGradient = []; discDuskAvgGradientSq = []; miniBatchSize = 1; numEpochs = 35;
Создайте minibatchqueue объект, который справляется с мини-пакетной обработкой наблюдений в пользовательском учебном цикле. minibatchqueue возразите также бросает данные к dlarray объект, который включает автоматическое дифференцирование в применении глубокого обучения.
Задайте мини-пакетный формат экстракции данных как SSCB (пространственный, пространственный, канал, пакет). Установите DispatchInBackground аргумент значения имени как булевская переменная, возвращенная canUseGPU. Если поддерживаемый графический процессор доступен для расчета, то minibatchqueue объект предварительно обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.
mbqDayTrain = minibatchqueue(imdsDayTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU); mbqDuskTrain = minibatchqueue(imdsDuskTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
По умолчанию пример загружает предварительно обученную версию МОДУЛЬНОГО генератора для набора данных CamVid при помощи функции помощника downloadTrainedDayDuskGeneratorNet. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.
Чтобы обучить сеть, установите doTraining переменная в следующем коде к true. Обучите модель в пользовательском учебном цикле. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next функция.
Оцените градиенты модели с помощью dlfeval функционируйте и modelGradientsDisc и modelGradientGen функции помощника.
Обновите сетевые параметры с помощью adamupdate функция.
Отобразите вход и переведенные изображения для обоих входные и выходные области после каждой эпохи.
Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Обучение занимает приблизительно 88 часов на Титане NVIDIA RTX.
doTraining = false; if doTraining % Create a figure to show the results figure("Units","Normalized"); for iPlot = 1:4 ax(iPlot) = subplot(2,2,iPlot); end iteration = 0; % Loop over epochs for epoch = 1:numEpochs % Shuffle data every epoch reset(mbqDayTrain); shuffle(mbqDayTrain); reset(mbqDuskTrain); shuffle(mbqDuskTrain); % Run the loop until all the images in the mini-batch queue mbqDayTrain are processed while hasdata(mbqDayTrain) iteration = iteration + 1; % Read data from the day domain imDay = next(mbqDayTrain); % Read data from the dusk domain if hasdata(mbqDuskTrain) == 0 reset(mbqDuskTrain); shuffle(mbqDuskTrain); end imDusk = next(mbqDuskTrain); % Calculate discriminator gradients and losses [discDayGrads,discDuskGrads,discDayLoss,disDuskLoss] = dlfeval(@modelGradientDisc, ... gen,discDay,discDusk,imDay,imDusk,lossWeights.discLossWeight); % Apply weight decay regularization on day discriminator gradients discDayGrads = dlupdate(@(g,w) g+weightDecay*w,discDayGrads,discDay.Learnables); % Update parameters of day discriminator [discDay,discDayAvgGradient,discDayAvgGradientSq] = adamupdate(discDay,discDayGrads, ... discDayAvgGradient,discDayAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); % Apply weight decay regularization on dusk discriminator gradients discDuskGrads = dlupdate(@(g,w) g+weightDecay*w,discDuskGrads,discDusk.Learnables); % Update parameters of dusk discriminator [discDusk,discDuskAvgGradient,discDuskAvgGradientSq] = adamupdate(discDusk,discDuskGrads, ... discDuskAvgGradient,discDuskAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); % Calculate generator gradient and loss [genGrad,genLoss,images] = dlfeval(@modelGradientGen,gen,discDay,discDusk,imDay,imDusk,lossWeights); % Apply weight decay regularization on generator gradients genGrad = dlupdate(@(g,w) g+weightDecay*w,genGrad,gen.Learnables); % Update parameters of generator [gen,genAvgGradient,genAvgGradientSq] = adamupdate(gen,genGrad,genAvgGradient, ... genAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); end % Display the results updateTrainingPlotDayToDusk(ax,images{:}); end % Save the trained network modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss")); save(strcat("trainedDayDuskUNITGeneratorNet-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'gen'); else net_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedDayDuskUNITGeneratorNet.zip'; downloadTrainedDayDuskGeneratorNet(net_url,dataDir); load(fullfile(dataDir,'trainedDayDuskUNITGeneratorNet.mat')); end
Перевод источника к целевому изображению использует МОДУЛЬНЫЙ генератор, чтобы сгенерировать изображение в цели (сумрак) область от изображения в источнике (день) область.
Считайте изображение из datastore дневных тестовых изображений.
idxToTest = 1; dayTestImage = readimage(imdsDayTest,idxToTest);
Преобразуйте изображение в тип данных single и нормируйте изображение к области значений [-1, 1].
dayTestImage = im2single(dayTestImage); dayTestImage = (dayTestImage-0.5)/0.5;
Создайте dlarray возразите что входные данные против генератора. Если поддерживаемый графический процессор доступен для расчета, то выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.
dlDayImage = dlarray(dayTestImage,'SSCB'); if canUseGPU dlDayImage = gpuArray(dlDayImage); end
Переведите входное дневное изображение в область сумрака использование unitPredict (Image Processing Toolbox) функция.
dlDayToDuskImage = unitPredict(gen,dlDayImage); dayToDuskImage = extractdata(gather(dlDayToDuskImage));
Последний слой сети генератора производит активации в области значений [-1, 1]. Для отображения перемасштабируйте активации к области значений [0, 1]. Кроме того, перемасштабируйте входное дневное изображение перед отображением.
dayToDuskImage = rescale(dayToDuskImage); dayTestImage = rescale(dayTestImage);
Отобразите входное дневное изображение и его переведенную версию сумрака в монтаже.
figure
montage({dayTestImage dayToDuskImage})
title(['Day Test Image ',num2str(idxToTest),' with Translated Dusk Image'])
Перевод цели к исходному изображению использует МОДУЛЬНЫЙ генератор, чтобы сгенерировать изображение в источнике (день) область от изображения в цели (сумрак) область.
Считайте изображение из datastore тестовых изображений сумрака.
idxToTest = 1; duskTestImage = readimage(imdsDuskTest,idxToTest);
Преобразуйте изображение в тип данных single и нормируйте изображение к области значений [-1, 1].
duskTestImage = im2single(duskTestImage); duskTestImage = (duskTestImage-0.5)/0.5;
Создайте dlarray возразите что входные данные против генератора. Если поддерживаемый графический процессор доступен для расчета, то выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.
dlDuskImage = dlarray(duskTestImage,'SSCB'); if canUseGPU dlDuskImage = gpuArray(dlDuskImage); end
Переведите входное изображение сумрака в дневную область использование unitPredict (Image Processing Toolbox) функция.
dlDuskToDayImage = unitPredict(gen,dlDuskImage,"OutputType","TargetToSource"); duskToDayImage = extractdata(gather(dlDuskToDayImage));
Для отображения перемасштабируйте активации к области значений [0, 1]. Кроме того, перемасштабируйте входное изображение сумрака перед отображением.
duskToDayImage = rescale(duskToDayImage); duskTestImage = rescale(duskTestImage);
Отобразите входное изображение сумрака и его переведенную дневную версию в монтаже.
montage({duskTestImage duskToDayImage})
title(['Test Dusk Image ',num2str(idxToTest),' with Translated Day Image'])
modelGradientDisc функция помощника вычисляет градиенты и потерю для этих двух различителей.
function [discAGrads,discBGrads,discALoss,discBLoss] = modelGradientDisc(gen, ... discA,discB,ImageA,ImageB,discLossWeight) [~,fakeA,fakeB,~] = forward(gen,ImageA,ImageB); % Calculate loss of the discriminator for X_A outA = forward(discA,ImageA); outfA = forward(discA,fakeA); discALoss = discLossWeight*computeDiscLoss(outA,outfA); % Update parameters of the discriminator for X discAGrads = dlgradient(discALoss,discA.Learnables); % Calculate loss of the discriminator for X_B outB = forward(discB,ImageB); outfB = forward(discB,fakeB); discBLoss = discLossWeight*computeDiscLoss(outB,outfB); % Update parameters of the discriminator for Y discBGrads = dlgradient(discBLoss,discB.Learnables); % Convert the data type from dlarray to single discALoss = extractdata(discALoss); discBLoss = extractdata(discBLoss); end
modelGradientGen функция помощника вычисляет градиенты и потерю для генератора.
function [genGrad,genLoss,images] = modelGradientGen(gen,discA,discB,ImageA,ImageB,lossWeights) [ImageAA,ImageBA,ImageAB,ImageBB] = forward(gen,ImageA,ImageB); hidden = forward(gen,ImageA,ImageB,'Outputs','encoderSharedBlock'); [~,ImageABA,ImageBAB,~] = forward(gen,ImageBA,ImageAB); cycle_hidden = forward(gen,ImageBA,ImageAB,'Outputs','encoderSharedBlock'); % Calculate different losses selfReconLoss = computeReconLoss(ImageA,ImageAA) + computeReconLoss(ImageB,ImageBB); hiddenKLLoss = computeKLLoss(hidden); cycleReconLoss = computeReconLoss(ImageA,ImageABA) + computeReconLoss(ImageB,ImageBAB); cycleHiddenKLLoss = computeKLLoss(cycle_hidden); outA = forward(discA,ImageBA); outB = forward(discB,ImageAB); advLoss = computeAdvLoss(outA) + computeAdvLoss(outB); % Calculate the total loss of generator as a weighted sum of five % losses genTotalLoss = ... selfReconLoss*lossWeights.selfReconLossWeight + ... hiddenKLLoss*lossWeights.hiddenKLLossWeight + ... cycleReconLoss*lossWeights.cycleConsisLossWeight + ... cycleHiddenKLLoss*lossWeights.cycleHiddenKLLossWeight + ... advLoss*lossWeights.advLossWeight; % Update the parameters of generator genGrad = dlgradient(genTotalLoss,gen.Learnables); % Convert the data type from dlarray to single genLoss = extractdata(genTotalLoss); images = {ImageA,ImageAB,ImageB,ImageBA}; end
computeDiscLoss функция помощника вычисляет потерю различителя. Каждая потеря различителя является суммой двух компонентов:
Различие в квадрате между вектором из единиц и предсказаниями различителя на действительных изображениях,
Различие в квадрате между нулевым вектором и предсказаниями различителя на сгенерированных изображениях,
function discLoss = computeDiscLoss(Yreal,Ytranslated) discLoss = mean(((1-Yreal).^2),"all") + ... mean(((0-Ytranslated).^2),"all"); end
computeAdvLoss функция помощника вычисляет соперничающую потерю для генератора. Соперничающая потеря является различием в квадрате между вектором из единиц и предсказаниями различителя на переведенном изображении.
function advLoss = computeAdvLoss(Ytranslated) advLoss = mean(((Ytranslated-1).^2),"all"); end
computeReconLoss функция помощника вычисляет потерю самореконструкции и потерю непротиворечивости цикла для генератора. Сам потеря реконструкции расстояние между входными изображениями и их самовосстановленными версиями. Потеря непротиворечивости цикла расстояние между входными изображениями и их восстановленными циклом версиями.
function reconLoss = computeReconLoss(Yreal,Yrecon) reconLoss = mean(abs(Yreal-Yrecon),"all"); end
computeKLLoss функция помощника вычисляет скрытую потерю KL и скрытую от цикла потерю KL для генератора. Скрытая потеря KL является различием в квадрате между нулевым вектором и 'encoderSharedBlock'активация для потока самореконструкции. Скрытая от цикла потеря KL является различием в квадрате между нулевым вектором и 'encoderSharedBlock'активация для потока реконструкции цикла.
function klLoss = computeKLLoss(hidden) klLoss = mean(abs(hidden.^2),"all"); end
[1] Лю, Мин-Юй, Томас Бреуель и Ян Коц, "Безнадзорные сети перевода от изображения к изображению". В Усовершенствованиях в Нейронных Системах обработки информации, 2017. https://arxiv.org/abs/1703.00848.
[2] Brostow, Габриэль Дж., Жюльен Фокер и Роберто Сиполья. "Семантические Классы объектов в Видео: База данных Основной истины Высокой четкости". Pattern Recognition Letters. Vol. 30, Issue 2, 2009, стр 88-97.
adamupdate | dlarray | dlfeval | minibatchqueue | transform | patchGANDiscriminator (Image Processing Toolbox) | unitGenerator (Image Processing Toolbox) | unitPredict (Image Processing Toolbox)