Обучите передаточную сеть быстрого стиля

В этом примере показано, как обучить сеть переносить стиль изображения во второе изображение. Он основан на архитектуре, определенной в [1].

Этот пример похож на передачу нейронного стиля с использованием глубокого обучения, но он работает быстрее, когда вы обучили сеть на изображении стиля S. Это потому, что, чтобы получить стилизованное изображение Y, вам нужно только сделать прямой проход входного изображения X в сеть.

Найдите высокоуровневую схему алгоритма настройки ниже. Для вычисления потерь используются три изображения: Вход изображение X, преобразованное изображение Y и изображение стиля S.

Обратите внимание, что функция потерь использует предварительно обученную сетевую VGG-16 для извлечения функций из изображений. Его реализацию и математическое определение можно найти в разделе Style Transfer Loss этого примера.

Загрузка обучающих данных

Загрузите и извлеките изображения и подписи обучения COCO 2014 из http://cocodataset.org/#download, нажав на «2014 Train image». Сохраните данные в папке, заданной как imageFolder. Извлеките изображения в imageFolder. COCO 2014 был собран консорциумом Coco.

Создайте директории для хранения набора данных COCO.

imageFolder = fullfile(tempdir,"coco");
if ~exist(imageFolder,'dir')
    mkdir(imageFolder);
end

Создайте datastore, содержащее изображения COCO.

imds = imageDatastore(imageFolder,'IncludeSubfolders',true);

Обучение может занять много времени. Если вы хотите уменьшить время обучения за счет точности полученной сети, выберите подмножество datastore путем установки fraction к меньшему значению.

fraction = 1;
numObservations = numel(imds.Files);
imds = subset(imds,1:floor(numObservations*fraction));

Чтобы изменить размер изображений и преобразовать их все в RGB, создайте дополненный image datastore.

augimds = augmentedImageDatastore([256 256],imds,'ColorPreprocessing',"gray2rgb");

Считайте изображение стиля.

styleImage = imread('starryNight.jpg');
styleImage = imresize(styleImage,[256 256]);

Отображение выбранного изображения стиля.

figure
imshow(styleImage)
title("Style Image")

Определите сеть трансформатора изображений

Определите сеть трансформатора изображений. Это сеть «изображение-изображение». Сеть состоит из 3 частей:

  1. Первая часть сети принимает за вход изображение RGB размера [256x256x3] и понижает его до карты функций размера [64x64x128].

  2. Вторая часть сети состоит из пяти одинаковых остаточных блоков, определенных в поддерживающей функции residualBlock.

  3. Третья и последняя часть сети повышает качество карты функций до исходного размера изображения и возвращает преобразованное изображение. Эта последняя часть использует upsampleLayer, который является пользовательским слоем, присоединенным к этому примеру как вспомогательный файл.

layers = [
    
    % First part.
    imageInputLayer([256 256 3], 'Name', 'input', 'Normalization','none')
    
    convolution2dLayer([9 9], 32, 'Padding','same','Name', 'conv1')
    groupNormalizationLayer('channel-wise','Name','norm1')
    reluLayer('Name', 'relu1')
    
    convolution2dLayer([3 3], 64, 'Stride', 2,'Padding','same','Name', 'conv2')
    groupNormalizationLayer('channel-wise' ,'Name','norm2')
    reluLayer('Name', 'relu2')
    
    convolution2dLayer([3 3], 128, 'Stride', 2, 'Padding','same','Name', 'conv3')
    groupNormalizationLayer('channel-wise' ,'Name','norm3')
    reluLayer('Name', 'relu3')
    
    % Second part. 
    residualBlock("1")
    residualBlock("2")
    residualBlock("3")
    residualBlock("4")
    residualBlock("5")
    
    % Third part.
    upsampleLayer('up1')
    convolution2dLayer([3 3], 64, 'Stride', 1, 'Padding','same','Name', 'upconv1')
    groupNormalizationLayer('channel-wise' ,'Name','norm6')
    reluLayer('Name', 'relu5')
    
    upsampleLayer('up2')
    convolution2dLayer([3 3], 32, 'Stride', 1, 'Padding','same','Name', 'upconv2')
    groupNormalizationLayer('channel-wise' ,'Name','norm7')
    reluLayer('Name', 'relu6')
    
    convolution2dLayer(9,3,'Padding','same','Name','conv_out')];

lgraph = layerGraph(layers);

Добавьте отсутствующие соединения в остаточные блоки.

lgraph = connectLayers(lgraph,"relu3","add1/in2");
lgraph = connectLayers(lgraph,"add1","add2/in2");
lgraph = connectLayers(lgraph,"add2","add3/in2");
lgraph = connectLayers(lgraph,"add3","add4/in2");
lgraph = connectLayers(lgraph,"add4","add5/in2");

Визуализируйте сеть трансформатора изображений на графике.

figure
plot(lgraph)
title("Transform Network")

Создайте dlnetwork объект из графика слоев.

dlnetTransform = dlnetwork(lgraph);

Сеть потерь стиля

Этот пример использует предварительно обученную VGG-16 глубокую нейронную сеть, чтобы извлечь функции изображений содержимого и стиля в различных слоях. Эти многослойные функции используются для вычисления соответствующего содержимого и потерь стиля.

Чтобы получить предварительно обученную VGG-16 сеть, используйте vgg16 функция. Если у вас нет установленных необходимых пакетов поддержки, то программное обеспечение предоставляет ссылку на загрузку.

netLoss = vgg16;

Чтобы извлечь функцию, необходимую для вычисления потерь, вам нужны только первые 24 слоя. Извлечение и преобразование в график слоев.

lossLayers = netLoss.Layers(1:24);
lgraph = layerGraph(lossLayers);

Преобразуйте в dlnetwork.

dlnetLoss = dlnetwork(lgraph);

Определите функцию потерь и матрицу граммов

Создайте styleTransferLoss функция, заданная в разделе Style Transfer Loss этого примера.

Функция styleTransferLoss принимает за вход сеть потерь dlnetLossмини-пакет входа преобразованных изображений dlX, мини-пакет преобразованных изображений dlYмассив, содержащий матрицы Грамма изображения стиля dlSGram, вес, связанный с потерей содержимого contentWeight и вес, связанный с потерей стиля styleWeight. Функция возвращает общую потерю loss и отдельные компоненты: потеря содержимого lossContent и потеря стиля lossStyle.

The styleTransferLoss функция использует вспомогательную функцию createGramMatrix при расчете потери стиля.

The createGramMatrix функция принимает за вход функций, извлеченную сетью потерь, и возвращает стилистическое представление для каждого изображения в мини-пакете. Вы можете найти реализацию и математическое определение матрицы Gram в разделе Gram Matrix.

Задайте функцию градиентов модели

Создайте функцию modelGradients, перечисленный в разделе Model Gradients Function примера. Эта функция принимает как вход сеть потерь dlnetLoss, сеть трансформатора изображений dlnetTransformмини-пакет входа изображений dlXмассив, содержащий матрицы Грамма изображения стиля dlSGram, вес, связанный с потерей содержимого contentWeight и вес, связанный с потерей стиля styleWeight. Функция возвращает gradients потерь относительно настраиваемых параметров трансформатора изображения, состояния сети трансформатора изображения, преобразованных изображений dlY, общая потеря loss, потеря, связанная с содержимым lossContent и потеря, связанная со стилем lossStyle.

Настройка опций обучения

Обучите с размером мини-пакета, равный 4, для 2 эпох, как в [1].

numEpochs = 2;
miniBatchSize = 4;

Установите размер чтения дополненного datastore изображения в мини-пакет.

augimds.MiniBatchSize = miniBatchSize;

Задайте опции для оптимизации ADAM. Задайте скорость обучения 0,001 с коэффициентом градиентного распада 0,01 и квадратным коэффициентом градиентного распада 0,999.

learnRate = 0.001;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучите на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

executionEnvironment = "auto";

Укажите вес, придаваемый потере стиля, и вес, придаваемый потере содержимого при расчете общей потери.

Обратите внимание, что, порядком найти хороший баланс между содержимым и потерей стиля, вам может потребоваться экспериментировать с различными комбинациями весов.

weightContent = 1e-4;
weightStyle = 3e-8; 

Выберите частоту построения процесса обучения. Это определяет, сколько итераций существует между каждым обновлением графика.

plotFrequency = 10;

Обучите модель

В порядок, чтобы вычислить потери во время обучения, вычислите матрицы Gram для изображения стиля.

Преобразуйте изображение стиля в dlarray.

dlS = dlarray(single(styleImage),'SSC');

В порядок вычисления матрицы Gram передайте изображение стиля в VGG-16 сеть и извлеките активации в четырех разных слоях.

[dlSActivations1,dlSActivations2,dlSActivations3,dlSActivations4] = forward(dlnetLoss,dlS, ...
    'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

Вычислите матрицу Gram для каждого набора активаций с помощью вспомогательной функции createGramMatrix.

dlSGram{1} = createGramMatrix(dlSActivations1);
dlSGram{2} = createGramMatrix(dlSActivations2);
dlSGram{3} = createGramMatrix(dlSActivations3);
dlSGram{4} = createGramMatrix(dlSActivations4);

Обучающие графики состоят из двух рисунков:

  1. Рисунок, показывающий график потерь во время обучения

  2. Рисунок, содержащий входное и выходное изображение сети трансформатора изображений

Инициализируйте обучающие графики. Детали инициализации можно проверить в вспомогательной функции initializeFigures. Эта функция возвращает: ось ax1 где вы строите график потерь, ось ax2 где вы строите графики изображений валидации, анимированной линии lineLossContent который содержит потерю содержимого, анимированную линию lineLossStyle который содержит потерю стиля и анимированную линию lineLossTotal который содержит общую потерю.

[ax1,ax2,lineLossContent,lineLossStyle,lineLossTotal]=initializeStyleTransferPlots();

Инициализируйте средний и средний квадратные градиентные гиперпараметры для оптимизатора ADAM.

averageGrad = [];
averageSqGrad = [];

Вычислите общее количество итераций обучения.

numIterations = floor(augimds.NumObservations*numEpochs/miniBatchSize);

Инициализируйте номер итерации и таймер перед обучением.

iteration = 0;
start = tic;

Обучите модель. Это может занять много времени, чтобы бежать.

% Loop over epochs.
for i = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Extract the images from data store into a cell array.
        images = data{:,1};
        
        % Concatenate the images along the 4th dimension.
        X = cat(4,images{:});
        X = single(X);
        
        % Convert mini-batch of data to dlarray and specify the dimension labels
        % 'SSCB' (spatial, spatial, channel, batch).
        dlX = dlarray(X, 'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and the network state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradients,state,dlY,loss,lossContent,lossStyle] = dlfeval(@modelGradients, ...
            dlnetLoss,dlnetTransform,dlX,dlSGram,weightContent,weightStyle);
        
        dlnetTransform.State = state;
        
        % Update the network parameters.
        [dlnetTransform,averageGrad,averageSqGrad] = ...
            adamupdate(dlnetTransform,gradients,averageGrad,averageSqGrad,iteration,...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
              
        
        % Every plotFequency iterations, plot the training progress.
        if mod(iteration,plotFrequency) == 0
            addpoints(lineLossTotal,iteration,double(gather(extractdata(loss))))
            addpoints(lineLossContent,iteration,double(gather(extractdata(lossContent))))
            addpoints(lineLossStyle,iteration,double(gather(extractdata(lossStyle))))
            
            % Use the first image of the mini-batch as a validation image.
            dlV = dlX(:,:,:,1);
            % Use the transformed validation image computed previously.
            dlVY = dlY(:,:,:,1);
            
            % To use the function imshow, convert to uint8.
            validationImage = uint8(gather(extractdata(dlV)));
            transformedValidationImage = uint8(gather(extractdata(dlVY)));
            
            % Plot the input image and the output image and increase size
            imshow(imtile({validationImage,transformedValidationImage}),'Parent',ax2);
        end
        
        % Display time elapsed since start of training and training completion percentage.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        completionPercentage = round(iteration/numIterations*100,2);
        title(ax1,"Epoch: " + i + ", Iteration: " + iteration +" of "+ numIterations + "(" + completionPercentage + "%)" +", Elapsed: " + string(D))
        drawnow
        
    end
end

Стилизовать изображение

После завершения обучения можно использовать трансформатор изображений на любом изображении по своему выбору.

Загрузите изображение, которое вы хотите преобразовать.

imFilename = 'peppers.png';
im = imread(imFilename);

Измените размер входа изображения на вход размерностей трансформатора изображения.

im = imresize(im,[256,256]);

Преобразуйте его в dlarray.

dlX = dlarray(single(im),'SSCB');

Как использовать преобразование графического процессора в gpuArray если он доступен.

if canUseGPU
    dlX = gpuArray(dlX);
end

Чтобы применить стиль к изображению, передайте его в изображение трансформатора используя функцию predict.

dlY = predict(dlnetTransform,dlX);

Переформулируйте изображение в область значений [0 255]. Во-первых, используйте функцию tanh для пересмотра dlY в область значений [-1 1]. Затем сдвиньте и масштабируйте выход, чтобы перерассчитать в область значений [0 255].

Y = 255*(tanh(dlY)+1)/2;

Подготовка Y для графического изображения. Используйте функцию extraxtdata для извлечения данных из dlarray.Используйте функцию gather, чтобы перенести Y из графический процессор в локальную рабочую область.

Y = uint8(gather(extractdata(Y)));

Отображение входа изображения (слева) рядом со стилизованным изображением (справа).

figure
m = imtile({im,Y});
imshow(m)

Функция градиентов модели

Функция modelGradients принимает за вход сеть потерь dlnetLoss, сеть трансформатора изображений dlnetTransformмини-пакет входа изображений dlXмассив, содержащий матрицы Грамма изображения стиля dlSGram, вес, связанный с потерей содержимого contentWeight и вес, связанный с потерей стиля styleWeight. Это возвращает gradients потерь относительно настраиваемых параметров трансформатора изображения, состояния сети трансформатора изображения, преобразованных изображений dlY, общие потери loss, потеря, связанная с содержимым lossContent и потеря, связанная со стилем lossStyle.

function [gradients,state,dlY,loss,lossContent,lossStyle] = ...
    modelGradients(dlnetLoss,dlnetTransform,dlX,dlSGram,contentWeight,styleWeight)

[dlY,state] = forward(dlnetTransform,dlX);

dlY = 255*(tanh(dlY)+1)/2;

[loss,lossContent,lossStyle] = styleTransferLoss(dlnetLoss,dlY,dlX,dlSGram,contentWeight,styleWeight);

gradients = dlgradient(loss,dlnetTransform.Learnables);

end

Передаточные потери стиля

Функция styleTransferLoss принимает за вход сеть потерь dlnetLossмини-пакет входа изображений dlX, мини-пакет преобразованных изображений dlYмассив, содержащий матрицы Грамма изображения стиля dlSGram, веса, связанные с содержимым и стилем contentWeight и styleWeight, соответственно. Возвращает общую потерю loss и отдельные компоненты: потеря содержимого lossContent и потеря стиля lossStyle.

Потеря содержимого является мерой того, сколько различие в пространственной структуре находится между входом изображением X и выходные изображения Y.

С другой стороны, потеря стиля подсказывает, насколько сильно отличается стилистический внешний вид между стилевым изображением S и выходное изображение Y.

График ниже объясняет алгоритм, который styleTransferLoss реализует, чтобы вычислить общую потерю.

Во-первых, функция передает входные изображения X, преобразованные изображения Y и изображение стиля S к предварительно обученной сетевой VGG-16. Эта предварительно обученная сеть извлекает несколько функций из этих изображений. Алгоритм затем вычисляет потерю содержимого с помощью пространственных функций входного изображения X и выходного изображения Y. Кроме того, он вычисляет потерю стиля с помощью стилистических функций выходного изображения Y и изображения стиля S. Наконец, он получает общую потерю, добавляя содержимое и потери стиля.

Потеря содержимого

Для каждого изображения в мини-пакете функция потери содержимого сравнивает функции оригинального изображения и преобразованного изображения, выводимого слоем relu_3_3. В частности, он вычисляет среднюю квадратную ошибку между активациями и возвращает среднюю потерю для мини-пакета:

lossContent=1Nn=1Nсредний([ϕ(Xn)-ϕ(Yn)]2),

где X содержит вход изображения, Y содержит преобразованные изображения, N размер мини-пакета, и ϕ() представляет активации, извлеченные на слое relu_3_3.

Потеря стиля

Чтобы вычислить потерю стиля, для каждого отдельного изображения в мини-пакете:

  1. Извлеките активации в слоях relu1_2, relu2_2, relu3_3 и relu4_3.

  2. Для каждой из четырех активаций ϕj вычислить матрицу Грамма G(ϕj).

  3. Вычислите квадратное различие между соответствующими матрицами Грамма.

  4. Сложите четыре выхода для каждого слоя j с предыдущего шага.

Чтобы получить потерю стиля для всего мини-пакета, вычислите среднее значение потери стиля для каждого изображения n в мини-пакете:

lossStyle=1Nn=1Nj=14[G(ϕj(Xn))-G(ϕj(S))]2,

где j - индекс слоя, и G() - матрица граммов.

Общая потеря

function [loss,lossContent,lossStyle] = styleTransferLoss(dlnetLoss,dlY,dlX, ...
    dlSGram,weightContent,weightStyle)

% Extract activations.
dlYActivations = cell(1,4);
[dlYActivations{1},dlYActivations{2},dlYActivations{3},dlYActivations{4}] = ...
    forward(dlnetLoss,dlY,'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

dlXActivations = forward(dlnetLoss,dlX,'Outputs','relu3_3');

% Calculate the mean square error between activations.
lossContent = mean((dlYActivations{3} - dlXActivations).^2,'all');

% Add up the losses for all the four activations.
lossStyle = 0;
for j = 1:4
    G = createGramMatrix(dlYActivations{j});
    lossStyle = lossStyle + sum((G - dlSGram{j}).^2,'all');
end

% Average the loss over the mini-batch.
miniBatchSize = size(dlX,4);
lossStyle = lossStyle/miniBatchSize;

% Apply weights.
lossContent = weightContent * lossContent;
lossStyle = weightStyle * lossStyle;

% Calculate the total loss.
loss = lossContent + lossStyle;

end

Остаточный блок

The residualBlock функция возвращает массив из шести слоев. Он состоит из слоев свертки, слоев нормализации образцов, слоя ReLu и слоя сложения. Обратите внимание, что groupNormalizationLayer('channel-wise') является просто слоем нормализации образцов.

function layers = residualBlock(name)

layers = [    
    convolution2dLayer([3 3], 128, 'Stride', 1,'Padding','same','Name', "convRes"+name+"_1")
    groupNormalizationLayer('channel-wise','Name',"normRes"+name+"_1")
    reluLayer('Name', "reluRes"+name+"_1")
    convolution2dLayer([3 3], 128, 'Stride', 1,'Padding','same', 'Name', "convRes"+name+"_2")
    groupNormalizationLayer('channel-wise','Name',"normRes"+name+"_2")
    additionLayer(2,'Name',"add"+name)];

end

Грамм- Матрица

Функция createGramMatrix принимает за вход активации одного слоя и возвращает стилистическое представление для каждого изображения в мини-пакете .Вход является картой функций размера [H, W, C, N], где H - высота, W - ширина, C - количество каналов и N - размер мини-пакета. Функция выводит массив G размера [C, C, N]. Каждая подрешетка G(:,:,k) - матрица Грамма, соответствующая kth изображение в мини-пакете. Каждая запись G(i,j,k) матрицы Грамма представляет корреляцию между каналами ci и cj, потому что каждая запись в канале ci умножает запись в соответствующем положении в канале cj:

G(i,j,k)=1C×H×Wh=1Hw=1Wϕk(h,w,ci)ϕk(h,w,cj),

где ϕk являются активациями для kth изображение в мини-пакете.

Матрица Gram содержит информацию о том, какие функции активируются вместе, но не имеет информации о том, где функции происходят в изображении. Это связано с тем, что суммирование по высоте и ширине теряет информацию о пространственной структуре. Функция потерь использует эту матрицу как стилистическое представление изображения.

function G = createGramMatrix(activations)

[h,w,numChannels] = size(activations,1:3);

features = reshape(activations,h*w,numChannels,[]);
featuresT = permute(features,[2 1 3]);

G = dlmtimes(featuresT,features) / (h*w*numChannels);

end

Ссылки

  1. Джонсон, Джастин, Александр Алахи и Ли Фэй-Фэй. «Ощутимые потери для передачи стиля в реальном времени и суперразрешение». Европейская конференция по компьютерному зрению. Спрингер, Чэм, 2016.

См. также

| | | | | |

Похожие темы