Разрабатывайте трубопровод обработки необработанных камер с помощью глубокого обучения

В этом примере показано, как преобразовать необработанные данные камеры в эстетически приятное цветное изображение с помощью U-Net.

DSLR и многие современные камеры телефона предлагают возможность сохранять данные, собранные непосредственно с датчика камеры, в качестве RAW- файла. Каждый пиксель данных RAW соответствует непосредственно количеству света, захваченному соответствующим фотосенсором камеры. Данные зависят от фиксированных характеристик оборудования камеры, таких как чувствительность к каждому фотосенсору к конкретной области значений длин волн электромагнитного спектра. Данные также зависят от настроек захвата камеры, таких как время экспозиции, и факторов сцены, таких как источник света.

Демозаицирование является единственной необходимой операцией для преобразования одноканальных данных RAW в трехканальное изображение RGB. Однако без дополнительных операций обработки изображений полученное изображение RGB имеет субъективно низкое качество зрения.

Традиционный трубопровод обработки изображений выполняет комбинацию дополнительных операций, включая шумоподавление, линеаризацию, балансировку белого, коррекцию цвета, регулировку яркости и регулировку контрастности [1]. Задача разработки трубопровода заключается в уточнении алгоритмов, чтобы оптимизировать субъективный внешний вид окончательного изображения RGB независимо от изменений в сцене и настройках сбора.

Глубокие методы глубокого обучения позволяют прямое преобразование RAW в RGB без необходимости разработки традиционного конвейера обработки. Для образца один метод компенсирует недооценку при преобразовании изображений RAW в RGB [2]. В этом примере показано, как преобразовать изображения RAW из нижней конечной камеры телефона в изображения RGB, которые аппроксимируют качество камеры DSLR более высокого уровня [3].

Загрузить Zurich RAW в набор данных RGB

Этот пример использует набор данных Zurich RAW to RGB [3]. Размер набора данных составляет 22 ГБ. Набор данных содержит 48 043 пространственно зарегистрированных пар закрашенных фигур обучающих изображений RAW и RGB размера 448 на 448. Набор данных содержит два отдельных тестовых набора. Один тестовый набор состоит из 1 204 пространственно зарегистрированных пар изображений RAW и RGB закрашенных фигур размера 448 на 448. Другой тестовый набор состоит из незарегистрированных изображений RAW и RGB с полным разрешением.

Создайте директорию для хранения набора данных.

imageDir = fullfile(tempdir,'ZurichRAWToRGB');
if ~exist(imageDir,'dir')
    mkdir(imageDir);
end

Чтобы загрузить набор данных, запросите доступ с помощью формы Zurich RAW to RGB dataset. Извлеките данные в директорию, заданную imageDir переменная. При успешном извлечении imageDir содержит три директории с именем full_resolution, test, и train.

Создайте хранилища данных для обучения, валидации и проверки

Создайте Datastore для обучающих данных закрашенных фигур для изображений RGB

Создайте imageDatastore который считывает целевые закрашенные фигуры обучающего изображения RGB, полученные с помощью DSLR Canon высшего класса.

trainImageDir = fullfile(imageDir,'train');
dsTrainRGB = imageDatastore(fullfile(trainImageDir,'canon'),'ReadSize',16);

Предварительный просмотр закрашенной фигуры обучающего изображения RGB.

groundTruthPatch = preview(dsTrainRGB);
imshow(groundTruthPatch)

Создайте Datastore для RAW- Изображения Закрашенной фигуры обучающих данных

Создайте imageDatastore который считывает входные обучающие закрашенные фигуры изображение, полученные с помощью камеры телефона Huawei. Изображения RAW получаются с 10-битной точностью и представлены как 8-битными, так и 16-битными файлами PNG. 8-битные файлы обеспечивают компактное представление закрашенных фигур с данными в области значений [0, 255]. Масштабирование не выполнено ни на одном из RAW- данных.

dsTrainRAW = imageDatastore(fullfile(trainImageDir,'huawei_raw'),'ReadSize',16);

Предварительный просмотр входной закрашенной фигуры обучающего изображения RAW. datastore читает эту закрашенную фигуру как 8-битный uint8 изображение, поскольку счетчики датчиков находятся в области значений [0, 255]. Чтобы симулировать 10-битную динамическую область значений обучающих данных, разделите значения интенсивности изображения на 4. Если вы увеличиваете изображение, то вы можете увидеть шаблон RGGB Bayer.

inputPatch = preview(dsTrainRAW);
inputPatchRAW = inputPatch/4;
imshow(inputPatchRAW)

Чтобы симулировать минимальный традиционный трубопровод обработки, демосапирируйте шаблон RGGB Bayer данных RAW с помощью demosaic (Image Processing Toolbox) функция. Отобразите обработанное изображение и осветлите отображение. По сравнению с целевым изображением RGB минимально обработанное изображение RGB темное и имеет несбалансированные цвета и заметные программные продукты. Обученная сеть RAW-to-RGB выполняет операции предварительной обработки так, чтобы выходное изображение RGB напоминало целевое изображение.

inputPatchRGB = demosaic(inputPatch,'rggb');
imshow(rescale(inputPatchRGB))

Разбиение тестовых изображений на наборы для валидации и тестирования

Тестовые данные содержат изображения RAW и RGB закрашенных фигур а также полноразмерные изображения. Этот пример разделяет закрашенные фигуры тестового изображения на наборы валидации и тестовые наборы. Пример использует полноразмерные тестовые изображения только для качественной проверки. См. «Оценка обученного конвейера обработки изображений на полноразмерных изображениях».

Создайте хранилища изображений, которые считывают тестовые закрашенные фигуры RAW и RGB.

testImageDir = fullfile(imageDir,'test');
dsTestRAW = imageDatastore(fullfile(testImageDir,'huawei_raw'),'ReadSize',16);
dsTestRGB = imageDatastore(fullfile(testImageDir,'canon'),'ReadSize',16);

Случайным образом разделите тестовые данные на два набора для валидации и обучения. Набор данных валидации содержит 200 изображений. Тестовый набор содержит оставшиеся изображения.

numTestImages = dsTestRAW.numpartitions;
numValImages = 200;

testIdx = randperm(numTestImages);
validationIdx = testIdx(1:numValImages);
testIdx = testIdx(numValImages+1:numTestImages);

dsValRAW = subset(dsTestRAW,validationIdx);
dsValRGB = subset(dsTestRGB,validationIdx);

dsTestRAW = subset(dsTestRAW,testIdx);
dsTestRGB = subset(dsTestRGB,testIdx);

Предварительная обработка и увеличение данных

Датчик получает данные о цвете в повторяющемся шаблоне Байера, который включает в себя один красный, два зеленых и один синий фотосенсор. Предварительно обработайте данные в четырехканальное изображение, ожидаемое от сети, используя transform функция. The transform функция обрабатывает данные с помощью операций, заданных в preprocessRAWDataForRAWToRGB вспомогательная функция. Функция helper присоединена к примеру как вспомогательный файл.

The preprocessRAWDataForRAWToRGB Функция helper преобразует H-by-W-by-1 изображение RAW в H/2-by-W/2-by-4 многоканальное изображение, состоящее из одного красного, двух зеленых и одного синего канала.

Функция также приводит данные к типу данных single масштабируется до области значений [0, 1].

dsTrainRAW = transform(dsTrainRAW,@preprocessRAWDataForRAWToRGB);
dsValRAW = transform(dsValRAW,@preprocessRAWDataForRAWToRGB);
dsTestRAW = transform(dsTestRAW,@preprocessRAWDataForRAWToRGB);

Целевые изображения RGB хранятся на диске в виде неподписанных 8-битных данных. Чтобы сделать расчет метрик и проект сети более удобным, предварительно обработайте целевые обучающие изображения RGB с помощью transform функции и preprocessRGBDataForRAWToRGB вспомогательная функция. Функция helper присоединена к примеру как вспомогательный файл.

The preprocessRGBDataForRAWToRGB функция helper приводит изображения к типу данных single масштабируется до области значений [0, 1].

dsTrainRGB = transform(dsTrainRGB,@preprocessRGBDataForRAWToRGB);
dsValRGB = transform(dsValRGB,@preprocessRGBDataForRAWToRGB);

Объедините входные RAW и целевые данные RGB для наборов изображений для обучения, валидации и тестирования с помощью combine функция.

dsTrain = combine(dsTrainRAW,dsTrainRGB);
dsVal = combine(dsValRAW,dsValRGB);
dsTest = combine(dsTestRAW,dsTestRGB);

Случайным образом увеличьте обучающие данные, используя transform функции и augmentDataForRAWToRGB вспомогательная функция. Функция helper присоединена к примеру как вспомогательный файл.

The augmentDataForRAWToRGB Функция helper случайным образом применяет вращение степени 90 и горизонтальное отражение к парам входа RAW и целевых обучающих изображений RGB.

dsTrainAug = transform(dsTrain,@augmentDataForRAWToRGB);

Предварительный просмотр дополненных обучающих данных.

exampleAug = preview(dsTrainAug)
exampleAug=8×2 cell array
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}

Отобразите входное и целевое изображения сети в монтаже. Вход сети имеет четыре канала, поэтому отобразите первый канал, переведенный в область значений [0, 1]. Входные RAW и целевые изображения RGB имеют идентичное увеличение.

exampleInput = exampleAug{1,1};
exampleOutput = exampleAug{1,2};
montage({rescale(exampleInput(:,:,1)),exampleOutput})

Пакетное обучение и данные валидации во время обучения

Этот пример использует пользовательский цикл обучения. The minibatchqueue объект полезен для управления мини-пакетированием наблюдений в пользовательских циклах обучения. The minibatchqueue объект также переводит данные в dlarray объект, который позволяет проводить автоматическую дифференциацию в применениях глубокого обучения.

miniBatchSize = 12;
valBatchSize = 10;
trainingQueue = minibatchqueue(dsTrainAug,'MiniBatchSize',miniBatchSize,'PartialMiniBatch','discard','MiniBatchFormat','SSCB');
validationQueue = minibatchqueue(dsVal,'MiniBatchSize',valBatchSize,'MiniBatchFormat','SSCB');

The next функция minibatchqueue приводит к следующему мини-пакету данных. Предварительный просмотр выходов одного вызова в next функция. Выходы имеют тип данных dlarray. Данные уже переведены в gpuArray, на графический процессор, и готов к обучению.

[inputRAW,targetRGB] = next(trainingQueue);
whos inputRAW
whos targetRGB
  Name             Size                     Bytes  Class      Attributes

  targetRGB      448x448x3x12            28901384  dlarray              

Настройка слоев сети U-Net

Этот пример использует изменение сети U-Net. В U-Net начальная серия сверточных слоев перемежается с максимальными слоями объединения, последовательно уменьшая разрешение входного изображения. Эти слои сопровождаются серией сверточных слоев, чередующихся с операторами повышающей дискретизации, последовательно увеличивая разрешение входного изображения. Имя U-Net происходит от того, что сеть может быть нарисована с симметричной формой, такой как буква U.

Этот пример использует простую архитектуру U-Net с двумя модификациями. Во-первых, сеть заменяет окончательную операцию транспонированной свертки пользовательской операцией тасования пикселей с повышенной дискретизацией (также известной как операция «глубина в пространство»). Во-вторых, сеть использует пользовательский слой активации гиперболического тангенса в качестве последнего слоя в сети.

Pixel Shuffle Увеличение дискретизации

Свертка с последующим повышением дискретизации пикселей может задать субпиксельную свертку для приложений супер разрешения. Субпиксельная свертка предотвращает программные продукты контроля, которые могут возникнуть из-за транспонированной свертки [6]. Поскольку модель должна сопоставить H/2-by-W/2-by-4 входы RAW с W-by-H-by-3 выходами RGB, конечный этап увеличения дискретизации модели может быть рассмотрен аналогично супер- разрешение, где количество пространственных выборок увеличивается от входных до выходных.

Данные показывают, как пиксельная повышающая дискретизация перетасовки работает на вход 2 на 2 на 4. Первые две размерности являются пространственными размерностями, а 3-я размерность является размерностью канала. В целом, перемещение пикселей с повышенной дискретизацией в множителе S принимает вход H-на-W-на-C и приводит к S * H-на-S * W-by-CS2 выход.

Функция тасования пикселей увеличивает пространственные размерности выходного сигнала путем отображения информации от размеров канала в заданном пространственном местоположении в пространственные блоки S на S в выходе, в котором каждый канал вносит вклад в согласованное пространственное положение относительно его соседей во время увеличения дискретизации.

Масштабная и гиперболическая активация тангенса

Слой активации гиперболического тангенса применяет tanh функция на входах слоя. Этот пример использует масштабированную и shfited версию tanh функция, которая поощряет, но не строго следит за тем, чтобы выходы сети RGB находились в области значений [0, 1] [6].

f(x)=0.58*tanh(x)+0.5

Вычислите Набор обучающих данных статистику для нормализации Входа

Использование tall вычислить среднее сокращение по каналам для обучающих данных набора. Уровень входа сети выполняет среднее центрирование входов во время обучения и проверки с помощью средней статистики.

dsIn = copy(dsTrainRAW);
dsIn.UnderlyingDatastore.ReadSize = 1;
t = tall(dsIn);
perChannelMean = gather(mean(t,[1 2]));

Создайте U-Net

Создайте слои начальной подсети, задав среднее значение по каналам.

inputSize = [256 256 4];
initialLayer = imageInputLayer(inputSize,"Normalization","zerocenter","Mean",perChannelMean, ...
    "Name","ImageInputLayer");

Добавьте слои первой подсети кодирования. Первый энкодер имеет 32 сверточных фильтра.

numEncoderStages = 4;
numFiltersFirstEncoder = 32;
encoderNamePrefix = "Encoder-Stage-";

encoderLayers = [
    convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name",strcat(encoderNamePrefix,"1-Conv-1"))
    leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-1"))
    convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name",strcat(encoderNamePrefix,"1-Conv-2"))
    leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-2"))
    maxPooling2dLayer([2 2],"Stride",[2 2], ...
        "Name",strcat(encoderNamePrefix,"1-MaxPool"));  
    ];

Добавьте слои дополнительных подсетей кодирования. Эти подсети добавляют нормализацию образца по каналу после каждого сверточного слоя с помощью groupNormalizationLayer. Каждая подсеть энкодера имеет в два раза больше фильтров, чем предыдущая подсеть энкодера.

cnIdx = 1;
for stage = 2:numEncoderStages
    
    numFilters = numFiltersFirstEncoder*2^(stage-1);
    layerNamePrefix = strcat(encoderNamePrefix,num2str(stage));
    
    encoderLayers = [
        encoderLayers
        convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-Conv-1"))
        groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx)))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1"))
        convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-Conv-2"))
        groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1)))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))
        maxPooling2dLayer([2 2],"Stride",[2 2],"Name",strcat(layerNamePrefix,"-MaxPool"))
        ];     
    
    cnIdx = cnIdx + 2;
end

Добавьте слои моста. Подсеть моста имеет вдвое больше фильтров, чем подсеть конечного энкодера и подсеть первого декодера.

numFilters = numFiltersFirstEncoder*2^numEncoderStages;
bridgeLayers = [
    convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name","Bridge-Conv-1")
    groupNormalizationLayer("channel-wise","Name","cn7")
    leakyReluLayer(0.2,"Name","Bridge-ReLU-1")
    convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name","Bridge-Conv-2")
    groupNormalizationLayer("channel-wise","Name","cn8")
    leakyReluLayer(0.2,"Name","Bridge-ReLU-2")];

Добавьте слои первых трех подсетей декодера.

numDecoderStages = 4;
cnIdx = 9;
decoderNamePrefix = "Decoder-Stage-";

decoderLayers = [];
for stage = 1:numDecoderStages-1
    
    numFilters = numFiltersFirstEncoder*2^(numDecoderStages-stage);
    layerNamePrefix = strcat(decoderNamePrefix,num2str(stage));  
    
    decoderLayers = [
        decoderLayers
        transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-UpConv"))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU"))
        depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation"))
        convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-Conv-1"))
        groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx)))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1"))
        convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-Conv-2"))
        groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1)))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))
        ];        
    
    cnIdx = cnIdx + 2;    
end

Добавьте слои последней подсети декодера. Эта подсеть исключает нормализацию образца по каналу, выполняемую другими подсетями декодера. Каждая подсеть декодера имеет половину количества фильтров, как предыдущая подсеть.

numFilters = numFiltersFirstEncoder;
layerNamePrefix = strcat(decoderNamePrefix,num2str(stage+1)); 

decoderLayers = [
    decoderLayers
    transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ...
       "Name",strcat(layerNamePrefix,"-UpConv"))
    leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU"))
    depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation"))
    convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name",strcat(layerNamePrefix,"-Conv-1"))
    leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1"))
    convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name",strcat(layerNamePrefix,"-Conv-2"))
    leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))];

Добавьте конечные слои U-Net. Слой тасования пикселей переходит от размера H/2-by-W/2-by-12 канала активаций от конечной свертки к активациям канала H-на-W-на 3 с помощью увеличения дискретизации пикселей. Конечный слой поощряет выходы в желаемую область значений [0, 1], используя гиперболическую функцию тангенса.

finalLayers = [
    convolution2dLayer([3 3],12,"Padding","same","WeightsInitializer","narrow-normal", ...
       "Name","Decoder-Stage-4-Conv-3")
    pixelShuffleLayer("pixelShuffle",2)
    tanhScaledAndShiftedLayer("tanhActivation")];

layers = [initialLayer;encoderLayers;bridgeLayers;decoderLayers;finalLayers];
lgraph = layerGraph(layers);

Соединяет слои подсетей кодирования и декодирования.

lgraph = connectLayers(lgraph,"Encoder-Stage-1-ReLU-2","Decoder-Stage-4-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-2-ReLU-2","Decoder-Stage-3-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-3-ReLU-2","Decoder-Stage-2-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-4-ReLU-2","Decoder-Stage-1-DepthConcatenation/in2");
net = dlnetwork(lgraph);

Визуализируйте сетевую архитектуру с помощью приложения Deep Network Designer.

deepNetworkDesigner(lgraph)

Загрузка сети редукции данных

Эта функция изменяет предварительно обученную VGG-16 глубокую нейронную сеть, чтобы извлечь функции изображения в различных слоях. Эти многослойные функции используются для вычисления потерь содержимого.

Чтобы получить предварительно обученную VGG-16 сеть, установите vgg16. Если у вас нет установленного необходимого пакета поддержки, то программное обеспечение предоставляет ссылку на загрузку.

vggNet = vgg16;

Чтобы сделать VGG-16 сеть подходящей для редукции данных, используйте слои до 'relu5 _ 3 '.

vggNet = vggNet.Layers(1:31);
vggNet = dlnetwork(layerGraph(vggNet));

Задайте градиенты модели и функции потерь

Функция помощника modelGradients вычисляет градиенты и общие потери для пакетов обучающих данных. Эта функция определяется в разделе Вспомогательные функции этого примера.

Общая потеря представляет собой взвешенную сумму двух потерь: средней абсолютной ошибки (MAE) и потери содержимого. Потери содержимого взвешены таким образом, что потери MAE и содержимого способствуют примерно равным образом общим потерям:

lossOverall=lossMAE+weightFactor*lossContent

Потеря MAE наказывает L1 расстояние между выборками сетевых предсказаний и выборками целевого изображения. L1 часто является лучшим выбором, чем L2 для приложений обработки изображений, потому что это может помочь уменьшить размытие программных продуктов [4]. Эта потеря реализована с помощью maeLoss вспомогательная функция, заданная в разделе Вспомогательные функции этого примера.

Потеря содержимого помогает сети узнать как высокоуровневое структурное содержание, так и низкоуровневое ребро и цветовую информацию. Функция потерь вычисляет взвешенную сумму средней квадратной ошибки (MSE) между предсказаниями и целями для каждого слоя активации. Эта потеря реализована с помощью contentLoss вспомогательная функция, заданная в разделе Вспомогательные функции этого примера.

Вычислите коэффициент веса потери содержимого

The modelGradients Функция helper требует, чтобы коэффициент веса потери содержимого был входным параметром. Вычислите весовой коэффициент для выборки пакета обучающих данных таким образом, чтобы потеря MAE равнялась потере взвешенного содержимого.

Предварительный просмотр пакета обучающих данных, который состоит из пар входов RAW-сети и целевых выходов RGB.

trainingBatch = preview(dsTrainAug);
networkInput = dlarray((trainingBatch{1,1}),'SSC');
targetOutput = dlarray((trainingBatch{1,2}),'SSC');

Спрогнозируйте ответ нетренированной сети U-Net с помощью forward функция.

predictedOutput = forward(net,networkInput);

Вычислите MAE и потери содержимого между предсказанным и целевым изображениями RGB.

sampleMAELoss = maeLoss(predictedOutput,targetOutput);
sampleContentLoss = contentLoss(vggNet,predictedOutput,targetOutput);

Вычислите весовой коэффициент.

weightContent = sampleMAELoss/sampleContentLoss;

Настройка опций обучения

Задайте опции обучения, которые используются в пользовательском цикле обучения для управления аспектами оптимизации Адама. Обучайте на 20 эпох.

learnRate = 5e-5;
numEpochs = 20;

Обучите сеть

По умолчанию пример загружает предварительно обученную версию сети RAW-to-RGB с помощью функции helper downloadTrainedRAWToRGBNet. Функция helper присоединена к примеру как вспомогательный файл. Предварительно обученная сеть позволяет запускать весь пример, не дожидаясь завершения обучения.

Чтобы обучить сеть, установите doTraining переменная в следующем коде, для true. Обучите модель в пользовательском цикле обучения. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next функция.

  • Оцените градиенты модели с помощью dlfeval функции и modelGradients вспомогательная функция.

  • Обновляйте параметры сети с помощью adamupdate функция и информация о градиенте.

  • Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.

Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения дополнительной информации смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Обучение занимает около 88 часов на NVIDIA™ Titan RTX и может занять еще больше времени в зависимости от оборудования графического процессора.

doTraining = false;
if doTraining
    
    % Create a directory to store checkpoints
    checkpointDir = fullfile(imageDir,'checkpoints',filesep);
    if ~exist(checkpointDir,'dir')
        mkdir(checkpointDir);
    end
    
    % Initialize training plot
    [hFig,batchLine,validationLine] = initializeTrainingPlotRAWPipeline;
    
    % Initialize Adam solver state
    [averageGrad,averageSqGrad] = deal([]);
    iteration = 0;
    
    start = tic;
    for epoch = 1:numEpochs
        reset(trainingQueue);
        shuffle(trainingQueue);
        while hasdata(trainingQueue)
            [inputRAW,targetRGB] = next(trainingQueue);  
            
            [grad,loss] = dlfeval(@modelGradients,net,vggNet,inputRAW,targetRGB,weightContent);
            
            iteration = iteration + 1;
            
            [net,averageGrad,averageSqGrad] = adamupdate(net, ...
                grad,averageGrad,averageSqGrad,iteration,learnRate);
              
            updateTrainingPlotRAWPipeline(batchLine,validationLine,iteration,loss,start,epoch,...
                validationQueue,valSetSize,valBatchSize,net,vggNet,weightContent);
        end
        % Save checkpoint of network state
        save(checkpointDir + "epoch" + epoch,'net');
    end
    % Save the final network state
    save(checkpointDir + "trainedRAWToRGBNet.mat",'net');
    
else
    trainedRAWToRGBNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedRAWToRGBNet.mat';
    downloadTrainedRAWToRGBNet(trainedRAWToRGBNet_url,imageDir);
    load(fullfile(imageDir,'trainedRAWToRGBNet.mat'));
end

Вычислите метрики качества изображений

Основанные на ссылках метрики качества, такие как MSSIM или PSNR, обеспечивают количественную меру качества изображения. Можно вычислить MSSIM и PSNR исправленных тестовых изображений, поскольку они пространственно зарегистрированы и имеют одинаковый размер.

Выполните итерацию тестового набора исправленных изображений с помощью minibatchqueue объект.

patchTestSet = combine(dsTestRAW,dsTestRGB);
testPatchQueue = minibatchqueue(patchTestSet,'MiniBatchSize',16,'MiniBatchFormat','SSCB');

Выполните итерацию тестового набора и вычислите MSSIM и PSNR для каждого тестового изображения с помощью multissim (Image Processing Toolbox) и psnr (Image Processing Toolbox) функции. Хотя функции принимают изображения RGB, метрики не четко определены для изображений RGB. Поэтому аппроксимируйте MSSIM и PSNR путем вычисления метрики цветовых каналов отдельно. Можно использовать calculateRAWToRGBQualityMetrics вспомогательная функция для вычисления метрик по каналам. Эта функция присоединена к примеру как вспомогательный файл.

totalMSSIM = 0;
totalPSNR = 0;
while hasdata(testPatchQueue)
    [inputRAW,targetRGB] = next(testPatchQueue);
    outputRGB = forward(net,inputRAW);
    [mssimOut,psnrOut] = calculateRAWToRGBQualityMetrics(outputRGB,targetRGB);
    totalMSSIM = totalMSSIM + mssimOut;
    totalPSNR = totalPSNR + psnrOut;
end

Вычислите среднее значение MSSIM и среднее значение PSNR по тестовому набору. Этот результат согласуется с аналогичным подходом U-Net от [3] для среднего MSSIM и конкурентоспособен с подходом PyNet в [3] для среднего PSNR. Различия в функциях потерь и использовании пиксельного тасования с повышенной дискретизацией по сравнению с [3], вероятно, объясняют эти различия.

numObservations = dsTestRGB.numpartitions;
meanMSSIM = totalMSSIM / numObservations
meanMSSIM = single
    0.8534
meanPSNR = totalPSNR / numObservations
meanPSNR = 21.2956

Оценка обученного трубопровода обработки изображений на полноразмерных изображениях

Из-за различий между камерой телефона и DSLR, используемой для получения тестовых изображений с полным разрешением, сцены не регистрируются и не имеют одинакового размера. Основанное на ссылках сравнение изображений полного разрешения из сети и DSLR ISP трудно. Однако качественное сравнение изображений полезно, потому что целью обработки изображений является создание эстетически приятного изображения.

Создайте изображение datastore, который содержит полноразмерные RAW- изображений, полученные телефонной камерой.

testImageDir = fullfile(imageDir,'test');
testImageDirRAW = "huawei_full_resolution";
dsTestFullRAW = imageDatastore(fullfile(testImageDir,testImageDirRAW));

Получите имена файлов изображений в полноразмерном тестовом наборе RAW.

targetFilesToInclude = extractAfter(string(dsTestFullRAW.Files),fullfile(testImageDirRAW,filesep));
targetFilesToInclude = extractBefore(targetFilesToInclude,".png");

Предварительно обработайте данные RAW путем преобразования данных в форму, ожидаемую сетью с помощью transform функция. The transform функция обрабатывает данные с помощью операций, заданных в preprocessRAWDataForRAWToRGB вспомогательная функция. Функция helper присоединена к примеру как вспомогательный файл.

dsTestFullRAW = transform(dsTestFullRAW,@preprocessRAWDataForRAWToRGB);

Создайте image datastore, который содержит полноразмерные тестовые изображения RGB, полученные из DSLR высшего класса. Набор данных Zurich RAW to RGB содержит более полноразмерные изображения RGB, чем изображения RAW, поэтому включают только изображения RGB с соответствующим изображением RAW.

dsTestFullRGB = imageDatastore(fullfile(imageDir,'full_resolution','canon'));
dsTestFullRGB.Files = dsTestFullRGB.Files(contains(dsTestFullRGB.Files,targetFilesToInclude));

Чтение в целевых изображениях RGB. Почувствуйте общий выход, посмотрев на вид монтажа.

targetRGB = readall(dsTestFullRGB);
montage(targetRGB,"Size",[5 2],"Interpolation","bilinear")

Выполните итерацию тестового набора полноразмерных изображений с помощью minibatchqueue объект. Если у вас есть устройство GPU с достаточной памятью для обработки изображений с полным разрешением, то можно запустить предварительное обучение на графическом процессоре, задав выход окружения следующим образом "gpu".

testQueue = minibatchqueue(dsTestFullRAW,"MiniBatchSize",1, ...
    "MiniBatchFormat","SSCB","OutputEnvironment","cpu");

Для каждого полноразмерного тестового изображения RAW спрогнозируйте выходное изображение RGB по вызову forward в сети.

outputSize = 2*size(preview(dsTestFullRAW),[1 2]);
outputImages = zeros([outputSize,3,dsTestFullRAW.numpartitions],'uint8');

idx = 1;
while hasdata(testQueue)
    inputRAW = next(testQueue);
    rgbOut = forward(net,inputRAW);
    rgbOut = gather(extractdata(rgbOut));    
    outputImages(:,:,:,idx) = im2uint8(rgbOut);
    idx = idx+1;
end

Почувствуйте общий выход, посмотрев на вид монтажа. Сеть производит изображения, которые эстетически приятны, со схожими характеристиками.

montage(outputImages,"Size",[5 2],"Interpolation","bilinear")

Сравните одно целевое изображение RGB с соответствующим изображением, предсказанным сетью. Сеть производит цвета, которые более насыщены, чем целевые изображения DSLR. Несмотря на то, что цвета от простой архитектуры U-Net не совпадают с целями DSLR, изображения во многих случаях по-прежнему качественно радуют.

imgIdx = 1;
imTarget = targetRGB{imgIdx};
imPredicted = outputImages(:,:,:,imgIdx);
figure
montage({imTarget,imPredicted},"Interpolation","bilinear")

Для повышения эффективности сети RAW в RGB сетевая архитектура изучит подробные локализованные пространственные функции с использованием нескольких шкал от глобальных функций, которые описывают цвет и контрастность [3].

Вспомогательные функции

Функция градиентов модели

The modelGradients Функция helper вычисляет градиенты и общие потери. Информация о градиенте возвращается как таблица, которая включает слой, имя параметра и значение для каждого настраиваемого параметра в модели.

function [gradients,loss] = modelGradients(dlnet,vggNet,Xpatch,Target,weightContent)
    Y = forward(dlnet,Xpatch);
    lossMAE = maeLoss(Y,Target);
    lossContent = contentLoss(vggNet,Y,Target);
    loss = lossMAE + weightContent.*lossContent;
    gradients = dlgradient(loss,dlnet.Learnables);
end

Средняя функция абсолютной потери ошибок

Функция помощника maeLoss вычисляет среднюю абсолютную ошибку между сетевыми предсказаниями, Y, и целевые изображения, T.

function loss = maeLoss(Y,T)
    loss = mean(abs(Y-T),'all');
end

Функция потери содержимого

Функция помощника contentLoss вычисляет взвешенную сумму MSE между сетевыми предсказаниями, Y, и целевые изображения, T, для каждого слоя активации. The contentLoss функция helper вычисляет MSE для каждого слоя активации с помощью mseLoss вспомогательная функция. Веса выбираются такими, чтобы потери от каждых слоев активации вносили примерно одинаковый вклад в общую потерю содержимого.

function loss = contentLoss(net,Y,T)

    layers = ["relu1_1","relu1_2","relu2_1","relu2_2","relu3_1","relu3_2","relu3_3","relu4_1"];
    [T1,T2,T3,T4,T5,T6,T7,T8] = forward(net,T,'Outputs',layers);
    [X1,X2,X3,X4,X5,X6,X7,X8] = forward(net,Y,'Outputs',layers);
    
    l1 = mseLoss(X1,T1);
    l2 = mseLoss(X2,T2);
    l3 = mseLoss(X3,T3);
    l4 = mseLoss(X4,T4);
    l5 = mseLoss(X5,T5);
    l6 = mseLoss(X6,T6);
    l7 = mseLoss(X7,T7);
    l8 = mseLoss(X8,T8);
    
    layerLosses = [l1 l2 l3 l4 l5 l6 l7 l8];
    weights = [1 0.0449 0.0107 0.0023 6.9445e-04 2.0787e-04 2.0118e-04 6.4759e-04];
    loss = sum(layerLosses.*weights);  
end

Функция потери среднеквадратичной ошибки

Функция помощника mseLoss вычисляет MSE между сетевыми предсказаниями, Y, и целевые изображения, T.

function loss = mseLoss(Y,T)
    loss = mean((Y-T).^2,'all');
end

Ссылки

1) Самнер, Роб. «Обработка RAW- Изображений в MATLAB». 19 мая 2014 года. https://rcsumner.net/raw_guide/RAWguide.pdf

2) Чэнь, Чэнь, Цифэн Чэнь, Цзи Сюй и Владлен Колтун. Учимся видеть в темноте. ArXiv:1805.01934 [Cs], 4 мая 2018 года. http://arxiv.org/abs/1805.01934.

3) Игнатов, Андрей, Люк Ван Голь, и Раду Тимофте. «Замена Mobile Camera ISP на одну модель глубокого обучения». ArXiv:2002.05509 [Cs, Eess], 13 февраля 2020 года. http://arxiv.org/abs/2002.05509. Веб-сайт проекта.

4) Чжао, Ханг, Орацио Галло, Иури Фрозио, и Ян Каутц. Функции потерь для нейронных сетей для обработки изображений. ArXiv:1511.08861 [Cs], 20 апреля 2018 года. http://arxiv.org/abs/1511.08861.

5) Джонсон, Джастин, Александр Алахи и Ли Фэй-Фэй. «Восприятие потерь для передачи стиля в реальном времени и суперразрешение». ArXiv:1603.08155 [Cs], 26 марта 2016 года. http://arxiv.org/abs/1603.08155.

6) Ши, Вэньчжэ, Хосе Кабальеро, Ференц Хуссар, Йоханнес Тоц, Эндрю П. Эйткен, Роб Бишоп, Даниэль Рюкерт, и Зехан Ван. «Одно изображение и видео суперразрешение в реальном времени с использованием эффективной субпиксельной сверточной нейронной сети». ArXiv:1609.05158 [Cs, Stat], 23 сентября 2016 года. http://arxiv.org/abs/1609.05158.

См. также

| | | |

Похожие темы