В этом примере показано, как преобразовать необработанные данные камеры в эстетически приятное цветное изображение с помощью U-Net.
DSLR и многие современные камеры телефона предлагают возможность сохранять данные, собранные непосредственно с датчика камеры, в качестве RAW- файла. Каждый пиксель данных RAW соответствует непосредственно количеству света, захваченному соответствующим фотосенсором камеры. Данные зависят от фиксированных характеристик оборудования камеры, таких как чувствительность к каждому фотосенсору к конкретной области значений длин волн электромагнитного спектра. Данные также зависят от настроек захвата камеры, таких как время экспозиции, и факторов сцены, таких как источник света.
Демозаицирование является единственной необходимой операцией для преобразования одноканальных данных RAW в трехканальное изображение RGB. Однако без дополнительных операций обработки изображений полученное изображение RGB имеет субъективно низкое качество зрения.
Традиционный трубопровод обработки изображений выполняет комбинацию дополнительных операций, включая шумоподавление, линеаризацию, балансировку белого, коррекцию цвета, регулировку яркости и регулировку контрастности [1]. Задача разработки трубопровода заключается в уточнении алгоритмов, чтобы оптимизировать субъективный внешний вид окончательного изображения RGB независимо от изменений в сцене и настройках сбора.
Глубокие методы глубокого обучения позволяют прямое преобразование RAW в RGB без необходимости разработки традиционного конвейера обработки. Для образца один метод компенсирует недооценку при преобразовании изображений RAW в RGB [2]. В этом примере показано, как преобразовать изображения RAW из нижней конечной камеры телефона в изображения RGB, которые аппроксимируют качество камеры DSLR более высокого уровня [3].
В этом примере используется набор данных Zurich RAW to RGB [3]. Размер набора данных составляет 22 ГБ. Набор данных содержит 48 043 пространственно зарегистрированных пар закрашенных фигур обучающих изображений RAW и RGB размера 448 на 448. Набор данных содержит два отдельных тестовых набора. Один тестовый набор состоит из 1 204 пространственно зарегистрированных пар изображений RAW и RGB закрашенных фигур размера 448 на 448. Другой тестовый набор состоит из незарегистрированных изображений RAW и RGB с полным разрешением.
Создайте директорию для хранения набора данных.
imageDir = fullfile(tempdir,'ZurichRAWToRGB'); if ~exist(imageDir,'dir') mkdir(imageDir); end
Чтобы загрузить набор данных, запросите доступ с помощью формы Zurich RAW to RGB dataset. Извлеките данные в директорию, заданную imageDir
переменная. При успешном извлечении imageDir
содержит три директории с именем full_resolution
, test
, и train
.
Создайте imageDatastore
который считывает целевые закрашенные фигуры обучающего изображения RGB, полученные с помощью DSLR Canon высшего класса.
trainImageDir = fullfile(imageDir,'train'); dsTrainRGB = imageDatastore(fullfile(trainImageDir,'canon'),'ReadSize',16);
Предварительный просмотр закрашенной фигуры обучающего изображения RGB.
groundTruthPatch = preview(dsTrainRGB); imshow(groundTruthPatch)
Создайте imageDatastore
который считывает входные обучающие закрашенные фигуры изображение, полученные с помощью камеры телефона Huawei. Изображения RAW получаются с 10-битной точностью и представлены как 8-битными, так и 16-битными файлами PNG. 8-битные файлы обеспечивают компактное представление закрашенных фигур с данными в области значений [0, 255]. Масштабирование не выполнено ни на одном из RAW- данных.
dsTrainRAW = imageDatastore(fullfile(trainImageDir,'huawei_raw'),'ReadSize',16);
Предварительный просмотр входной закрашенной фигуры обучающего изображения RAW. datastore читает эту закрашенную фигуру как 8-битный uint8
изображение, поскольку счетчики датчиков находятся в области значений [0, 255]. Чтобы симулировать 10-битную динамическую область значений обучающих данных, разделите значения интенсивности изображения на 4. Если вы увеличиваете изображение, то вы можете увидеть шаблон RGGB Bayer.
inputPatch = preview(dsTrainRAW); inputPatchRAW = inputPatch/4; imshow(inputPatchRAW)
Чтобы симулировать минимальный традиционный трубопровод обработки, демосапирируйте шаблон RGGB Bayer данных RAW с помощью demosaic
функция. Отобразите обработанное изображение и осветлите отображение. По сравнению с целевым изображением RGB минимально обработанное изображение RGB темное и имеет несбалансированные цвета и заметные программные продукты. Обученная сеть RAW-to-RGB выполняет операции предварительной обработки так, чтобы выходное изображение RGB напоминало целевое изображение.
inputPatchRGB = demosaic(inputPatch,'rggb');
imshow(rescale(inputPatchRGB))
Тестовые данные содержат изображения RAW и RGB закрашенных фигур а также полноразмерные изображения. Этот пример разделяет закрашенные фигуры тестового изображения на наборы валидации и тестовые наборы. Пример использует полноразмерные тестовые изображения только для качественной проверки. См. «Оценка обученного конвейера обработки изображений на полноразмерных изображениях».
Создайте хранилища изображений, которые считывают тестовые закрашенные фигуры RAW и RGB.
testImageDir = fullfile(imageDir,'test'); dsTestRAW = imageDatastore(fullfile(testImageDir,'huawei_raw'),'ReadSize',16); dsTestRGB = imageDatastore(fullfile(testImageDir,'canon'),'ReadSize',16);
Случайным образом разделите тестовые данные на два набора для валидации и обучения. Набор данных валидации содержит 200 изображений. Тестовый набор содержит оставшиеся изображения.
numTestImages = dsTestRAW.numpartitions; numValImages = 200; testIdx = randperm(numTestImages); validationIdx = testIdx(1:numValImages); testIdx = testIdx(numValImages+1:numTestImages); dsValRAW = subset(dsTestRAW,validationIdx); dsValRGB = subset(dsTestRGB,validationIdx); dsTestRAW = subset(dsTestRAW,testIdx); dsTestRGB = subset(dsTestRGB,testIdx);
Датчик получает данные о цвете в повторяющемся шаблоне Байера, который включает в себя один красный, два зеленых и один синий фотосенсор. Предварительно обработайте данные в четырехканальное изображение, ожидаемое от сети, используя transform
функция. The transform
функция обрабатывает данные с помощью операций, заданных в preprocessRAWDataForRAWToRGB
вспомогательная функция. Функция helper присоединена к примеру как вспомогательный файл.
The preprocessRAWDataForRAWToRGB
Функция helper преобразует H-by-W-by-1 изображение RAW в H/2-by-W/2-by-4 многоканальное изображение, состоящее из одного красного, двух зеленых и одного синего канала.
Функция также приводит данные к типу данных single
масштабируется до области значений [0, 1].
dsTrainRAW = transform(dsTrainRAW,@preprocessRAWDataForRAWToRGB); dsValRAW = transform(dsValRAW,@preprocessRAWDataForRAWToRGB); dsTestRAW = transform(dsTestRAW,@preprocessRAWDataForRAWToRGB);
Целевые изображения RGB хранятся на диске в виде неподписанных 8-битных данных. Чтобы сделать расчет метрик и проект сети более удобным, предварительно обработайте целевые обучающие изображения RGB с помощью transform
функции и preprocessRGBDataForRAWToRGB
вспомогательная функция. Функция helper присоединена к примеру как вспомогательный файл.
The preprocessRGBDataForRAWToRGB
функция helper приводит изображения к типу данных single
масштабируется до области значений [0, 1].
dsTrainRGB = transform(dsTrainRGB,@preprocessRGBDataForRAWToRGB); dsValRGB = transform(dsValRGB,@preprocessRGBDataForRAWToRGB);
Объедините входные RAW и целевые данные RGB для наборов изображений для обучения, валидации и тестирования с помощью combine
функция.
dsTrain = combine(dsTrainRAW,dsTrainRGB); dsVal = combine(dsValRAW,dsValRGB); dsTest = combine(dsTestRAW,dsTestRGB);
Случайным образом увеличьте обучающие данные, используя transform
функции и augmentDataForRAWToRGB
вспомогательная функция. Функция helper присоединена к примеру как вспомогательный файл.
The augmentDataForRAWToRGB
Функция helper случайным образом применяет вращение степени 90 и горизонтальное отражение к парам входа RAW и целевых обучающих изображений RGB.
dsTrainAug = transform(dsTrain,@augmentDataForRAWToRGB);
Предварительный просмотр дополненных обучающих данных.
exampleAug = preview(dsTrainAug)
exampleAug=8×2 cell array
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
Отобразите входное и целевое изображения сети в монтаже. Вход сети имеет четыре канала, поэтому отобразите первый канал, переведенный в область значений [0, 1]. Входные RAW и целевые изображения RGB имеют идентичное увеличение.
exampleInput = exampleAug{1,1}; exampleOutput = exampleAug{1,2}; montage({rescale(exampleInput(:,:,1)),exampleOutput})
Этот пример использует пользовательский цикл обучения. The minibatchqueue
Объект (Deep Learning Toolbox) полезен для управления мини-пакетированием наблюдений в пользовательских циклах обучения. The minibatchqueue
объект также переводит данные в dlarray
(Deep Learning Toolbox) объект, который включает автоматическую дифференциацию в приложениях глубокого обучения.
miniBatchSize = 12; valBatchSize = 10; trainingQueue = minibatchqueue(dsTrainAug,'MiniBatchSize',miniBatchSize,'PartialMiniBatch','discard','MiniBatchFormat','SSCB'); validationQueue = minibatchqueue(dsVal,'MiniBatchSize',valBatchSize,'MiniBatchFormat','SSCB');
The next
(Deep Learning Toolbox) функция minibatchqueue
приводит к следующему мини-пакету данных. Предварительный просмотр выходов одного вызова в next
функция. Выходы имеют тип данных dlarray
. Данные уже переведены в gpuArray
, на графический процессор, и готов к обучению.
[inputRAW,targetRGB] = next(trainingQueue); whos inputRAW whos targetRGB
Name Size Bytes Class Attributes targetRGB 448x448x3x12 28901384 dlarray
Этот пример использует изменение сети U-Net. В U-Net начальная серия сверточных слоев перемежается с максимальными слоями объединения, последовательно уменьшая разрешение входного изображения. Эти слои сопровождаются серией сверточных слоев, чередующихся с операторами повышающей дискретизации, последовательно увеличивая разрешение входного изображения. Имя U-Net происходит от того, что сеть может быть нарисована с симметричной формой, такой как буква U.
Этот пример использует простую архитектуру U-Net с двумя модификациями. Во-первых, сеть заменяет окончательную операцию транспонированной свертки пользовательской операцией тасования пикселей с повышенной дискретизацией (также известной как операция «глубина в пространство»). Во-вторых, сеть использует пользовательский слой активации гиперболического тангенса в качестве последнего слоя в сети.
Свертка с последующим повышением дискретизации пикселей может задать субпиксельную свертку для приложений супер разрешения. Субпиксельная свертка предотвращает программные продукты контроля, которые могут возникнуть из-за транспонированной свертки [6]. Поскольку модель должна сопоставить H/2-by-W/2-by-4 входы RAW с W-by-H-by-3 выходами RGB, конечный этап увеличения дискретизации модели может быть рассмотрен аналогично супер- разрешение, где количество пространственных выборок увеличивается от входных до выходных.
Данные показывают, как пиксельная повышающая дискретизация перетасовки работает на вход 2 на 2 на 4. Первые две размерности являются пространственными размерностями, а 3-я размерность является размерностью канала. В целом, перемещение пикселей с повышенной дискретизацией в множителе S принимает вход H-на-W-на-C и приводит к S * H-на-S * W-by- выход.
Функция тасования пикселей увеличивает пространственные размерности выходного сигнала путем отображения информации от размеров канала в заданном пространственном местоположении в пространственные блоки S на S в выходе, в котором каждый канал вносит вклад в согласованное пространственное положение относительно его соседей во время увеличения дискретизации.
Слой активации гиперболического тангенса применяет tanh
функция на входах слоя. Этот пример использует масштабированную и shfited версию tanh
функция, которая поощряет, но не строго следит за тем, чтобы выходы сети RGB находились в области значений [0, 1] [6].
Использование tall
вычислить среднее сокращение по каналам для обучающих данных набора. Уровень входа сети выполняет среднее центрирование входов во время обучения и проверки с помощью средней статистики.
dsIn = copy(dsTrainRAW); dsIn.UnderlyingDatastore.ReadSize = 1; t = tall(dsIn); perChannelMean = gather(mean(t,[1 2]));
Создайте слои начальной подсети, задав среднее значение по каналам.
inputSize = [256 256 4]; initialLayer = imageInputLayer(inputSize,"Normalization","zerocenter","Mean",perChannelMean, ... "Name","ImageInputLayer");
Добавьте слои первой подсети кодирования. Первый энкодер имеет 32 сверточных фильтра.
numEncoderStages = 4; numFiltersFirstEncoder = 32; encoderNamePrefix = "Encoder-Stage-"; encoderLayers = [ convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(encoderNamePrefix,"1-Conv-1")) leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-1")) convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(encoderNamePrefix,"1-Conv-2")) leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-2")) maxPooling2dLayer([2 2],"Stride",[2 2], ... "Name",strcat(encoderNamePrefix,"1-MaxPool")); ];
Добавьте слои дополнительных подсетей кодирования. Эти подсети добавляют нормализацию образца по каналу после каждого сверточного слоя с помощью groupNormalizationLayer
. Каждая подсеть энкодера имеет в два раза больше фильтров, чем предыдущая подсеть энкодера.
cnIdx = 1; for stage = 2:numEncoderStages numFilters = numFiltersFirstEncoder*2^(stage-1); layerNamePrefix = strcat(encoderNamePrefix,num2str(stage)); encoderLayers = [ encoderLayers convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-1")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-2")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2")) maxPooling2dLayer([2 2],"Stride",[2 2],"Name",strcat(layerNamePrefix,"-MaxPool")) ]; cnIdx = cnIdx + 2; end
Добавьте слои моста. Подсеть моста имеет вдвое больше фильтров, чем подсеть конечного энкодера и подсеть первого декодера.
numFilters = numFiltersFirstEncoder*2^numEncoderStages; bridgeLayers = [ convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name","Bridge-Conv-1") groupNormalizationLayer("channel-wise","Name","cn7") leakyReluLayer(0.2,"Name","Bridge-ReLU-1") convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name","Bridge-Conv-2") groupNormalizationLayer("channel-wise","Name","cn8") leakyReluLayer(0.2,"Name","Bridge-ReLU-2")];
Добавьте слои первых трех подсетей декодера.
numDecoderStages = 4; cnIdx = 9; decoderNamePrefix = "Decoder-Stage-"; decoderLayers = []; for stage = 1:numDecoderStages-1 numFilters = numFiltersFirstEncoder*2^(numDecoderStages-stage); layerNamePrefix = strcat(decoderNamePrefix,num2str(stage)); decoderLayers = [ decoderLayers transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-UpConv")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU")) depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-1")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-2")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2")) ]; cnIdx = cnIdx + 2; end
Добавьте слои последней подсети декодера. Эта подсеть исключает нормализацию образца по каналу, выполняемую другими подсетями декодера. Каждая подсеть декодера имеет половину количества фильтров, как предыдущая подсеть.
numFilters = numFiltersFirstEncoder; layerNamePrefix = strcat(decoderNamePrefix,num2str(stage+1)); decoderLayers = [ decoderLayers transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-UpConv")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU")) depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-1")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-2")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))];
Добавьте конечные слои U-Net. Слой тасования пикселей переходит от размера H/2-by-W/2-by-12 канала активаций от конечной свертки к активациям канала H-на-W-на 3 с помощью увеличения дискретизации пикселей. Конечный слой поощряет выходы в желаемую область значений [0, 1], используя гиперболическую функцию тангенса.
finalLayers = [ convolution2dLayer([3 3],12,"Padding","same","WeightsInitializer","narrow-normal", ... "Name","Decoder-Stage-4-Conv-3") pixelShuffleLayer("pixelShuffle",2) tanhScaledAndShiftedLayer("tanhActivation")]; layers = [initialLayer;encoderLayers;bridgeLayers;decoderLayers;finalLayers]; lgraph = layerGraph(layers);
Соединяет слои подсетей кодирования и декодирования.
lgraph = connectLayers(lgraph,"Encoder-Stage-1-ReLU-2","Decoder-Stage-4-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-2-ReLU-2","Decoder-Stage-3-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-3-ReLU-2","Decoder-Stage-2-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-4-ReLU-2","Decoder-Stage-1-DepthConcatenation/in2"); net = dlnetwork(lgraph);
Визуализируйте сетевую архитектуру с помощью приложения Deep Network Designer (Deep Learning Toolbox).
deepNetworkDesigner(lgraph)
Эта функция изменяет предварительно обученную VGG-16 глубокую нейронную сеть, чтобы извлечь функции изображения в различных слоях. Эти многослойные функции используются для вычисления потерь содержимого.
Чтобы получить предварительно обученную VGG-16 сеть, установите vgg16
(Deep Learning Toolbox). Если у вас нет установленного необходимого пакета поддержки, то программное обеспечение предоставляет ссылку на загрузку.
vggNet = vgg16;
Чтобы сделать VGG-16 сеть подходящей для редукции данных, используйте слои до 'relu5 _ 3 '.
vggNet = vggNet.Layers(1:31); vggNet = dlnetwork(layerGraph(vggNet));
Функция помощника modelGradients
вычисляет градиенты и общие потери для пакетов обучающих данных. Эта функция определяется в разделе Вспомогательные функции этого примера.
Общая потеря представляет собой взвешенную сумму двух потерь: средней абсолютной ошибки (MAE) и потери содержимого. Потери содержимого взвешены таким образом, что потери MAE и содержимого способствуют примерно равным образом общим потерям:
Потеря MAE наказывает расстояние между выборками сетевых предсказаний и выборками целевого изображения. часто является лучшим выбором, чем для приложений обработки изображений, потому что это может помочь уменьшить размытие программных продуктов [4]. Эта потеря реализована с помощью maeLoss
вспомогательная функция, заданная в разделе Вспомогательные функции этого примера.
Потеря содержимого помогает сети узнать как высокоуровневое структурное содержание, так и низкоуровневое ребро и цветовую информацию. Функция потерь вычисляет взвешенную сумму средней квадратной ошибки (MSE) между предсказаниями и целями для каждого слоя активации. Эта потеря реализована с помощью contentLoss
вспомогательная функция, заданная в разделе Вспомогательные функции этого примера.
The modelGradients
Функция helper требует, чтобы коэффициент веса потери содержимого был входным параметром. Вычислите весовой коэффициент для выборки пакета обучающих данных таким образом, чтобы потеря MAE равнялась потере взвешенного содержимого.
Предварительный просмотр пакета обучающих данных, который состоит из пар входов RAW-сети и целевых выходов RGB.
trainingBatch = preview(dsTrainAug); networkInput = dlarray((trainingBatch{1,1}),'SSC'); targetOutput = dlarray((trainingBatch{1,2}),'SSC');
Спрогнозируйте ответ нетренированной сети U-Net с помощью forward
(Deep Learning Toolbox) функция.
predictedOutput = forward(net,networkInput);
Вычислите MAE и потери содержимого между предсказанным и целевым изображениями RGB.
sampleMAELoss = maeLoss(predictedOutput,targetOutput); sampleContentLoss = contentLoss(vggNet,predictedOutput,targetOutput);
Вычислите весовой коэффициент.
weightContent = sampleMAELoss/sampleContentLoss;
Задайте опции обучения, которые используются в пользовательском цикле обучения для управления аспектами оптимизации Адама. Обучайте на 20 эпох.
learnRate = 5e-5; numEpochs = 20;
По умолчанию пример загружает предварительно обученную версию сети RAW-to-RGB с помощью функции helper downloadTrainedRAWToRGBNet
. Функция helper присоединена к примеру как вспомогательный файл. Предварительно обученная сеть позволяет запускать весь пример, не дожидаясь завершения обучения.
Чтобы обучить сеть, установите doTraining
переменная в следующем коде, для true
. Обучите модель в пользовательском цикле обучения. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next
(Deep Learning Toolbox) функция.
Оцените градиенты модели с помощью dlfeval
(Deep Learning Toolbox) функцию и modelGradients
вспомогательная функция.
Обновляйте параметры сети с помощью adamupdate
(Deep Learning Toolbox) функцию и информацию о градиенте.
Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.
Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения дополнительной информации смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Обучение занимает около 88 часов на NVIDIA™ Titan RTX и может занять еще больше времени в зависимости от оборудования графического процессора.
doTraining = false; if doTraining % Create a directory to store checkpoints checkpointDir = fullfile(imageDir,'checkpoints',filesep); if ~exist(checkpointDir,'dir') mkdir(checkpointDir); end % Initialize training plot [hFig,batchLine,validationLine] = initializeTrainingPlotRAWPipeline; % Initialize Adam solver state [averageGrad,averageSqGrad] = deal([]); iteration = 0; start = tic; for epoch = 1:numEpochs reset(trainingQueue); shuffle(trainingQueue); while hasdata(trainingQueue) [inputRAW,targetRGB] = next(trainingQueue); [grad,loss] = dlfeval(@modelGradients,net,vggNet,inputRAW,targetRGB,weightContent); iteration = iteration + 1; [net,averageGrad,averageSqGrad] = adamupdate(net, ... grad,averageGrad,averageSqGrad,iteration,learnRate); updateTrainingPlotRAWPipeline(batchLine,validationLine,iteration,loss,start,epoch,... validationQueue,valSetSize,valBatchSize,net,vggNet,weightContent); end % Save checkpoint of network state save(checkpointDir + "epoch" + epoch,'net'); end % Save the final network state save(checkpointDir + "trainedRAWToRGBNet.mat",'net'); else trainedRAWToRGBNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedRAWToRGBNet.mat'; downloadTrainedRAWToRGBNet(trainedRAWToRGBNet_url,imageDir); load(fullfile(imageDir,'trainedRAWToRGBNet.mat')); end
Основанные на ссылках метрики качества, такие как MSSIM или PSNR, обеспечивают количественную меру качества изображения. Можно вычислить MSSIM и PSNR исправленных тестовых изображений, поскольку они пространственно зарегистрированы и имеют одинаковый размер.
Выполните итерацию тестового набора исправленных изображений с помощью minibatchqueue
объект.
patchTestSet = combine(dsTestRAW,dsTestRGB); testPatchQueue = minibatchqueue(patchTestSet,'MiniBatchSize',16,'MiniBatchFormat','SSCB');
Выполните итерацию тестового набора и вычислите MSSIM и PSNR для каждого тестового изображения с помощью multissim
и psnr
функций. Хотя функции принимают изображения RGB, метрики не четко определены для изображений RGB. Поэтому аппроксимируйте MSSIM и PSNR путем вычисления метрики цветовых каналов отдельно. Можно использовать calculateRAWToRGBQualityMetrics
вспомогательная функция для вычисления метрик по каналам. Эта функция присоединена к примеру как вспомогательный файл.
totalMSSIM = 0; totalPSNR = 0; while hasdata(testPatchQueue) [inputRAW,targetRGB] = next(testPatchQueue); outputRGB = forward(net,inputRAW); [mssimOut,psnrOut] = calculateRAWToRGBQualityMetrics(outputRGB,targetRGB); totalMSSIM = totalMSSIM + mssimOut; totalPSNR = totalPSNR + psnrOut; end
Вычислите среднее значение MSSIM и среднее значение PSNR по тестовому набору. Этот результат согласуется с аналогичным подходом U-Net от [3] для среднего MSSIM и конкурентоспособен с подходом PyNet в [3] для среднего PSNR. Различия в функциях потерь и использовании пиксельного тасования с повышенной дискретизацией по сравнению с [3], вероятно, объясняют эти различия.
numObservations = dsTestRGB.numpartitions; meanMSSIM = totalMSSIM / numObservations
meanMSSIM = single
0.8534
meanPSNR = totalPSNR / numObservations
meanPSNR = 21.2956
Из-за различий между камерой телефона и DSLR, используемой для получения тестовых изображений с полным разрешением, сцены не регистрируются и не имеют одинакового размера. Основанное на ссылках сравнение изображений полного разрешения из сети и DSLR ISP трудно. Однако качественное сравнение изображений полезно, потому что целью обработки изображений является создание эстетически приятного изображения.
Создайте изображение datastore, который содержит полноразмерные RAW- изображений, полученные телефонной камерой.
testImageDir = fullfile(imageDir,'test'); testImageDirRAW = "huawei_full_resolution"; dsTestFullRAW = imageDatastore(fullfile(testImageDir,testImageDirRAW));
Получите имена файлов изображений в полноразмерном тестовом наборе RAW.
targetFilesToInclude = extractAfter(string(dsTestFullRAW.Files),fullfile(testImageDirRAW,filesep));
targetFilesToInclude = extractBefore(targetFilesToInclude,".png");
Предварительно обработайте данные RAW путем преобразования данных в форму, ожидаемую сетью с помощью transform
функция. The transform
функция обрабатывает данные с помощью операций, заданных в preprocessRAWDataForRAWToRGB
вспомогательная функция. Функция helper присоединена к примеру как вспомогательный файл.
dsTestFullRAW = transform(dsTestFullRAW,@preprocessRAWDataForRAWToRGB);
Создайте image datastore, который содержит полноразмерные тестовые изображения RGB, полученные из DSLR высшего класса. Набор данных Zurich RAW to RGB содержит более полноразмерные изображения RGB, чем изображения RAW, поэтому включают только изображения RGB с соответствующим изображением RAW.
dsTestFullRGB = imageDatastore(fullfile(imageDir,'full_resolution','canon')); dsTestFullRGB.Files = dsTestFullRGB.Files(contains(dsTestFullRGB.Files,targetFilesToInclude));
Чтение в целевых изображениях RGB. Почувствуйте общий выход, посмотрев на вид монтажа.
targetRGB = readall(dsTestFullRGB); montage(targetRGB,"Size",[5 2],"Interpolation","bilinear")
Выполните итерацию тестового набора полноразмерных изображений с помощью minibatchqueue
объект. Если у вас есть устройство GPU с достаточной памятью для обработки изображений с полным разрешением, то можно запустить предварительное обучение на графическом процессоре, задав выход окружения следующим образом "gpu"
.
testQueue = minibatchqueue(dsTestFullRAW,"MiniBatchSize",1, ... "MiniBatchFormat","SSCB","OutputEnvironment","cpu");
Для каждого полноразмерного тестового изображения RAW спрогнозируйте выходное изображение RGB по вызову forward
(Deep Learning Toolbox) в сети.
outputSize = 2*size(preview(dsTestFullRAW),[1 2]); outputImages = zeros([outputSize,3,dsTestFullRAW.numpartitions],'uint8'); idx = 1; while hasdata(testQueue) inputRAW = next(testQueue); rgbOut = forward(net,inputRAW); rgbOut = gather(extractdata(rgbOut)); outputImages(:,:,:,idx) = im2uint8(rgbOut); idx = idx+1; end
Почувствуйте общий выход, посмотрев на вид монтажа. Сеть производит изображения, которые эстетически приятны, со схожими характеристиками.
montage(outputImages,"Size",[5 2],"Interpolation","bilinear")
Сравните одно целевое изображение RGB с соответствующим изображением, предсказанным сетью. Сеть производит цвета, которые более насыщены, чем целевые изображения DSLR. Несмотря на то, что цвета от простой архитектуры U-Net не совпадают с целями DSLR, изображения во многих случаях по-прежнему качественно радуют.
imgIdx = 1; imTarget = targetRGB{imgIdx}; imPredicted = outputImages(:,:,:,imgIdx); figure montage({imTarget,imPredicted},"Interpolation","bilinear")
Для повышения эффективности сети RAW в RGB сетевая архитектура изучит подробные локализованные пространственные функции с использованием нескольких шкал от глобальных функций, которые описывают цвет и контрастность [3].
The modelGradients
Функция helper вычисляет градиенты и общие потери. Информация о градиенте возвращается как таблица, которая включает слой, имя параметра и значение для каждого настраиваемого параметра в модели.
function [gradients,loss] = modelGradients(dlnet,vggNet,Xpatch,Target,weightContent) Y = forward(dlnet,Xpatch); lossMAE = maeLoss(Y,Target); lossContent = contentLoss(vggNet,Y,Target); loss = lossMAE + weightContent.*lossContent; gradients = dlgradient(loss,dlnet.Learnables); end
Функция помощника maeLoss
вычисляет среднюю абсолютную ошибку между сетевыми предсказаниями, Y
, и целевые изображения, T
.
function loss = maeLoss(Y,T) loss = mean(abs(Y-T),'all'); end
Функция помощника contentLoss
вычисляет взвешенную сумму MSE между сетевыми предсказаниями, Y
, и целевые изображения, T
, для каждого слоя активации. The contentLoss
функция helper вычисляет MSE для каждого слоя активации с помощью mseLoss
вспомогательная функция. Веса выбираются такими, чтобы потери от каждых слоев активации вносили примерно одинаковый вклад в общую потерю содержимого.
function loss = contentLoss(net,Y,T) layers = ["relu1_1","relu1_2","relu2_1","relu2_2","relu3_1","relu3_2","relu3_3","relu4_1"]; [T1,T2,T3,T4,T5,T6,T7,T8] = forward(net,T,'Outputs',layers); [X1,X2,X3,X4,X5,X6,X7,X8] = forward(net,Y,'Outputs',layers); l1 = mseLoss(X1,T1); l2 = mseLoss(X2,T2); l3 = mseLoss(X3,T3); l4 = mseLoss(X4,T4); l5 = mseLoss(X5,T5); l6 = mseLoss(X6,T6); l7 = mseLoss(X7,T7); l8 = mseLoss(X8,T8); layerLosses = [l1 l2 l3 l4 l5 l6 l7 l8]; weights = [1 0.0449 0.0107 0.0023 6.9445e-04 2.0787e-04 2.0118e-04 6.4759e-04]; loss = sum(layerLosses.*weights); end
Функция помощника mseLoss
вычисляет MSE между сетевыми предсказаниями, Y
, и целевые изображения, T
.
function loss = mseLoss(Y,T) loss = mean((Y-T).^2,'all'); end
1) Самнер, Роб. «Обработка RAW- Изображений в MATLAB». 19 мая 2014 года. https://rcsumner.net/raw_guide/RAWguide.pdf
2) Чэнь, Чэнь, Цифэн Чэнь, Цзи Сюй и Владлен Колтун. Учимся видеть в темноте. ArXiv:1805.01934 [Cs], 4 мая 2018 года. http://arxiv.org/abs/1805.01934.
3) Игнатов, Андрей, Люк Ван Голь, и Раду Тимофте. «Замена Mobile Camera ISP на одну модель глубокого обучения». ArXiv:2002.05509 [Cs, Eess], 13 февраля 2020 года. http://arxiv.org/abs/2002.05509. Веб-сайт проекта.
4) Чжао, Ханг, Орацио Галло, Иури Фрозио, и Ян Каутц. Функции потерь для нейронных сетей для обработки изображений. ArXiv:1511.08861 [Cs], 20 апреля 2018 года. http://arxiv.org/abs/1511.08861.
5) Джонсон, Джастин, Александр Алахи и Ли Фэй-Фэй. «Восприятие потерь для передачи стиля в реальном времени и суперразрешение». ArXiv:1603.08155 [Cs], 26 марта 2016 года. http://arxiv.org/abs/1603.08155.
6) Ши, Вэньчжэ, Хосе Кабальеро, Ференц Хуссар, Йоханнес Тоц, Эндрю П. Эйткен, Роб Бишоп, Даниэль Рюкерт, и Зехан Ван. «Одно изображение и видео суперразрешение в реальном времени с использованием эффективной субпиксельной сверточной нейронной сети». ArXiv:1609.05158 [Cs, Stat], 23 сентября 2016 года. http://arxiv.org/abs/1609.05158.
combine
| imageDatastore
| transform
| trainingOptions
(Deep Learning Toolbox) | trainNetwork
(Deep Learning Toolbox)