Этот пример показывает, как выполнить преобразование области между изображениями, полученными в дневные и сумеречные условия, используя неконтролируемую сеть преобразования изображения в изображение (UNIT).
Перевод домена - это задача переноса стилей и характеристик из одной области изображений в другую. Этот метод может быть распространен на другие операции обучения между изображениями, такие как улучшение изображения, колоризация изображения, генерация дефектов и анализ медицинских изображений.
UNIT [1] является типом генеративной состязательной сети (GAN), которая состоит из одной сети генератора и двух сетей дискриминатора, которые вы обучаете одновременно, чтобы максимизировать общую эффективность. Для получения дополнительной информации о МОДУЛЕ смотрите Запуск с сетями GAN для преобразования изображения в изображение ( Image Processing Toolbox).
Этот пример использует Набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных представляет собой набор 701 изображений, содержащих представления уличного уровня, полученные во время вождения.
Загрузите набор данных CamVid с этого URL-адреса. Время загрузки зависит от вашего подключения к Интернету.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; dataDir = fullfile(tempdir,'CamVid'); downloadCamVidImageData(dataDir,imageURL); imgDir = fullfile(dataDir,"images","701_StillsRaw_full");
Набор данных изображений CamVid включает 497 изображений, полученных в дневное время, и 124 изображения, полученные в сумерках. Эффективность обученной сети UNIT ограничена, потому что количество обучающих изображений CamVid относительно мало, что ограничивает эффективность обученной сети. Кроме того, некоторые изображения относятся к последовательности изображений и поэтому коррелируют с другими изображениями в наборе данных. Чтобы минимизировать влияние этих ограничений, этот пример вручную разделяет данные на наборы обучающих и тестовых данных таким образом, чтобы максимизировать изменчивость обучающих данных.
Получите имена файлов изображений day и dusk для обучения и проверки путем загрузки файла 'camvidDayDuskDatasetFileNames.mat'. Обучающие данные состоят из 263 дневных изображений и 107 сумеречных изображений. Наборы тестовых данных состоят из 234 дневных изображений и 17 сумеречных изображений.
load('camvidDayDuskDatasetFileNames.mat');
Создание imageDatastore
объекты, которые управляют изображениями дня и сумерек для обучения и проверки.
imdsDayTrain = imageDatastore(fullfile(imgDir,trainDayNames)); imdsDuskTrain = imageDatastore(fullfile(imgDir,trainDuskNames)); imdsDayTest = imageDatastore(fullfile(imgDir,testDayNames)); imdsDuskTest = imageDatastore(fullfile(imgDir,testDuskNames));
Предварительный просмотр обучающего изображения из наборов дневных и сумеречных обучающих данных.
day = preview(imdsDayTrain); dusk = preview(imdsDuskTrain); montage({day,dusk})
Задайте размер входа изображения для исходного и целевого изображений.
inputSize = [256,256,3];
Увеличение и предварительная обработка обучающих данных с помощью transform
функция с пользовательскими операциями предварительной обработки, заданными функцией helper augmentDataForDayToDusk
. Эта функция присоединена к примеру как вспомогательный файл.
The augmentDataForDayToDusk
функция выполняет следующие операции:
Измените размер изображения на заданный размер входа с помощью бикубической интерполяции.
Случайным образом разверните изображение в горизонтальном направлении.
Масштабируйте изображение до области значений [-1, 1]. Эта область значений соответствует области значений конечных tanhLayer
используется в генераторе.
imdsDayTrain = transform(imdsDayTrain, @(x)augmentDataForDayToDusk(x,inputSize)); imdsDuskTrain = transform(imdsDuskTrain, @(x)augmentDataForDayToDusk(x,inputSize));
Создайте сеть генератора UNIT с помощью unitGenerator
(Image Processing Toolbox) функция. Каждая из секций источника и целевого энкодера генератора состоит из двух блоков понижающей дискретизации и пяти остаточных блоков. Разделы энкодера совместно используют два из пяти остаточных блоков. Точно так же каждая секция исходного и целевого декодера генератора состоит из двух блоков понижающей дискретизации и пяти остаточных блоков, и секции декодера совместно используют два из пяти остаточных блоков.
gen = unitGenerator(inputSize,'NumResidualBlocks',5,'NumSharedBlocks',2);
Визуализируйте сеть генератора.
analyzeNetwork(gen)
Существует две сети дискриминаторов, по одной для каждого из областей изображений (день и сумерки). Создайте дискриминаторы для исходных и целевых областей, используя patchGANDiscriminator
(Image Processing Toolbox) функция.
discDay = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ... "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none"); discDusk = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ... "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");
Визуализируйте сети дискриминатора.
analyzeNetwork(discDay); analyzeNetwork(discDusk);
The modelGradientsDisc
и modelGradientGen
вспомогательные функции вычисляют градиенты и потери для дискриминаторов и генератора, соответственно. Эти функции определены в разделе Вспомогательные функции этого примера.
Цель каждого дискриминатора состоит в том, чтобы правильно различать реальные изображения (1) и переведенные изображения (0) для изображений в своей области. Каждый дискриминатор имеет одну функцию потерь.
Цель генератора состоит в том, чтобы сгенерировать переведенные изображения, которые дискриминаторы классифицируют как действительные. Потеря генератора - взвешенная сумма пяти типов потерь: потеря самореконструкции, потеря последовательности цикла, скрытая потеря KL, цикл скрытая потеря KL и adverserial потеря.
Укажите весовые коэффициенты для различных потерь.
lossWeights.selfReconLossWeight = 10; lossWeights.hiddenKLLossWeight = 0.01; lossWeights.cycleConsisLossWeight = 10; lossWeights.cycleHiddenKLLossWeight = 0.01; lossWeights.advLossWeight = 1; lossWeights.discLossWeight = 0.5;
Задайте опции для оптимизации Adam. Обучите сеть 35 эпох. Задайте одинаковые опции для сетей генератора и дискриминатора.
Задайте равную скорость обучения 0,0001.
Инициализируйте конечный средний градиент и конечный средний градиент-квадратные скорости распада с []
.
Используйте коэффициент градиентного распада 0,5 и квадратный коэффициент градиентного распада 0,999.
Используйте регуляризацию распада веса с коэффициентом 0,0001.
Используйте мини-пакет размером 1 для обучения.
learnRate = 0.0001; gradDecay = 0.5; sqGradDecay = 0.999; weightDecay = 0.0001; genAvgGradient = []; genAvgGradientSq = []; discDayAvgGradient = []; discDayAvgGradientSq = []; discDuskAvgGradient = []; discDuskAvgGradientSq = []; miniBatchSize = 1; numEpochs = 35;
Создайте minibatchqueue
объект, который управляет мини-пакетированием наблюдений в пользовательском цикле обучения. The minibatchqueue
объект также переводит данные в dlarray
объект, который позволяет проводить автоматическую дифференциацию в применениях глубокого обучения.
Задайте формат извлечения данных пакета следующим SSCB
(пространственный, пространственный, канальный, пакетный). Установите DispatchInBackground
аргумент имя-значение как логический, возвращенный canUseGPU
. Если поддерживаемый графический процессор доступен для расчетов, то minibatchqueue
объект обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.
mbqDayTrain = minibatchqueue(imdsDayTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU); mbqDuskTrain = minibatchqueue(imdsDuskTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
По умолчанию пример загружает предварительно обученную версию генератора UNIT для набора данных CamVid с помощью функции helper downloadTrainedDayDuskGeneratorNet
. Функция helper присоединена к примеру как вспомогательный файл. Предварительно обученная сеть позволяет запускать весь пример, не дожидаясь завершения обучения.
Чтобы обучить сеть, установите doTraining
переменная в следующем коде, для true
. Обучите модель в пользовательском цикле обучения. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next
функция.
Оцените градиенты модели с помощью dlfeval
функции и modelGradientsDisc
и modelGradientGen
вспомогательные функции.
Обновляйте параметры сети с помощью adamupdate
функция.
Отображение входных и переведенных изображений для исходных и целевых областей после каждой эпохи.
Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения дополнительной информации смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Обучение занимает около 88 часов на NVIDIA Titan RTX.
doTraining = false; if doTraining % Create a figure to show the results figure("Units","Normalized"); for iPlot = 1:4 ax(iPlot) = subplot(2,2,iPlot); end iteration = 0; % Loop over epochs for epoch = 1:numEpochs % Shuffle data every epoch reset(mbqDayTrain); shuffle(mbqDayTrain); reset(mbqDuskTrain); shuffle(mbqDuskTrain); % Run the loop until all the images in the mini-batch queue mbqDayTrain are processed while hasdata(mbqDayTrain) iteration = iteration + 1; % Read data from the day domain imDay = next(mbqDayTrain); % Read data from the dusk domain if hasdata(mbqDuskTrain) == 0 reset(mbqDuskTrain); shuffle(mbqDuskTrain); end imDusk = next(mbqDuskTrain); % Calculate discriminator gradients and losses [discDayGrads,discDuskGrads,discDayLoss,disDuskLoss] = dlfeval(@modelGradientDisc, ... gen,discDay,discDusk,imDay,imDusk,lossWeights.discLossWeight); % Apply weight decay regularization on day discriminator gradients discDayGrads = dlupdate(@(g,w) g+weightDecay*w,discDayGrads,discDay.Learnables); % Update parameters of day discriminator [discDay,discDayAvgGradient,discDayAvgGradientSq] = adamupdate(discDay,discDayGrads, ... discDayAvgGradient,discDayAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); % Apply weight decay regularization on dusk discriminator gradients discDuskGrads = dlupdate(@(g,w) g+weightDecay*w,discDuskGrads,discDusk.Learnables); % Update parameters of dusk discriminator [discDusk,discDuskAvgGradient,discDuskAvgGradientSq] = adamupdate(discDusk,discDuskGrads, ... discDuskAvgGradient,discDuskAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); % Calculate generator gradient and loss [genGrad,genLoss,images] = dlfeval(@modelGradientGen,gen,discDay,discDusk,imDay,imDusk,lossWeights); % Apply weight decay regularization on generator gradients genGrad = dlupdate(@(g,w) g+weightDecay*w,genGrad,gen.Learnables); % Update parameters of generator [gen,genAvgGradient,genAvgGradientSq] = adamupdate(gen,genGrad,genAvgGradient, ... genAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); end % Display the results updateTrainingPlotDayToDusk(ax,images{:}); end % Save the trained network modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss")); save(strcat("trainedDayDuskUNITGeneratorNet-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'gen'); else net_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedDayDuskUNITGeneratorNet.zip'; downloadTrainedDayDuskGeneratorNet(net_url,dataDir); load(fullfile(dataDir,'trainedDayDuskUNITGeneratorNet.mat')); end
Преобразование изображений от источника к цели использует генератор UNIT, чтобы сгенерировать изображение в целевой области (сумерек) из изображения в исходной области (день).
Считайте изображение из тестовых изображений datastore дня.
idxToTest = 1; dayTestImage = readimage(imdsDayTest,idxToTest);
Преобразуйте изображение в тип данных single
и нормализуйте изображение в области значений [-1, 1].
dayTestImage = im2single(dayTestImage); dayTestImage = (dayTestImage-0.5)/0.5;
Создайте dlarray
объект, который вводит данные в генератор. Если поддерживаемый графический процессор доступен для расчетов, выполните вывод на графическом процессоре, преобразовав данные в gpuArray
объект.
dlDayImage = dlarray(dayTestImage,'SSCB'); if canUseGPU dlDayImage = gpuArray(dlDayImage); end
Переведите изображение входа дня в домен сумерек с помощью unitPredict
(Image Processing Toolbox) функция.
dlDayToDuskImage = unitPredict(gen,dlDayImage); dayToDuskImage = extractdata(gather(dlDayToDuskImage));
Конечный слой сети генератора производит активации в области значений [-1, 1]. Для отображения измените значения активации на область значений [0, 1]. Также перепрофилируйте изображение за вход день до отображения.
dayToDuskImage = rescale(dayToDuskImage); dayTestImage = rescale(dayTestImage);
Отобразите изображение входа дня и его переведенную версию сумерека в монтаже.
figure montage({dayTestImage dayToDuskImage}) title(['Day Test Image ',num2str(idxToTest),' with Translated Dusk Image'])
Преобразование изображения «цель-источник» использует генератор UNIT, чтобы сгенерировать изображение в исходной (дневной) области из изображения в целевой (сумеречной) области.
Считайте изображение из datastore тестовых изображений сумерек.
idxToTest = 1; duskTestImage = readimage(imdsDuskTest,idxToTest);
Преобразуйте изображение в тип данных single
и нормализуйте изображение в области значений [-1, 1].
duskTestImage = im2single(duskTestImage); duskTestImage = (duskTestImage-0.5)/0.5;
Создайте dlarray
объект, который вводит данные в генератор. Если поддерживаемый графический процессор доступен для расчетов, выполните вывод на графическом процессоре, преобразовав данные в gpuArray
объект.
dlDuskImage = dlarray(duskTestImage,'SSCB'); if canUseGPU dlDuskImage = gpuArray(dlDuskImage); end
Переведите входное изображение сумерек в дневную область с помощью unitPredict
(Image Processing Toolbox) функция.
dlDuskToDayImage = unitPredict(gen,dlDuskImage,"OutputType","TargetToSource"); duskToDayImage = extractdata(gather(dlDuskToDayImage));
Для отображения измените значения активации на область значений [0, 1]. Кроме того, перепрофилируйте входное изображение сумматора перед отображением.
duskToDayImage = rescale(duskToDayImage); duskTestImage = rescale(duskTestImage);
Отобразите входное изображение сумерек и его переведенную дневную версию в монтаже.
montage({duskTestImage duskToDayImage}) title(['Test Dusk Image ',num2str(idxToTest),' with Translated Day Image'])
The modelGradientDisc
Функция helper вычисляет градиенты и потери для двух дискриминаторов.
function [discAGrads,discBGrads,discALoss,discBLoss] = modelGradientDisc(gen, ... discA,discB,ImageA,ImageB,discLossWeight) [~,fakeA,fakeB,~] = forward(gen,ImageA,ImageB); % Calculate loss of the discriminator for X_A outA = forward(discA,ImageA); outfA = forward(discA,fakeA); discALoss = discLossWeight*computeDiscLoss(outA,outfA); % Update parameters of the discriminator for X discAGrads = dlgradient(discALoss,discA.Learnables); % Calculate loss of the discriminator for X_B outB = forward(discB,ImageB); outfB = forward(discB,fakeB); discBLoss = discLossWeight*computeDiscLoss(outB,outfB); % Update parameters of the discriminator for Y discBGrads = dlgradient(discBLoss,discB.Learnables); % Convert the data type from dlarray to single discALoss = extractdata(discALoss); discBLoss = extractdata(discBLoss); end
The modelGradientGen
Функция helper вычисляет градиенты и потери для генератора.
function [genGrad,genLoss,images] = modelGradientGen(gen,discA,discB,ImageA,ImageB,lossWeights) [ImageAA,ImageBA,ImageAB,ImageBB] = forward(gen,ImageA,ImageB); hidden = forward(gen,ImageA,ImageB,'Outputs','encoderSharedBlock'); [~,ImageABA,ImageBAB,~] = forward(gen,ImageBA,ImageAB); cycle_hidden = forward(gen,ImageBA,ImageAB,'Outputs','encoderSharedBlock'); % Calculate different losses selfReconLoss = computeReconLoss(ImageA,ImageAA) + computeReconLoss(ImageB,ImageBB); hiddenKLLoss = computeKLLoss(hidden); cycleReconLoss = computeReconLoss(ImageA,ImageABA) + computeReconLoss(ImageB,ImageBAB); cycleHiddenKLLoss = computeKLLoss(cycle_hidden); outA = forward(discA,ImageBA); outB = forward(discB,ImageAB); advLoss = computeAdvLoss(outA) + computeAdvLoss(outB); % Calculate the total loss of generator as a weighted sum of five % losses genTotalLoss = ... selfReconLoss*lossWeights.selfReconLossWeight + ... hiddenKLLoss*lossWeights.hiddenKLLossWeight + ... cycleReconLoss*lossWeights.cycleConsisLossWeight + ... cycleHiddenKLLoss*lossWeights.cycleHiddenKLLossWeight + ... advLoss*lossWeights.advLossWeight; % Update the parameters of generator genGrad = dlgradient(genTotalLoss,gen.Learnables); % Convert the data type from dlarray to single genLoss = extractdata(genTotalLoss); images = {ImageA,ImageAB,ImageB,ImageBA}; end
The computeDiscLoss
Функция helper вычисляет потери дискриминатора. Каждая потеря дискриминатора является суммой двух компонентов:
Квадратное различие между вектором таковых и предсказаниями дискриминатора на реальных изображениях,
Квадратное различие между нулевым вектором и предсказаниями дискриминатора на сгенерированных изображениях,
function discLoss = computeDiscLoss(Yreal,Ytranslated) discLoss = mean(((1-Yreal).^2),"all") + ... mean(((0-Ytranslated).^2),"all"); end
The computeAdvLoss
вспомогательная функция вычисляет состязательные потери для генератора. Состязательные потери - это квадратное различие между вектором таковых и предсказаниями дискриминатора на переведенном изображении.
function advLoss = computeAdvLoss(Ytranslated) advLoss = mean(((Ytranslated-1).^2),"all"); end
The computeReconLoss
Функция helper вычисляет потери самовосстановления и потери согласованности цикла для генератора. Потеря самовосстановления расстояние между входными изображениями и их самовосстановленными вариантами. Потеря непротиворечивости цикла расстояние между входными изображениями и их восстановленными циклом вариантами.
function reconLoss = computeReconLoss(Yreal,Yrecon) reconLoss = mean(abs(Yreal-Yrecon),"all"); end
The computeKLLoss
Функция helper вычисляет скрытые потери KL и скрытые циклом потери KL для генератора. Скрытые потери KL - это квадратное различие между нулевым вектором и 'encoderSharedBlock
'активация для потока самовосстановления. Скрытые циклом KL-потери - это квадратное различие между нулевым вектором и 'encoderSharedBlock
'активация для потока реконструкции цикла.
function klLoss = computeKLLoss(hidden) klLoss = mean(abs(hidden.^2),"all"); end
[1] Liu, Ming-Yu, Thomas Breuel, and Jan Kautz, «Unsupervised image-to-image translation networks». В Усовершенствования в системах нейронной обработки информации, 2017. https://arxiv.org/abs/1703.00848.
[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. Semantic Object Classes in Video: A High-Definition Ground Truth Database (неопр.) (недоступная ссылка). Распознавание Букв. Том 30, Выпуск 2, 2009, стр. 88-97.
adamupdate
| dlarray
| dlfeval
| minibatchqueue
| transform
| patchGANDiscriminator
(Image Processing Toolbox) | unitGenerator
(Набор Image Processing Toolbox) | unitPredict
(Набор Image Processing Toolbox)