В этом примере показано, как преобразовать необработанные данные камеры в эстетически приятное цветное изображение с помощью U-Net.
DSLR и многие современные телефонные камеры предлагают возможность сохранять данные, собранные непосредственно с датчика камеры, в виде файла RAW. Каждый пиксель данных RAW непосредственно соответствует количеству света, захваченного соответствующим фотодатчиком камеры. Данные зависят от фиксированных характеристик аппаратных средств камеры, таких как чувствительность каждого фотодатчика к определенному диапазону длин волн электромагнитного спектра. Данные также зависят от настроек захвата камеры, таких как время экспозиции, и факторов сцены, таких как источник света.
Демосейсинг является единственной операцией, необходимой для преобразования одноканальных RAW-данных в трехканальный RGB-образ. Однако без дополнительных операций обработки изображения результирующее изображение RGB имеет субъективно низкое качество изображения.
Традиционный конвейер обработки изображений выполняет комбинацию дополнительных операций, включая деноизирование, линеаризацию, балансировку белого, коррекцию цвета, регулировку яркости и регулировку контраста [1]. Задача проектирования конвейера заключается в уточнении алгоритмов для оптимизации субъективного внешнего вида конечного изображения RGB независимо от вариаций в сцене и настройках обнаружения.

Методы глубокого обучения обеспечивают прямое преобразование RAW в RGB без необходимости разработки традиционного конвейера обработки. Например, один метод компенсирует неполноту при преобразовании RAW-изображений в RGB [2]. В этом примере показано, как преобразовать изображения RAW из камеры телефона нижнего конца в изображения RGB, которые приближаются к качеству камеры DSLR верхнего конца [3].
В этом примере используется набор данных Zurich RAW to RGB [3]. Размер набора данных составляет 22 ГБ. Набор данных содержит 48 043 пространственно зарегистрированных пары исправлений обучающих изображений RAW и RGB размером 448 на 448. Набор данных содержит два отдельных тестовых набора. Один тестовый набор состоит из 1 204 пространственно зарегистрированных пар исправлений RAW и RGB изображений размером 448 на 448. Другой набор тестов состоит из незарегистрированных изображений RAW и RGB с полным разрешением.
Создайте каталог для хранения набора данных.
imageDir = fullfile(tempdir,'ZurichRAWToRGB'); if ~exist(imageDir,'dir') mkdir(imageDir); end
Чтобы загрузить набор данных, запросите доступ с помощью формы набора данных Zurich RAW to RGB. Извлеките данные в каталог, указанный imageDir переменная. При успешном извлечении imageDir содержит три каталога с именем full_resolution, test, и train.
Создание imageDatastore считывает целевые исправления обучающего образа RGB, полученные с помощью высококлассного Canon DSLR.
trainImageDir = fullfile(imageDir,'train'); dsTrainRGB = imageDatastore(fullfile(trainImageDir,'canon'),'ReadSize',16);
Предварительный просмотр исправления учебного образа RGB.
groundTruthPatch = preview(dsTrainRGB); imshow(groundTruthPatch)

Создание imageDatastore считывает входные исправления обучающего изображения RAW, полученные с помощью камеры телефона Huawei. Изображения RAW захватываются с 10-битовой точностью и представляются как 8-битные и 16-битные файлы PNG. 8-битные файлы обеспечивают компактное представление исправлений с данными в диапазоне [0, 255]. Ни для одного из данных RAW масштабирование не выполнялось.
dsTrainRAW = imageDatastore(fullfile(trainImageDir,'huawei_raw'),'ReadSize',16);
Предварительный просмотр введенного фрагмента обучающего изображения RAW. Хранилище данных считывает это исправление как 8-разрядное uint8 поскольку количество датчиков находится в диапазоне [0, 255]. Чтобы смоделировать 10-битный динамический диапазон обучающих данных, разделите значения интенсивности изображения на 4. При увеличении изображения можно увидеть узор RGGB Bayer.
inputPatch = preview(dsTrainRAW); inputPatchRAW = inputPatch/4; imshow(inputPatchRAW)

Чтобы смоделировать минимальный традиционный конвейер обработки, продемонстрируйте шаблон RGGB Bayer данных RAW с использованием demosaic(Панель инструментов обработки изображений). Отобразить обработанное изображение и осветить дисплей. По сравнению с целевым RGB-изображением, минимально обработанное RGB-изображение темно и имеет несбалансированные цвета и заметные артефакты. Обученная сеть RAW-RGB выполняет операции предварительной обработки, так что выходное изображение RGB напоминает целевое изображение.
inputPatchRGB = demosaic(inputPatch,'rggb');
imshow(rescale(inputPatchRGB))
Данные теста содержат исправления RAW и RGB-изображений и полноразмерные изображения. В этом примере исправления тестового изображения разделяются на набор проверки и набор тестов. В примере используются полноразмерные тестовые изображения только для качественного тестирования. См. раздел Оценка конвейера обработки обученных изображений на полноразмерных изображениях.
Создание хранилищ данных образов, считывающих исправления тестовых образов RAW и RGB.
testImageDir = fullfile(imageDir,'test'); dsTestRAW = imageDatastore(fullfile(testImageDir,'huawei_raw'),'ReadSize',16); dsTestRGB = imageDatastore(fullfile(testImageDir,'canon'),'ReadSize',16);
Случайное разделение данных теста на два набора для проверки и обучения. Набор данных проверки содержит 200 изображений. Тестовый набор содержит оставшиеся изображения.
numTestImages = dsTestRAW.numpartitions; numValImages = 200; testIdx = randperm(numTestImages); validationIdx = testIdx(1:numValImages); testIdx = testIdx(numValImages+1:numTestImages); dsValRAW = subset(dsTestRAW,validationIdx); dsValRGB = subset(dsTestRGB,validationIdx); dsTestRAW = subset(dsTestRAW,testIdx); dsTestRGB = subset(dsTestRGB,testIdx);
Датчик получает цветовые данные в повторяющемся рисунке Байера, который включает в себя один красный, два зеленых и один синий фотодатчик. Предварительная обработка данных в четырехканальный образ, ожидаемый от сети, с помощью transform функция. transform обрабатывает данные с помощью операций, указанных в preprocessRAWDataForRAWToRGB функция помощника. Вспомогательная функция прикрепляется к примеру как вспомогательный файл.
preprocessRAWDataForRAWToRGB функция помощника преобразовывает СЫРОЕ изображение В на Ш на 1 в H/2-by-W/2-by-4 многоканальное изображение, состоящее из одного красного, двух зеленых, и один синий канал.

Функция также приводит данные к типу данных. single масштабировано до диапазона [0, 1].
dsTrainRAW = transform(dsTrainRAW,@preprocessRAWDataForRAWToRGB); dsValRAW = transform(dsValRAW,@preprocessRAWDataForRAWToRGB); dsTestRAW = transform(dsTestRAW,@preprocessRAWDataForRAWToRGB);
Целевые образы RGB хранятся на диске как неподписанные 8-битные данные. Чтобы сделать вычисление метрик и проектирование сети более удобными, выполните предварительную обработку целевых учебных изображений RGB с помощью transform функции и preprocessRGBDataForRAWToRGB функция помощника. Вспомогательная функция прикрепляется к примеру как вспомогательный файл.
preprocessRGBDataForRAWToRGB вспомогательная функция приводит изображения к типу данных single масштабировано до диапазона [0, 1].
dsTrainRGB = transform(dsTrainRGB,@preprocessRGBDataForRAWToRGB); dsValRGB = transform(dsValRGB,@preprocessRGBDataForRAWToRGB);
Объединение входных данных RAW и целевых данных RGB для наборов обучающих, валидационных и тестовых изображений с помощью combine функция.
dsTrain = combine(dsTrainRAW,dsTrainRGB); dsVal = combine(dsValRAW,dsValRGB); dsTest = combine(dsTestRAW,dsTestRGB);
Случайное увеличение данных обучения с использованием transform функции и augmentDataForRAWToRGB функция помощника. Вспомогательная функция прикрепляется к примеру как вспомогательный файл.
augmentDataForRAWToRGB вспомогательная функция случайным образом применяет поворот на 90 градусов и горизонтальное отражение к парам входных учебных изображений RAW и целевого RGB.
dsTrainAug = transform(dsTrain,@augmentDataForRAWToRGB);
Предварительный просмотр дополненных данных обучения.
exampleAug = preview(dsTrainAug)
exampleAug=8×2 cell array
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
Отображение сетевого входного и целевого изображения в монтаже. Сетевой вход имеет четыре канала, поэтому первый канал отображается в диапазоне [0, 1]. Входные RAW и целевые RGB-образы имеют одинаковое увеличение.
exampleInput = exampleAug{1,1};
exampleOutput = exampleAug{1,2};
montage({rescale(exampleInput(:,:,1)),exampleOutput})
В этом примере используется настраиваемый цикл обучения. minibatchqueue полезен для управления мини-пакетами наблюдений в пользовательских учебных циклах. minibatchqueue объект также передает данные в dlarray объект, обеспечивающий автоматическую дифференциацию в приложениях глубокого обучения.
miniBatchSize = 12; valBatchSize = 10; trainingQueue = minibatchqueue(dsTrainAug,'MiniBatchSize',miniBatchSize,'PartialMiniBatch','discard','MiniBatchFormat','SSCB'); validationQueue = minibatchqueue(dsVal,'MiniBatchSize',valBatchSize,'MiniBatchFormat','SSCB');
next функция minibatchqueue выдает следующий мини-пакет данных. Предварительный просмотр выходных данных одного вызова next функция. Выходные данные имеют тип данных dlarray. Данные уже приведены к gpuArray, на ГПУ, и готов к обучению.
[inputRAW,targetRGB] = next(trainingQueue); whos inputRAW whos targetRGB
Name Size Bytes Class Attributes targetRGB 448x448x3x12 28901384 dlarray
В этом примере используется вариант сети U-Net. В U-Net начальные серии сверточных слоев перемежаются с максимальными слоями объединения, последовательно уменьшая разрешение входного изображения. За этими слоями следует ряд сверточных слоев, перемежающихся операторами повышающей дискретизации, последовательно повышая разрешающую способность входного изображения. Название U-Net происходит от того, что сеть может быть нарисована симметричной формой наподобие буквы U.
В этом примере используется простая архитектура U-Net с двумя модификациями. Во-первых, сеть заменяет конечную транспонированную операцию свертки на пользовательскую операцию повышения дискретизации пикселя (также известную как операция «глубина-пространство»). Во-вторых, сеть использует пользовательский слой активации гиперболических касательных в качестве конечного слоя в сети.
Свертка, за которой следует повышающая дискретизация пикселя, может определять субпиксельную свертку для приложений суперразрешения. Субпиксельная свертка предотвращает артефакты контрольной доски, которые могут возникнуть из транспонированной свертки [6]. Поскольку модель должна сопоставлять H/2-by-W/2-by-4 входные данные RAW с W-by-H-by-3 выходами RGB, заключительную стадию повышающей дискретизации модели можно рассматривать аналогично сверхразрешению, где число пространственных выборок растет от входа к выходу.
Данные показывают, как пиксельная повышающая дискретизация перетасовки работает на вход 2 на 2 на 4. Первые два размера являются пространственными размерами, а третий размер является размером канала. В общем случае, повышающая дискретизация пикселя на коэффициент S принимает входной сигнал H-by-W-by-C и выдает выходной сигнал S * H-by-S * W-by-CS2.

Функция тасования пикселей увеличивает пространственные размеры выходного сигнала путем отображения информации из размеров канала в заданном пространственном местоположении в пространственные блоки S-на-S на выходе, в которых каждый канал вносит вклад в последовательное пространственное положение относительно своих соседей во время повышающей дискретизации.
Слой активации гиперболической касательной применяет tanh функция на входах слоев. В этом примере используется масштабированная и измененная версия tanh функция, которая поощряет, но не строго обеспечивает, чтобы сетевые выходы RGB находились в диапазоне [0, 1] [6].
) + 0,5

Использовать tall вычислять среднее снижение по каналу в наборе обучающих данных. Входной уровень сети выполняет среднюю центровку входных данных во время обучения и тестирования с использованием средней статистики.
dsIn = copy(dsTrainRAW); dsIn.UnderlyingDatastore.ReadSize = 1; t = tall(dsIn); perChannelMean = gather(mean(t,[1 2]));
Создайте уровни начальной подсети, указав среднее значение для каждого канала.
inputSize = [256 256 4]; initialLayer = imageInputLayer(inputSize,"Normalization","zerocenter","Mean",perChannelMean, ... "Name","ImageInputLayer");
Добавление уровней первой кодирующей подсети. Первый кодер имеет 32 сверточных фильтра.
numEncoderStages = 4; numFiltersFirstEncoder = 32; encoderNamePrefix = "Encoder-Stage-"; encoderLayers = [ convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(encoderNamePrefix,"1-Conv-1")) leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-1")) convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(encoderNamePrefix,"1-Conv-2")) leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-2")) maxPooling2dLayer([2 2],"Stride",[2 2], ... "Name",strcat(encoderNamePrefix,"1-MaxPool")); ];
Добавление уровней дополнительных подсетей кодирования. Эти подсети добавляют нормализацию экземпляра канала после каждого сверточного уровня с использованием groupNormalizationLayer. Каждая подсеть кодера имеет вдвое большее количество фильтров, чем предыдущая подсеть кодера.
cnIdx = 1; for stage = 2:numEncoderStages numFilters = numFiltersFirstEncoder*2^(stage-1); layerNamePrefix = strcat(encoderNamePrefix,num2str(stage)); encoderLayers = [ encoderLayers convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-1")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-2")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2")) maxPooling2dLayer([2 2],"Stride",[2 2],"Name",strcat(layerNamePrefix,"-MaxPool")) ]; cnIdx = cnIdx + 2; end
Добавление слоев моста. Мостовая подсеть имеет вдвое больше фильтров, чем конечная подсеть кодера и первая подсеть декодера.
numFilters = numFiltersFirstEncoder*2^numEncoderStages;
bridgeLayers = [
convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
"Name","Bridge-Conv-1")
groupNormalizationLayer("channel-wise","Name","cn7")
leakyReluLayer(0.2,"Name","Bridge-ReLU-1")
convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
"Name","Bridge-Conv-2")
groupNormalizationLayer("channel-wise","Name","cn8")
leakyReluLayer(0.2,"Name","Bridge-ReLU-2")];Добавление уровней первых трех подсетей декодера.
numDecoderStages = 4; cnIdx = 9; decoderNamePrefix = "Decoder-Stage-"; decoderLayers = []; for stage = 1:numDecoderStages-1 numFilters = numFiltersFirstEncoder*2^(numDecoderStages-stage); layerNamePrefix = strcat(decoderNamePrefix,num2str(stage)); decoderLayers = [ decoderLayers transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-UpConv")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU")) depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-1")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-2")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2")) ]; cnIdx = cnIdx + 2; end
Добавление уровней последней подсети декодера. Эта подсеть исключает нормализацию экземпляра канала, выполняемую другими подсетями декодера. Каждая подсеть декодера имеет половину числа фильтров по сравнению с предыдущей подсетью.
numFilters = numFiltersFirstEncoder;
layerNamePrefix = strcat(decoderNamePrefix,num2str(stage+1));
decoderLayers = [
decoderLayers
transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ...
"Name",strcat(layerNamePrefix,"-UpConv"))
leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU"))
depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation"))
convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
"Name",strcat(layerNamePrefix,"-Conv-1"))
leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1"))
convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
"Name",strcat(layerNamePrefix,"-Conv-2"))
leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))];Добавьте конечные слои U-Net. Слой тасования пикселей перемещается от H/2-by-W/2-by-12 размера канала активизаций от конечной свертки до активизаций канала H-by-W-by-3 с использованием повышающей дискретизации тасования пикселей. Конечный уровень стимулирует вывод в желаемый диапазон [0, 1] с использованием гиперболической касательной функции.
finalLayers = [
convolution2dLayer([3 3],12,"Padding","same","WeightsInitializer","narrow-normal", ...
"Name","Decoder-Stage-4-Conv-3")
pixelShuffleLayer("pixelShuffle",2)
tanhScaledAndShiftedLayer("tanhActivation")];
layers = [initialLayer;encoderLayers;bridgeLayers;decoderLayers;finalLayers];
lgraph = layerGraph(layers);Соедините уровни подсетей кодирования и декодирования.
lgraph = connectLayers(lgraph,"Encoder-Stage-1-ReLU-2","Decoder-Stage-4-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-2-ReLU-2","Decoder-Stage-3-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-3-ReLU-2","Decoder-Stage-2-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-4-ReLU-2","Decoder-Stage-1-DepthConcatenation/in2"); net = dlnetwork(lgraph);
Визуализация сетевой архитектуры с помощью приложения Deep Network Designer.
deepNetworkDesigner(lgraph)
Эта функция модифицирует заранее подготовленную VGG-16 глубокую нейронную сеть для извлечения признаков изображения на различных уровнях. Эти многослойные функции используются для вычисления потери содержимого.
Чтобы получить предварительно обученную сеть VGG-16, установите vgg16. Если необходимый пакет поддержки не установлен, программа предоставляет ссылку для загрузки.
vggNet = vgg16;
Чтобы сделать сеть VGG-16 подходящей для извлечения элементов, используйте слои до 'relu5 _ 3 '.
vggNet = vggNet.Layers(1:31); vggNet = dlnetwork(layerGraph(vggNet));
Вспомогательная функция modelGradients вычисляет градиенты и общие потери для партий учебных данных. Эта функция определена в разделе «Вспомогательные функции» данного примера.
Общая потеря представляет собой взвешенную сумму двух потерь: среднее значение потери при абсолютной ошибке (MAE) и потери содержимого. Потеря содержимого взвешивается таким образом, что потеря MAE и потеря содержимого вносят приблизительно равный вклад в общую потерю:
Потеря MAE штрафует за расстояние между выборками сетевых прогнозов и выборками целевого изображения. часто лучше, чем для приложений обработки изображений, потому что это может помочь уменьшить размытие артефактов [4]. Эта потеря реализуется с помощью maeLoss вспомогательная функция, определенная в разделе «Вспомогательные функции» данного примера.
Потеря содержимого помогает сети изучать как высокоуровневое структурное содержимое, так и низкоуровневую краевую и цветовую информацию. Функция потерь вычисляет взвешенную сумму среднеквадратической ошибки (MSE) между прогнозами и целями для каждого уровня активации. Эта потеря реализуется с помощью contentLoss вспомогательная функция, определенная в разделе «Вспомогательные функции» данного примера.
modelGradients вспомогательная функция требует в качестве входного аргумента коэффициент потери содержимого. Вычислите весовой коэффициент для партии образцов учебных данных таким образом, чтобы потери MAE были равны взвешенным потерям содержимого.
Предварительный просмотр пакета обучающих данных, который состоит из пар сетевых входов RAW и целевых выходов RGB.
trainingBatch = preview(dsTrainAug);
networkInput = dlarray((trainingBatch{1,1}),'SSC');
targetOutput = dlarray((trainingBatch{1,2}),'SSC');Прогнозирование ответа необученной сети U-Net с помощью forward функция.
predictedOutput = forward(net,networkInput);
Вычислите MAE и потери содержимого между прогнозируемым и целевым RGB-изображениями.
sampleMAELoss = maeLoss(predictedOutput,targetOutput); sampleContentLoss = contentLoss(vggNet,predictedOutput,targetOutput);
Рассчитайте весовой коэффициент.
weightContent = sampleMAELoss/sampleContentLoss;
Определите параметры обучения, которые используются в пользовательском цикле обучения для управления аспектами оптимизации Adam. Поезд на 20 эпох.
learnRate = 5e-5; numEpochs = 20;
По умолчанию в примере загружается предварительно подготовленная версия сети RAW-RGB с помощью функции помощника. downloadTrainedRAWToRGBNet. Вспомогательная функция прикрепляется к примеру как вспомогательный файл. Предварительно обученная сеть позволяет выполнять весь пример без ожидания завершения обучения.
Для обучения сети установите doTraining переменная в следующем коде true. Обучение модели в индивидуальном цикле обучения. Для каждой итерации:
Считывание данных для текущего мини-пакета с помощью next функция.
Оцените градиенты модели с помощью dlfeval функции и modelGradients функция помощника.
Обновление параметров сети с помощью adamupdate и информацию о градиенте.
Обновите график хода обучения для каждой итерации и просмотрите различные вычисленные потери.
Обучение на GPU, если он доступен. Для использования графического процессора требуются параллельные вычислительные Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ®. Дополнительные сведения см. в разделе Поддержка графического процессора по выпуску (Панель инструментов параллельных вычислений). Обучение занимает около 88 часов на NVIDIA™ Titan RTX и может занять еще больше времени в зависимости от оборудования графического процессора.
doTraining = false; if doTraining % Create a directory to store checkpoints checkpointDir = fullfile(imageDir,'checkpoints',filesep); if ~exist(checkpointDir,'dir') mkdir(checkpointDir); end % Initialize training plot [hFig,batchLine,validationLine] = initializeTrainingPlotRAWPipeline; % Initialize Adam solver state [averageGrad,averageSqGrad] = deal([]); iteration = 0; start = tic; for epoch = 1:numEpochs reset(trainingQueue); shuffle(trainingQueue); while hasdata(trainingQueue) [inputRAW,targetRGB] = next(trainingQueue); [grad,loss] = dlfeval(@modelGradients,net,vggNet,inputRAW,targetRGB,weightContent); iteration = iteration + 1; [net,averageGrad,averageSqGrad] = adamupdate(net, ... grad,averageGrad,averageSqGrad,iteration,learnRate); updateTrainingPlotRAWPipeline(batchLine,validationLine,iteration,loss,start,epoch,... validationQueue,valSetSize,valBatchSize,net,vggNet,weightContent); end % Save checkpoint of network state save(checkpointDir + "epoch" + epoch,'net'); end % Save the final network state save(checkpointDir + "trainedRAWToRGBNet.mat",'net'); else trainedRAWToRGBNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedRAWToRGBNet.mat'; downloadTrainedRAWToRGBNet(trainedRAWToRGBNet_url,imageDir); load(fullfile(imageDir,'trainedRAWToRGBNet.mat')); end
Эталонные показатели качества, такие как MSIM или PSNR, позволяют количественно измерять качество изображения. Можно вычислить MSIM и PSNR исправленных тестовых изображений, поскольку они пространственно зарегистрированы и имеют одинаковый размер.
Итерация через тестовый набор исправленных изображений с помощью minibatchqueue объект.
patchTestSet = combine(dsTestRAW,dsTestRGB); testPatchQueue = minibatchqueue(patchTestSet,'MiniBatchSize',16,'MiniBatchFormat','SSCB');
Выполните итерацию через тестовый набор и вычислите MSIM и PSNR для каждого тестового образа с помощью multissim(Панель инструментов обработки изображений) и psnr(Панель инструментов обработки изображений). Хотя функции принимают изображения RGB, метрики для изображений RGB определены недостаточно. Поэтому аппроксимируйте MSIM и PSNR путем раздельного вычисления метрик цветовых каналов. Вы можете использовать calculateRAWToRGBQualityMetrics вспомогательная функция для вычисления канальных метрик. Эта функция присоединена к примеру как вспомогательный файл.
totalMSSIM = 0; totalPSNR = 0; while hasdata(testPatchQueue) [inputRAW,targetRGB] = next(testPatchQueue); outputRGB = forward(net,inputRAW); [mssimOut,psnrOut] = calculateRAWToRGBQualityMetrics(outputRGB,targetRGB); totalMSSIM = totalMSSIM + mssimOut; totalPSNR = totalPSNR + psnrOut; end
Рассчитайте среднее значение MSIM и среднее значение PSNR по тестовому набору. Этот результат согласуется с аналогичным подходом U-Net из [3] для среднего MSIM и конкурирует с подходом PyNet в [3] для среднего PSNR. Различия в функциях потерь и использовании повышающей дискретизации пикселя по сравнению с [3], вероятно, являются причиной этих различий.
numObservations = dsTestRGB.numpartitions; meanMSSIM = totalMSSIM / numObservations
meanMSSIM = single
0.8534
meanPSNR = totalPSNR / numObservations
meanPSNR = 21.2956
Из-за различий в датчиках между камерой телефона и DSLR, используемыми для получения тестовых изображений с полным разрешением, сцены не регистрируются и не имеют одинакового размера. Эталонное сравнение изображений с полным разрешением из сети и интернет-провайдера DSLR затруднено. Однако качественное сравнение изображений полезно, поскольку целью обработки изображения является создание эстетически приятного изображения.
Создайте хранилище данных изображений, содержащее полноразмерные RAW-изображения, полученные камерой телефона.
testImageDir = fullfile(imageDir,'test'); testImageDirRAW = "huawei_full_resolution"; dsTestFullRAW = imageDatastore(fullfile(testImageDir,testImageDirRAW));
Получение имен файлов изображений в полноразмерном наборе тестов RAW.
targetFilesToInclude = extractAfter(string(dsTestFullRAW.Files),fullfile(testImageDirRAW,filesep));
targetFilesToInclude = extractBefore(targetFilesToInclude,".png");Предварительная обработка данных RAW путем преобразования данных в форму, ожидаемую сетью, с помощью transform функция. transform обрабатывает данные с помощью операций, указанных в preprocessRAWDataForRAWToRGB функция помощника. Вспомогательная функция прикрепляется к примеру как вспомогательный файл.
dsTestFullRAW = transform(dsTestFullRAW,@preprocessRAWDataForRAWToRGB);
Создайте хранилище данных образа, содержащее полноразмерные тестовые образы RGB, полученные из DSLR высшего класса. Набор данных Zurich RAW-RGB содержит больше полноразмерных RGB-изображений, чем RAW-изображений, поэтому включает только RGB-изображения с соответствующим RAW-изображением.
dsTestFullRGB = imageDatastore(fullfile(imageDir,'full_resolution','canon')); dsTestFullRGB.Files = dsTestFullRGB.Files(contains(dsTestFullRGB.Files,targetFilesToInclude));
Чтение целевых изображений RGB. Почувствуйте общий результат, посмотрев на вид монтажа.
targetRGB = readall(dsTestFullRGB); montage(targetRGB,"Size",[5 2],"Interpolation","bilinear")

Итерация через тестовый набор полноразмерных изображений с помощью minibatchqueue объект. Если у вас есть графический процессор с достаточным объемом памяти для обработки изображений с полным разрешением, можно запустить предварительное обучение на графическом процессоре, указав среду вывода как "gpu".
testQueue = minibatchqueue(dsTestFullRAW,"MiniBatchSize",1, ... "MiniBatchFormat","SSCB","OutputEnvironment","cpu");
Для каждого полноразмерного тестового образа RAW спрогнозируйте выходной RGB-образ путем вызова forward в сети.
outputSize = 2*size(preview(dsTestFullRAW),[1 2]); outputImages = zeros([outputSize,3,dsTestFullRAW.numpartitions],'uint8'); idx = 1; while hasdata(testQueue) inputRAW = next(testQueue); rgbOut = forward(net,inputRAW); rgbOut = gather(extractdata(rgbOut)); outputImages(:,:,:,idx) = im2uint8(rgbOut); idx = idx+1; end
Почувствуйте общий результат, посмотрев на вид монтажа. Сеть производит изображения, которые эстетически приятны, с аналогичными характеристиками.
montage(outputImages,"Size",[5 2],"Interpolation","bilinear")

Сравните один целевой RGB-образ с соответствующим изображением, предсказанным сетью. Сеть создает более насыщенные цвета, чем целевые DSLR-изображения. Несмотря на то, что цвета из простой архитектуры U-Net не совпадают с целевыми объектами DSLR, во многих случаях изображения все еще качественно приятны.
imgIdx = 1;
imTarget = targetRGB{imgIdx};
imPredicted = outputImages(:,:,:,imgIdx);
figure
montage({imTarget,imPredicted},"Interpolation","bilinear")
Для повышения производительности сети RAW-RGB сетевая архитектура будет изучать подробные локализованные пространственные особенности с использованием нескольких масштабов из глобальных функций, описывающих цвет и контраст [3].
modelGradients вспомогательная функция вычисляет градиенты и общие потери. Информация о градиенте возвращается в виде таблицы, включающей слой, имя параметра и значение для каждого обучаемого параметра в модели.
function [gradients,loss] = modelGradients(dlnet,vggNet,Xpatch,Target,weightContent) Y = forward(dlnet,Xpatch); lossMAE = maeLoss(Y,Target); lossContent = contentLoss(vggNet,Y,Target); loss = lossMAE + weightContent.*lossContent; gradients = dlgradient(loss,dlnet.Learnables); end
Вспомогательная функция maeLoss вычисляет среднюю абсолютную ошибку между предсказаниями сети, Yи целевые изображения, T.
function loss = maeLoss(Y,T) loss = mean(abs(Y-T),'all'); end
Вспомогательная функция contentLoss вычисляет взвешенную сумму MSE между предсказаниями сети, Yи целевые изображения, T, для каждого слоя активации. contentLoss вспомогательная функция вычисляет MSE для каждого уровня активации с помощью mseLoss функция помощника. Веса выбирают таким образом, чтобы потери от каждого слоя активации вносили примерно равный вклад в общую потерю содержимого.
function loss = contentLoss(net,Y,T) layers = ["relu1_1","relu1_2","relu2_1","relu2_2","relu3_1","relu3_2","relu3_3","relu4_1"]; [T1,T2,T3,T4,T5,T6,T7,T8] = forward(net,T,'Outputs',layers); [X1,X2,X3,X4,X5,X6,X7,X8] = forward(net,Y,'Outputs',layers); l1 = mseLoss(X1,T1); l2 = mseLoss(X2,T2); l3 = mseLoss(X3,T3); l4 = mseLoss(X4,T4); l5 = mseLoss(X5,T5); l6 = mseLoss(X6,T6); l7 = mseLoss(X7,T7); l8 = mseLoss(X8,T8); layerLosses = [l1 l2 l3 l4 l5 l6 l7 l8]; weights = [1 0.0449 0.0107 0.0023 6.9445e-04 2.0787e-04 2.0118e-04 6.4759e-04]; loss = sum(layerLosses.*weights); end
Вспомогательная функция mseLoss вычисляет MSE между предсказаниями сети, Yи целевые изображения, T.
function loss = mseLoss(Y,T) loss = mean((Y-T).^2,'all'); end
1) Самнер, Роб. «Обработка изображений RAW в MATLAB». 19 мая 2014 года. https://rcsumner.net/raw_guide/RAWguide.pdf
2) Чэнь, Чэнь, Цифэн Чэнь, Цзя Сюй, и Владлен Колтун. «Учимся видеть в темноте.» ArXiv:1805.01934 [Cs], 4 мая 2018 года. http://arxiv.org/abs/1805.01934.
3) Игнатов, Андрей, Люк Ван Гул, и Раду Тимофте. «Замена интернет-провайдера мобильной камеры единой моделью глубокого обучения». ArXiv:2002.05509 [Cs, Eess], 13 февраля 2020 года. http://arxiv.org/abs/2002.05509. Веб-сайт проекта.
4) Чжао, Ханг, Орацио Галло, Иури Фрозио, и Ян Каутц. «Функции потери нейронных сетей для обработки изображений». ArXiv:1511.08861 [Cs], 20 апреля 2018 года. http://arxiv.org/abs/1511.08861.
5) Джонсон, Джастин, Александр Алахи и Ли Фэй-Фэй. «Потери восприятия для передачи стиля в реальном времени и суперразрешения». ArXiv:1603.08155 [Cs], 26 марта 2016 г. http://arxiv.org/abs/1603.08155.
6) Ши, Вэньчжэ, Хосе Кабальеро, Ференц Хусар, Йоханнес Тоц, Эндрю П. Эйткен, Роб Бишоп, Даниэль Рюккерт, и Зехан Ван. «Сверхразрешающее изображение и видео в реальном времени с использованием эффективной свёрточной нейронной сети под-пикселов». ArXiv:1609.05158 [Cs, Stat], 23 сентября 2016 г. http://arxiv.org/abs/1609.05158.
combine | imageDatastore | trainingOptions | trainNetwork | transform