Обучите быструю сеть передачи стиля

В этом примере показано, как обучить сеть, чтобы передать стиль изображения к второму изображению. Это основано на архитектуре, заданной в [1].

Этот пример похож на Нейронную Передачу Стиля Используя Глубокое обучение, но это работает быстрее, если вы обучили сеть на S стиля изображений. Это вызвано тем, что чтобы получить стилизованное изображение Y только необходимо сделать, прямая передача входа отображает X к сети.

Найдите высокоуровневую схему алгоритма настройки ниже. Это использует три изображения, чтобы вычислить потерю: входное изображение X, преобразованное изображение Y и стиль отображает S.

Обратите внимание на то, что функция потерь использует предварительно обученную сеть VGG-16, чтобы извлечь функции из изображений. Можно найти его реализацию и математическое определение в разделе Style Transfer Loss этого примера.

Загрузите обучающие данные

Загрузите и извлеките COCO 2014, обучают изображения и заголовки под эгидой https://cocodataset.org/#download путем нажатия на "2014 Train images". Сохраните в папке данные, заданные imageFolder. Извлеките изображения в imageFolder. COCO 2014 был собран Кокосовым Консорциумом.

Создайте директории, чтобы сохранить набор данных COCO.

imageFolder = fullfile(tempdir,"coco");
if ~exist(imageFolder,'dir')
    mkdir(imageFolder);
end

Создайте datastore изображений, содержащий изображения COCO.

imds = imageDatastore(imageFolder,'IncludeSubfolders',true);

Обучение может занять много времени, чтобы запуститься. Если вы хотите уменьшить учебное время за счет точности получившейся сети, то выберите подмножество datastore изображений установкой fraction к меньшему значению.

fraction = 1;
numObservations = numel(imds.Files);
imds = subset(imds,1:floor(numObservations*fraction));

Чтобы изменить размер изображений и преобразовать их всех в RGB, создайте увеличенный datastore изображений.

augimds = augmentedImageDatastore([256 256],imds,'ColorPreprocessing',"gray2rgb");

Считайте изображение стиля.

styleImage = imread('starryNight.jpg');
styleImage = imresize(styleImage,[256 256]);

Отобразите выбранное изображение стиля.

figure
imshow(styleImage)
title("Style Image")

Задайте сеть трансформатора изображений

Задайте сеть трансформатора изображений. Это - сеть от изображения к изображению. Сеть состоит из 3 частей:

  1. Первая часть сети берет в качестве входа изображение RGB размера [256x256x3] и прореживает его к карте функции размера [64x64x128].

  2. Вторая часть сети состоит из пяти идентичных остаточных блоков, заданных в функции поддержки residualBlock.

  3. Третья и итоговая часть сети сверхдискретизировала карту функции к первоначальному размеру изображения и возвращает преобразованное изображение. Эта последняя часть использует upsampleLayer, который является пользовательским слоем, присоединенным к этому примеру как вспомогательный файл.

layers = [
    
    % First part.
    imageInputLayer([256 256 3], 'Name', 'input', 'Normalization','none')
    
    convolution2dLayer([9 9], 32, 'Padding','same','Name', 'conv1')
    groupNormalizationLayer('channel-wise','Name','norm1')
    reluLayer('Name', 'relu1')
    
    convolution2dLayer([3 3], 64, 'Stride', 2,'Padding','same','Name', 'conv2')
    groupNormalizationLayer('channel-wise' ,'Name','norm2')
    reluLayer('Name', 'relu2')
    
    convolution2dLayer([3 3], 128, 'Stride', 2, 'Padding','same','Name', 'conv3')
    groupNormalizationLayer('channel-wise' ,'Name','norm3')
    reluLayer('Name', 'relu3')
    
    % Second part. 
    residualBlock("1")
    residualBlock("2")
    residualBlock("3")
    residualBlock("4")
    residualBlock("5")
    
    % Third part.
    upsampleLayer('up1')
    convolution2dLayer([3 3], 64, 'Stride', 1, 'Padding','same','Name', 'upconv1')
    groupNormalizationLayer('channel-wise' ,'Name','norm6')
    reluLayer('Name', 'relu5')
    
    upsampleLayer('up2')
    convolution2dLayer([3 3], 32, 'Stride', 1, 'Padding','same','Name', 'upconv2')
    groupNormalizationLayer('channel-wise' ,'Name','norm7')
    reluLayer('Name', 'relu6')
    
    convolution2dLayer(9,3,'Padding','same','Name','conv_out')];

lgraph = layerGraph(layers);

Добавьте недостающие связи в остаточных блоках.

lgraph = connectLayers(lgraph,"relu3","add1/in2");
lgraph = connectLayers(lgraph,"add1","add2/in2");
lgraph = connectLayers(lgraph,"add2","add3/in2");
lgraph = connectLayers(lgraph,"add3","add4/in2");
lgraph = connectLayers(lgraph,"add4","add5/in2");

Визуализируйте сеть трансформатора изображений в графике.

figure
plot(lgraph)
title("Transform Network")

Создайте dlnetwork объект от графика слоев.

dlnetTransform = dlnetwork(lgraph);

Разработайте сеть потерь

Этот пример использует предварительно обученную глубокую нейронную сеть VGG-16, чтобы извлечь функции содержимого и изображений стиля на различных слоях. Эти многоуровневые функции используются для расчета соответствующее содержимое и разрабатывают потери.

Чтобы получить предварительно обученную сеть VGG-16, используйте vgg16 функция. Если вам не установили необходимые пакеты поддержки, то программное обеспечение обеспечивает ссылку на загрузку.

netLoss = vgg16;

Чтобы извлечь функцию, необходимую, чтобы вычислить потерю, вам нужны первые 24 слоя только. Извлеките и преобразуйте в график слоев.

lossLayers = netLoss.Layers(1:24);
lgraph = layerGraph(lossLayers);

Преобразуйте в dlnetwork.

dlnetLoss = dlnetwork(lgraph);

Задайте матрицу функции потерь и грамма

Создайте styleTransferLoss функция, определяемая в разделе Style Transfer Loss этого примера.

Функциональный styleTransferLoss берет в качестве входа сеть dlnetLoss потерь, мини-пакет входных преобразованных изображений dlX, мини-пакет преобразованных изображений dlY, массив, содержащий матрицы Грамма стиля, отображает dlSGram, вес сопоставил с потерей содержимого contentWeight и вес сопоставил с потерей стиля styleWeight. Функция возвращает общую сумму убытков loss и отдельные компоненты: потеря содержимого lossContent и потеря стиля lossStyle.

styleTransferLoss функционируйте использует функцию поддержки createGramMatrix в расчете потери стиля.

createGramMatrix функционируйте берет в качестве входа функции, извлеченные сетью потерь, и возвращает стилистическое представление для каждого изображения в мини-пакете. Можно найти реализацию и математическое определение матрицы Грамма в разделе Gram Matrix.

Задайте функцию градиентов модели

Создайте функциональный modelGradients, перечисленный в разделе Model Gradients Function примера. Эта функция берет в качестве входа сеть dlnetLoss потерь, сеть dlnetTransform трансформатора изображений, мини-пакет входа отображает dlX, массив, содержащий матрицы Грамма стиля, отображает dlSGram, вес сопоставил с потерей содержимого contentWeight и вес сопоставил с потерей стиля styleWeight. Функция возвращает gradients из потери относительно настраиваемых параметров трансформатора изображений, состояния сети трансформатора изображений, преобразованные изображения dlY, общая сумма убытков loss, потеря сопоставила с содержимым lossContent и потеря сопоставила со стилем lossStyle.

Задайте опции обучения

Обучайтесь с размером мини-пакета, равным 4, в течение 2 эпох как в [1].

numEpochs = 2;
miniBatchSize = 4;

Установите размер чтения увеличенного datastore изображений к мини-пакетному размеру.

augimds.MiniBatchSize = miniBatchSize;

Задайте опции для оптимизации ADAM. Задайте изучить уровень 0,001 с фактором затухания градиента 0,01 и фактором затухания градиента в квадрате 0,999.

learnRate = 0.001;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

executionEnvironment = "auto";

Задайте вес, данный потере стиля и один данный к потере содержимого в вычислении общей суммы убытков.

Обратите внимание на то, что, для того, чтобы найти хороший баланс между содержимым и потерей стиля, вы можете должны быть экспериментировать с различными комбинациями весов.

weightContent = 1e-4;
weightStyle = 3e-8; 

Выберите частоту графика процесса обучения. Это задает, сколько итерации там между каждым обновлением графика.

plotFrequency = 10;

Обучите модель

Для того, чтобы смочь вычислить потерю во время обучения, вычислить матрицы Грамма для изображения стиля.

Преобразуйте изображение стиля в dlarray.

dlS = dlarray(single(styleImage),'SSC');

Для того, чтобы вычислить матрицу Грамма, накормите изображением стиля сеть VGG-16 и извлеките активации на четырех различных слоях.

[dlSActivations1,dlSActivations2,dlSActivations3,dlSActivations4] = forward(dlnetLoss,dlS, ...
    'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

Вычислите матрицу Грамма для каждого набора активаций с помощью функции поддержки createGramMatrix.

dlSGram{1} = createGramMatrix(dlSActivations1);
dlSGram{2} = createGramMatrix(dlSActivations2);
dlSGram{3} = createGramMatrix(dlSActivations3);
dlSGram{4} = createGramMatrix(dlSActivations4);

Учебные графики состоят из двух фигур:

  1. Фигура, показывающая график потерь во время обучения

  2. Фигура, содержащая вход и выходное изображение сети трансформатора изображений

Инициализируйте учебные графики. Можно проверять детали инициализации в функции поддержки initializeFigures. Эта функция возвращается: ось ax1 где вы строите потерю, ось ax2 где вы строите изображения валидации, анимированная линия lineLossContent который содержит потерю содержимого, анимированная линия lineLossStyle который содержит потерю стиля и анимированную линию lineLossTotal который содержит общую сумму убытков.

[ax1,ax2,lineLossContent,lineLossStyle,lineLossTotal]=initializeStyleTransferPlots();

Инициализируйте средний градиент и средние гиперпараметры градиента в квадрате для оптимизатора ADAM.

averageGrad = [];
averageSqGrad = [];

Вычислите общее количество учебных итераций.

numIterations = floor(augimds.NumObservations*numEpochs/miniBatchSize);

Инициализируйте номер итерации и таймер перед обучением.

iteration = 0;
start = tic;

Обучите модель. Это могло занять много времени, чтобы запуститься.

% Loop over epochs.
for i = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Extract the images from data store into a cell array.
        images = data{:,1};
        
        % Concatenate the images along the 4th dimension.
        X = cat(4,images{:});
        X = single(X);
        
        % Convert mini-batch of data to dlarray and specify the dimension labels
        % 'SSCB' (spatial, spatial, channel, batch).
        dlX = dlarray(X, 'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and the network state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradients,state,dlY,loss,lossContent,lossStyle] = dlfeval(@modelGradients, ...
            dlnetLoss,dlnetTransform,dlX,dlSGram,weightContent,weightStyle);
        
        dlnetTransform.State = state;
        
        % Update the network parameters.
        [dlnetTransform,averageGrad,averageSqGrad] = ...
            adamupdate(dlnetTransform,gradients,averageGrad,averageSqGrad,iteration,...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
              
        
        % Every plotFequency iterations, plot the training progress.
        if mod(iteration,plotFrequency) == 0
            addpoints(lineLossTotal,iteration,double(gather(extractdata(loss))))
            addpoints(lineLossContent,iteration,double(gather(extractdata(lossContent))))
            addpoints(lineLossStyle,iteration,double(gather(extractdata(lossStyle))))
            
            % Use the first image of the mini-batch as a validation image.
            dlV = dlX(:,:,:,1);
            % Use the transformed validation image computed previously.
            dlVY = dlY(:,:,:,1);
            
            % To use the function imshow, convert to uint8.
            validationImage = uint8(gather(extractdata(dlV)));
            transformedValidationImage = uint8(gather(extractdata(dlVY)));
            
            % Plot the input image and the output image and increase size
            imshow(imtile({validationImage,transformedValidationImage}),'Parent',ax2);
        end
        
        % Display time elapsed since start of training and training completion percentage.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        completionPercentage = round(iteration/numIterations*100,2);
        title(ax1,"Epoch: " + i + ", Iteration: " + iteration +" of "+ numIterations + "(" + completionPercentage + "%)" +", Elapsed: " + string(D))
        drawnow
        
    end
end

Стилизуйте изображение

Если обучение закончилось, можно использовать трансформатор изображений на любом изображении по вашему выбору.

Загрузите изображение, которое требуется преобразовать.

imFilename = 'peppers.png';
im = imread(imFilename);

Измените размер входного изображения к входным размерностям трансформатора изображений.

im = imresize(im,[256,256]);

Преобразуйте его в dlarray.

dlX = dlarray(single(im),'SSCB');

Чтобы использовать графический процессор преобразуют в gpuArray если вы доступны.

if canUseGPU
    dlX = gpuArray(dlX);
end

Чтобы применить стиль к изображению, передайте передачу это трансформатору изображений с помощью функционального predict.

dlY = predict(dlnetTransform,dlX);

Перемасштабируйте изображение в область значений [0 255]. Во-первых, используйте функциональный tanh перемасштабировать dlY к области значений [-1 1]. Затем сдвиг и шкала выход, чтобы перемасштабировать в [0 255] область значений.

Y = 255*(tanh(dlY)+1)/2;

Подготовьте Y для графического вывода. Используйте функциональный extractdata извлекать данные из dlarray.Использование функция собирается, чтобы передать Y от графического процессора до локальной рабочей области.

Y = uint8(gather(extractdata(Y)));

Покажите входное изображение (слева) рядом со стилизованным изображением (справа).

figure
m = imtile({im,Y});
imshow(m)

Функция градиентов модели

Функциональный modelGradients берет в качестве входа сеть dlnetLoss потерь, сеть dlnetTransform трансформатора изображений, мини-пакет входа отображает dlX, массив, содержащий матрицы Грамма стиля, отображает dlSGram, вес сопоставил с потерей содержимого contentWeight и вес сопоставил с потерей стиля styleWeight. Это возвращает gradients из потери относительно настраиваемых параметров трансформатора изображений, состояния сети трансформатора изображений, преобразованные изображения dlY, общая сумма убытков loss, потеря сопоставила с содержимым lossContent и потеря сопоставила со стилем lossStyle.

function [gradients,state,dlY,loss,lossContent,lossStyle] = ...
    modelGradients(dlnetLoss,dlnetTransform,dlX,dlSGram,contentWeight,styleWeight)

[dlY,state] = forward(dlnetTransform,dlX);

dlY = 255*(tanh(dlY)+1)/2;

[loss,lossContent,lossStyle] = styleTransferLoss(dlnetLoss,dlY,dlX,dlSGram,contentWeight,styleWeight);

gradients = dlgradient(loss,dlnetTransform.Learnables);

end

Разработайте потерю передачи

Функциональный styleTransferLoss берет в качестве входа сеть dlnetLoss потерь, мини-пакет входа отображает dlX, мини-пакет преобразованных изображений dlY, массив, содержащий матрицы Грамма стиля, отображает dlSGram, веса сопоставили с содержимым и стилем contentWeight и styleWeight, соответственно. Это возвращает общую сумму убытков loss и отдельные компоненты: потеря содержимого lossContent и потеря стиля lossStyle.

Потеря содержимого является мерой сколько различия в пространственной структуре существует между входным изображением X и выходные изображения Y.

С другой стороны, потеря стиля говорит вам сколько различия в стилистическом внешнем виде существует между изображением стиля S и выходное изображение Y.

График ниже объясняет алгоритм что styleTransferLoss реализации, чтобы вычислить общую сумму убытков.

Во-первых, функция передает входные изображения X, преобразованные изображения Y и стиль отображает S к предварительно обученной сети VGG-16. Эта предварительно обученная сеть извлекает несколько функций из этих изображений. Алгоритм затем вычисляет, потеря содержимого при помощи пространственных функций входа отображают X и выходного изображения Y. Кроме того, это вычисляет потерю стиля при помощи стилистических функций выходного изображения Y, и стиля отображают S. Наконец, это получает общую сумму убытков путем добавления потерь стиля и содержимого.

Потеря содержимого

Для каждого изображения в мини-пакете функция потерь содержимого сравнивает функции оригинального изображения и преобразованного изображения, выведенного слоем relu_3_3. В частности, это вычисляет среднеквадратичную погрешность между активациями и возвращает среднюю потерю для мини-пакета:

lossContent=1Nn=1Nсреднее значение([ϕ(Xn)-ϕ(Yn)]2),

где X содержит входные изображения, Y содержит преобразованные изображения, N мини-пакетный размер, и ϕ() представляет активации, извлеченные на слое relu_3_3.

Разработайте потерю

Вычислить потерю стиля, для каждого одного изображения в мини-пакете:

  1. Извлеките активации на слоях relu1_2, relu2_2, relu3_3 и relu4_3.

  2. Для каждой из этих четырех активаций ϕj вычислите матрицу Грамма G(ϕj).

  3. Вычислите разность в квадрате между соответствующими матрицами Грамма.

  4. Сложите эти четыре выходных параметров для каждого слоя j от предыдущего шага.

Чтобы получить потерю стиля для целого мини-пакета, вычислите среднее значение потери стиля для каждого изображения n в мини-пакете:

lossStyle=1Nn=1Nj=14[G(ϕj(Xn))-G(ϕj(S))]2,

где j индекс слоя, и G() Матрица Грамма.

Общая сумма убытков

function [loss,lossContent,lossStyle] = styleTransferLoss(dlnetLoss,dlY,dlX, ...
    dlSGram,weightContent,weightStyle)

% Extract activations.
dlYActivations = cell(1,4);
[dlYActivations{1},dlYActivations{2},dlYActivations{3},dlYActivations{4}] = ...
    forward(dlnetLoss,dlY,'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

dlXActivations = forward(dlnetLoss,dlX,'Outputs','relu3_3');

% Calculate the mean square error between activations.
lossContent = mean((dlYActivations{3} - dlXActivations).^2,'all');

% Add up the losses for all the four activations.
lossStyle = 0;
for j = 1:4
    G = createGramMatrix(dlYActivations{j});
    lossStyle = lossStyle + sum((G - dlSGram{j}).^2,'all');
end

% Average the loss over the mini-batch.
miniBatchSize = size(dlX,4);
lossStyle = lossStyle/miniBatchSize;

% Apply weights.
lossContent = weightContent * lossContent;
lossStyle = weightStyle * lossStyle;

% Calculate the total loss.
loss = lossContent + lossStyle;

end

Остаточный блок

residualBlock функция возвращает массив шести слоев. Это состоит из слоев свертки, слоев нормализации экземпляра, слоя ReLu и слоя сложения. Обратите внимание на то, что groupNormalizationLayer('channel-wise') просто слой нормализации экземпляра.

function layers = residualBlock(name)

layers = [    
    convolution2dLayer([3 3], 128, 'Stride', 1,'Padding','same','Name', "convRes"+name+"_1")
    groupNormalizationLayer('channel-wise','Name',"normRes"+name+"_1")
    reluLayer('Name', "reluRes"+name+"_1")
    convolution2dLayer([3 3], 128, 'Stride', 1,'Padding','same', 'Name', "convRes"+name+"_2")
    groupNormalizationLayer('channel-wise','Name',"normRes"+name+"_2")
    additionLayer(2,'Name',"add"+name)];

end

Матрица грамма

Функциональный createGramMatrix берет в качестве входа активации единственного слоя и возвращает стилистическое представление для каждого изображения в mini-batch. Вход является картой функции размера [H, W, C, N], где H является высотой, W является шириной, C является количеством каналов, и N является мини-пакетным размером. Функциональные выходные параметры массив G из размера [C, C, N]. Каждая подрешетка G(:,:,k) матричное соответствие Грамма kth отобразите в мини-пакете. Каждая запись G(i,j,k) из Грамма матрица представляет корреляцию между каналами ci и cj, потому что каждая запись в канале ci умножает запись в соответствующем положении в канале cj:

G(i,j,k)=1C×H×Wh=1Hw=1Wϕk(h,w,ci)ϕk(h,w,cj),

где ϕk активации для kth отобразите в мини-пакете.

Матрица Грамма содержит информацию, о которой функции активируются вместе, но не имеет никакой информации о том, где функции происходят в изображении. Это вызвано тем, что суммирование по высоте и ширине теряет информацию о пространственной структуре. Функция потерь использует эту матрицу в качестве стилистического представления изображения.

function G = createGramMatrix(activations)

[h,w,numChannels] = size(activations,1:3);

features = reshape(activations,h*w,numChannels,[]);
featuresT = permute(features,[2 1 3]);

G = dlmtimes(featuresT,features) / (h*w*numChannels);

end

Ссылки

  1. Джонсон, Джастин, Александр Алахи и Ли Фэй-Фэй. "Перцепционные потери для передачи стиля в реальном времени и суперразрешения". Европейская конференция по компьютерному зрению. Спрингер, Хан, 2016.

Смотрите также

| | | | | |

Похожие темы