Нейронная передача стиля Используя глубокое обучение

В этом примере показано, как применить стилистический внешний вид одного изображения к содержимому сцены второго изображения с помощью предварительно обученной сети VGG-19 [1].

Загрузка данных

Загрузите изображение стиля и изображение содержимого. Этот пример использует отличительного Ван Гога, рисующего "Звездную Ночь" как изображение стиля и фотография маяка как довольное изображение.

styleImage = im2double(imread('starryNight.jpg'));
contentImage = imread('lighthouse.png');

Отобразите изображение стиля и изображение содержимого как монтаж.

imshow(imtile({styleImage,contentImage},'BackgroundColor','w'));

Загрузите сеть извлечения признаков

В этом примере вы используете модифицированную предварительно обученную глубокую нейронную сеть VGG-19, чтобы извлечь функции содержимого и изображения стиля на различных слоях. Эти многоуровневые функции используются для расчета соответствующее содержимое и разрабатывают потери. Сеть генерирует стилизованное изображение передачи с помощью объединенной потери.

Чтобы получить предварительно обученную сеть VGG-19, установите vgg19 (Deep Learning Toolbox). Если вам не установили необходимые пакеты поддержки, то программное обеспечение обеспечивает ссылку на загрузку.

net = vgg19;

Чтобы сделать сеть VGG-19 подходящей для извлечения признаков, удалите все полносвязные слоя от сети.

lastFeatureLayerIdx = 38;
layers = net.Layers;
layers = layers(1:lastFeatureLayerIdx);

Макс. слои объединения сети VGG-19 вызывают исчезающий эффект. Чтобы уменьшить исчезающий эффект и увеличить поток градиента, замените все макс. слои объединения на средние слои объединения [1].

for l = 1:lastFeatureLayerIdx
    layer = layers(l);
    if isa(layer,'nnet.cnn.layer.MaxPooling2DLayer')
        layers(l) = averagePooling2dLayer(layer.PoolSize,'Stride',layer.Stride,'Name',layer.Name);
    end
end

Создайте график слоев с модифицированными слоями.

lgraph = layerGraph(layers);

Визуализируйте сеть извлечения признаков в графике.

plot(lgraph)
title('Feature Extraction Network')

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnet = dlnetwork(lgraph);

Предварительная Обработка Данных

Измените размер изображения стиля и изображения содержимого к меньшему размеру для более быстрой обработки.

imageSize = [384,512];
styleImg = imresize(styleImage,imageSize);
contentImg = imresize(contentImage,imageSize);

Предварительно обученная сеть VGG-19 выполняет классификацию на вычтенном изображении мудрого каналом среднего значения. Получите мудрое каналом среднее значение от входного слоя изображений, который является первым слоем в сети.

imgInputLayer = lgraph.Layers(1);
meanVggNet = imgInputLayer.Mean(1,1,:);

Значения мудрого каналом среднего значения подходят для изображений типа данных с плавающей запятой с пиксельными значениями в области значений [0, 255]. Преобразуйте изображение стиля и изображение содержимого к типу данных single с областью значений [0, 255]. Затем вычтите мудрое каналом среднее значение из изображения стиля и изображения содержимого.

styleImg = rescale(single(styleImg),0,255) - meanVggNet;
contentImg = rescale(single(contentImg),0,255) - meanVggNet;

Инициализируйте изображение передачи

Изображение передачи является выходным изображением в результате передачи стиля. Можно инициализировать изображение передачи изображением стиля, изображением содержимого или любым случайным изображением. Инициализация с изображением стиля или изображением содержимого смещает процесс переноса стиля и производит изображение передачи, более похожее на входное изображение. В отличие от этого инициализация с белым шумом удаляет смещение, но занимает больше времени, чтобы сходиться на стилизованном изображении. Для лучшей стилизации и более быстрой сходимости, этот пример инициализирует выходное изображение передачи как взвешенную комбинацию довольного изображение и изображение белого шума.

noiseRatio = 0.7;
randImage = randi([-20,20],[imageSize 3]);
transferImage = noiseRatio.*randImage + (1-noiseRatio).*contentImg;

Задайте функции потерь и разработайте параметры передачи

Потеря содержимого

Цель потери содержимого состоит в том, чтобы заставить функции изображения передачи совпадать с функциями довольного изображение. Потеря содержимого вычисляется как среднеквадратическое различие между довольным функции изображений и функциями передачи изображений каждого довольного слой [1] функции. Yˆ предсказанная карта функции для изображения передачи и Y предсказанная карта функции для довольного изображение. Wcl вес содержательного слоя для lth слой. H,W,Cвысота, ширина и каналы карт функции, соответственно.

Lcontent=lWcl×1HWCi,j(Yˆi,jl-Yi,jl)2

Задайте имена слоя извлечения признаков содержимого. Функции, извлеченные из этих слоев, используются для расчета потеря содержимого. В сети VGG-19 обучение является более эффективными функциями использования от более глубоких слоев, а не функциями от мелких слоев. Поэтому задайте слой извлечения признаков содержимого как четвертый сверточный слой.

styleTransferOptions.contentFeatureLayerNames = {'conv4_2'};

Задайте веса слоев извлечения признаков содержимого.

styleTransferOptions.contentFeatureLayerWeights = 1;

Разработайте потерю

Цель потери стиля состоит в том, чтобы заставить структуру изображения передачи совпадать со структурой изображения стиля. Представление стиля изображения представлено как матрица Грамма. Поэтому потеря стиля вычисляется как среднеквадратическое различие между матрицей Грамма изображения стиля и матрицей Грамма изображения передачи [1]. Z и Zˆ предсказанные карты функции для стиля и передают изображение, соответственно. GZ и GZˆ матрицы Грамма для функций стиля и передают функции, соответственно. Wsl вес слоя стиля для lth разработайте слой.

GZˆ=i,jZˆi,j×Zˆj,i

GZ=i,jZi,j×Zj,i

Lstyle=lWsl×1(2HWC)2(GZˆl-GZl)2

Задайте имена слоев извлечения признаков стиля. Функции, извлеченные из этих слоев, используются для расчета потеря стиля.

styleTransferOptions.styleFeatureLayerNames = {'conv1_1','conv2_1','conv3_1','conv4_1','conv5_1'};

Задайте веса слоев извлечения признаков стиля. Задайте маленькие веса для изображений простого стиля и увеличьте веса для комплексных изображений стиля.

styleTransferOptions.styleFeatureLayerWeights = [0.5,1.0,1.5,3.0,4.0];

Общая сумма убытков

Общая сумма убытков является взвешенной комбинацией потери содержимого и потери стиля. α и β весовые коэффициенты за потерю содержимого и разрабатывают потерю, соответственно.

Ltotal=α×Lcontent+β×Lstyle

Задайте весовые коэффициенты alpha и beta за потерю содержимого и потерю стиля. Отношение alpha к beta должен быть вокруг 1e-3 или 1e-4 [1].

styleTransferOptions.alpha = 1; 
styleTransferOptions.beta = 1e3;

Задайте опции обучения

Обучайтесь для 2 500 итераций.

numIterations = 2500;

Задайте опции для оптимизации Адама. Установите скорость обучения на 2 для более быстрой сходимости. Можно экспериментировать со скоростью обучения путем наблюдения выходного изображения и потерь. Инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с [].

learningRate = 2;
trailingAvg = [];
trailingAvgSq = [];

Обучите сеть

Преобразуйте изображение стиля, изображение содержимого, и передайте изображение dlarray (Deep Learning Toolbox) возражает с базовым типом single и размерность маркирует 'SSC'.

dlStyle = dlarray(styleImg,'SSC');
dlContent = dlarray(contentImg,'SSC');
dlTransfer = dlarray(transferImage,'SSC');

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Для обучения графического процессора преобразуйте данные в gpuArray.

if canUseGPU
    dlContent = gpuArray(dlContent);
    dlStyle = gpuArray(dlStyle);
    dlTransfer = gpuArray(dlTransfer);
end

Извлеките функции содержимого из довольного изображение.

numContentFeatureLayers = numel(styleTransferOptions.contentFeatureLayerNames);
contentFeatures = cell(1,numContentFeatureLayers);
[contentFeatures{:}] = forward(dlnet,dlContent,'Outputs',styleTransferOptions.contentFeatureLayerNames);

Извлеките функции стиля из изображения стиля.

numStyleFeatureLayers = numel(styleTransferOptions.styleFeatureLayerNames);
styleFeatures = cell(1,numStyleFeatureLayers);
[styleFeatures{:}] = forward(dlnet,dlStyle,'Outputs',styleTransferOptions.styleFeatureLayerNames);

Обучите модель с помощью пользовательского учебного цикла. Для каждой итерации:

  • Вычислите потерю содержимого и разработайте потерю, использующую функции довольного изображение, разработайте изображение и передайте изображение. Чтобы вычислить потерю и градиенты, используйте функцию помощника imageGradients (заданный в разделе Supporting Functions этого примера).

  • Обновите изображение передачи с помощью adamupdate (Deep Learning Toolbox) функция.

  • Выберите лучшее изображение передачи стиля как изображение окончательного результата.

figure

minimumLoss = inf;

for iteration = 1:numIterations
    % Evaluate the transfer image gradients and state using dlfeval and the
    % imageGradients function listed at the end of the example. 
    [grad,losses] = dlfeval(@imageGradients,dlnet,dlTransfer,contentFeatures,styleFeatures,styleTransferOptions);
    [dlTransfer,trailingAvg,trailingAvgSq] = adamupdate(dlTransfer,grad,trailingAvg,trailingAvgSq,iteration,learningRate);
  
    if losses.totalLoss < minimumLoss
        minimumLoss = losses.totalLoss;
        dlOutput = dlTransfer;        
    end   
    
    % Display the transfer image on the first iteration and after every 50
    % iterations. The postprocessing steps are described in the "Postprocess
    % Transfer Image for Display" section of this example.
    if mod(iteration,50) == 0 || (iteration == 1)
        
        transferImage = gather(extractdata(dlTransfer));
        transferImage = transferImage + meanVggNet;
        transferImage = uint8(transferImage);
        transferImage = imresize(transferImage,size(contentImage,[1 2]));
        
        image(transferImage)
        title(['Transfer Image After Iteration ',num2str(iteration)])
        axis off image
        drawnow
    end   
    
end

Постобработайте изображение передачи для отображения

Получите обновленное изображение передачи.

transferImage = gather(extractdata(dlOutput));

Добавьте обученное сетью среднее для изображения передачи.

transferImage = transferImage + meanVggNet;

Некоторые пиксельные значения могут превысить исходную область значений [0, 255] содержимого и разработать изображение. Можно отсечь значения к области значений [0, 255] путем преобразования типа данных в uint8.

transferImage = uint8(transferImage);

Измените размер изображения передачи к первоначальному размеру довольного изображение.

transferImage = imresize(transferImage,size(contentImage,[1 2]));

Отобразите довольное изображение, передайте изображение и разработайте изображение в монтаже.

imshow(imtile({contentImage,transferImage,styleImage}, ...
    'GridSize',[1 3],'BackgroundColor','w'));

Вспомогательные Функции

Вычислите потерю изображений и градиенты

imageGradients функция помощника возвращает потерю и градиенты, использующие функции довольного изображение, изображение стиля и изображение передачи.

function [gradients,losses] = imageGradients(dlnet,dlTransfer,contentFeatures,styleFeatures,params)
 
    % Initialize transfer image feature containers. 
    numContentFeatureLayers = numel(params.contentFeatureLayerNames);
    numStyleFeatureLayers = numel(params.styleFeatureLayerNames);
 
    transferContentFeatures = cell(1,numContentFeatureLayers);
    transferStyleFeatures = cell(1,numStyleFeatureLayers);
 
    % Extract content features of transfer image.
    [transferContentFeatures{:}] = forward(dlnet,dlTransfer,'Outputs',params.contentFeatureLayerNames);
     
    % Extract style features of transfer image.
    [transferStyleFeatures{:}] = forward(dlnet,dlTransfer,'Outputs',params.styleFeatureLayerNames);
 
    % Compute content loss. 
    cLoss = contentLoss(transferContentFeatures,contentFeatures,params.contentFeatureLayerWeights);
 
    % Compute style loss. 
    sLoss = styleLoss(transferStyleFeatures,styleFeatures,params.styleFeatureLayerWeights);
 
    % Compute final loss as weighted combination of content and style loss. 
    loss = (params.alpha * cLoss) + (params.beta * sLoss);
 
    % Calculate gradient with respect to transfer image.
    gradients = dlgradient(loss,dlTransfer);
    
    % Extract various losses. 
    losses.totalLoss = gather(extractdata(loss));
    losses.contentLoss = gather(extractdata(cLoss));
    losses.styleLoss = gather(extractdata(sLoss));
 
end

Вычислите потерю содержимого

contentLoss функция помощника вычисляет различие взвешенного среднего в квадрате между довольным функции изображений и функциями передачи изображений.

function loss = contentLoss(transferContentFeatures,contentFeatures,contentWeights)

    loss = 0;
    for i=1:numel(contentFeatures)
        temp = 0.5 .* mean((transferContentFeatures{1,i} - contentFeatures{1,i}).^2,'all');
        loss = loss + (contentWeights(i)*temp);
    end
end

Вычислите потерю стиля

styleLoss функция помощника вычисляет различие взвешенного среднего в квадрате между матрицей Грамма функций стиля изображений и матрицей Грамма функций передачи изображений.

function loss = styleLoss(transferStyleFeatures,styleFeatures,styleWeights)

    loss = 0;
    for i=1:numel(styleFeatures)
        
        tsf = transferStyleFeatures{1,i};
        sf = styleFeatures{1,i};    
        [h,w,c] = size(sf);
        
        gramStyle = computeGramMatrix(sf);
        gramTransfer = computeGramMatrix(tsf);
        sLoss = mean((gramTransfer - gramStyle).^2,'all') / ((h*w*c)^2);
        
        loss = loss + (styleWeights(i)*sLoss);
    end
end

Вычислите матрицу грамма

computeGramMatrix функция помощника используется styleLoss функция помощника, чтобы вычислить матрицу Грамма карты функции.

function gramMatrix = computeGramMatrix(featureMap)
    [H,W,C] = size(featureMap);
    reshapedFeatures = reshape(featureMap,H*W,C);
    gramMatrix = reshapedFeatures' * reshapedFeatures;
end

Ссылки

[1] Леон А. Гэтис, Александр С. Экер и Мэттиас Бетдж. "Нейронный Алгоритм Художественного стиля". Предварительно распечатайте, представленный 2 сентября 2015. https://arxiv.org/abs/1508.06576

Смотрите также

(Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте