В этом примере показано, как преобразовать НЕОБРАБОТАННЫЕ данные о камере в эстетически приятное цветное изображение с помощью U-Net.
Цифровые однообъективные зеркальные фотоаппараты и много современных телефонных камер предлагают способность сохранить данные, собранные непосредственно от датчика камеры как НЕОБРАБОТАННЫЙ файл. Каждый пиксель Необработанных данных соответствует непосредственно на сумму света, полученного соответствующим фотодатчиком камеры. Данные зависят от фиксированных характеристик оборудования камеры, таких как чувствительность каждого фотодатчика к конкретной области значений длин волн электромагнитного спектра. Данные также зависят от настроек захвата камеры, таких как выдержка и факторы сцены, такие как источник света.
Demosaicing является единственной необходимой операцией, чтобы преобразовать одноканальные Необработанные данные в изображение RGB с тремя каналами. Однако без дополнительных операций обработки изображений, получившееся изображение RGB имеет субъективно плохое визуальное качество.
Традиционный трубопровод обработки изображений выполняет комбинацию дополнительных операций включая шумоподавление, линеаризацию, балансировку белого, коррекцию цвета, настройку яркости и контрастную корректировку [1]. Проблема разработки трубопровода находится в совершенствовании алгоритмов, чтобы оптимизировать субъективный внешний вид итогового изображения RGB независимо от изменений настроек захвата и сцене.
Методы глубокого обучения включают прямые СЫРЫЕ ДАННЫЕ к преобразованию RGB без необходимости разработки традиционного конвейера обработки. Например, один метод компенсирует недоэкспонирование при преобразовании НЕОБРАБОТАННЫХ изображений в RGB [2]. В этом примере показано, как преобразовать НЕОБРАБОТАННЫЕ изображения от низкокачественной телефонной камеры до изображений RGB, которые аппроксимируют качество камеры DSLR более высокого качества.
Этот пример использует Цюрихские СЫРЫЕ ДАННЫЕ для набора данных RGB [3]. Размер набора данных составляет 22 Гбайт. Набор данных содержит 48 043 пространственно зарегистрированных пары СЫРЫХ ДАННЫХ и закрашенные фигуры обучения RGB изображений размера 448 448. Набор данных содержит два отдельных набора тестов. Один набор тестов состоит из 1 204 пространственно зарегистрированных пар СЫРЫХ ДАННЫХ и закрашенных фигур RGB изображений размера 448 448. Другой набор тестов состоит из незарегистрированных СЫРЫХ ДАННЫХ полного разрешения и изображений RGB.
Создайте директорию, чтобы сохранить набор данных.
imageDir = fullfile(tempdir,'ZurichRAWToRGB'); if ~exist(imageDir,'dir') mkdir(imageDir); end
Чтобы загрузить набор данных, запросите доступ с помощью Цюрихских СЫРЫХ ДАННЫХ для формы набора данных RGB. Извлеките данные в директорию, заданную imageDir
переменная. Когда извлечено успешно, imageDir
содержит три директории под названием full_resolution
Тест
, и train
.
Создайте imageDatastore
это читает, целевые закрашенные фигуры обучения RGB изображений получили использование высокопроизводительного Canon DSLR.
trainImageDir = fullfile(imageDir,'train'); dsTrainRGB = imageDatastore(fullfile(trainImageDir,'canon'),'ReadSize',16);
Предварительно просмотрите закрашенную фигуру обучения RGB изображений.
groundTruthPatch = preview(dsTrainRGB); imshow(groundTruthPatch)
Создайте imageDatastore
это читает, закрашенные фигуры обучения входа RAW изображений получили использование камеры телефона Huawei. НЕОБРАБОТАННЫЕ изображения получены с 10-битной точностью и представлены и как 8-битные и как 16-битные файлы PNG. 8-битные файлы предоставляют компактному представлению закрашенных фигур с данными в области значений [0, 255]. Никакое масштабирование не было сделано ни на одних из Необработанных данных.
dsTrainRAW = imageDatastore(fullfile(trainImageDir,'huawei_raw'),'ReadSize',16);
Предварительно просмотрите закрашенную фигуру обучения входа RAW изображений. Datastore читает эту закрашенную фигуру как 8-битный uint8
отобразите, потому что количества датчика находятся в области значений [0, 255]. Чтобы симулировать 10-битный динамический диапазон обучающих данных, разделите значения интенсивности изображений на 4. Если вы увеличиваете масштаб изображения, то вы видите шаблон Байера RGGB.
inputPatch = preview(dsTrainRAW); inputPatchRAW = inputPatch/4; imshow(inputPatchRAW)
Симулировать минимальный традиционный конвейер обработки, demosaic шаблон Байера RGGB Необработанных данных с помощью demosaic
функция. Отобразите обработанное изображение и украсьте отображение. По сравнению с целевым изображением RGB минимально обработанное изображение RGB является темным и имеет неустойчивые цвета и значимые артефакты. Обученная сеть RAW-to-RGB выполняет операции предварительной обработки так, чтобы изображение выхода RGB напомнило целевое изображение.
inputPatchRGB = demosaic(inputPatch,'rggb');
imshow(rescale(inputPatchRGB))
Тестовые данные содержат СЫРЫЕ ДАННЫЕ и закрашенные фигуры RGB изображений и полноразмерные изображения. Этот пример делит тестовые закрашенные фигуры изображений в набор валидации и набор тестов. Пример использует полноразмерные тестовые изображения для качественного тестирования только. Смотрите Оценивают Обученный Трубопровод Обработки изображений на Полноразмерных Изображениях.
Создайте хранилища данных изображений, которые читают СЫРЫЕ ДАННЫЕ и тестовые закрашенные фигуры RGB изображений.
testImageDir = fullfile(imageDir,'test'); dsTestRAW = imageDatastore(fullfile(testImageDir,'huawei_raw'),'ReadSize',16); dsTestRGB = imageDatastore(fullfile(testImageDir,'canon'),'ReadSize',16);
Случайным образом разделите тестовые данные в два набора для валидации и обучения. Набор данных валидации содержит 200 изображений. Набор тестов содержит остающиеся изображения.
numTestImages = dsTestRAW.numpartitions; numValImages = 200; testIdx = randperm(numTestImages); validationIdx = testIdx(1:numValImages); testIdx = testIdx(numValImages+1:numTestImages); dsValRAW = subset(dsTestRAW,validationIdx); dsValRGB = subset(dsTestRGB,validationIdx); dsTestRAW = subset(dsTestRAW,testIdx); dsTestRGB = subset(dsTestRGB,testIdx);
Датчик получает цветные данные в повторяющемся шаблоне Байера, который включает один красный, два зеленых, и один синий фотодатчик. Предварительно обработайте данные в изображение с четырьмя каналами, ожидаемое сети с помощью transform
функция. transform
функциональные процессы данные с помощью операций заданы в preprocessRAWDataForRAWToRGB
функция помощника. Функция помощника присоединена к примеру как к вспомогательному файлу.
preprocessRAWDataForRAWToRGB
функция помощника преобразует H W 1 НЕОБРАБОТАННЫМ изображением к H/2-by-W/2-by-4 многоканальному изображению, состоящему из одного красного, двух зеленых, и один синий канал.
Функция также бросает данные к типу данных single
масштабируемый к области значений [0, 1].
dsTrainRAW = transform(dsTrainRAW,@preprocessRAWDataForRAWToRGB); dsValRAW = transform(dsValRAW,@preprocessRAWDataForRAWToRGB); dsTestRAW = transform(dsTestRAW,@preprocessRAWDataForRAWToRGB);
Целевые изображения RGB хранятся на диске как 8-битные данные без знака. Чтобы сделать расчет метрик и проектирование сети более удобными, предварительно обработайте целевые изображения обучения RGB с помощью transform
функционируйте и preprocessRGBDataForRAWToRGB
функция помощника. Функция помощника присоединена к примеру как к вспомогательному файлу.
preprocessRGBDataForRAWToRGB
функция помощника бросает изображения к типу данных single
масштабируемый к области значений [0, 1].
dsTrainRGB = transform(dsTrainRGB,@preprocessRGBDataForRAWToRGB); dsValRGB = transform(dsValRGB,@preprocessRGBDataForRAWToRGB);
Объедините вход RAW и предназначайтесь для данных о RGB для обучения, валидации, и протестируйте наборы изображений при помощи combine
функция.
dsTrain = combine(dsTrainRAW,dsTrainRGB); dsVal = combine(dsValRAW,dsValRGB); dsTest = combine(dsTestRAW,dsTestRGB);
Случайным образом увеличьте обучающие данные с помощью transform
функционируйте и augmentDataForRAWToRGB
функция помощника. Функция помощника присоединена к примеру как к вспомогательному файлу.
augmentDataForRAWToRGB
функция помощника случайным образом применяет 90 вращений степени и горизонтальное отражение к парам входа RAW и целевых изображений обучения RGB.
dsTrainAug = transform(dsTrain,@augmentDataForRAWToRGB);
Предварительно просмотрите увеличенные обучающие данные.
exampleAug = preview(dsTrainAug)
exampleAug=8×2 cell array
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
{224×224×4 single} {448×448×3 single}
Отобразите сетевой вход и целевое изображение в монтаже. Сетевой вход имеет четыре канала, так отобразите первый канал, перемасштабированный к области значений [0, 1]. Вход RAW и целевые изображения RGB имеют идентичное увеличение.
exampleInput = exampleAug{1,1}; exampleOutput = exampleAug{1,2}; montage({rescale(exampleInput(:,:,1)),exampleOutput})
Этот пример использует пользовательский учебный цикл. minibatchqueue
Объект (Deep Learning Toolbox) полезен для управления мини-пакетная обработка наблюдений в пользовательских учебных циклах. minibatchqueue
возразите также бросает данные к dlarray
Объект (Deep Learning Toolbox), который включает автоматическое дифференцирование в применении глубокого обучения.
miniBatchSize = 12; valBatchSize = 10; trainingQueue = minibatchqueue(dsTrainAug,'MiniBatchSize',miniBatchSize,'PartialMiniBatch','discard','MiniBatchFormat','SSCB'); validationQueue = minibatchqueue(dsVal,'MiniBatchSize',valBatchSize,'MiniBatchFormat','SSCB');
next
(Deep Learning Toolbox) функция minibatchqueue
дает к следующему мини-пакету данных. Предварительно просмотрите выходные параметры от одного вызова до next
функция. Выходные параметры имеют тип данных dlarray
. Данные уже брошены к gpuArray
, на графическом процессоре, и готовый к обучению.
[inputRAW,targetRGB] = next(trainingQueue);
whos inputRAW
Name Size Bytes Class Attributes inputRAW 224x224x4x12 9633800 dlarray
whos targetRGB
Name Size Bytes Class Attributes targetRGB 448x448x3x12 28901384 dlarray
Этот пример использует изменение сети U-Net. В U-Net начальные серии сверточных слоев вкраплены макс. слоями объединения, последовательно уменьшив разрешение входного изображения. Эти слои сопровождаются серией сверточных слоев, вкрапленных повышающей дискретизацией операторов, последовательно увеличивая разрешение входного изображения. U-Net имени прибывает из того, что сеть может чертиться с симметричной формой как буква U.
Этот пример использует простую архитектуру U-Net с двумя модификациями. Во-первых, сеть заменяет транспонированную операцию свертки финала на пользовательскую пиксельную повышающую дискретизацию перестановки (также известный как глубину к пробелу) операция. Во-вторых, сеть использует пользовательский гиперболический слой активации касательной в качестве последнего слоя в сети.
Свертка, сопровождаемая пиксельной повышающей дискретизацией перестановки, может задать субпиксельную свертку для супер приложений разрешения. Субпиксельная свертка предотвращает checkboard артефакты, которые могут явиться результатом транспонированной свертки [6]. Поскольку модель должна сопоставить H/2-by-W/2-by-4 НЕОБРАБОТАННЫЕ входные параметры с W H 3 RGB выходные параметры, итоговый этап повышающей дискретизации модели может думаться так же к супер разрешению, где количество пространственных выборок растет от входа до выхода.
Рисунок показывает, как пиксельная повышающая дискретизация перестановки работает на 2 2 4 входами. Первые две размерности являются пространственными размерностями, и третья размерность является размерностью канала. В общем случае пиксельная повышающая дискретизация перестановки на коэффициент S берет вход H by W by C и дает к S*H-by-S*W-by- вывод .
Пиксельная функция перестановки выращивает пространственные размерности выхода путем отображения информации от размерностей канала в данном пространственном местоположении в S-by-S пространственные блоки в выходе, в котором каждый канал способствует сопоставимому пространственному положению относительно своих соседей во время повышающей дискретизации.
Гиперболический слой активации касательной применяет tanh
функция на входных параметрах слоя. Этот пример использует масштабированную и переключенную версию tanh
функция, которая поощряет, но строго не осуществляет это сеть RGB выходные параметры, находится в области значений [0, 1].
Используйте tall
вычислить среднее сокращение на канал через обучающий набор данных. Входной слой сети выполняет среднее центрирование входных параметров во время обучения и тестирования использования средней статистики.
dsIn = copy(dsTrainRAW); dsIn.UnderlyingDatastore.ReadSize = 1; t = tall(dsIn); perChannelMean = gather(mean(t,[1 2]));
Evaluating tall expression using the Parallel Pool 'LocalProfile12': - Pass 1 of 1: 0% complete Evaluation 0% complete - Pass 1 of 1: Completed in 1 min 1 sec Evaluation completed in 1 min 2 sec
Создайте слои начальной подсети, задав среднее значение на канал.
inputSize = [256 256 4]; initialLayer = imageInputLayer(inputSize,"Normalization","zerocenter","Mean",perChannelMean, ... "Name","ImageInputLayer");
Добавьте слои первой подсети кодирования. Первый энкодер имеет 32 сверточных фильтра.
numEncoderStages = 4; numFiltersFirstEncoder = 32; encoderNamePrefix = "Encoder-Stage-"; encoderLayers = [ convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(encoderNamePrefix,"1-Conv-1")) leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-1")) convolution2dLayer([3 3],numFiltersFirstEncoder,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(encoderNamePrefix,"1-Conv-2")) leakyReluLayer(0.2,"Name",strcat(encoderNamePrefix,"1-ReLU-2")) maxPooling2dLayer([2 2],"Stride",[2 2], ... "Name",strcat(encoderNamePrefix,"1-MaxPool")); ];
Добавьте слои дополнительных подсетей кодирования. Эти подсети добавляют мудрую каналом нормализацию экземпляра после каждого сверточного слоя с помощью groupNormalizationLayer
. Каждая подсеть энкодера имеет дважды количество фильтров как предыдущая подсеть энкодера.
cnIdx = 1; for stage = 2:numEncoderStages numFilters = numFiltersFirstEncoder*2^(stage-1); layerNamePrefix = strcat(encoderNamePrefix,num2str(stage)); encoderLayers = [ encoderLayers convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-1")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-2")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2")) maxPooling2dLayer([2 2],"Stride",[2 2],"Name",strcat(layerNamePrefix,"-MaxPool")) ]; cnIdx = cnIdx + 2; end
Добавьте мостоукладчики. Подсеть моста имеет дважды количество фильтров как итоговая подсеть энкодера и первая подсеть декодера.
numFilters = numFiltersFirstEncoder*2^numEncoderStages; bridgeLayers = [ convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name","Bridge-Conv-1") groupNormalizationLayer("channel-wise","Name","cn7") leakyReluLayer(0.2,"Name","Bridge-ReLU-1") convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name","Bridge-Conv-2") groupNormalizationLayer("channel-wise","Name","cn8") leakyReluLayer(0.2,"Name","Bridge-ReLU-2")];
Добавьте слои первых трех подсетей декодера.
numDecoderStages = 4; cnIdx = 9; decoderNamePrefix = "Decoder-Stage-"; decoderLayers = []; for stage = 1:numDecoderStages-1 numFilters = numFiltersFirstEncoder*2^(numDecoderStages-stage); layerNamePrefix = strcat(decoderNamePrefix,num2str(stage)); decoderLayers = [ decoderLayers transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-UpConv")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU")) depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-1")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-2")) groupNormalizationLayer("channel-wise","Name",strcat("cn",num2str(cnIdx+1))) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2")) ]; cnIdx = cnIdx + 2; end
Добавьте слои последней подсети декодера. Эта подсеть исключает мудрую каналом нормализацию экземпляра, выполняемую другими подсетями декодера. Каждая подсеть декодера имеет половину количества фильтров как предыдущая подсеть.
numFilters = numFiltersFirstEncoder; layerNamePrefix = strcat(decoderNamePrefix,num2str(stage+1)); decoderLayers = [ decoderLayers transposedConv2dLayer([3 3],numFilters,"Stride",[2 2],"Cropping","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-UpConv")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-UpReLU")) depthConcatenationLayer(2,"Name",strcat(layerNamePrefix,"-DepthConcatenation")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-1")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-1")) convolution2dLayer([3 3],numFilters,"Padding","same","WeightsInitializer","narrow-normal", ... "Name",strcat(layerNamePrefix,"-Conv-2")) leakyReluLayer(0.2,"Name",strcat(layerNamePrefix,"-ReLU-2"))];
Добавьте последние слои U-Net. Пиксельный слой перестановки перемещается от H/2-by-W/2-by-12 размер канала активаций от итоговой свертки до H W 3 активациями канала с помощью пиксельной повышающей дискретизации перестановки. Последний слой призывает выходные параметры к желаемой области значений [0, 1] использование гиперболической функции тангенса.
finalLayers = [ convolution2dLayer([3 3],12,"Padding","same","WeightsInitializer","narrow-normal", ... "Name","Decoder-Stage-4-Conv-3") pixelShuffleLayer("pixelShuffle",2) tanhScaledAndShiftedLayer("tanhActivation")]; layers = [initialLayer;encoderLayers;bridgeLayers;decoderLayers;finalLayers]; lgraph = layerGraph(layers);
Соедините слои кодирования и декодирования подсетей.
lgraph = connectLayers(lgraph,"Encoder-Stage-1-ReLU-2","Decoder-Stage-4-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-2-ReLU-2","Decoder-Stage-3-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-3-ReLU-2","Decoder-Stage-2-DepthConcatenation/in2"); lgraph = connectLayers(lgraph,"Encoder-Stage-4-ReLU-2","Decoder-Stage-1-DepthConcatenation/in2"); net = dlnetwork(lgraph);
Визуализируйте сетевую архитектуру с помощью Deep Network Designer (Deep Learning Toolbox) приложение.
%deepNetworkDesigner(lgraph)
Эта функция изменяет предварительно обученную глубокую нейронную сеть VGG-16, чтобы извлечь функции изображений на различных слоях. Эти многоуровневые функции используются для расчета потеря содержимого.
Чтобы получить предварительно обученную сеть VGG-16, установите vgg16
(Deep Learning Toolbox). Если вам не установили необходимый пакет поддержки, то программное обеспечение обеспечивает ссылку на загрузку.
vggNet = load('vgg16'); vggNet = vggNet.net; %vggNet = vgg16;
Чтобы сделать сеть VGG-16 подходящей для извлечения признаков, используйте слои до 'relu5_3'
.
vggNet = vggNet.Layers(1:31); vggNet = dlnetwork(layerGraph(vggNet));
Функция помощника modelGradients
вычисляет градиенты и полную потерю для пакетов обучающих данных. Эта функция задана в разделе Supporting Functions этого примера.
Полная потеря является взвешенной суммой двух потерь: среднее значение абсолютной погрешности (MAE) потеря и потеря содержимого. Потеря содержимого взвешивается таким образом, что потеря MAE и потеря содержимого способствуют приблизительно одинаково полной потере:
Потеря MAE штрафует расстояние между выборками сетевых предсказаний и выборками целевого изображения. часто лучший выбор, чем для приложений для обработки изображений, потому что это может помочь уменьшать размывающиеся артефакты [4]. Эта потеря реализована с помощью maeLoss
функция, определяемая помощника в разделе Supporting Functions этого примера.
Потеря содержимого помогает сети изучить и высокоуровневое структурное содержимое и низкоуровневое ребро и информацию о цвете. Функция потерь вычисляет взвешенную сумму среднеквадратичной погрешности (MSE) между предсказаниями и целями для каждого слоя активации. Эта потеря реализована с помощью contentLoss
функция, определяемая помощника в разделе Supporting Functions этого примера.
modelGradients
функция помощника требует весового коэффициента содержимого потерь как входного параметра. Вычислите весовой коэффициент для демонстрационного пакета обучающих данных, таким образом, что потеря MAE равна взвешенной потере содержимого.
Предварительно просмотрите пакет обучающих данных, который состоит из пар входных параметров сети RAW, и RGB предназначаются для выходных параметров.
trainingBatch = preview(dsTrainAug); networkInput = dlarray((trainingBatch{1,1}),'SSC'); targetOutput = dlarray((trainingBatch{1,2}),'SSC');
Предскажите ответ нетренированной сети U-Net с помощью forward
(Deep Learning Toolbox) функция.
predictedOutput = forward(net,networkInput);
Вычислите MAE и потери содержимого между предсказанными и целевыми изображениями RGB.
sampleMAELoss = maeLoss(predictedOutput,targetOutput); sampleContentLoss = contentLoss(vggNet,predictedOutput,targetOutput);
Вычислите весовой коэффициент.
weightContent = sampleMAELoss/sampleContentLoss;
Задайте опции обучения, которые используются в пользовательском учебном цикле, чтобы управлять аспектами оптимизации Адама. Обучайтесь в течение 20 эпох.
learnRate = 5e-5; numEpochs = 20;
По умолчанию пример загружает предварительно обученную версию сети RAW-to-RGB при помощи функции помощника downloadTrainedRAWToRGBNet
. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.
Чтобы обучить сеть, установите doTraining
переменная в следующем коде к true
. Обучите модель в пользовательском учебном цикле. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next
(Deep Learning Toolbox) функция.
Оцените градиенты модели с помощью dlfeval
(Deep Learning Toolbox) функция и modelGradients
функция помощника.
Обновите сетевые параметры с помощью adamupdate
(Deep Learning Toolbox) функция и информация о градиенте.
Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.
Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Обучение занимает приблизительно 88 часов на Титане NVIDIA™ RTX и может взять еще дольше в зависимости от вашего оборудования графического процессора.
doTraining = false; if doTraining % Create a directory to store checkpoints checkpointDir = fullfile(imageDir,'checkpoints',filesep); if ~exist(checkpointDir,'dir') mkdir(checkpointDir); end % Initialize training plot [hFig,batchLine,validationLine] = initializeTrainingPlotRAWPipeline; % Initialize Adam solver state [averageGrad,averageSqGrad] = deal([]); iteration = 0; start = tic; for epoch = 1:numEpochs reset(trainingQueue); shuffle(trainingQueue); while hasdata(trainingQueue) [inputRAW,targetRGB] = next(trainingQueue); [grad,loss] = dlfeval(@modelGradients,net,vggNet,inputRAW,targetRGB,weightContent); iteration = iteration + 1; [net,averageGrad,averageSqGrad] = adamupdate(net, ... grad,averageGrad,averageSqGrad,iteration,learnRate); updateTrainingPlotRAWPipeline(batchLine,validationLine,iteration,loss,start,epoch,... validationQueue,numValImages,valBatchSize,net,vggNet,weightContent); end % Save checkpoint of network state save(checkpointDir + "epoch" + epoch,'net'); end % Save the final network state save(checkpointDir + "trainedRAWToRGBNet.mat",'net'); else trainedRAWToRGBNet_url = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedRAWToRGBNet.mat'; downloadTrainedRAWToRGBNet(trainedRAWToRGBNet_url,imageDir); load(fullfile(imageDir,'trainedRAWToRGBNet.mat')); end
Pretrained RAW-to-RGB network already exists.
Основанные на ссылке метрики качества, такие как MSSIM или PSNR включают количественные показатели качества изображения. Можно вычислить MSSIM и PSNR исправленных тестовых изображений, потому что они пространственно указаны и тот же размер.
Выполните итерации через набор тестов исправленных изображений с помощью minibatchqueue
объект.
patchTestSet = combine(dsTestRAW,dsTestRGB); testPatchQueue = minibatchqueue(patchTestSet,'MiniBatchSize',16,'MiniBatchFormat','SSCB');
Выполните итерации через набор тестов и вычислите MSSIM и PSNR для каждого тестового изображения с помощью multissim
и psnr
функции. Вычислите MSSIM для цветных изображений при помощи среднего значения метрики для каждого цветового канала как приближение, поскольку метрика не четко определена для многоканальных входных параметров.
totalMSSIM = 0; totalPSNR = 0; while hasdata(testPatchQueue) [inputRAW,targetRGB] = next(testPatchQueue); outputRGB = forward(net,inputRAW); targetRGB = targetRGB ./ 255; mssimOut = sum(mean(multissim(outputRGB,targetRGB),3),4); psnrOut = sum(psnr(outputRGB,targetRGB),4); totalMSSIM = totalMSSIM + mssimOut; totalPSNR = totalPSNR + psnrOut; end
Вычислите средний MSSIM и означайте PSNR по набору тестов. Этот результат сопоставим с подобным подходом U-Net от [3] для среднего MSSIM и конкурентоспособен по отношению к подходу PyNet в [3] в среднем PSNR. Различия в функциях потерь и использовании пиксельной повышающей дискретизации перестановки по сравнению с [3] вероятный счет на эти различия.
numObservations = dsTestRGB.numpartitions; meanMSSIM = totalMSSIM / numObservations
meanMSSIM = 1(S) × 1(S) × 1(C) × 1(B) single dlarray 0.8425
meanPSNR = totalPSNR / numObservations
meanPSNR = 1(S) × 1(S) × 1(C) × 1(B) single dlarray 21.1213
Из-за различий в датчике между телефонной камерой и DSLR, используемым, чтобы получить тестовые изображения полного разрешения, сцены не указаны и не являются тем же размером. Основанное на ссылке сравнение изображений полного разрешения от сети и DSLR ISP затрудняет. Однако качественное сравнение изображений полезно, потому что цель обработки изображений состоит в том, чтобы создать эстетически приятное изображение.
Создайте datastore изображений, который содержит полноразмерные НЕОБРАБОТАННЫЕ изображения, полученные телефонной камерой.
testImageDir = fullfile(imageDir,'test'); testImageDirRAW = "huawei_full_resolution"; dsTestFullRAW = imageDatastore(fullfile(testImageDir,testImageDirRAW));
Получите имена файлов изображений в полноразмерном НЕОБРАБОТАННОМ наборе тестов.
targetFilesToInclude = extractAfter(string(dsTestFullRAW.Files),fullfile(testImageDirRAW,filesep));
targetFilesToInclude = extractBefore(targetFilesToInclude,".png");
Предварительно обработайте Необработанные данные путем преобразования данных в форму, ожидаемую сетью с помощью transform
функция. transform
функциональные процессы данные с помощью операций заданы в preprocessRAWDataForRAWToRGB
функция помощника. Функция помощника присоединена к примеру как к вспомогательному файлу.
dsTestFullRAW = transform(dsTestFullRAW,@preprocessRAWDataForRAWToRGB);
Создайте datastore изображений, который содержит полноразмерные тестовые изображения RGB, полученные от высокопроизводительного DSLR. Цюрихский набор данных RAW-to-RGB содержит больше полноразмерных изображений RGB, чем НЕОБРАБОТАННЫЕ изображения, поэтому включайте только изображения RGB с соответствующим НЕОБРАБОТАННЫМ изображением.
dsTestFullRGB = imageDatastore(fullfile(imageDir,'full_resolution','canon')); dsTestFullRGB.Files = dsTestFullRGB.Files(contains(dsTestFullRGB.Files,targetFilesToInclude));
Читайте в целевых изображениях RGB. Получите смысл полного выхода путем рассмотрения представления монтажа.
targetRGB = readall(dsTestFullRGB); montage(targetRGB,"Size",[5 2],"Interpolation","bilinear")
Выполните итерации через набор тестов полноразмерных изображений с помощью minibatchqueue
объект. Если у вас есть устройство графического процессора с достаточной памятью, чтобы обработать изображения полного разрешения, то можно запустить предсказание на графическом процессоре путем определения выходной среды как "gpu"
.
testQueue = minibatchqueue(dsTestFullRAW,"MiniBatchSize",1, ... "MiniBatchFormat","SSCB","OutputEnvironment","cpu");
Для каждого полноразмерного НЕОБРАБОТАННОГО тестового изображения предскажите изображение выхода RGB путем вызова forward
(Deep Learning Toolbox) в сети.
outputSize = 2*size(preview(dsTestFullRAW),[1 2]); outputImages = zeros([outputSize,3,dsTestFullRAW.numpartitions],'uint8'); idx = 1; while hasdata(testQueue) inputRAW = next(testQueue); rgbOut = forward(net,inputRAW); rgbOut = gather(extractdata(rgbOut)); outputImages(:,:,:,idx) = im2uint8(rgbOut); idx = idx+1; end
Получите смысл полного выхода путем рассмотрения представления монтажа. Сеть производит изображения, которые эстетически приятны с подобными характеристиками.
montage(outputImages,"Size",[5 2],"Interpolation","bilinear")
Сравните одно целевое изображение RGB с соответствующим изображением, предсказанным сетью. Сеть производит цвета, которые более насыщаются, чем целевые изображения DSLR. Несмотря на то, что цвета от простой архитектуры U-Net различные, когда DSLR предназначается, изображения все еще качественно приятны во многих случаях.
imgIdx = 1; imTarget = targetRGB{imgIdx}; imPredicted = outputImages(:,:,:,imgIdx); figure montage({imTarget,imPredicted},"Interpolation","bilinear")
Чтобы улучшать производительность сети RAW-to-RGB, сетевая архитектура училась бы подробный, локализовал пространственные функции с помощью нескольких шкал от глобальных функций, которые описывают цвет и контраст [3].
modelGradients
функция помощника вычисляет градиенты и полную потерю. Информация о градиенте возвращена как таблица, которая включает слой, название параметра и значение для каждого настраиваемого параметра в модели.
function [gradients,loss] = modelGradients(dlnet,vggNet,Xpatch,Target,weightContent) Y = forward(dlnet,Xpatch); lossMAE = maeLoss(Y,Target); lossContent = contentLoss(vggNet,Y,Target); loss = lossMAE + weightContent.*lossContent; gradients = dlgradient(loss,dlnet.Learnables); end
Функция помощника maeLoss
вычисляет среднюю абсолютную погрешность между сетевыми предсказаниями, Y
, и целевые изображения, T
.
function loss = maeLoss(Y,T) loss = mean(abs(Y-T),'all'); end
Функция помощника contentLoss
вычисляет взвешенную сумму MSE между сетевыми предсказаниями, Y
, и целевые изображения, T
, для каждого слоя активации. contentLoss
функция помощника вычисляет MSE для каждого слоя активации с помощью mseLoss
функция помощника. Веса выбраны таким образом, что потеря от каждой активации слои способствует примерно одинаково полной потере содержимого.
function loss = contentLoss(net,Y,T) layers = ["relu1_1","relu1_2","relu2_1","relu2_2","relu3_1","relu3_2","relu3_3","relu4_1"]; [T1,T2,T3,T4,T5,T6,T7,T8] = forward(net,T,'Outputs',layers); [X1,X2,X3,X4,X5,X6,X7,X8] = forward(net,Y,'Outputs',layers); l1 = mseLoss(X1,T1); l2 = mseLoss(X2,T2); l3 = mseLoss(X3,T3); l4 = mseLoss(X4,T4); l5 = mseLoss(X5,T5); l6 = mseLoss(X6,T6); l7 = mseLoss(X7,T7); l8 = mseLoss(X8,T8); layerLosses = [l1 l2 l3 l4 l5 l6 l7 l8]; weights = [1 0.0449 0.0107 0.0023 6.9445e-04 2.0787e-04 2.0118e-04 6.4759e-04]; loss = sum(layerLosses.*weights); end
Функция помощника mseLoss
вычисляет MSE между сетевыми предсказаниями, Y
, и целевые изображения, T
.
function loss = mseLoss(Y,T) loss = mean((Y-T).^2,'all'); end
1) Самнер, Ограбить. "Обрабатывая НЕОБРАБОТАННЫЕ Изображения в MATLAB". 19 мая 2014. https://rcsumner.net/raw_guide/RAWguide.pdf.
2) Чэнь, Чэнь, Цифэн Чэнь, Цзя Сюй и Владлен Кольтун. “Учась Видеть в темноте”. ArXiv:1805.01934 [Cs], 4 мая 2018. http://arxiv.org/abs/1805.01934.
3) Игнатов, Андрей, Люк Ван Гул и Рэду Тимофт. “Заменяя Мобильную Камеру ISP на Одну Модель Глубокого обучения”. ArXiv:2002.05509 [Cs, Eess], 13 февраля 2020. http://arxiv.org/abs/2002.05509. Веб-сайт проекта.
4) Чжао, Висите, Орацио Галло, Юри Фрозьо и Ян Коц. “Функции потерь для Нейронных сетей для Обработки изображений”. ArXiv:1511.08861 [Cs], 20 апреля 2018. http://arxiv.org/abs/1511.08861.
5) Джонсон, Джастин, Александр Алахи и Ли Фэй-Фэй. “Перцепционные Потери для Передачи Стиля В реальном времени и Суперразрешения”. ArXiv:1603.08155 [Cs], 26 марта 2016. http://arxiv.org/abs/1603.08155.
6) Ши, Wenzhe, Хосе Кэбаллеро, Ференц Хусзар, Джоханнс Тоц, Эндрю П. Эйткен, Грабит епископа, Дэниела Руекерта, и Зехэна Вана. “Одно Изображение в реальном времени и Видео Суперразрешение Используя Эффективную Субпиксельную Сверточную нейронную сеть”. ArXiv:1609.05158 [Cs, Статистика], 23 сентября 2016. http://arxiv.org/abs/1609.05158.
imageDatastore
| trainingOptions
(Deep Learning Toolbox) | trainNetwork
(Deep Learning Toolbox) | transform
| combine