exponenta event banner

Перевод изображений без присмотра в сумерках с помощью UNIT

В этом примере показано, как выполнять преобразование домена между изображениями, полученными в дневное время, и в сумерках с использованием неподконтрольной сети преобразования изображения в изображение (UNIT).

Перевод домена - задача переноса стилей и характеристик из одного домена изображения в другой. Этот метод может быть расширен для других операций обучения изображению-изображению, таких как улучшение изображения, окрашивание изображения, генерация дефектов и медицинский анализ изображения.

UNIT [1] - тип генеративной состязательной сети (GAN), состоящий из одной генераторной сети и двух дискриминаторных сетей, которые обучаются одновременно для максимизации общей производительности. Дополнительные сведения о UNIT см. в разделе Начало работы с GAN для преобразования образа в образ.

Загрузить набор данных

В этом примере для Набор данных CamVid[2]обучения используется Кембриджского университета. Этот набор данных представляет собой набор из 701 изображений, содержащих виды на уровне улиц, полученные во время вождения.

Загрузите набор данных CamVid с этого URL-адреса. Время загрузки зависит от вашего подключения к Интернету.

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';

dataDir = fullfile(tempdir,'CamVid'); 
downloadCamVidImageData(dataDir,imageURL);
imgDir = fullfile(dataDir,"images","701_StillsRaw_full");

Загрузка данных о дне и сумерках

Набор данных изображения CamVid включает в себя 497 изображений, полученных в дневное время, и 124 изображения, полученных в сумерках. Производительность обучаемой сети UNIT ограничена, поскольку количество обучающих изображений CamVid относительно невелико, что ограничивает производительность обучаемой сети. Кроме того, некоторые изображения принадлежат последовательности изображений и поэтому коррелируются с другими изображениями в наборе данных. Чтобы минимизировать влияние этих ограничений, в этом примере данные вручную разбиваются на наборы обучающих и тестовых данных таким образом, чтобы максимизировать изменчивость обучающих данных.

Получите имена файлов дневных и сумерковых изображений для обучения и тестирования, загрузив файл camvidDayDuskDatasetFileNames.mat. Наборы обучающих данных состоят из 263 дневных изображений и 107 сумерков. Наборы тестовых данных состоят из 234 дневных изображений и 17 сумерков.

load('camvidDayDuskDatasetFileNames.mat');

Создать imageDatastore объекты, управляющие изображениями дня и сумерков для обучения и тестирования.

imdsDayTrain = imageDatastore(fullfile(imgDir,trainDayNames));
imdsDuskTrain = imageDatastore(fullfile(imgDir,trainDuskNames));
imdsDayTest = imageDatastore(fullfile(imgDir,testDayNames));
imdsDuskTest = imageDatastore(fullfile(imgDir,testDuskNames));

Предварительный просмотр обучающего изображения из наборов данных обучения в дневные и сумерки.

day = preview(imdsDayTrain);
dusk = preview(imdsDuskTrain);
montage({day,dusk})

Предварительная обработка и дополнительные учебные данные

Укажите размер входного изображения для исходного и целевого изображений.

inputSize = [256,256,3];

Увеличение и предварительная обработка данных обучения с помощью transform функция с пользовательскими операциями предварительной обработки, заданными функцией помощника augmentDataForDayToDusk. Эта функция присоединена к примеру как вспомогательный файл.

augmentDataForDayToDusk функция выполняет следующие операции:

  1. Измените размер изображения до указанного входного размера с помощью бикубической интерполяции.

  2. Случайное разворот изображения в горизонтальном направлении.

  3. Масштабируйте изображение до диапазона [-1, 1]. Этот диапазон соответствует диапазону конечного tanhLayer (Deep Learning Toolbox) используется в генераторе.

imdsDayTrain = transform(imdsDayTrain, @(x)augmentDataForDayToDusk(x,inputSize));
imdsDuskTrain = transform(imdsDuskTrain, @(x)augmentDataForDayToDusk(x,inputSize));

Создание сети генератора

Создание сети генератора UNIT с помощью unitGenerator функция. Каждая из секций источника и целевого кодера генератора состоит из двух блоков понижающей дискретизации и пяти остаточных блоков. Секции кодера совместно используют два из пяти остаточных блоков. Аналогично, каждая секция исходного и целевого декодера генератора состоит из двух блоков понижающей дискретизации и пяти остаточных блоков, и секции декодера совместно используют два из пяти остаточных блоков.

gen = unitGenerator(inputSize,'NumResidualBlocks',5,'NumSharedBlocks',2);

Визуализация сети генератора.

analyzeNetwork(gen)

Создание дискриминаторных сетей

Существует две дискриминаторные сети, по одной для каждого из доменов изображения (день и сумерки). Создайте дискриминаторы для исходного и целевого доменов с помощью patchGANDiscriminator функция.

discDay = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ...
    "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");
discDusk = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ...
    "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");

Визуализация дискриминаторных сетей.

analyzeNetwork(discDay);
analyzeNetwork(discDusk);

Определение градиентов модели и функций потерь

modelGradientsDisc и modelGradientGen вспомогательные функции вычисляют градиенты и потери для дискриминаторов и генератора соответственно. Эти функции определены в разделе «Вспомогательные функции» данного примера.

Задача каждого дискриминатора состоит в том, чтобы правильно различать реальные изображения (1) и преобразованные изображения (0) для изображений в своей области. Каждый дискриминатор имеет одну функцию потерь.

Целью генератора является формирование преобразованных изображений, которые дискриминаторы классифицируют как реальные. Потери генератора представляют собой взвешенную сумму пяти типов потерь: потери самовосстановления, потери согласованности цикла, скрытые потери KL, потери KL цикла и потери противника.

Укажите весовые коэффициенты для различных потерь.

lossWeights.selfReconLossWeight = 10;
lossWeights.hiddenKLLossWeight = 0.01;
lossWeights.cycleConsisLossWeight = 10;
lossWeights.cycleHiddenKLLossWeight = 0.01;
lossWeights.advLossWeight = 1;
lossWeights.discLossWeight = 0.5;

Укажите параметры обучения

Укажите параметры оптимизации Adam. Тренировать сеть в течение 35 эпох. Укажите идентичные опции для сетей генератора и дискриминатора.

  • Укажите равную скорость обучения 0,0001.

  • Инициализируйте среднюю градиентную скорость и среднюю градиентно-квадратную скорость затухания с помощью [].

  • Используйте градиентный коэффициент затухания 0,5 и квадрат градиентного коэффициента затухания 0,999.

  • Используйте регуляризацию распада веса с коэффициентом 0,0001.

  • Для обучения используйте мини-пакет размером 1.

learnRate = 0.0001;
gradDecay = 0.5;
sqGradDecay = 0.999;
weightDecay = 0.0001;

genAvgGradient = [];
genAvgGradientSq = [];

discDayAvgGradient = [];
discDayAvgGradientSq = [];

discDuskAvgGradient = [];
discDuskAvgGradientSq = [];

miniBatchSize = 1;
numEpochs = 35;

Данные пакетного обучения

Создать minibatchqueue Объект (Deep Learning Toolbox), который управляет мини-пакетами наблюдений в пользовательском цикле обучения. minibatchqueue объект также передает данные в dlarray Объект (Deep Learning Toolbox), обеспечивающий автоматическую дифференциацию в приложениях для глубокого обучения.

Укажите формат извлечения данных мини-партии как SSCB (пространственный, пространственный, канальный, пакетный). Установите DispatchInBackground аргумент name-value как логическое значение, возвращенное canUseGPU. Если поддерживаемый графический процессор доступен для вычисления, то minibatchqueue объект предварительно обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.

mbqDayTrain = minibatchqueue(imdsDayTrain,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
mbqDuskTrain = minibatchqueue(imdsDuskTrain,"MiniBatchSize",miniBatchSize, ...
    "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);

Обучение сети

По умолчанию в примере загружается предварительно подготовленная версия генератора UNIT для набора данных CamVid с помощью вспомогательной функции. downloadTrainedDayDuskGeneratorNet. Вспомогательная функция прикрепляется к примеру как вспомогательный файл. Предварительно обученная сеть позволяет выполнять весь пример без ожидания завершения обучения.

Для обучения сети установите doTraining переменная в следующем коде true. Обучение модели в индивидуальном цикле обучения. Для каждой итерации:

  • Считывание данных для текущего мини-пакета с помощью next (Deep Learning Toolbox).

  • Оцените градиенты модели с помощью dlfeval (Deep Learning Toolbox) и modelGradientsDisc и modelGradientGen вспомогательные функции.

  • Обновление параметров сети с помощью adamupdate (Deep Learning Toolbox).

  • Отображение входных и преобразованных изображений для исходного и целевого доменов после каждой эпохи.

Обучение на GPU, если он доступен. Для использования графического процессора требуются параллельные вычислительные Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ®. Дополнительные сведения см. в разделе Поддержка графического процессора по выпуску (Панель инструментов параллельных вычислений). Обучение занимает около 88 часов на устройстве NVIDIA Titan RTX.

doTraining = false;
if doTraining
    % Create a figure to show the results
    figure("Units","Normalized");
    for iPlot = 1:4
        ax(iPlot) = subplot(2,2,iPlot);
    end
    
    iteration = 0;

    % Loop over epochs
    for epoch = 1:numEpochs
        
        % Shuffle data every epoch
        reset(mbqDayTrain);
        shuffle(mbqDayTrain);
        reset(mbqDuskTrain);
        shuffle(mbqDuskTrain);
        
        % Run the loop until all the images in the mini-batch queue mbqDayTrain are processed
        while hasdata(mbqDayTrain)
            iteration = iteration + 1;
            
            % Read data from the day domain
            imDay = next(mbqDayTrain); 
             
            % Read data from the dusk domain
            if hasdata(mbqDuskTrain) == 0
                reset(mbqDuskTrain);
                shuffle(mbqDuskTrain);
            end
            imDusk = next(mbqDuskTrain);
    
            % Calculate discriminator gradients and losses
            [discDayGrads,discDuskGrads,discDayLoss,disDuskLoss] = dlfeval(@modelGradientDisc, ...
                gen,discDay,discDusk,imDay,imDusk,lossWeights.discLossWeight);
            
            % Apply weight decay regularization on day discriminator gradients
            discDayGrads = dlupdate(@(g,w) g+weightDecay*w,discDayGrads,discDay.Learnables);
            
            % Update parameters of day discriminator
            [discDay,discDayAvgGradient,discDayAvgGradientSq] = adamupdate(discDay,discDayGrads, ...
                discDayAvgGradient,discDayAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);  
            
            % Apply weight decay regularization on dusk discriminator gradients
            discDuskGrads = dlupdate(@(g,w) g+weightDecay*w,discDuskGrads,discDusk.Learnables);
            
            % Update parameters of dusk discriminator
            [discDusk,discDuskAvgGradient,discDuskAvgGradientSq] = adamupdate(discDusk,discDuskGrads, ...
                discDuskAvgGradient,discDuskAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);
            
            % Calculate generator gradient and loss
            [genGrad,genLoss,images] = dlfeval(@modelGradientGen,gen,discDay,discDusk,imDay,imDusk,lossWeights);
            
            % Apply weight decay regularization on generator gradients
            genGrad = dlupdate(@(g,w) g+weightDecay*w,genGrad,gen.Learnables);
            
            % Update parameters of generator
            [gen,genAvgGradient,genAvgGradientSq] = adamupdate(gen,genGrad,genAvgGradient, ...
                genAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay);
        end
        
        % Display the results
        updateTrainingPlotDayToDusk(ax,images{:});
    end
    
    % Save the trained network
    modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss"));
    save(strcat("trainedDayDuskUNITGeneratorNet-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'gen');
    
else    
    net_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedDayDuskUNITGeneratorNet.zip';
    downloadTrainedDayDuskGeneratorNet(net_url,dataDir);
    load(fullfile(dataDir,'trainedDayDuskUNITGeneratorNet.mat'));
end

Вычислить преобразование «источник-цель»

Преобразование «источник-цель» использует генератор UNIT для генерации изображения в целевом (сумерках) домене из изображения в исходном (дневном) домене.

Считывание изображения из хранилища данных дневных тестовых изображений.

idxToTest = 1;
dayTestImage = readimage(imdsDayTest,idxToTest);

Преобразование изображения в тип данных single и нормализовать изображение к диапазону [-1, 1].

dayTestImage = im2single(dayTestImage);
dayTestImage = (dayTestImage-0.5)/0.5;

Создать dlarray объект, который вводит данные в генератор. Если поддерживаемый графический процессор доступен для вычисления, выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.

dlDayImage = dlarray(dayTestImage,'SSCB');    
if canUseGPU
    dlDayImage = gpuArray(dlDayImage);
end

Перевести изображение дня ввода в область сумерков с помощью unitPredict функция.

dlDayToDuskImage = unitPredict(gen,dlDayImage);
dayToDuskImage = extractdata(gather(dlDayToDuskImage));

Конечный слой генераторной сети создает активизации в диапазоне [-1, 1]. Для отображения измените масштаб активизаций на диапазон [0, 1]. Кроме того, перед отображением измените масштаб входного дневного изображения.

dayToDuskImage = rescale(dayToDuskImage);
dayTestImage = rescale(dayTestImage);

Отображение изображения дня ввода и его переведенной версии в сумерках в монтаже.

figure
montage({dayTestImage dayToDuskImage})
title(['Day Test Image ',num2str(idxToTest),' with Translated Dusk Image'])

Оценка преобразования «цель-источник»

При преобразовании целевого изображения в исходное используется генератор UNIT для генерации изображения в исходном (дневном) домене из изображения в целевом (в сумерках) домене.

Считывание изображения из хранилища данных тестовых изображений в сумерках.

idxToTest = 1;
duskTestImage = readimage(imdsDuskTest,idxToTest);

Преобразование изображения в тип данных single и нормализовать изображение к диапазону [-1, 1].

duskTestImage = im2single(duskTestImage);
duskTestImage = (duskTestImage-0.5)/0.5;

Создать dlarray объект, который вводит данные в генератор. Если поддерживаемый графический процессор доступен для вычисления, выполните вывод на графическом процессоре путем преобразования данных в gpuArray объект.

dlDuskImage = dlarray(duskTestImage,'SSCB');    
if canUseGPU
    dlDuskImage = gpuArray(dlDuskImage);
end

Перевести входное изображение в сумерках в дневной домен с помощью unitPredict функция.

dlDuskToDayImage = unitPredict(gen,dlDuskImage,"OutputType","TargetToSource");
duskToDayImage = extractdata(gather(dlDuskToDayImage));

Для отображения измените масштаб активизаций на диапазон [0, 1]. Также перед отображением измените масштаб входного изображения в сумерках.

duskToDayImage = rescale(duskToDayImage);
duskTestImage = rescale(duskTestImage);

Отображение входного изображения в сумерках и его переведенной дневной версии в монтаже.

montage({duskTestImage duskToDayImage})
title(['Test Dusk Image ',num2str(idxToTest),' with Translated Day Image'])

Вспомогательные функции

Функции градиентов модели

modelGradientDisc вспомогательная функция вычисляет градиенты и потери для двух дискриминаторов.

function [discAGrads,discBGrads,discALoss,discBLoss] = modelGradientDisc(gen, ...
    discA,discB,ImageA,ImageB,discLossWeight)

    [~,fakeA,fakeB,~] = forward(gen,ImageA,ImageB);
    
    % Calculate loss of the discriminator for X_A
    outA = forward(discA,ImageA); 
    outfA = forward(discA,fakeA);
    discALoss = discLossWeight*computeDiscLoss(outA,outfA);
    
    % Update parameters of the discriminator for X
    discAGrads = dlgradient(discALoss,discA.Learnables); 
    
    % Calculate loss of the discriminator for X_B
    outB = forward(discB,ImageB); 
    outfB = forward(discB,fakeB);
    discBLoss = discLossWeight*computeDiscLoss(outB,outfB);
    
    % Update parameters of the discriminator for Y
    discBGrads = dlgradient(discBLoss,discB.Learnables);
    
    % Convert the data type from dlarray to single
    discALoss = extractdata(discALoss);
    discBLoss = extractdata(discBLoss);
end

modelGradientGen вспомогательная функция вычисляет градиенты и потери для генератора.

function [genGrad,genLoss,images] = modelGradientGen(gen,discA,discB,ImageA,ImageB,lossWeights)
    
    [ImageAA,ImageBA,ImageAB,ImageBB] = forward(gen,ImageA,ImageB);
    hidden = forward(gen,ImageA,ImageB,'Outputs','encoderSharedBlock');
    
    [~,ImageABA,ImageBAB,~] = forward(gen,ImageBA,ImageAB);
    cycle_hidden = forward(gen,ImageBA,ImageAB,'Outputs','encoderSharedBlock');
    
    % Calculate different losses
    selfReconLoss = computeReconLoss(ImageA,ImageAA) + computeReconLoss(ImageB,ImageBB);
    hiddenKLLoss = computeKLLoss(hidden);
    cycleReconLoss = computeReconLoss(ImageA,ImageABA) + computeReconLoss(ImageB,ImageBAB);
    cycleHiddenKLLoss = computeKLLoss(cycle_hidden);
    
    outA = forward(discA,ImageBA);
    outB = forward(discB,ImageAB);
    advLoss = computeAdvLoss(outA) + computeAdvLoss(outB);
    
    % Calculate the total loss of generator as a weighted sum of five
    % losses
    genTotalLoss = ...
        selfReconLoss*lossWeights.selfReconLossWeight + ...
        hiddenKLLoss*lossWeights.hiddenKLLossWeight + ...
        cycleReconLoss*lossWeights.cycleConsisLossWeight + ...
        cycleHiddenKLLoss*lossWeights.cycleHiddenKLLossWeight + ...
        advLoss*lossWeights.advLossWeight;
    
    % Update the parameters of generator
    genGrad = dlgradient(genTotalLoss,gen.Learnables); 
    
    % Convert the data type from dlarray to single
    genLoss = extractdata(genTotalLoss);
    images = {ImageA,ImageAB,ImageB,ImageBA};
end

Функции потерь

computeDiscLoss вспомогательная функция вычисляет потерю дискриминатора. Каждая потеря дискриминатора представляет собой сумму двух компонентов:

  • Квадрат разности между вектором единиц и предсказаниями дискриминатора на реальных изображениях, Yreal

  • Квадрат разности между вектором нулей и предсказаниями дискриминатора на сгенерированных изображениях, Yˆtranslated

дискриминаторПотеря = (1-Yreal) 2 + (0-Yˆtranslated) 2

function discLoss = computeDiscLoss(Yreal,Ytranslated)
    discLoss = mean(((1-Yreal).^2),"all") + ...
               mean(((0-Ytranslated).^2),"all");
end

computeAdvLoss вспомогательная функция вычисляет состязательные потери для генератора. Состязательная потеря - это квадрат разности между вектором единиц и предсказаниями дискриминатора на преобразованном изображении.

adversarityLoss = (1-Yˆtranslated) 2

function advLoss = computeAdvLoss(Ytranslated)
    advLoss = mean(((Ytranslated-1).^2),"all");
end

computeReconLoss вспомогательная функция вычисляет потери самовосстановления и потери последовательности циклов для генератора. Потеря самовосстановления - это L1 расстояние между входными изображениями и их самовосстановленными версиями. Потеря согласованности циклов - это L1 расстояние между входными изображениями и их восстановленными версиями.

selfLoss = (Yreal-Yself-реконструированный) ‖ 1

cycleConsistencyLoss = (Yreal-Ycycle - реконструирован) ‖ 1

function reconLoss = computeReconLoss(Yreal,Yrecon)
    reconLoss = mean(abs(Yreal-Yrecon),"all");
end

computeKLLoss вспомогательная функция вычисляет скрытые потери KL и циклические потери KL для генератора. Скрытая потеря KL - это квадрат разности между вектором нулей и 'encoderSharedBlock"активация для потока самовосстановления. Скрытая в цикле потеря KL - это квадрат разности между вектором нулей и 'encoderSharedBlock"активация для потока реконструкции цикла.

hiddKLLoss = (0-YencoderSharedBlockActivation) 2

cycleHiddenKLLoss = (0-YencoderSharedBlockActivation) 2

function klLoss = computeKLLoss(hidden)
    klLoss = mean(abs(hidden.^2),"all");
end

Ссылки

[1] Лю, Мин-Ю, Томас Бреуэль и Ян Каутц, «Неподконтрольные сети перевода изображений на изображения». In Advances in Neural Information Processing Systems, 2017. https://arxiv.org/abs/1703.00848.

[2] Бростоу, Габриэль Дж., Жюльен Фокер и Роберто Чиполла. «Классы семантических объектов в видео: база данных истинности земли высокой четкости». Буквы распознавания образов. Том 30, выпуск 2, 2009, стр. 88-97.

См. также

| | | | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения)

Связанные темы