Разработайте НЕОБРАБОТАННЫЙ конвейер обработки камеры Используя глубокое обучение

В этом примере показано, как преобразовать НЕОБРАБОТАННЫЕ данные о камере в эстетически приятное цветное изображение с помощью U-Net.

Цифровые однообъективные зеркальные фотоаппараты и много современных телефонных камер предлагают способность сохранить данные, собранные непосредственно от датчика камеры как НЕОБРАБОТАННЫЙ файл. Каждый пиксель Необработанных данных соответствует непосредственно на сумму света, полученного соответствующим фотодатчиком камеры. Данные зависят от фиксированных характеристик оборудования камеры, таких как чувствительность каждого фотодатчика к конкретной области значений длин волн электромагнитного спектра. Данные также зависят от настроек захвата камеры, таких как выдержка и факторы сцены, такие как источник света.

Demosaicing является единственной необходимой операцией, чтобы преобразовать одноканальные Необработанные данные в изображение RGB с тремя каналами. Однако без дополнительных операций обработки изображений, получившееся изображение RGB имеет субъективно плохое визуальное качество.

Традиционный трубопровод обработки изображений выполняет комбинацию дополнительных операций включая шумоподавление, линеаризацию, балансировку белого, коррекцию цвета, настройку яркости и контрастную корректировку [1]. Проблема разработки трубопровода находится в совершенствовании алгоритмов, чтобы оптимизировать субъективный внешний вид итогового изображения RGB независимо от изменений настроек захвата и сцене.

Методы глубокого обучения включают прямые СЫРЫЕ ДАННЫЕ к преобразованию RGB без необходимости разработки традиционного конвейера обработки. Например, один метод компенсирует недоэкспонирование при преобразовании НЕОБРАБОТАННЫХ изображений в RGB [2]. В этом примере показано, как преобразовать НЕОБРАБОТАННЫЕ изображения от низкокачественной телефонной камеры до изображений RGB, которые аппроксимируют качество камеры DSLR более высокого качества.

Загрузите Цюрихские СЫРЫЕ ДАННЫЕ на набор данных RGB

Этот пример использует Цюрихские СЫРЫЕ ДАННЫЕ для набора данных RGB [3]. Размер набора данных составляет 22 Гбайт. Набор данных содержит 48 043 пространственно зарегистрированных пары СЫРЫХ ДАННЫХ и закрашенные фигуры обучения RGB изображений размера 448 448. Набор данных содержит два отдельных набора тестов. Один набор тестов состоит из 1 204 пространственно зарегистрированных пар СЫРЫХ ДАННЫХ и закрашенных фигур RGB изображений размера 448 448. Другой набор тестов состоит из незарегистрированных СЫРЫХ ДАННЫХ полного разрешения и изображений RGB.

Создайте директорию, чтобы сохранить набор данных.

imageDir = fullfile(tempdir,'ZurichRAWToRGB');
if ~exist(imageDir,'dir')
    mkdir(imageDir);
end

Чтобы загрузить набор данных, запросите доступ с помощью Цюрихских СЫРЫХ ДАННЫХ для формы набора данных RGB. Извлеките данные в директорию, заданную imageDir переменная. Когда извлечено успешно, imageDir содержит три директории под названием full_resolutionТест, и train.

Создайте хранилища данных для обучения, валидации и тестирования

Создайте Datastore для обучающих данных закрашенной фигуры RGB изображений

Создайте imageDatastore это читает, целевые закрашенные фигуры обучения RGB изображений получили использование высокопроизводительного Canon DSLR.

trainImageDir = fullfile(imageDir,'train');
dsTrainRGB = imageDatastore(fullfile(trainImageDir,'canon'),'ReadSize',16);

Предварительно просмотрите закрашенную фигуру обучения RGB изображений.

groundTruthPatch = preview(dsTrainRGB);
imshow(groundTruthPatch)

Создайте Datastore для НЕОБРАБОТАННЫХ обучающих данных закрашенной фигуры изображений

Создайте imageDatastore это читает, закрашенные фигуры обучения входа RAW изображений получили использование камеры телефона Huawei. НЕОБРАБОТАННЫЕ изображения получены с 10-битной точностью и представлены и как 8-битные и как 16-битные файлы PNG. 8-битные файлы предоставляют компактному представлению закрашенных фигур с данными в области значений [0, 255]. Никакое масштабирование не было сделано ни на одних из Необработанных данных.

dsTrainRAW = imageDatastore(fullfile(trainImageDir,'huawei_raw'),'ReadSize',16);

Предварительно просмотрите закрашенную фигуру обучения входа RAW изображений. Datastore читает эту закрашенную фигуру как 8-битный uint8 отобразите, потому что количества датчика находятся в области значений [0, 255]. Чтобы симулировать 10-битный динамический диапазон обучающих данных, разделите значения интенсивности изображений на 4. Если вы увеличиваете масштаб изображения, то вы видите шаблон Байера RGGB.

inputPatch = preview(dsTrainRAW);
inputPatchRAW = inputPatch/4;
imshow(inputPatchRAW)

Симулировать минимальный традиционный конвейер обработки, demosaic шаблон Байера RGGB Необработанных данных с помощью demosaic функция. Отобразите обработанное изображение и украсьте отображение. По сравнению с целевым изображением RGB минимально обработанное изображение RGB является темным и имеет неустойчивые цвета и значимые артефакты. Обученная сеть RAW-to-RGB выполняет операции предварительной обработки так, чтобы изображение выхода RGB напомнило целевое изображение.

inputPatchRGB = demosaic(inputPatch,'rggb');
imshow(rescale(inputPatchRGB))

Тестовые изображения раздела в валидацию и наборы тестов

Тестовые данные содержат СЫРЫЕ ДАННЫЕ и закрашенные фигуры RGB изображений и полноразмерные изображения. Этот пример делит тестовые закрашенные фигуры изображений в набор валидации и набор тестов. Пример использует полноразмерные тестовые изображения для качественного тестирования только. Смотрите Оценивают Обученный Трубопровод Обработки изображений на Полноразмерных Изображениях.

Создайте хранилища данных изображений, которые читают СЫРЫЕ ДАННЫЕ и тестовые закрашенные фигуры RGB изображений.

testImageDir = fullfile(imageDir,'test');
dsTestRAW = imageDatastore(fullfile(testImageDir,'huawei_raw'),'ReadSize',16);
dsTestRGB = imageDatastore(fullfile(testImageDir,'canon'),'ReadSize',16);

Случайным образом разделите тестовые данные в два набора для валидации и обучения. Набор данных валидации содержит 200 изображений. Набор тестов содержит остающиеся изображения.

numTestImages = dsTestRAW.numpartitions;
numValImages = 200;

testIdx = randperm(numTestImages);
validationIdx = testIdx(1:numValImages);
testIdx = testIdx(numValImages+1:numTestImages);

dsValRAW = subset(dsTestRAW,validationIdx);
dsValRGB = subset(dsTestRGB,validationIdx);

dsTestRAW = subset(dsTestRAW,testIdx);
dsTestRGB = subset(dsTestRGB,testIdx);

Предварительно обработайте и увеличьте данные

Датчик получает цветные данные в повторяющемся шаблоне Байера, который включает один красный, два зеленых, и один синий фотодатчик. Предварительно обработайте данные в изображение с четырьмя каналами, ожидаемое сети с помощью transform функция. transform функциональные процессы данные с помощью операций заданы в preprocessRAWDataForRAWToRGB функция помощника. Функция помощника присоединена к примеру как к вспомогательному файлу.

preprocessRAWDataForRAWToRGB функция помощника преобразует H W 1 НЕОБРАБОТАННЫМ изображением к H/2-by-W/2-by-4 многоканальному изображению, состоящему из одного красного, двух зеленых, и один синий канал.

Функция также бросает данные к типу данных single масштабируемый к области значений [0, 1].

dsTrainRAW = transform(dsTrainRAW,@preprocessRAWDataForRAWToRGB);
dsValRAW = transform(dsValRAW,@preprocessRAWDataForRAWToRGB);
dsTestRAW = transform(dsTestRAW,@preprocessRAWDataForRAWToRGB);

Целевые изображения RGB хранятся на диске как 8-битные данные без знака. Чтобы сделать расчет метрик и проектирование сети более удобными, предварительно обработайте целевые изображения обучения RGB с помощью transform функционируйте и preprocessRGBDataForRAWToRGB функция помощника. Функция помощника присоединена к примеру как к вспомогательному файлу.

preprocessRGBDataForRAWToRGB функция помощника бросает изображения к типу данных single масштабируемый к области значений [0, 1].

dsTrainRGB = transform(dsTrainRGB,@preprocessRGBDataForRAWToRGB);
dsValRGB = transform(dsValRGB,@preprocessRGBDataForRAWToRGB);

Объедините вход RAW и предназначайтесь для данных о RGB для обучения, валидации, и протестируйте наборы изображений при помощи combine функция.

dsTrain = combine(dsTrainRAW,dsTrainRGB);
dsVal = combine(dsValRAW,dsValRGB);
dsTest = combine(dsTestRAW,dsTestRGB);

Случайным образом увеличьте обучающие данные с помощью transform функционируйте и augmentDataForRAWToRGB функция помощника. Функция помощника присоединена к примеру как к вспомогательному файлу.

augmentDataForRAWToRGB функция помощника случайным образом применяет 90 вращений степени и горизонтальное отражение к парам входа RAW и целевых изображений обучения RGB.

dsTrainAug = transform(dsTrain,@augmentDataForRAWToRGB);

Предварительно просмотрите увеличенные обучающие данные.

exampleAug = preview(dsTrainAug)
exampleAug=8×2 cell array
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}

Отобразите сетевой вход и целевое изображение в монтаже. Сетевой вход имеет четыре канала, так отобразите первый канал, перемасштабированный к области значений [0, 1]. Вход RAW и целевые изображения RGB имеют идентичное увеличение.

exampleInput = exampleAug{1,1};
exampleOutput = exampleAug{1,2};
montage({rescale(exampleInput(:,:,1)),exampleOutput})

Обработайте данные об обучении и валидации в пакетном режиме во время обучения

Этот пример использует пользовательский учебный цикл. minibatchqueue Объект (Deep Learning Toolbox) полезен для управления мини-пакетная обработка наблюдений в пользовательских учебных циклах. minibatchqueue возразите также бросает данные к dlarray Объект (Deep Learning Toolbox), который включает автоматическое дифференцирование в применении глубокого обучения.

miniBatchSize = 12;
valBatchSize = 10;
trainingQueue = minibatchqueue(dsTrainAug,'MiniBatchSize',miniBatchSize,'PartialMiniBatch','discard','MiniBatchFormat','SSCB');
validationQueue = minibatchqueue(dsVal,'MiniBatchSize',valBatchSize,'MiniBatchFormat','SSCB');

next (Deep Learning Toolbox) функция minibatchqueue дает к следующему мини-пакету данных. Предварительно просмотрите выходные параметры от одного вызова до next функция. Выходные параметры имеют тип данных dlarray. Данные уже брошены к gpuArray, на графическом процессоре, и готовый к обучению.

[inputRAW,targetRGB] = next(trainingQueue);
whos inputRAW
  Name            Size                    Bytes  Class      Attributes

  inputRAW      224x224x4x12            9633800  dlarray              
whos targetRGB
  Name             Size                     Bytes  Class      Attributes

  targetRGB      448x448x3x12            28901384  dlarray              

Настройте слоя сети U-Net

Этот пример использует изменение сети U-Net. В U-Net начальные серии сверточных слоев вкраплены макс. слоями объединения, последовательно уменьшив разрешение входного изображения. Эти слои сопровождаются серией сверточных слоев, вкрапленных повышающей дискретизацией операторов, последовательно увеличивая разрешение входного изображения. U-Net имени прибывает из того, что сеть может чертиться с симметричной формой как буква U.

Этот пример использует простую архитектуру U-Net с двумя модификациями. Во-первых, сеть заменяет транспонированную операцию свертки финала на пользовательскую пиксельную повышающую дискретизацию перестановки (также известный как глубину к пробелу) операция. Во-вторых, сеть использует пользовательский гиперболический слой активации касательной в качестве последнего слоя в сети.

Пиксельная повышающая дискретизация перестановки

Свертка, сопровождаемая пиксельной повышающей дискретизацией перестановки, может задать субпиксельную свертку для супер приложений разрешения. Субпиксельная свертка предотвращает checkboard артефакты, которые могут явиться результатом транспонированной свертки [6]. Поскольку модель должна сопоставить H/2-by-W/2-by-4 НЕОБРАБОТАННЫЕ входные параметры с W H 3 RGB выходные параметры, итоговый этап повышающей дискретизации модели может думаться так же к супер разрешению, где количество пространственных выборок растет от входа до выхода.

Рисунок показывает, как пиксельная повышающая дискретизация перестановки работает на 2 2 4 входами. Первые две размерности являются пространственными размерностями, и третья размерность является размерностью канала. В общем случае пиксельная повышающая дискретизация перестановки на коэффициент S берет вход H by W by C и дает к S*H-by-S*W-by-CS2 вывод .

Пиксельная функция перестановки выращивает пространственные размерности выхода путем отображения информации от размерностей канала в данном пространственном местоположении в S-by-S пространственные блоки в выходе, в котором каждый канал способствует сопоставимому пространственному положению относительно своих соседей во время повышающей дискретизации.

Масштабируемая и гиперболическая активация касательной

Гиперболический слой активации касательной применяет tanh функция на входных параметрах слоя. Этот пример использует масштабированную и переключенную версию tanh функция, которая поощряет, но строго не осуществляет это сеть RGB выходные параметры, находится в области значений [0, 1].

f(x)=0.58*tanh(x)+0.5

Вычислите статистику набора обучающих данных для входной нормализации

Используйте tall вычислить среднее сокращение на канал через обучающий набор данных. Входной слой сети выполняет среднее центрирование входных параметров во время обучения и тестирования использования средней статистики.

dsIn = copy(dsTrainRAW);
dsIn.UnderlyingDatastore.ReadSize = 1;
t = tall(dsIn);
perChannelMean = gather(mean(t,[1 2]));
Evaluating tall expression using the Parallel Pool 'LocalProfile12':
- Pass 1 of 1: 0% complete
Evaluation 0% complete

- Pass 1 of 1: Completed in 1 min 1 sec
Evaluation completed in 1 min 2 sec

Создайте U-Net

Создайте слои начальной подсети, задав среднее значение на канал.

inputSize = [256 256 4];
initialLayer = imageInputLayer(inputSize,"Normalization","zerocenter","Mean",perChannelMean, ...
    "Name","ImageInputLayer");

Добавьте слои первой подсети кодирования. Первый энкодер имеет 32 сверточных фильтра.

numEncoderStages = 4;
numFiltersFirstEncoder = 32;
encoderNamePrefix = "Encoder-Stage-";

encoderLayers = [
    convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name",strcat(encoderNamePrefix,"1-Conv-1"))
    leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-1"))
    convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name",strcat(encoderNamePrefix,"1-Conv-2"))
    leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-2"))
    maxPooling2dLayer([2 2],"Stride",[2 2], ...
        "Name",strcat(encoderNamePrefix,"1-MaxPool"));  
    ];

Добавьте слои дополнительных подсетей кодирования. Эти подсети добавляют мудрую каналом нормализацию экземпляра после каждого сверточного слоя с помощью groupNormalizationLayer. Каждая подсеть энкодера имеет дважды количество фильтров как предыдущая подсеть энкодера.

cnIdx = 1;
for stage = 2:numEncoderStages
    
    numFilters = numFiltersFirstEncoder*2^(stage-1);
    layerNamePrefix = strcat(encoderNamePrefix,num2str(stage));
    
    encoderLayers = [
        encoderLayers
        convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-Conv-1"))
        groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx)))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1"))
        convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-Conv-2"))
        groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1)))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))
        maxPooling2dLayer([2 2],"Stride",[2 2],"Name",strcat(layerNamePrefix,"-MaxPool"))
        ];     
    
    cnIdx = cnIdx + 2;
end

Добавьте мостоукладчики. Подсеть моста имеет дважды количество фильтров как итоговая подсеть энкодера и первая подсеть декодера.

numFilters = numFiltersFirstEncoder*2^numEncoderStages;
bridgeLayers = [
    convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name","Bridge-Conv-1")
    groupNormalizationLayer("channel-wise","Name","cn7")
    leakyReluLayer(0.2,"Name","Bridge-ReLU-1")
    convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name","Bridge-Conv-2")
    groupNormalizationLayer("channel-wise","Name","cn8")
    leakyReluLayer(0.2,"Name","Bridge-ReLU-2")];

Добавьте слои первых трех подсетей декодера.

numDecoderStages = 4;
cnIdx = 9;
decoderNamePrefix = "Decoder-Stage-";

decoderLayers = [];
for stage = 1:numDecoderStages-1
    
    numFilters = numFiltersFirstEncoder*2^(numDecoderStages-stage);
    layerNamePrefix = strcat(decoderNamePrefix,num2str(stage));  
    
    decoderLayers = [
        decoderLayers
        transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-UpConv"))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU"))
        depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation"))
        convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-Conv-1"))
        groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx)))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1"))
        convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
            "Name",strcat(layerNamePrefix,"-Conv-2"))
        groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1)))
        leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))
        ];        
    
    cnIdx = cnIdx + 2;    
end

Добавьте слои последней подсети декодера. Эта подсеть исключает мудрую каналом нормализацию экземпляра, выполняемую другими подсетями декодера. Каждая подсеть декодера имеет половину количества фильтров как предыдущая подсеть.

numFilters = numFiltersFirstEncoder;
layerNamePrefix = strcat(decoderNamePrefix,num2str(stage+1)); 

decoderLayers = [
    decoderLayers
    transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ...
       "Name",strcat(layerNamePrefix,"-UpConv"))
    leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU"))
    depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation"))
    convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name",strcat(layerNamePrefix,"-Conv-1"))
    leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1"))
    convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ...
        "Name",strcat(layerNamePrefix,"-Conv-2"))
    leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))];

Добавьте последние слои U-Net. Пиксельный слой перестановки перемещается от H/2-by-W/2-by-12 размер канала активаций от итоговой свертки до H W 3 активациями канала с помощью пиксельной повышающей дискретизации перестановки. Последний слой призывает выходные параметры к желаемой области значений [0, 1] использование гиперболической функции тангенса.

finalLayers = [
    convolution2dLayer([3 3],12,"Padding","same","WeightsInitializer","narrow-normal", ...
       "Name","Decoder-Stage-4-Conv-3")
    pixelShuffleLayer("pixelShuffle",2)
    tanhScaledAndShiftedLayer("tanhActivation")];

layers = [initialLayer;encoderLayers;bridgeLayers;decoderLayers;finalLayers];
lgraph = layerGraph(layers);

Соедините слои кодирования и декодирования подсетей.

lgraph = connectLayers(lgraph,"Encoder-Stage-1-ReLU-2","Decoder-Stage-4-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-2-ReLU-2","Decoder-Stage-3-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-3-ReLU-2","Decoder-Stage-2-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-4-ReLU-2","Decoder-Stage-1-DepthConcatenation/in2");
net = dlnetwork(lgraph);

Визуализируйте сетевую архитектуру с помощью Deep Network Designer (Deep Learning Toolbox) приложение.

%deepNetworkDesigner(lgraph)

Загрузите сеть извлечения признаков

Эта функция изменяет предварительно обученную глубокую нейронную сеть VGG-16, чтобы извлечь функции изображений на различных слоях. Эти многоуровневые функции используются для расчета потеря содержимого.

Чтобы получить предварительно обученную сеть VGG-16, установите vgg16 (Deep Learning Toolbox). Если вам не установили необходимый пакет поддержки, то программное обеспечение обеспечивает ссылку на загрузку.

vggNet = load('vgg16');
vggNet = vggNet.net;
%vggNet = vgg16;

Чтобы сделать сеть VGG-16 подходящей для извлечения признаков, используйте слои до 'relu5_3'.

vggNet = vggNet.Layers(1:31);
vggNet = dlnetwork(layerGraph(vggNet));

Градиенты модели Define и функции потерь

Функция помощника modelGradients вычисляет градиенты и полную потерю для пакетов обучающих данных. Эта функция задана в разделе Supporting Functions этого примера.

Полная потеря является взвешенной суммой двух потерь: среднее значение абсолютной погрешности (MAE) потеря и потеря содержимого. Потеря содержимого взвешивается таким образом, что потеря MAE и потеря содержимого способствуют приблизительно одинаково полной потере:

lossOverall=lossMAE+weightFactor*lossContent

Потеря MAE штрафует L1 расстояние между выборками сетевых предсказаний и выборками целевого изображения. L1 часто лучший выбор, чем L2 для приложений для обработки изображений, потому что это может помочь уменьшать размывающиеся артефакты [4]. Эта потеря реализована с помощью maeLoss функция, определяемая помощника в разделе Supporting Functions этого примера.

Потеря содержимого помогает сети изучить и высокоуровневое структурное содержимое и низкоуровневое ребро и информацию о цвете. Функция потерь вычисляет взвешенную сумму среднеквадратичной погрешности (MSE) между предсказаниями и целями для каждого слоя активации. Эта потеря реализована с помощью contentLoss функция, определяемая помощника в разделе Supporting Functions этого примера.

Вычислите весовой коэффициент содержимого потерь

modelGradients функция помощника требует весового коэффициента содержимого потерь как входного параметра. Вычислите весовой коэффициент для демонстрационного пакета обучающих данных, таким образом, что потеря MAE равна взвешенной потере содержимого.

Предварительно просмотрите пакет обучающих данных, который состоит из пар входных параметров сети RAW, и RGB предназначаются для выходных параметров.

trainingBatch = preview(dsTrainAug);
networkInput = dlarray((trainingBatch{1,1}),'SSC');
targetOutput = dlarray((trainingBatch{1,2}),'SSC');

Предскажите ответ нетренированной сети U-Net с помощью forward (Deep Learning Toolbox) функция.

predictedOutput = forward(net,networkInput);

Вычислите MAE и потери содержимого между предсказанными и целевыми изображениями RGB.

sampleMAELoss = maeLoss(predictedOutput,targetOutput);
sampleContentLoss = contentLoss(vggNet,predictedOutput,targetOutput);

Вычислите весовой коэффициент.

weightContent = sampleMAELoss/sampleContentLoss;

Задайте опции обучения

Задайте опции обучения, которые используются в пользовательском учебном цикле, чтобы управлять аспектами оптимизации Адама. Обучайтесь в течение 20 эпох.

learnRate = 5e-5;
numEpochs = 20;

Обучение сети

По умолчанию пример загружает предварительно обученную версию сети RAW-to-RGB при помощи функции помощника downloadTrainedRAWToRGBNet. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.

Чтобы обучить сеть, установите doTraining переменная в следующем коде к true. Обучите модель в пользовательском учебном цикле. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next (Deep Learning Toolbox) функция.

  • Оцените градиенты модели с помощью dlfeval (Deep Learning Toolbox) функция и modelGradients функция помощника.

  • Обновите сетевые параметры с помощью adamupdate (Deep Learning Toolbox) функция и информация о градиенте.

  • Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Обучение занимает приблизительно 88 часов на Титане NVIDIA™ RTX и может взять еще дольше в зависимости от вашего оборудования графического процессора.

doTraining = false;
if doTraining
    
    % Create a directory to store checkpoints
    checkpointDir = fullfile(imageDir,'checkpoints',filesep);
    if ~exist(checkpointDir,'dir')
        mkdir(checkpointDir);
    end
    
    % Initialize training plot
    [hFig,batchLine,validationLine] = initializeTrainingPlotRAWPipeline;
    
    % Initialize Adam solver state
    [averageGrad,averageSqGrad] = deal([]);
    iteration = 0;
    
    start = tic;
    for epoch = 1:numEpochs
        reset(trainingQueue);
        shuffle(trainingQueue);
        while hasdata(trainingQueue)
            [inputRAW,targetRGB] = next(trainingQueue);  
            
            [grad,loss] = dlfeval(@modelGradients,net,vggNet,inputRAW,targetRGB,weightContent);
            
            iteration = iteration + 1;
            
            [net,averageGrad,averageSqGrad] = adamupdate(net, ...
                grad,averageGrad,averageSqGrad,iteration,learnRate);
              
            updateTrainingPlotRAWPipeline(batchLine,validationLine,iteration,loss,start,epoch,...
                validationQueue,numValImages,valBatchSize,net,vggNet,weightContent);
        end
        % Save checkpoint of network state
        save(checkpointDir + "epoch" + epoch,'net');
    end
    % Save the final network state
    save(checkpointDir + "trainedRAWToRGBNet.mat",'net');
    
else
    trainedRAWToRGBNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedRAWToRGBNet.mat';
    downloadTrainedRAWToRGBNet(trainedRAWToRGBNet_url,imageDir);
    load(fullfile(imageDir,'trainedRAWToRGBNet.mat'));
end
Pretrained RAW-to-RGB network already exists.

Вычислите метрики качества изображения

Основанные на ссылке метрики качества, такие как MSSIM или PSNR включают количественные показатели качества изображения. Можно вычислить MSSIM и PSNR исправленных тестовых изображений, потому что они пространственно указаны и тот же размер.

Выполните итерации через набор тестов исправленных изображений с помощью minibatchqueue объект.

patchTestSet = combine(dsTestRAW,dsTestRGB);
testPatchQueue = minibatchqueue(patchTestSet,'MiniBatchSize',16,'MiniBatchFormat','SSCB');

Выполните итерации через набор тестов и вычислите MSSIM и PSNR для каждого тестового изображения с помощью multissim и psnr функции. Вычислите MSSIM для цветных изображений при помощи среднего значения метрики для каждого цветового канала как приближение, поскольку метрика не четко определена для многоканальных входных параметров.

totalMSSIM = 0;
totalPSNR = 0;
while hasdata(testPatchQueue)
    [inputRAW,targetRGB] = next(testPatchQueue);
    outputRGB = forward(net,inputRAW);
    targetRGB = targetRGB ./ 255; 
    mssimOut = sum(mean(multissim(outputRGB,targetRGB),3),4);
    psnrOut = sum(psnr(outputRGB,targetRGB),4);
    totalMSSIM = totalMSSIM + mssimOut;
    totalPSNR = totalPSNR + psnrOut;
end

Вычислите средний MSSIM и означайте PSNR по набору тестов. Этот результат сопоставим с подобным подходом U-Net от [3] для среднего MSSIM и конкурентоспособен по отношению к подходу PyNet в [3] в среднем PSNR. Различия в функциях потерь и использовании пиксельной повышающей дискретизации перестановки по сравнению с [3] вероятный счет на эти различия.

numObservations = dsTestRGB.numpartitions;
meanMSSIM = totalMSSIM / numObservations
meanMSSIM = 
  1(S) × 1(S) × 1(C) × 1(B) single dlarray

    0.8425

meanPSNR = totalPSNR / numObservations
meanPSNR = 
  1(S) × 1(S) × 1(C) × 1(B) single dlarray

   21.1213

Оцените обученный трубопровод обработки изображений на полноразмерных изображениях

Из-за различий в датчике между телефонной камерой и DSLR, используемым, чтобы получить тестовые изображения полного разрешения, сцены не указаны и не являются тем же размером. Основанное на ссылке сравнение изображений полного разрешения от сети и DSLR ISP затрудняет. Однако качественное сравнение изображений полезно, потому что цель обработки изображений состоит в том, чтобы создать эстетически приятное изображение.

Создайте datastore изображений, который содержит полноразмерные НЕОБРАБОТАННЫЕ изображения, полученные телефонной камерой.

testImageDir = fullfile(imageDir,'test');
testImageDirRAW = "huawei_full_resolution";
dsTestFullRAW = imageDatastore(fullfile(testImageDir,testImageDirRAW));

Получите имена файлов изображений в полноразмерном НЕОБРАБОТАННОМ наборе тестов.

targetFilesToInclude = extractAfter(string(dsTestFullRAW.Files),fullfile(testImageDirRAW,filesep));
targetFilesToInclude = extractBefore(targetFilesToInclude,".png");

Предварительно обработайте Необработанные данные путем преобразования данных в форму, ожидаемую сетью с помощью transform функция. transform функциональные процессы данные с помощью операций заданы в preprocessRAWDataForRAWToRGB функция помощника. Функция помощника присоединена к примеру как к вспомогательному файлу.

dsTestFullRAW = transform(dsTestFullRAW,@preprocessRAWDataForRAWToRGB);

Создайте datastore изображений, который содержит полноразмерные тестовые изображения RGB, полученные от высокопроизводительного DSLR. Цюрихский набор данных RAW-to-RGB содержит больше полноразмерных изображений RGB, чем НЕОБРАБОТАННЫЕ изображения, поэтому включайте только изображения RGB с соответствующим НЕОБРАБОТАННЫМ изображением.

dsTestFullRGB = imageDatastore(fullfile(imageDir,'full_resolution','canon'));
dsTestFullRGB.Files = dsTestFullRGB.Files(contains(dsTestFullRGB.Files,targetFilesToInclude));

Читайте в целевых изображениях RGB. Получите смысл полного выхода путем рассмотрения представления монтажа.

targetRGB = readall(dsTestFullRGB);
montage(targetRGB,"Size",[5 2],"Interpolation","bilinear")

Выполните итерации через набор тестов полноразмерных изображений с помощью minibatchqueue объект. Если у вас есть устройство графического процессора с достаточной памятью, чтобы обработать изображения полного разрешения, то можно запустить предсказание на графическом процессоре путем определения выходной среды как "gpu".

testQueue = minibatchqueue(dsTestFullRAW,"MiniBatchSize",1, ...
    "MiniBatchFormat","SSCB","OutputEnvironment","cpu");

Для каждого полноразмерного НЕОБРАБОТАННОГО тестового изображения предскажите изображение выхода RGB путем вызова forward (Deep Learning Toolbox) в сети.

outputSize = 2*size(preview(dsTestFullRAW),[1 2]);
outputImages = zeros([outputSize,3,dsTestFullRAW.numpartitions],'uint8');

idx = 1;
while hasdata(testQueue)
    inputRAW = next(testQueue);
    rgbOut = forward(net,inputRAW);
    rgbOut = gather(extractdata(rgbOut));    
    outputImages(:,:,:,idx) = im2uint8(rgbOut);
    idx = idx+1;
end

Получите смысл полного выхода путем рассмотрения представления монтажа. Сеть производит изображения, которые эстетически приятны с подобными характеристиками.

montage(outputImages,"Size",[5 2],"Interpolation","bilinear")

Сравните одно целевое изображение RGB с соответствующим изображением, предсказанным сетью. Сеть производит цвета, которые более насыщаются, чем целевые изображения DSLR. Несмотря на то, что цвета от простой архитектуры U-Net различные, когда DSLR предназначается, изображения все еще качественно приятны во многих случаях.

imgIdx = 1;
imTarget = targetRGB{imgIdx};
imPredicted = outputImages(:,:,:,imgIdx);
figure
montage({imTarget,imPredicted},"Interpolation","bilinear")

Чтобы улучшать производительность сети RAW-to-RGB, сетевая архитектура училась бы подробный, локализовал пространственные функции с помощью нескольких шкал от глобальных функций, которые описывают цвет и контраст [3].

Вспомогательные Функции

Функция градиентов модели

modelGradients функция помощника вычисляет градиенты и полную потерю. Информация о градиенте возвращена как таблица, которая включает слой, название параметра и значение для каждого настраиваемого параметра в модели.

function [gradients,loss] = modelGradients(dlnet,vggNet,Xpatch,Target,weightContent)
    Y = forward(dlnet,Xpatch);
    lossMAE = maeLoss(Y,Target);
    lossContent = contentLoss(vggNet,Y,Target);
    loss = lossMAE + weightContent.*lossContent;
    gradients = dlgradient(loss,dlnet.Learnables);
end

Средняя функция потерь абсолютной погрешности

Функция помощника maeLoss вычисляет среднюю абсолютную погрешность между сетевыми предсказаниями, Y, и целевые изображения, T.

function loss = maeLoss(Y,T)
    loss = mean(abs(Y-T),'all');
end

Функция потерь содержимого

Функция помощника contentLoss вычисляет взвешенную сумму MSE между сетевыми предсказаниями, Y, и целевые изображения, T, для каждого слоя активации. contentLoss функция помощника вычисляет MSE для каждого слоя активации с помощью mseLoss функция помощника. Веса выбраны таким образом, что потеря от каждой активации слои способствует примерно одинаково полной потере содержимого.

function loss = contentLoss(net,Y,T)

    layers = ["relu1_1","relu1_2","relu2_1","relu2_2","relu3_1","relu3_2","relu3_3","relu4_1"];
    [T1,T2,T3,T4,T5,T6,T7,T8] = forward(net,T,'Outputs',layers);
    [X1,X2,X3,X4,X5,X6,X7,X8] = forward(net,Y,'Outputs',layers);
    
    l1 = mseLoss(X1,T1);
    l2 = mseLoss(X2,T2);
    l3 = mseLoss(X3,T3);
    l4 = mseLoss(X4,T4);
    l5 = mseLoss(X5,T5);
    l6 = mseLoss(X6,T6);
    l7 = mseLoss(X7,T7);
    l8 = mseLoss(X8,T8);
    
    layerLosses = [l1 l2 l3 l4 l5 l6 l7 l8];
    weights = [1 0.0449 0.0107 0.0023 6.9445e-04 2.0787e-04 2.0118e-04 6.4759e-04];
    loss = sum(layerLosses.*weights);  
end

Функция потерь среднеквадратичной погрешности

Функция помощника mseLoss вычисляет MSE между сетевыми предсказаниями, Y, и целевые изображения, T.

function loss = mseLoss(Y,T)
    loss = mean((Y-T).^2,'all');
end

Ссылки

1) Самнер, Ограбить. "Обрабатывая НЕОБРАБОТАННЫЕ Изображения в MATLAB". 19 мая 2014. https://rcsumner.net/raw_guide/RAWguide.pdf.

2) Чэнь, Чэнь, Цифэн Чэнь, Цзя Сюй и Владлен Кольтун. “Учась Видеть в темноте”. ArXiv:1805.01934 [Cs], 4 мая 2018. http://arxiv.org/abs/1805.01934.

3) Игнатов, Андрей, Люк Ван Гул и Рэду Тимофт. “Заменяя Мобильную Камеру ISP на Одну Модель Глубокого обучения”. ArXiv:2002.05509 [Cs, Eess], 13 февраля 2020. http://arxiv.org/abs/2002.05509. Веб-сайт проекта.

4) Чжао, Висите, Орацио Галло, Юри Фрозьо и Ян Коц. “Функции потерь для Нейронных сетей для Обработки изображений”. ArXiv:1511.08861 [Cs], 20 апреля 2018. http://arxiv.org/abs/1511.08861.

5) Джонсон, Джастин, Александр Алахи и Ли Фэй-Фэй. “Перцепционные Потери для Передачи Стиля В реальном времени и Суперразрешения”. ArXiv:1603.08155 [Cs], 26 марта 2016. http://arxiv.org/abs/1603.08155.

6) Ши, Wenzhe, Хосе Кэбаллеро, Ференц Хусзар, Джоханнс Тоц, Эндрю П. Эйткен, Грабит епископа, Дэниела Руекерта, и Зехэна Вана. “Одно Изображение в реальном времени и Видео Суперразрешение Используя Эффективную Субпиксельную Сверточную нейронную сеть”. ArXiv:1609.05158 [Cs, Статистика], 23 сентября 2016. http://arxiv.org/abs/1609.05158.

Смотрите также

| (Deep Learning Toolbox) | (Deep Learning Toolbox) | |

Связанные примеры

Больше о