Неконтролируемый перевод изображений день-в-сумерки с использованием МОДУЛЬ

Этот пример показывает, как выполнить преобразование области между изображениями, полученными в дневные и сумеречные условия, используя неконтролируемую сеть преобразования изображения в изображение (UNIT).

Перевод домена - это задача переноса стилей и характеристик из одной области изображений в другую. Этот метод может быть распространен на другие операции обучения между изображениями, такие как улучшение изображения, колоризация изображения, генерация дефектов и анализ медицинских изображений.

UNIT [1] является типом генеративной состязательной сети (GAN), которая состоит из одной сети генератора и двух сетей дискриминатора, которые вы обучаете одновременно, чтобы максимизировать общую эффективность. Для получения дополнительной информации о МОДУЛЕ смотрите Запуск с сетями GAN для преобразования изображения в изображение ( Image Processing Toolbox).

Загрузить набор данных

Этот пример использует Набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных представляет собой набор 701 изображений, содержащих представления уличного уровня, полученные во время вождения.

Загрузите набор данных CamVid с этого URL-адреса. Время загрузки зависит от вашего подключения к Интернету.

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';

dataDir = fullfile(tempdir,'CamVid'); 
downloadCamVidImageData(dataDir,imageURL);
imgDir = fullfile(dataDir,"images","701_StillsRaw_full");

Загрузка данных о дне и сумереках

Набор данных изображений CamVid включает 497 изображений, полученных в дневное время, и 124 изображения, полученные в сумерках. Эффективность обученной сети UNIT ограничена, потому что количество обучающих изображений CamVid относительно мало, что ограничивает эффективность обученной сети. Кроме того, некоторые изображения относятся к последовательности изображений и поэтому коррелируют с другими изображениями в наборе данных. Чтобы минимизировать влияние этих ограничений, этот пример вручную разделяет данные на наборы обучающих и тестовых данных таким образом, чтобы максимизировать изменчивость обучающих данных.

Получите имена файлов изображений day и dusk для обучения и проверки путем загрузки файла 'camvidDayDuskDatasetFileNames.mat'. Обучающие данные состоят из 263 дневных изображений и 107 сумеречных изображений. Наборы тестовых данных состоят из 234 дневных изображений и 17 сумеречных изображений.

load('camvidDayDuskDatasetFileNames.mat');

Создание imageDatastore объекты, которые управляют изображениями дня и сумерек для обучения и проверки.

imdsDayTrain = imageDatastore(fullfile(imgDir,trainDayNames));
imdsDuskTrain = imageDatastore(fullfile(imgDir,trainDuskNames));
imdsDayTest = imageDatastore(fullfile(imgDir,testDayNames));
imdsDuskTest = imageDatastore(fullfile(imgDir,testDuskNames));

Предварительный просмотр обучающего изображения из наборов дневных и сумеречных обучающих данных.

day = preview(imdsDayTrain);
dusk = preview(imdsDuskTrain);
montage({day,dusk})

Предварительная обработка и увеличение обучающих данных

Задайте размер входа изображения для исходного и целевого изображений.

inputSize = [256,256,3];

Увеличение и предварительная обработка обучающих данных с помощью transform функция с пользовательскими операциями предварительной обработки, заданными функцией helper augmentDataForDayToDusk. Эта функция присоединена к примеру как вспомогательный файл.

The augmentDataForDayToDusk функция выполняет следующие операции:

  1. Измените размер изображения на заданный размер входа с помощью бикубической интерполяции.

  2. Случайным образом разверните изображение в горизонтальном направлении.

  3. Масштабируйте изображение до области значений [-1, 1]. Эта область значений соответствует области значений конечных tanhLayer используется в генераторе.

imdsDayTrain = transform(imdsDayTrain, @(x)augmentDataForDayToDusk(x,inputSize));
imdsDuskTrain = transform(imdsDuskTrain, @(x)augmentDataForDayToDusk(x,inputSize));

Создайте сеть генератора

Создайте сеть генератора UNIT с помощью unitGenerator (Image Processing Toolbox) функция. Каждая из секций источника и целевого энкодера генератора состоит из двух блоков понижающей дискретизации и пяти остаточных блоков. Разделы энкодера совместно используют два из пяти остаточных блоков. Точно так же каждая секция исходного и целевого декодера генератора состоит из двух блоков понижающей дискретизации и пяти остаточных блоков, и секции декодера совместно используют два из пяти остаточных блоков.

gen = unitGenerator(inputSize,'NumResidualBlocks',5,'NumSharedBlocks',2);

Визуализируйте сеть генератора.

analyzeNetwork(gen)

Создайте сети дискриминатора

Существует две сети дискриминаторов, по одной для каждого из областей изображений (день и сумерки). Создайте дискриминаторы для исходных и целевых областей, используя patchGANDiscriminator (Image Processing Toolbox) функция.

discDay = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ...
    "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");
discDusk = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ...
    "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");

Визуализируйте сети дискриминатора.

analyzeNetwork(discDay);
analyzeNetwork(discDusk);

Задайте градиенты модели и функции потерь

The modelGradientsDisc и modelGradientGen вспомогательные функции вычисляют градиенты и потери для дискриминаторов и генератора, соответственно. Эти функции определены в разделе Вспомогательные функции этого примера.

Цель каждого дискриминатора состоит в том, чтобы правильно различать реальные изображения (1) и переведенные изображения (0) для изображений в своей области. Каждый дискриминатор имеет одну функцию потерь.

Цель генератора состоит в том, чтобы сгенерировать переведенные изображения, которые дискриминаторы классифицируют как действительные. Потеря генератора - взвешенная сумма пяти типов потерь: потеря самореконструкции, потеря последовательности цикла, скрытая потеря KL, цикл скрытая потеря KL и adverserial потеря.

Укажите весовые коэффициенты для различных потерь.

lossWeights.selfReconLossWeight = 10;
lossWeights.hiddenKLLossWeight = 0.01;
lossWeights.cycleConsisLossWeight = 10;
lossWeights.cycleHiddenKLLossWeight = 0.01;
lossWeights.advLossWeight = 1;
lossWeights.discLossWeight = 0.5;

Настройка опций обучения

Задайте опции для оптимизации Adam. Обучите сеть 35 эпох. Задайте одинаковые опции для сетей генератора и дискриминатора.

  • Задайте равную скорость обучения 0,0001.

  • Инициализируйте конечный средний градиент и конечный средний градиент-квадратные скорости распада с [].

  • Используйте коэффициент градиентного распада 0,5 и квадратный коэффициент градиентного распада 0,999.

  • Используйте регуляризацию распада веса с коэффициентом 0,0001.

  • Используйте мини-пакет размером 1 для обучения.

learnRate = 0.0001;
gradDecay = 0.5;
sqGradDecay = 0.999;
weightDecay = 0.0001;

genAvgGradient = [];
genAvgGradientSq = [];

discDayAvgGradient = [];
discDayAvgGradientSq = [];

discDuskAvgGradient = [];
discDuskAvgGradientSq = [];

miniBatchSize = 1;
numEpochs = 35;

Пакетные обучающие данные

Создайте minibatchqueue объект, который управляет мини-пакетированием наблюдений в пользовательском цикле обучения. The minibatchqueue объект также переводит данные в dlarray объект, который позволяет проводить автоматическую дифференциацию в применениях глубокого обучения.

Задайте формат извлечения данных пакета следующим SSCB (пространственный, пространственный, канальный, пакетный). Установите DispatchInBackground аргумент имя-значение как логический, возвращенный canUseGPU. Если поддерживаемый графический процессор доступен для расчетов, то minibatchqueue объект обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.

mbqDayTrain = minibatchqueue(imdsDayTrain,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
mbqDuskTrain = minibatchqueue(imdsDuskTrain,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);

Обучите сеть

По умолчанию пример загружает предварительно обученную версию генератора UNIT для набора данных CamVid с помощью функции helper downloadTrainedDayDuskGeneratorNet. Функция helper присоединена к примеру как вспомогательный файл. Предварительно обученная сеть позволяет запускать весь пример, не дожидаясь завершения обучения.

Чтобы обучить сеть, установите doTraining переменная в следующем коде, для true. Обучите модель в пользовательском цикле обучения. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next функция.

  • Оцените градиенты модели с помощью dlfeval функции и modelGradientsDisc и modelGradientGen вспомогательные функции.

  • Обновляйте параметры сети с помощью adamupdate функция.

  • Отображение входных и переведенных изображений для исходных и целевых областей после каждой эпохи.

Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения дополнительной информации смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Обучение занимает около 88 часов на NVIDIA Titan RTX.

doTraining = false;
if doTraining
    % Create a figure to show the results
    figure("Units","Normalized");
    for iPlot = 1:4
        ax(iPlot) = subplot(2,2,iPlot);
    end
    
    iteration = 0;

    % Loop over epochs
    for epoch = 1:numEpochs
        
        % Shuffle data every epoch
        reset(mbqDayTrain);
        shuffle(mbqDayTrain);
        reset(mbqDuskTrain);
        shuffle(mbqDuskTrain);
        
        % Run the loop until all the images in the mini-batch queue mbqDayTrain are processed
        while hasdata(mbqDayTrain)
            iteration = iteration + 1;
            
            % Read data from the day domain
            imDay = next(mbqDayTrain); 
             
            % Read data from the dusk domain
            if hasdata(mbqDuskTrain) == 0
                reset(mbqDuskTrain);
                shuffle(mbqDuskTrain);
            end
            imDusk = next(mbqDuskTrain);
    
            % Calculate discriminator gradients and losses
            [discDayGrads,discDuskGrads,discDayLoss,disDuskLoss] = dlfeval(@modelGradientDisc, ...
                gen,discDay,discDusk,imDay,imDusk,lossWeights.discLossWeight);
            
            % Apply weight decay regularization on day discriminator gradients
            discDayGrads = dlupdate(@(g,w) g+weightDecay*w,discDayGrads,discDay.Learnables);
            
            % Update parameters of day discriminator
            [discDay,discDayAvgGradient,discDayAvgGradientSq] = adamupdate(discDay,discDayGrads, ...
                discDayAvgGradient,discDayAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);  
            
            % Apply weight decay regularization on dusk discriminator gradients
            discDuskGrads = dlupdate(@(g,w) g+weightDecay*w,discDuskGrads,discDusk.Learnables);
            
            % Update parameters of dusk discriminator
            [discDusk,discDuskAvgGradient,discDuskAvgGradientSq] = adamupdate(discDusk,discDuskGrads, ...
                discDuskAvgGradient,discDuskAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);
            
            % Calculate generator gradient and loss
            [genGrad,genLoss,images] = dlfeval(@modelGradientGen,gen,discDay,discDusk,imDay,imDusk,lossWeights);
            
            % Apply weight decay regularization on generator gradients
            genGrad = dlupdate(@(g,w) g+weightDecay*w,genGrad,gen.Learnables);
            
            % Update parameters of generator
            [gen,genAvgGradient,genAvgGradientSq] = adamupdate(gen,genGrad,genAvgGradient, ...
                genAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);
        end
        
        % Display the results
        updateTrainingPlotDayToDusk(ax,images{:});
    end
    
    % Save the trained network
    modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss"));
    save(strcat("trainedDayDuskUNITGeneratorNet-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'gen');
    
else    
    net_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedDayDuskUNITGeneratorNet.zip';
    downloadTrainedDayDuskGeneratorNet(net_url,dataDir);
    load(fullfile(dataDir,'trainedDayDuskUNITGeneratorNet.mat'));
end

Вычисление преобразования от источника к целевому

Преобразование изображений от источника к цели использует генератор UNIT, чтобы сгенерировать изображение в целевой области (сумерек) из изображения в исходной области (день).

Считайте изображение из тестовых изображений datastore дня.

idxToTest = 1;
dayTestImage = readimage(imdsDayTest,idxToTest);

Преобразуйте изображение в тип данных single и нормализуйте изображение в области значений [-1, 1].

dayTestImage = im2single(dayTestImage);
dayTestImage = (dayTestImage-0.5)/0.5;

Создайте dlarray объект, который вводит данные в генератор. Если поддерживаемый графический процессор доступен для расчетов, выполните вывод на графическом процессоре, преобразовав данные в gpuArray объект.

dlDayImage = dlarray(dayTestImage,'SSCB');    
if canUseGPU
    dlDayImage = gpuArray(dlDayImage);
end

Переведите изображение входа дня в домен сумерек с помощью unitPredict (Image Processing Toolbox) функция.

dlDayToDuskImage = unitPredict(gen,dlDayImage);
dayToDuskImage = extractdata(gather(dlDayToDuskImage));

Конечный слой сети генератора производит активации в области значений [-1, 1]. Для отображения измените значения активации на область значений [0, 1]. Также перепрофилируйте изображение за вход день до отображения.

dayToDuskImage = rescale(dayToDuskImage);
dayTestImage = rescale(dayTestImage);

Отобразите изображение входа дня и его переведенную версию сумерека в монтаже.

figure
montage({dayTestImage dayToDuskImage})
title(['Day Test Image ',num2str(idxToTest),' with Translated Dusk Image'])

Вычисление преобразования «цель-источник»

Преобразование изображения «цель-источник» использует генератор UNIT, чтобы сгенерировать изображение в исходной (дневной) области из изображения в целевой (сумеречной) области.

Считайте изображение из datastore тестовых изображений сумерек.

idxToTest = 1;
duskTestImage = readimage(imdsDuskTest,idxToTest);

Преобразуйте изображение в тип данных single и нормализуйте изображение в области значений [-1, 1].

duskTestImage = im2single(duskTestImage);
duskTestImage = (duskTestImage-0.5)/0.5;

Создайте dlarray объект, который вводит данные в генератор. Если поддерживаемый графический процессор доступен для расчетов, выполните вывод на графическом процессоре, преобразовав данные в gpuArray объект.

dlDuskImage = dlarray(duskTestImage,'SSCB');    
if canUseGPU
    dlDuskImage = gpuArray(dlDuskImage);
end

Переведите входное изображение сумерек в дневную область с помощью unitPredict (Image Processing Toolbox) функция.

dlDuskToDayImage = unitPredict(gen,dlDuskImage,"OutputType","TargetToSource");
duskToDayImage = extractdata(gather(dlDuskToDayImage));

Для отображения измените значения активации на область значений [0, 1]. Кроме того, перепрофилируйте входное изображение сумматора перед отображением.

duskToDayImage = rescale(duskToDayImage);
duskTestImage = rescale(duskTestImage);

Отобразите входное изображение сумерек и его переведенную дневную версию в монтаже.

montage({duskTestImage duskToDayImage})
title(['Test Dusk Image ',num2str(idxToTest),' with Translated Day Image'])

Вспомогательные функции

Моделируйте функции градиентов

The modelGradientDisc Функция helper вычисляет градиенты и потери для двух дискриминаторов.

function [discAGrads,discBGrads,discALoss,discBLoss] = modelGradientDisc(gen, ...
    discA,discB,ImageA,ImageB,discLossWeight)

    [~,fakeA,fakeB,~] = forward(gen,ImageA,ImageB);
    
    % Calculate loss of the discriminator for X_A
    outA = forward(discA,ImageA); 
    outfA = forward(discA,fakeA);
    discALoss = discLossWeight*computeDiscLoss(outA,outfA);
    
    % Update parameters of the discriminator for X
    discAGrads = dlgradient(discALoss,discA.Learnables); 
    
    % Calculate loss of the discriminator for X_B
    outB = forward(discB,ImageB); 
    outfB = forward(discB,fakeB);
    discBLoss = discLossWeight*computeDiscLoss(outB,outfB);
    
    % Update parameters of the discriminator for Y
    discBGrads = dlgradient(discBLoss,discB.Learnables);
    
    % Convert the data type from dlarray to single
    discALoss = extractdata(discALoss);
    discBLoss = extractdata(discBLoss);
end

The modelGradientGen Функция helper вычисляет градиенты и потери для генератора.

function [genGrad,genLoss,images] = modelGradientGen(gen,discA,discB,ImageA,ImageB,lossWeights)
    
    [ImageAA,ImageBA,ImageAB,ImageBB] = forward(gen,ImageA,ImageB);
    hidden = forward(gen,ImageA,ImageB,'Outputs','encoderSharedBlock');
    
    [~,ImageABA,ImageBAB,~] = forward(gen,ImageBA,ImageAB);
    cycle_hidden = forward(gen,ImageBA,ImageAB,'Outputs','encoderSharedBlock');
    
    % Calculate different losses
    selfReconLoss = computeReconLoss(ImageA,ImageAA) + computeReconLoss(ImageB,ImageBB);
    hiddenKLLoss = computeKLLoss(hidden);
    cycleReconLoss = computeReconLoss(ImageA,ImageABA) + computeReconLoss(ImageB,ImageBAB);
    cycleHiddenKLLoss = computeKLLoss(cycle_hidden);
    
    outA = forward(discA,ImageBA);
    outB = forward(discB,ImageAB);
    advLoss = computeAdvLoss(outA) + computeAdvLoss(outB);
    
    % Calculate the total loss of generator as a weighted sum of five
    % losses
    genTotalLoss = ...
        selfReconLoss*lossWeights.selfReconLossWeight + ...
        hiddenKLLoss*lossWeights.hiddenKLLossWeight + ...
        cycleReconLoss*lossWeights.cycleConsisLossWeight + ...
        cycleHiddenKLLoss*lossWeights.cycleHiddenKLLossWeight + ...
        advLoss*lossWeights.advLossWeight;
    
    % Update the parameters of generator
    genGrad = dlgradient(genTotalLoss,gen.Learnables); 
    
    % Convert the data type from dlarray to single
    genLoss = extractdata(genTotalLoss);
    images = {ImageA,ImageAB,ImageB,ImageBA};
end

Функции потерь

The computeDiscLoss Функция helper вычисляет потери дискриминатора. Каждая потеря дискриминатора является суммой двух компонентов:

  • Квадратное различие между вектором таковых и предсказаниями дискриминатора на реальных изображениях, Yreal

  • Квадратное различие между нулевым вектором и предсказаниями дискриминатора на сгенерированных изображениях, Yˆtranslated

discriminatorLoss=(1-Yreal)2+(0-Yˆtranslated)2

function discLoss = computeDiscLoss(Yreal,Ytranslated)
    discLoss = mean(((1-Yreal).^2),"all") + ...
               mean(((0-Ytranslated).^2),"all");
end

The computeAdvLoss вспомогательная функция вычисляет состязательные потери для генератора. Состязательные потери - это квадратное различие между вектором таковых и предсказаниями дискриминатора на переведенном изображении.

adversarialLoss=(1-Yˆtranslated)2

function advLoss = computeAdvLoss(Ytranslated)
    advLoss = mean(((Ytranslated-1).^2),"all");
end

The computeReconLoss Функция helper вычисляет потери самовосстановления и потери согласованности цикла для генератора. Потеря самовосстановления L1 расстояние между входными изображениями и их самовосстановленными вариантами. Потеря непротиворечивости цикла L1 расстояние между входными изображениями и их восстановленными циклом вариантами.

selfReconstructionLoss=(Yreal-Yself-reconstructed)1

cycleConsistencyLoss=(Yreal-Ycycle-reconstructed)1

function reconLoss = computeReconLoss(Yreal,Yrecon)
    reconLoss = mean(abs(Yreal-Yrecon),"all");
end

The computeKLLoss Функция helper вычисляет скрытые потери KL и скрытые циклом потери KL для генератора. Скрытые потери KL - это квадратное различие между нулевым вектором и 'encoderSharedBlock'активация для потока самовосстановления. Скрытые циклом KL-потери - это квадратное различие между нулевым вектором и 'encoderSharedBlock'активация для потока реконструкции цикла.

hiddenKLLoss=(0-YencoderSharedBlockActivation)2

cycleHiddenKLLoss=(0-YencoderSharedBlockActivation)2

function klLoss = computeKLLoss(hidden)
    klLoss = mean(abs(hidden.^2),"all");
end

Ссылки

[1] Liu, Ming-Yu, Thomas Breuel, and Jan Kautz, «Unsupervised image-to-image translation networks». В Усовершенствования в системах нейронной обработки информации, 2017. https://arxiv.org/abs/1703.00848.

[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. Semantic Object Classes in Video: A High-Definition Ground Truth Database (неопр.) (недоступная ссылка). Распознавание Букв. Том 30, Выпуск 2, 2009, стр. 88-97.

См. также

| | | | | (Image Processing Toolbox) | (Набор Image Processing Toolbox) | (Набор Image Processing Toolbox)

Похожие темы