В этом примере показано, как выполнить доменный перевод между изображениями, полученными во время дневного времени и условий сумрака с помощью безнадзорной сети перевода от изображения к изображению (МОДУЛЬ).
Доменный перевод является задачей передачи стилей и характеристик от одной области изображения до другого. Этот метод может быть расширен к другим операциям изучения от изображения к изображению, таким как повышение качества изображения, колоризация изображений, дефектная генерация и медицинский анализ изображения.
МОДУЛЬ [1] является типом порождающей соперничающей сети (GAN), которая состоит из одной сети генератора и двух сетей различителя, которые вы обучаете одновременно, чтобы максимизировать общую производительность. Для получения дополнительной информации о МОДУЛЕ, смотрите Начало работы с GANs для Перевода От изображения к изображению.
Этот пример использует Набор данных CamVid [2] из Кембриджского университета для обучения. Этот набор данных является набором 701 изображения, содержащего представления уличного уровня, полученные при управлении.
Загрузите набор данных CamVid с этого URL. Время загрузки зависит от вашего интернет-соединения.
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; dataDir = fullfile(tempdir,'CamVid'); downloadCamVidImageData(dataDir,imageURL); imgDir = fullfile(dataDir,"images","701_StillsRaw_full");
Набор данных изображения CamVid включает 497 изображений, полученных в дневное время и 124 изображения, полученные в сумраке. Эффективность обученной сети UNIT ограничивается, потому что количество изображений обучения CamVid относительно мало, который ограничивает эффективность обучившего сеть. Далее, некоторые изображения принадлежат последовательности изображений и поэтому коррелируются с другими изображениями в наборе данных. Чтобы минимизировать удар этих ограничений, этот пример вручную делит данные в обучение и наборы тестовых данных способом, который максимизирует изменчивость обучающих данных.
Получите имена файлов дня и изображений сумрака для обучения и тестирования путем загрузки файла 'camvidDayDuskDatasetFileNames.mat'. Обучающие наборы данных состоят из 263-дневных изображений и 107 изображений сумрака. Наборы тестовых данных состоят из 234-дневных изображений и 17 изображений сумрака.
load('camvidDayDuskDatasetFileNames.mat');
Создайте imageDatastore
объекты, которые управляют днем и изображениями сумрака для обучения и тестирования.
imdsDayTrain = imageDatastore(fullfile(imgDir,trainDayNames)); imdsDuskTrain = imageDatastore(fullfile(imgDir,trainDuskNames)); imdsDayTest = imageDatastore(fullfile(imgDir,testDayNames)); imdsDuskTest = imageDatastore(fullfile(imgDir,testDuskNames));
Предварительно просмотрите учебное изображение с обучающих наборов данных сумрака и дня.
day = preview(imdsDayTrain); dusk = preview(imdsDuskTrain); montage({day,dusk})
Задайте входной размер изображений для входных и выходных изображений.
inputSize = [256,256,3];
Увеличьте и предварительно обработайте обучающие данные при помощи transform
функция с пользовательскими операциями предварительной обработки, заданными помощником, функционирует augmentDataForDayToDusk
Эта функция присоединена к примеру как вспомогательный файл.
augmentDataForDayToDusk
функция выполняет эти операции:
Измените размер изображения к заданному входному размеру с помощью бикубической интерполяции.
Случайным образом инвертируйте изображение в горизонтальном направлении.
Масштабируйте изображение к области значений [-1, 1]. Эта область значений совпадает с областью значений итогового tanhLayer
(Deep Learning Toolbox) используется в генераторе.
imdsDayTrain = transform(imdsDayTrain, @(x)augmentDataForDayToDusk(x,inputSize)); imdsDuskTrain = transform(imdsDuskTrain, @(x)augmentDataForDayToDusk(x,inputSize));
Создайте МОДУЛЬНУЮ сеть генератора использование unitGenerator
функция. Входные и выходные разделы энкодера генератора каждый состоит из двух блоков субдискретизации и пяти остаточных блоков. Разделы энкодера совместно используют два из пяти остаточных блоков. Аналогично, входные и выходные разделы декодера генератора, каждый состоит из двух блоков субдискретизации и пяти остаточных блоков и разделов декодера, совместно используют два из пяти остаточных блоков.
gen = unitGenerator(inputSize,'NumResidualBlocks',5,'NumSharedBlocks',2);
Визуализируйте сеть генератора.
analyzeNetwork(gen)
Существует две сети различителя, один для каждой из областей изображения (день и сумрак). Создайте различители для входных и выходных областей с помощью patchGANDiscriminator
функция.
discDay = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ... "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none"); discDusk = patchGANDiscriminator(inputSize,"NumDownsamplingBlocks",4,"FilterSize",3, ... "ConvolutionWeightsInitializer","narrow-normal","NormalizationLayer","none");
Визуализируйте сети различителя.
analyzeNetwork(discDay); analyzeNetwork(discDusk);
modelGradientsDisc
и modelGradientGen
функции помощника вычисляют градиенты и потери для различителей и генератора, соответственно. Эти функции заданы в разделе Supporting Functions этого примера.
Цель каждого различителя состоит в том, чтобы правильно различать действительные изображения (1) и переведенные изображения (0) для изображений в его области. Каждый различитель имеет одну функцию потерь.
Цель генератора состоит в том, чтобы сгенерировать переведенные изображения, которые различители классифицируют как действительные. Потеря генератора является взвешенной суммой пяти типов потерь: потеря самореконструкции, потеря непротиворечивости цикла, скрытая потеря KL, цикл скрытая потеря KL и adverserial потеря.
Задайте весовые коэффициенты за различные потери.
lossWeights.selfReconLossWeight = 10; lossWeights.hiddenKLLossWeight = 0.01; lossWeights.cycleConsisLossWeight = 10; lossWeights.cycleHiddenKLLossWeight = 0.01; lossWeights.advLossWeight = 1; lossWeights.discLossWeight = 0.5;
Задайте опции для оптимизации Адама. Обучите сеть в течение 35 эпох. Задайте идентичные опции для сетей различителя и генератора.
Задайте равную скорость обучения 0,0001.
Инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с []
.
Используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.
Используйте регуляризацию затухания веса с фактором 0,0001.
Используйте мини-пакетный размер 1 для обучения.
learnRate = 0.0001; gradDecay = 0.5; sqGradDecay = 0.999; weightDecay = 0.0001; genAvgGradient = []; genAvgGradientSq = []; discDayAvgGradient = []; discDayAvgGradientSq = []; discDuskAvgGradient = []; discDuskAvgGradientSq = []; miniBatchSize = 1; numEpochs = 35;
Создайте minibatchqueue
Объект (Deep Learning Toolbox), который справляется с мини-пакетной обработкой наблюдений в пользовательском учебном цикле. minibatchqueue
возразите также бросает данные к dlarray
Объект (Deep Learning Toolbox), который включает автоматическое дифференцирование в применении глубокого обучения.
Задайте мини-пакетный формат экстракции данных как SSCB
(пространственный, пространственный, канал, пакет). Установите DispatchInBackground
аргумент значения имени как булевская переменная, возвращенная canUseGPU
. Если поддерживаемый графический процессор доступен для расчета, то minibatchqueue
объект предварительно обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.
mbqDayTrain = minibatchqueue(imdsDayTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU); mbqDuskTrain = minibatchqueue(imdsDuskTrain,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","DispatchInBackground",canUseGPU);
По умолчанию пример загружает предварительно обученную версию МОДУЛЬНОГО генератора для набора данных CamVid при помощи функции помощника downloadTrainedDayDuskGeneratorNet
. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.
Чтобы обучить сеть, установите doTraining
переменная в следующем коде к true
. Обучите модель в пользовательском учебном цикле. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next
(Deep Learning Toolbox) функция.
Оцените градиенты модели с помощью dlfeval
(Deep Learning Toolbox) функция и modelGradientsDisc
и modelGradientGen
функции помощника.
Обновите сетевые параметры с помощью adamupdate
(Deep Learning Toolbox) функция.
Отобразите вход и переведенные изображения для обоих входные и выходные области после каждой эпохи.
Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Обучение занимает приблизительно 88 часов на Титане NVIDIA RTX.
doTraining = false; if doTraining % Create a figure to show the results figure("Units","Normalized"); for iPlot = 1:4 ax(iPlot) = subplot(2,2,iPlot); end iteration = 0; % Loop over epochs for epoch = 1:numEpochs % Shuffle data every epoch reset(mbqDayTrain); shuffle(mbqDayTrain); reset(mbqDuskTrain); shuffle(mbqDuskTrain); % Run the loop until all the images in the mini-batch queue mbqDayTrain are processed while hasdata(mbqDayTrain) iteration = iteration + 1; % Read data from the day domain imDay = next(mbqDayTrain); % Read data from the dusk domain if hasdata(mbqDuskTrain) == 0 reset(mbqDuskTrain); shuffle(mbqDuskTrain); end imDusk = next(mbqDuskTrain); % Calculate discriminator gradients and losses [discDayGrads,discDuskGrads,discDayLoss,disDuskLoss] = dlfeval(@modelGradientDisc, ... gen,discDay,discDusk,imDay,imDusk,lossWeights.discLossWeight); % Apply weight decay regularization on day discriminator gradients discDayGrads = dlupdate(@(g,w) g+weightDecay*w,discDayGrads,discDay.Learnables); % Update parameters of day discriminator [discDay,discDayAvgGradient,discDayAvgGradientSq] = adamupdate(discDay,discDayGrads, ... discDayAvgGradient,discDayAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); % Apply weight decay regularization on dusk discriminator gradients discDuskGrads = dlupdate(@(g,w) g+weightDecay*w,discDuskGrads,discDusk.Learnables); % Update parameters of dusk discriminator [discDusk,discDuskAvgGradient,discDuskAvgGradientSq] = adamupdate(discDusk,discDuskGrads, ... discDuskAvgGradient,discDuskAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); % Calculate generator gradient and loss [genGrad,genLoss,images] = dlfeval(@modelGradientGen,gen,discDay,discDusk,imDay,imDusk,lossWeights); % Apply weight decay regularization on generator gradients genGrad = dlupdate(@(g,w) g+weightDecay*w,genGrad,gen.Learnables); % Update parameters of generator [gen,genAvgGradient,genAvgGradientSq] = adamupdate(gen,genGrad,genAvgGradient, ... genAvgGradientSq,iteration,learnRate,gradDecay,sqGradDecay); end % Display the results updateTrainingPlotDayToDusk(ax,images{:}); end % Save the trained network modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss")); save(strcat("trainedDayDuskUNITGeneratorNet-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'gen'); else net_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedDayDuskUNITGeneratorNet.zip'; downloadTrainedDayDuskGeneratorNet(net_url,dataDir); load(fullfile(dataDir,'trainedDayDuskUNITGeneratorNet.mat')); end
Перевод источника к целевому изображению использует МОДУЛЬНЫЙ генератор, чтобы сгенерировать изображение в цели (сумрак) область от изображения в источнике (день) область.
Считайте изображение из datastore дневных тестовых изображений.
idxToTest = 1; dayTestImage = readimage(imdsDayTest,idxToTest);
Преобразуйте изображение в тип данных single
и нормируйте изображение к области значений [-1, 1].
dayTestImage = im2single(dayTestImage); dayTestImage = (dayTestImage-0.5)/0.5;
Создайте dlarray
возразите что входные данные против генератора. Если поддерживаемый графический процессор доступен для расчета, то выполните вывод на графическом процессоре путем преобразования данных в gpuArray
объект.
dlDayImage = dlarray(dayTestImage,'SSCB'); if canUseGPU dlDayImage = gpuArray(dlDayImage); end
Переведите входное дневное изображение в область сумрака использование unitPredict
функция.
dlDayToDuskImage = unitPredict(gen,dlDayImage); dayToDuskImage = extractdata(gather(dlDayToDuskImage));
Последний слой сети генератора производит активации в области значений [-1, 1]. Для отображения перемасштабируйте активации к области значений [0, 1]. Кроме того, перемасштабируйте входное дневное изображение перед отображением.
dayToDuskImage = rescale(dayToDuskImage); dayTestImage = rescale(dayTestImage);
Отобразите входное дневное изображение и его переведенную версию сумрака в монтаже.
figure montage({dayTestImage dayToDuskImage}) title(['Day Test Image ',num2str(idxToTest),' with Translated Dusk Image'])
Перевод цели к исходному изображению использует МОДУЛЬНЫЙ генератор, чтобы сгенерировать изображение в источнике (день) область от изображения в цели (сумрак) область.
Считайте изображение из datastore тестовых изображений сумрака.
idxToTest = 1; duskTestImage = readimage(imdsDuskTest,idxToTest);
Преобразуйте изображение в тип данных single
и нормируйте изображение к области значений [-1, 1].
duskTestImage = im2single(duskTestImage); duskTestImage = (duskTestImage-0.5)/0.5;
Создайте dlarray
возразите что входные данные против генератора. Если поддерживаемый графический процессор доступен для расчета, то выполните вывод на графическом процессоре путем преобразования данных в gpuArray
объект.
dlDuskImage = dlarray(duskTestImage,'SSCB'); if canUseGPU dlDuskImage = gpuArray(dlDuskImage); end
Переведите входное изображение сумрака в дневную область использование unitPredict
функция.
dlDuskToDayImage = unitPredict(gen,dlDuskImage,"OutputType","TargetToSource"); duskToDayImage = extractdata(gather(dlDuskToDayImage));
Для отображения перемасштабируйте активации к области значений [0, 1]. Кроме того, перемасштабируйте входное изображение сумрака перед отображением.
duskToDayImage = rescale(duskToDayImage); duskTestImage = rescale(duskTestImage);
Отобразите входное изображение сумрака и его переведенную дневную версию в монтаже.
montage({duskTestImage duskToDayImage}) title(['Test Dusk Image ',num2str(idxToTest),' with Translated Day Image'])
modelGradientDisc
функция помощника вычисляет градиенты и потерю для этих двух различителей.
function [discAGrads,discBGrads,discALoss,discBLoss] = modelGradientDisc(gen, ... discA,discB,ImageA,ImageB,discLossWeight) [~,fakeA,fakeB,~] = forward(gen,ImageA,ImageB); % Calculate loss of the discriminator for X_A outA = forward(discA,ImageA); outfA = forward(discA,fakeA); discALoss = discLossWeight*computeDiscLoss(outA,outfA); % Update parameters of the discriminator for X discAGrads = dlgradient(discALoss,discA.Learnables); % Calculate loss of the discriminator for X_B outB = forward(discB,ImageB); outfB = forward(discB,fakeB); discBLoss = discLossWeight*computeDiscLoss(outB,outfB); % Update parameters of the discriminator for Y discBGrads = dlgradient(discBLoss,discB.Learnables); % Convert the data type from dlarray to single discALoss = extractdata(discALoss); discBLoss = extractdata(discBLoss); end
modelGradientGen
функция помощника вычисляет градиенты и потерю для генератора.
function [genGrad,genLoss,images] = modelGradientGen(gen,discA,discB,ImageA,ImageB,lossWeights) [ImageAA,ImageBA,ImageAB,ImageBB] = forward(gen,ImageA,ImageB); hidden = forward(gen,ImageA,ImageB,'Outputs','encoderSharedBlock'); [~,ImageABA,ImageBAB,~] = forward(gen,ImageBA,ImageAB); cycle_hidden = forward(gen,ImageBA,ImageAB,'Outputs','encoderSharedBlock'); % Calculate different losses selfReconLoss = computeReconLoss(ImageA,ImageAA) + computeReconLoss(ImageB,ImageBB); hiddenKLLoss = computeKLLoss(hidden); cycleReconLoss = computeReconLoss(ImageA,ImageABA) + computeReconLoss(ImageB,ImageBAB); cycleHiddenKLLoss = computeKLLoss(cycle_hidden); outA = forward(discA,ImageBA); outB = forward(discB,ImageAB); advLoss = computeAdvLoss(outA) + computeAdvLoss(outB); % Calculate the total loss of generator as a weighted sum of five % losses genTotalLoss = ... selfReconLoss*lossWeights.selfReconLossWeight + ... hiddenKLLoss*lossWeights.hiddenKLLossWeight + ... cycleReconLoss*lossWeights.cycleConsisLossWeight + ... cycleHiddenKLLoss*lossWeights.cycleHiddenKLLossWeight + ... advLoss*lossWeights.advLossWeight; % Update the parameters of generator genGrad = dlgradient(genTotalLoss,gen.Learnables); % Convert the data type from dlarray to single genLoss = extractdata(genTotalLoss); images = {ImageA,ImageAB,ImageB,ImageBA}; end
computeDiscLoss
функция помощника вычисляет потерю различителя. Каждая потеря различителя является суммой двух компонентов:
Различие в квадрате между вектором из единиц и предсказаниями различителя на действительных изображениях,
Различие в квадрате между нулевым вектором и предсказаниями различителя на сгенерированных изображениях,
function discLoss = computeDiscLoss(Yreal,Ytranslated) discLoss = mean(((1-Yreal).^2),"all") + ... mean(((0-Ytranslated).^2),"all"); end
computeAdvLoss
функция помощника вычисляет соперничающую потерю для генератора. Соперничающая потеря является различием в квадрате между вектором из единиц и предсказаниями различителя на переведенном изображении.
function advLoss = computeAdvLoss(Ytranslated) advLoss = mean(((Ytranslated-1).^2),"all"); end
computeReconLoss
функция помощника вычисляет потерю самореконструкции и потерю непротиворечивости цикла для генератора. Сам потеря реконструкции расстояние между входными изображениями и их самовосстановленными версиями. Потеря непротиворечивости цикла расстояние между входными изображениями и их восстановленными циклом версиями.
function reconLoss = computeReconLoss(Yreal,Yrecon) reconLoss = mean(abs(Yreal-Yrecon),"all"); end
computeKLLoss
функция помощника вычисляет скрытую потерю KL и скрытую от цикла потерю KL для генератора. Скрытая потеря KL является различием в квадрате между нулевым вектором и 'encoderSharedBlock
'активация для потока самореконструкции. Скрытая от цикла потеря KL является различием в квадрате между нулевым вектором и 'encoderSharedBlock
'активация для потока реконструкции цикла.
function klLoss = computeKLLoss(hidden) klLoss = mean(abs(hidden.^2),"all"); end
[1] Лю, Мин-Юй, Томас Бреуель и Ян Коц, "Безнадзорные сети перевода от изображения к изображению". В Усовершенствованиях в Нейронных Системах обработки информации, 2017. https://arxiv.org/abs/1703.00848.
[2] Brostow, Габриэль Дж., Жюльен Фокер и Роберто Сиполья. "Семантические Классы объектов в Видео: База данных Основной истины Высокой четкости". Pattern Recognition Letters. Vol. 30, Issue 2, 2009, стр 88-97.
patchGANDiscriminator
| transform
| unitGenerator
| unitPredict
| adamupdate
(Deep Learning Toolbox) | dlarray
(Deep Learning Toolbox) | dlfeval
(Deep Learning Toolbox) | minibatchqueue
(Deep Learning Toolbox)