Безнадзорный перевод дня к сумраку изображений Используя МОДУЛЬ

В этом примере показано, как выполнить доменный перевод между изображениями, полученными во время дневного времени и условий сумрака с помощью безнадзорной сети перевода от изображения к изображению (МОДУЛЬ).

Доменный перевод является задачей передачи стилей и характеристик от одной области изображения до другого. Этот метод может быть расширен к другим операциям изучения от изображения к изображению, таким как повышение качества изображения, колоризация изображений, дефектная генерация и медицинский анализ изображения.

МОДУЛЬ [1] является типом порождающей соперничающей сети (GAN), которая состоит из одной сети генератора и двух сетей различителя, которые вы обучаете одновременно, чтобы максимизировать общую производительность. Для получения дополнительной информации о МОДУЛЕ, смотрите Начало работы с GANs для Перевода От изображения к изображению.

Загрузите набор данных

Этот пример использует Набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных является набором 701 изображения, содержащего представления уличного уровня, полученные при управлении.

Загрузите набор данных CamVid с этого URL. Время загрузки зависит от вашего интернет-соединения.

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';

dataDir = fullfile(tempdir,'CamVid'); 
downloadCamVidImageData(dataDir,imageURL);
imgDir = fullfile(dataDir,"images","701_StillsRaw_full");

Загрузите данные дня и сумрака

Набор данных изображения CamVid включает 497 изображений, полученных в дневное время и 124 изображения, полученные в сумраке. Эффективность обученной сети UNIT ограничивается, потому что количество изображений обучения CamVid относительно мало, который ограничивает эффективность обучившего сеть. Далее, некоторые изображения принадлежат последовательности изображений и поэтому коррелируются с другими изображениями в наборе данных. Чтобы минимизировать удар этих ограничений, этот пример вручную делит данные в обучение и наборы тестовых данных способом, который максимизирует изменчивость обучающих данных.

Получите имена файлов дня и изображений сумрака для обучения и тестирования путем загрузки файла 'camvidDayDuskDatasetFileNames.mat'. Обучающие наборы данных состоят из 263-дневных изображений и 107 изображений сумрака. Наборы тестовых данных состоят из 234-дневных изображений и 17 изображений сумрака.

load('camvidDayDuskDatasetFileNames.mat');

Создайте imageDatastore объекты, которые управляют днем и изображениями сумрака для обучения и тестирования.

imdsDayTrain = imageDatastore(fullfile(imgDir,trainDayNames));
imdsDuskTrain = imageDatastore(fullfile(imgDir,trainDuskNames));
imdsDayTest = imageDatastore(fullfile(imgDir,testDayNames));
imdsDuskTest = imageDatastore(fullfile(imgDir,testDuskNames));

Предварительно просмотрите учебное изображение с обучающих наборов данных сумрака и дня.

day = preview(imdsDayTrain);
dusk = preview(imdsDuskTrain);
montage({day,dusk})

Предварительно обработайте и увеличьте обучающие данные

Задайте входной размер изображений для входных и выходных изображений.

inputSize = [256,256,3];

Увеличьте и предварительно обработайте обучающие данные при помощи transform функция с пользовательскими операциями предварительной обработки, заданными помощником, функционирует augmentDataForDayToDuskЭта функция присоединена к примеру как вспомогательный файл.

augmentDataForDayToDusk функция выполняет эти операции:

  1. Измените размер изображения к заданному входному размеру с помощью бикубической интерполяции.

  2. Случайным образом инвертируйте изображение в горизонтальном направлении.

  3. Масштабируйте изображение к области значений [-1, 1]. Эта область значений совпадает с областью значений итогового tanhLayer (Deep Learning Toolbox) используется в генераторе.

imdsDayTrain = transform(imdsDayTrain, @(x)augmentDataForDayToDusk(x,inputSize));
imdsDuskTrain = transform(imdsDuskTrain, @(x)augmentDataForDayToDusk(x,inputSize));

Создайте сеть генератора

Создайте МОДУЛЬНУЮ сеть генератора использование unitGenerator функция. Входные и выходные разделы энкодера генератора каждый состоит из двух блоков субдискретизации и пяти остаточных блоков. Разделы энкодера совместно используют два из пяти остаточных блоков. Аналогично, входные и выходные разделы декодера генератора, каждый состоит из двух блоков субдискретизации и пяти остаточных блоков и разделов декодера, совместно используют два из пяти остаточных блоков.

gen = unitGenerator(inputSize,'NumResidualBlocks',5,'NumSharedBlocks',2);

Визуализируйте сеть генератора.

analyzeNetwork(gen)

Создайте сети различителя

Существует две сети различителя, один для каждой из областей изображения (день и сумрак). Создайте различители для входных и выходных областей с помощью patchGANDiscriminator функция.

discDay = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ...
    "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");
discDusk = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ...
    "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");

Визуализируйте сети различителя.

analyzeNetwork(discDay);
analyzeNetwork(discDusk);

Градиенты модели Define и функции потерь

modelGradientsDisc и modelGradientGen функции помощника вычисляют градиенты и потери для различителей и генератора, соответственно. Эти функции заданы в разделе Supporting Functions этого примера.

Цель каждого различителя состоит в том, чтобы правильно различать действительные изображения (1) и переведенные изображения (0) для изображений в его области. Каждый различитель имеет одну функцию потерь.

Цель генератора состоит в том, чтобы сгенерировать переведенные изображения, которые различители классифицируют как действительные. Потеря генератора является взвешенной суммой пяти типов потерь: потеря самореконструкции, потеря непротиворечивости цикла, скрытая потеря KL, цикл скрытая потеря KL и adverserial потеря.

Задайте весовые коэффициенты за различные потери.

lossWeights.selfReconLossWeight = 10;
lossWeights.hiddenKLLossWeight = 0.01;
lossWeights.cycleConsisLossWeight = 10;
lossWeights.cycleHiddenKLLossWeight = 0.01;
lossWeights.advLossWeight = 1;
lossWeights.discLossWeight = 0.5;

Задайте опции обучения

Задайте опции для оптимизации Адама. Обучите сеть в течение 35 эпох. Задайте идентичные опции для сетей различителя и генератора.

  • Задайте равную скорость обучения 0,0001.

  • Инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с [].

  • Используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.

  • Используйте регуляризацию затухания веса с фактором 0,0001.

  • Используйте мини-пакетный размер 1 для обучения.

learnRate = 0.0001;
gradDecay = 0.5;
sqGradDecay = 0.999;
weightDecay = 0.0001;

genAvgGradient = [];
genAvgGradientSq = [];

discDayAvgGradient = [];
discDayAvgGradientSq = [];

discDuskAvgGradient = [];
discDuskAvgGradientSq = [];

miniBatchSize = 1;
numEpochs = 35;

Пакетные обучающие данные

Создайте minibatchqueue Объект (Deep Learning Toolbox), который справляется с мини-пакетной обработкой наблюдений в пользовательском учебном цикле. minibatchqueue возразите также бросает данные к dlarray Объект (Deep Learning Toolbox), который включает автоматическое дифференцирование в применении глубокого обучения.

Задайте мини-пакетный формат экстракции данных как SSCB (пространственный, пространственный, канал, пакет). Установите DispatchInBackground аргумент значения имени как булевская переменная, возвращенная canUseGPU. Если поддерживаемый графический процессор доступен для расчета, то minibatchqueue объект предварительно обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.

mbqDayTrain = minibatchqueue(imdsDayTrain,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
mbqDuskTrain = minibatchqueue(imdsDuskTrain,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);

Обучите сеть

По умолчанию пример загружает предварительно обученную версию МОДУЛЬНОГО генератора для набора данных CamVid при помощи функции помощника downloadTrainedDayDuskGeneratorNet. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.

Чтобы обучить сеть, установите doTraining переменная в следующем коде к true. Обучите модель в пользовательском учебном цикле. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next (Deep Learning Toolbox) функция.

  • Оцените градиенты модели с помощью dlfeval (Deep Learning Toolbox) функция и modelGradientsDisc и modelGradientGen функции помощника.

  • Обновите сетевые параметры с помощью adamupdate (Deep Learning Toolbox) функция.

  • Отобразите вход и переведенные изображения для обоих входные и выходные области после каждой эпохи.

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Обучение занимает приблизительно 88 часов на Титане NVIDIA RTX.

doTraining = false;
if doTraining
    % Create a figure to show the results
    figure("Units","Normalized");
    for iPlot = 1:4
        ax(iPlot) = subplot(2,2,iPlot);
    end
    
    iteration = 0;

    % Loop over epochs
    for epoch = 1:numEpochs
        
        % Shuffle data every epoch
        reset(mbqDayTrain);
        shuffle(mbqDayTrain);
        reset(mbqDuskTrain);
        shuffle(mbqDuskTrain);
        
        % Run the loop until all the images in the mini-batch queue mbqDayTrain are processed
        while hasdata(mbqDayTrain)
            iteration = iteration + 1;
            
            % Read data from the day domain
            imDay = next(mbqDayTrain); 
             
            % Read data from the dusk domain
            if hasdata(mbqDuskTrain) == 0
                reset(mbqDuskTrain);
                shuffle(mbqDuskTrain);
            end
            imDusk = next(mbqDuskTrain);
    
            % Calculate discriminator gradients and losses
            [discDayGrads,discDuskGrads,discDayLoss,disDuskLoss] = dlfeval(@modelGradientDisc, ...
                gen,discDay,discDusk,imDay,imDusk,lossWeights.discLossWeight);
            
            % Apply weight decay regularization on day discriminator gradients
            discDayGrads = dlupdate(@(g,w) g+weightDecay*w,discDayGrads,discDay.Learnables);
            
            % Update parameters of day discriminator
            [discDay,discDayAvgGradient,discDayAvgGradientSq] = adamupdate(discDay,discDayGrads, ...
                discDayAvgGradient,discDayAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);  
            
            % Apply weight decay regularization on dusk discriminator gradients
            discDuskGrads = dlupdate(@(g,w) g+weightDecay*w,discDuskGrads,discDusk.Learnables);
            
            % Update parameters of dusk discriminator
            [discDusk,discDuskAvgGradient,discDuskAvgGradientSq] = adamupdate(discDusk,discDuskGrads, ...
                discDuskAvgGradient,discDuskAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);
            
            % Calculate generator gradient and loss
            [genGrad,genLoss,images] = dlfeval(@modelGradientGen,gen,discDay,discDusk,imDay,imDusk,lossWeights);
            
            % Apply weight decay regularization on generator gradients
            genGrad = dlupdate(@(g,w) g+weightDecay*w,genGrad,gen.Learnables);
            
            % Update parameters of generator
            [gen,genAvgGradient,genAvgGradientSq] = adamupdate(gen,genGrad,genAvgGradient, ...
                genAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);
        end
        
        % Display the results
        updateTrainingPlotDayToDusk(ax,images{:});
    end
    
    % Save the trained network
    modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss"));
    save(strcat("trainedDayDuskUNITGeneratorNet-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'gen');
    
else    
    net_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedDayDuskUNITGeneratorNet.zip';
    downloadTrainedDayDuskGeneratorNet(net_url,dataDir);
    load(fullfile(dataDir,'trainedDayDuskUNITGeneratorNet.mat'));
end

Оцените источник, чтобы предназначаться для перевода

Перевод источника к целевому изображению использует МОДУЛЬНЫЙ генератор, чтобы сгенерировать изображение в цели (сумрак) область от изображения в источнике (день) область.

Считайте изображение из datastore дневных тестовых изображений.

idxToTest = 1;
dayTestImage = readimage(imdsDayTest,idxToTest);

Преобразуйте изображение в тип данных single и нормируйте изображение к области значений [-1, 1].

dayTestImage = im2single(dayTestImage);
dayTestImage = (dayTestImage-0.5)/0.5;

Создайте dlarray возразите что входные данные против генератора. Если поддерживаемый графический процессор доступен для расчета, то выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.

dlDayImage = dlarray(dayTestImage,'SSCB');    
if canUseGPU
    dlDayImage = gpuArray(dlDayImage);
end

Переведите входное дневное изображение в область сумрака использование unitPredict функция.

dlDayToDuskImage = unitPredict(gen,dlDayImage);
dayToDuskImage = extractdata(gather(dlDayToDuskImage));

Последний слой сети генератора производит активации в области значений [-1, 1]. Для отображения перемасштабируйте активации к области значений [0, 1]. Кроме того, перемасштабируйте входное дневное изображение перед отображением.

dayToDuskImage = rescale(dayToDuskImage);
dayTestImage = rescale(dayTestImage);

Отобразите входное дневное изображение и его переведенную версию сумрака в монтаже.

figure
montage({dayTestImage dayToDuskImage})
title(['Day Test Image ',num2str(idxToTest),' with Translated Dusk Image'])

Оцените перевод цели к источнику

Перевод цели к исходному изображению использует МОДУЛЬНЫЙ генератор, чтобы сгенерировать изображение в источнике (день) область от изображения в цели (сумрак) область.

Считайте изображение из datastore тестовых изображений сумрака.

idxToTest = 1;
duskTestImage = readimage(imdsDuskTest,idxToTest);

Преобразуйте изображение в тип данных single и нормируйте изображение к области значений [-1, 1].

duskTestImage = im2single(duskTestImage);
duskTestImage = (duskTestImage-0.5)/0.5;

Создайте dlarray возразите что входные данные против генератора. Если поддерживаемый графический процессор доступен для расчета, то выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.

dlDuskImage = dlarray(duskTestImage,'SSCB');    
if canUseGPU
    dlDuskImage = gpuArray(dlDuskImage);
end

Переведите входное изображение сумрака в дневную область использование unitPredict функция.

dlDuskToDayImage = unitPredict(gen,dlDuskImage,"OutputType","TargetToSource");
duskToDayImage = extractdata(gather(dlDuskToDayImage));

Для отображения перемасштабируйте активации к области значений [0, 1]. Кроме того, перемасштабируйте входное изображение сумрака перед отображением.

duskToDayImage = rescale(duskToDayImage);
duskTestImage = rescale(duskTestImage);

Отобразите входное изображение сумрака и его переведенную дневную версию в монтаже.

montage({duskTestImage duskToDayImage})
title(['Test Dusk Image ',num2str(idxToTest),' with Translated Day Image'])

Вспомогательные Функции

Функции градиентов модели

modelGradientDisc функция помощника вычисляет градиенты и потерю для этих двух различителей.

function [discAGrads,discBGrads,discALoss,discBLoss] = modelGradientDisc(gen, ...
    discA,discB,ImageA,ImageB,discLossWeight)

    [~,fakeA,fakeB,~] = forward(gen,ImageA,ImageB);
    
    % Calculate loss of the discriminator for X_A
    outA = forward(discA,ImageA); 
    outfA = forward(discA,fakeA);
    discALoss = discLossWeight*computeDiscLoss(outA,outfA);
    
    % Update parameters of the discriminator for X
    discAGrads = dlgradient(discALoss,discA.Learnables); 
    
    % Calculate loss of the discriminator for X_B
    outB = forward(discB,ImageB); 
    outfB = forward(discB,fakeB);
    discBLoss = discLossWeight*computeDiscLoss(outB,outfB);
    
    % Update parameters of the discriminator for Y
    discBGrads = dlgradient(discBLoss,discB.Learnables);
    
    % Convert the data type from dlarray to single
    discALoss = extractdata(discALoss);
    discBLoss = extractdata(discBLoss);
end

modelGradientGen функция помощника вычисляет градиенты и потерю для генератора.

function [genGrad,genLoss,images] = modelGradientGen(gen,discA,discB,ImageA,ImageB,lossWeights)
    
    [ImageAA,ImageBA,ImageAB,ImageBB] = forward(gen,ImageA,ImageB);
    hidden = forward(gen,ImageA,ImageB,'Outputs','encoderSharedBlock');
    
    [~,ImageABA,ImageBAB,~] = forward(gen,ImageBA,ImageAB);
    cycle_hidden = forward(gen,ImageBA,ImageAB,'Outputs','encoderSharedBlock');
    
    % Calculate different losses
    selfReconLoss = computeReconLoss(ImageA,ImageAA) + computeReconLoss(ImageB,ImageBB);
    hiddenKLLoss = computeKLLoss(hidden);
    cycleReconLoss = computeReconLoss(ImageA,ImageABA) + computeReconLoss(ImageB,ImageBAB);
    cycleHiddenKLLoss = computeKLLoss(cycle_hidden);
    
    outA = forward(discA,ImageBA);
    outB = forward(discB,ImageAB);
    advLoss = computeAdvLoss(outA) + computeAdvLoss(outB);
    
    % Calculate the total loss of generator as a weighted sum of five
    % losses
    genTotalLoss = ...
        selfReconLoss*lossWeights.selfReconLossWeight + ...
        hiddenKLLoss*lossWeights.hiddenKLLossWeight + ...
        cycleReconLoss*lossWeights.cycleConsisLossWeight + ...
        cycleHiddenKLLoss*lossWeights.cycleHiddenKLLossWeight + ...
        advLoss*lossWeights.advLossWeight;
    
    % Update the parameters of generator
    genGrad = dlgradient(genTotalLoss,gen.Learnables); 
    
    % Convert the data type from dlarray to single
    genLoss = extractdata(genTotalLoss);
    images = {ImageA,ImageAB,ImageB,ImageBA};
end

Функции потерь

computeDiscLoss функция помощника вычисляет потерю различителя. Каждая потеря различителя является суммой двух компонентов:

  • Различие в квадрате между вектором из единиц и предсказаниями различителя на действительных изображениях, Yreal

  • Различие в квадрате между нулевым вектором и предсказаниями различителя на сгенерированных изображениях, Yˆtranslated

discriminatorLoss=(1-Yreal)2+(0-Yˆtranslated)2

function discLoss = computeDiscLoss(Yreal,Ytranslated)
    discLoss = mean(((1-Yreal).^2),"all") + ...
               mean(((0-Ytranslated).^2),"all");
end

computeAdvLoss функция помощника вычисляет соперничающую потерю для генератора. Соперничающая потеря является различием в квадрате между вектором из единиц и предсказаниями различителя на переведенном изображении.

adversarialLoss=(1-Yˆtranslated)2

function advLoss = computeAdvLoss(Ytranslated)
    advLoss = mean(((Ytranslated-1).^2),"all");
end

computeReconLoss функция помощника вычисляет потерю самореконструкции и потерю непротиворечивости цикла для генератора. Сам потеря реконструкции L1 расстояние между входными изображениями и их самовосстановленными версиями. Потеря непротиворечивости цикла L1 расстояние между входными изображениями и их восстановленными циклом версиями.

selfReconstructionLoss=(Yreal-Yself-reconstructed)1

cycleConsistencyLoss=(Yreal-Ycycle-reconstructed)1

function reconLoss = computeReconLoss(Yreal,Yrecon)
    reconLoss = mean(abs(Yreal-Yrecon),"all");
end

computeKLLoss функция помощника вычисляет скрытую потерю KL и скрытую от цикла потерю KL для генератора. Скрытая потеря KL является различием в квадрате между нулевым вектором и 'encoderSharedBlock'активация для потока самореконструкции. Скрытая от цикла потеря KL является различием в квадрате между нулевым вектором и 'encoderSharedBlock'активация для потока реконструкции цикла.

hiddenKLLoss=(0-YencoderSharedBlockActivation)2

cycleHiddenKLLoss=(0-YencoderSharedBlockActivation)2

function klLoss = computeKLLoss(hidden)
    klLoss = mean(abs(hidden.^2),"all");
end

Ссылки

[1] Лю, Мин-Юй, Томас Бреуель и Ян Коц, "Безнадзорные сети перевода от изображения к изображению". В Усовершенствованиях в Нейронных Системах обработки информации, 2017. https://arxiv.org/abs/1703.00848.

[2] Brostow, Габриэль Дж., Жюльен Фокер и Роберто Сиполья. "Семантические Классы объектов в Видео: База данных Основной истины Высокой четкости". Pattern Recognition Letters. Vol. 30, Issue 2, 2009, стр 88-97.

Смотрите также

| | | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Похожие темы