exponenta event banner

Сеть быстрой передачи стилей поездов

В этом примере показано, как обучить сеть переносить стиль изображения на второе изображение. Он основан на архитектуре, определенной в [1].

Этот пример похож на Neural Style Transfer Using Deep Learning, но он работает быстрее после того, как вы обучили сеть изображению стиля S. Это потому, что для получения стилизованного изображения Y нужно только сделать прямой проход входного изображения X в сеть.

Ниже приведена высокоуровневая схема алгоритма обучения. Для вычисления потерь используются три изображения: входное изображение X, преобразованное изображение Y и изображение стиля S.

Обратите внимание, что функция потери использует предварительно подготовленные сетевые VGG-16 для извлечения элементов из изображений. Его реализацию и математическое определение можно найти в разделе «Потери при переносе стилей» этого примера.

Загрузка данных обучения

Загрузите и извлеките изображения поезда COCO 2014 и подписи из http://cocodataset.org/#download, нажав кнопку «Изображения поезда 2014». Сохранить данные в папке, указанной imageFolder. Извлечение изображений в imageFolder. COCO 2014 была собрана Консорциумом Coco.

Создайте каталоги для хранения набора данных COCO.

imageFolder = fullfile(tempdir,"coco");
if ~exist(imageFolder,'dir')
    mkdir(imageFolder);
end

Создайте хранилище данных образа, содержащее изображения COCO.

imds = imageDatastore(imageFolder,'IncludeSubfolders',true);

Обучение может занять много времени. Если вы хотите сократить время обучения за счет точности результирующей сети, выберите подмножество хранилища данных образа путем установки fraction до меньшего значения.

fraction = 1;
numObservations = numel(imds.Files);
imds = subset(imds,1:floor(numObservations*fraction));

Чтобы изменить размер изображений и преобразовать их все в RGB, создайте хранилище данных дополненного изображения.

augimds = augmentedImageDatastore([256 256],imds,'ColorPreprocessing',"gray2rgb");

Прочтите изображение стиля.

styleImage = imread('starryNight.jpg');
styleImage = imresize(styleImage,[256 256]);

Отображение выбранного изображения стиля.

figure
imshow(styleImage)
title("Style Image")

Определение сети трансформаторов изображений

Определите сеть трансформатора изображений. Это сеть «образ-образ». Сеть состоит из 3 частей:

  1. Первая часть сети принимает в качестве входных данных изображение RGB размером [256x256x3] и понижает его выборку до карты характеристик размером [64x64x128].

  2. Вторая часть сети состоит из пяти идентичных остаточных блоков, определенных в поддерживающей функции residualBlock.

  3. Третья и последняя часть сети увеличивает выборку карты признаков до исходного размера изображения и возвращает преобразованное изображение. В этой последней части используется upsampleLayer, который является пользовательским слоем, присоединенным к этому примеру в качестве вспомогательного файла.

layers = [
    
    % First part.
    imageInputLayer([256 256 3], 'Name', 'input', 'Normalization','none')
    
    convolution2dLayer([9 9], 32, 'Padding','same','Name', 'conv1')
    groupNormalizationLayer('channel-wise','Name','norm1')
    reluLayer('Name', 'relu1')
    
    convolution2dLayer([3 3], 64, 'Stride', 2,'Padding','same','Name', 'conv2')
    groupNormalizationLayer('channel-wise' ,'Name','norm2')
    reluLayer('Name', 'relu2')
    
    convolution2dLayer([3 3], 128, 'Stride', 2, 'Padding','same','Name', 'conv3')
    groupNormalizationLayer('channel-wise' ,'Name','norm3')
    reluLayer('Name', 'relu3')
    
    % Second part. 
    residualBlock("1")
    residualBlock("2")
    residualBlock("3")
    residualBlock("4")
    residualBlock("5")
    
    % Third part.
    upsampleLayer('up1')
    convolution2dLayer([3 3], 64, 'Stride', 1, 'Padding','same','Name', 'upconv1')
    groupNormalizationLayer('channel-wise' ,'Name','norm6')
    reluLayer('Name', 'relu5')
    
    upsampleLayer('up2')
    convolution2dLayer([3 3], 32, 'Stride', 1, 'Padding','same','Name', 'upconv2')
    groupNormalizationLayer('channel-wise' ,'Name','norm7')
    reluLayer('Name', 'relu6')
    
    convolution2dLayer(9,3,'Padding','same','Name','conv_out')];

lgraph = layerGraph(layers);

Добавление отсутствующих соединений в остаточные блоки.

lgraph = connectLayers(lgraph,"relu3","add1/in2");
lgraph = connectLayers(lgraph,"add1","add2/in2");
lgraph = connectLayers(lgraph,"add2","add3/in2");
lgraph = connectLayers(lgraph,"add3","add4/in2");
lgraph = connectLayers(lgraph,"add4","add5/in2");

Визуализация сети трансформаторов изображений на графике.

figure
plot(lgraph)
title("Transform Network")

Создать dlnetwork объект из графа слоев.

dlnetTransform = dlnetwork(lgraph);

Сеть потери стиля

В этом примере используется заранее обученная VGG-16 глубокая нейронная сеть для извлечения особенностей содержимого и изображений стилей на различных слоях. Эти многослойные элементы используются для вычисления соответствующих потерь содержимого и стиля.

Чтобы получить предварительно обученную сеть VGG-16, используйте vgg16 функция. Если необходимые пакеты поддержки не установлены, программа предоставляет ссылку для загрузки.

netLoss = vgg16;

Для извлечения элемента, необходимого для вычисления потерь, необходимы только первые 24 слоя. Извлеките и преобразуйте в график слоев.

lossLayers = netLoss.Layers(1:24);
lgraph = layerGraph(lossLayers);

Преобразовать в dlnetwork.

dlnetLoss = dlnetwork(lgraph);

Определение функции потерь и матрицы грамма

Создать styleTransferLoss функция, определенная в разделе «Потери при переносе стилей» данного примера.

Функция styleTransferLoss принимает в качестве входных данных сеть с потерями dlnetLoss, мини-пакет входных преобразованных изображений dlX, мини-пакет преобразованных изображений dlY, массив, содержащий матрицы Gram изображения стиля dlSGram, вес, связанный с потерей содержимого contentWeight и вес, связанный с потерей стиля styleWeight. Функция возвращает общий убыток loss и отдельные компоненты: потеря содержимого lossContent и потери стиля lossStyle.

styleTransferLoss функция использует вспомогательную функцию createGramMatrix при вычислении потери стиля.

createGramMatrix функция принимает в качестве входных данных функции, извлеченные сетью потерь, и возвращает стилистическое представление для каждого изображения в мини-пакете. Реализацию и математическое определение матрицы Gram можно найти в разделе Матрица Gram.

Определение функции градиентов модели

Создание функции modelGradients, перечисленных в разделе «Функция градиентов модели» в примере. Эта функция используется в качестве входного сигнала сети с потерями dlnetLoss, сеть трансформатора изображений dlnetTransform, мини-пакет входных изображений dlX, массив, содержащий матрицы Gram изображения стиля dlSGram, вес, связанный с потерей содержимого contentWeight и вес, связанный с потерей стиля styleWeight. Функция возвращает значение gradients потери по отношению к обучаемым параметрам трансформатора изображения, состоянию сети трансформатора изображения, преобразованным изображениям dlY, общий убыток loss, потери, связанные с содержимым lossContent и потери, связанные со стилем lossStyle.

Укажите параметры обучения

Поезд с размером мини-партии 4 на 2 эпохи, как в [1].

numEpochs = 2;
miniBatchSize = 4;

Задайте размер чтения хранилища данных дополненного изображения, равный размеру мини-пакета.

augimds.MiniBatchSize = miniBatchSize;

Укажите параметры оптимизации ADAM. Укажите скорость обучения 0,001 с коэффициентом градиентного затухания 0,01 и коэффициентом градиентного затухания 0,999 в квадрате.

learnRate = 0.001;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучение на GPU, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

executionEnvironment = "auto";

Укажите вес, заданный для потери стиля, и вес, заданный для потери содержимого при расчете общей потери.

Обратите внимание, что для того, чтобы найти хороший баланс между содержанием и потерей стиля, может потребоваться экспериментировать с различными комбинациями весов.

weightContent = 1e-4;
weightStyle = 3e-8; 

Выберите частоту графика выполнения обучения. Это указывает количество итераций между каждым обновлением печати.

plotFrequency = 10;

Модель поезда

Чтобы иметь возможность вычислить потери во время обучения, вычислите матрицы Gram для изображения стиля.

Преобразование изображения стиля в dlarray.

dlS = dlarray(single(styleImage),'SSC');

Чтобы вычислить матрицу Gram, подайте изображение стиля в VGG-16 сеть и извлеките активации на четырех различных слоях.

[dlSActivations1,dlSActivations2,dlSActivations3,dlSActivations4] = forward(dlnetLoss,dlS, ...
    'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

Вычислите матрицу Gram для каждого набора активизаций с помощью поддерживающей функции createGramMatrix.

dlSGram{1} = createGramMatrix(dlSActivations1);
dlSGram{2} = createGramMatrix(dlSActivations2);
dlSGram{3} = createGramMatrix(dlSActivations3);
dlSGram{4} = createGramMatrix(dlSActivations4);

Учебные планы состоят из двух цифр:

  1. Рисунок, показывающий график потерь во время тренировки

  2. Фигура, содержащая входное и выходное изображение сети трансформатора изображения

Инициализируйте графики обучения. Подробные данные инициализации можно проверить в вспомогательной функции. initializeFigures. Эта функция возвращает: ось ax1 где вы строите график потерь, ось ax2 при печати подтверждающих изображений анимированная линия lineLossContent которая содержит потерю содержимого, анимированную строку lineLossStyle который содержит потерю стиля и анимированную линию lineLossTotal который содержит общий убыток.

[ax1,ax2,lineLossContent,lineLossStyle,lineLossTotal]=initializeStyleTransferPlots();

Инициализируйте гиперпараметры среднего градиента и среднего квадрата градиента для оптимизатора ADAM.

averageGrad = [];
averageSqGrad = [];

Вычислите общее количество итераций обучения.

numIterations = floor(augimds.NumObservations*numEpochs/miniBatchSize);

Инициализируйте номер итерации и таймер перед обучением.

iteration = 0;
start = tic;

Тренируйте модель. Это может занять много времени.

% Loop over epochs.
for i = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Extract the images from data store into a cell array.
        images = data{:,1};
        
        % Concatenate the images along the 4th dimension.
        X = cat(4,images{:});
        X = single(X);
        
        % Convert mini-batch of data to dlarray and specify the dimension labels
        % 'SSCB' (spatial, spatial, channel, batch).
        dlX = dlarray(X, 'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and the network state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradients,state,dlY,loss,lossContent,lossStyle] = dlfeval(@modelGradients, ...
            dlnetLoss,dlnetTransform,dlX,dlSGram,weightContent,weightStyle);
        
        dlnetTransform.State = state;
        
        % Update the network parameters.
        [dlnetTransform,averageGrad,averageSqGrad] = ...
            adamupdate(dlnetTransform,gradients,averageGrad,averageSqGrad,iteration,...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
              
        
        % Every plotFequency iterations, plot the training progress.
        if mod(iteration,plotFrequency) == 0
            addpoints(lineLossTotal,iteration,double(gather(extractdata(loss))))
            addpoints(lineLossContent,iteration,double(gather(extractdata(lossContent))))
            addpoints(lineLossStyle,iteration,double(gather(extractdata(lossStyle))))
            
            % Use the first image of the mini-batch as a validation image.
            dlV = dlX(:,:,:,1);
            % Use the transformed validation image computed previously.
            dlVY = dlY(:,:,:,1);
            
            % To use the function imshow, convert to uint8.
            validationImage = uint8(gather(extractdata(dlV)));
            transformedValidationImage = uint8(gather(extractdata(dlVY)));
            
            % Plot the input image and the output image and increase size
            imshow(imtile({validationImage,transformedValidationImage}),'Parent',ax2);
        end
        
        % Display time elapsed since start of training and training completion percentage.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        completionPercentage = round(iteration/numIterations*100,2);
        title(ax1,"Epoch: " + i + ", Iteration: " + iteration +" of "+ numIterations + "(" + completionPercentage + "%)" +", Elapsed: " + string(D))
        drawnow
        
    end
end

Стилизация изображения

После завершения обучения вы можете использовать трансформатор изображения на любом изображении по своему выбору.

Загрузите изображение, которое требуется преобразовать.

imFilename = 'peppers.png';
im = imread(imFilename);

Измените размер входного изображения на входные размеры трансформатора изображения.

im = imresize(im,[256,256]);

Преобразовать в dlarray.

dlX = dlarray(single(im),'SSCB');

Использование графического процессора для преобразования в gpuArray если он доступен.

if canUseGPU
    dlX = gpuArray(dlX);
end

Чтобы применить стиль к изображению, передайте его преобразователю изображения с помощью функции predict.

dlY = predict(dlnetTransform,dlX);

Измените масштаб изображения в диапазоне [0 255]. Сначала используйте функцию tanh повторно измерять dlY в диапазон [-1 1]. Затем сдвиньте и масштабируйте выходной сигнал для масштабирования в диапазон [0 255].

Y = 255*(tanh(dlY)+1)/2;

Подготовиться Y для печати. Используйте функцию extraxtdata для извлечения данных из dlarray.Используйте функцию «Собрать» для переноса Y из графического процессора в локальное рабочее пространство.

Y = uint8(gather(extractdata(Y)));

Отображение входного изображения (слева) рядом со стилизованным изображением (справа).

figure
m = imtile({im,Y});
imshow(m)

Функция градиентов модели

Функция modelGradients принимает в качестве входных данных сеть с потерями dlnetLoss, сеть трансформатора изображений dlnetTransform, мини-пакет входных изображений dlX, массив, содержащий матрицы Gram изображения стиля dlSGram, вес, связанный с потерей содержимого contentWeight и вес, связанный с потерей стиля styleWeight. Он возвращает gradients потери по отношению к обучаемым параметрам трансформатора изображения, состоянию сети трансформатора изображения, преобразованным изображениям dlY, общий убыток loss, потери, связанные с содержимым lossContent и потери, связанные со стилем lossStyle.

function [gradients,state,dlY,loss,lossContent,lossStyle] = ...
    modelGradients(dlnetLoss,dlnetTransform,dlX,dlSGram,contentWeight,styleWeight)

[dlY,state] = forward(dlnetTransform,dlX);

dlY = 255*(tanh(dlY)+1)/2;

[loss,lossContent,lossStyle] = styleTransferLoss(dlnetLoss,dlY,dlX,dlSGram,contentWeight,styleWeight);

gradients = dlgradient(loss,dlnetTransform.Learnables);

end

Потеря переноса стиля

Функция styleTransferLoss принимает в качестве входных данных сеть с потерями dlnetLoss, мини-пакет входных изображений dlX, мини-пакет преобразованных изображений dlY, массив, содержащий матрицы Gram изображения стиля dlSGram, веса, связанные с содержимым и стилем contentWeight и styleWeight, соответственно. Он возвращает общий убыток loss и отдельные компоненты: потеря содержимого lossContent и потери стиля lossStyle.

Потеря содержимого является мерой того, насколько велика разница в пространственной структуре между входным изображением X и выходные изображения Y.

С другой стороны, потеря стиля говорит вам, насколько разница в стилистическом облике между образом стиля S и выходное изображение Y.

График ниже объясняет алгоритм, что styleTransferLoss реализует для расчета общего убытка.

Сначала функция передает входные изображения X, преобразованные изображения Y и изображение стиля S в предварительно обученный сетевой VGG-16. Эта предварительно обученная сеть извлекает несколько функций из этих образов. Затем алгоритм вычисляет потерю содержимого, используя пространственные особенности входного изображения X и выходного изображения Y. Кроме того, он вычисляет потерю стиля, используя стилистические особенности выходного изображения Y и изображения стиля S. Наконец, он получает полную потерю, добавляя содержимое и потери стиля.

Потеря содержимого

Для каждого изображения в мини-пакете функция потери содержимого сравнивает характеристики исходного изображения и преобразованного изображения, выводимого слоем. relu_3_3. В частности, вычисляется среднеквадратическая ошибка между включениями и возвращается средняя потеря для мини-партии:

lossContent=1N∑n=1Nmean ([start( Xn) -start( Yn)] 2),

где X содержит входные изображения, Y содержит преобразованные изображения, N - размер мини-партии, и relu_3_3.

Потеря стиля

Чтобы вычислить потерю стиля, для каждого отдельного изображения в мини-пакете:

  1. Извлеките активации на слоях relu1_2, relu2_2, relu3_3 и relu4_3.

  2. Для каждой из четырех активизаций/j вычисляют Gram-матрицу/G/.

  3. Вычислите квадрат разности между соответствующими матрицами Gram.

  4. Сложите четыре вывода для каждого уровня j из предыдущего шага.

Чтобы получить потерю стиля для всего мини-пакета, вычислите среднее значение потери стиля для каждого изображения n в мини-пакете:

lossStyle=1N∑n=1N∑j=14 [G (ϕj (Xn))-G (ϕj (S))] 2,

где j - индекс слоя, а G () - Gram Matrix.

Общий убыток

function [loss,lossContent,lossStyle] = styleTransferLoss(dlnetLoss,dlY,dlX, ...
    dlSGram,weightContent,weightStyle)

% Extract activations.
dlYActivations = cell(1,4);
[dlYActivations{1},dlYActivations{2},dlYActivations{3},dlYActivations{4}] = ...
    forward(dlnetLoss,dlY,'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

dlXActivations = forward(dlnetLoss,dlX,'Outputs','relu3_3');

% Calculate the mean square error between activations.
lossContent = mean((dlYActivations{3} - dlXActivations).^2,'all');

% Add up the losses for all the four activations.
lossStyle = 0;
for j = 1:4
    G = createGramMatrix(dlYActivations{j});
    lossStyle = lossStyle + sum((G - dlSGram{j}).^2,'all');
end

% Average the loss over the mini-batch.
miniBatchSize = size(dlX,4);
lossStyle = lossStyle/miniBatchSize;

% Apply weights.
lossContent = weightContent * lossContent;
lossStyle = weightStyle * lossStyle;

% Calculate the total loss.
loss = lossContent + lossStyle;

end

Остаточный блок

residualBlock функция возвращает массив из шести слоев. Он состоит из слоев свертки, слоев нормализации экземпляра, слоя ReLu и слоя сложения. Обратите внимание, что groupNormalizationLayer('channel-wise') является просто уровнем нормализации экземпляра.

function layers = residualBlock(name)

layers = [    
    convolution2dLayer([3 3], 128, 'Stride', 1,'Padding','same','Name', "convRes"+name+"_1")
    groupNormalizationLayer('channel-wise','Name',"normRes"+name+"_1")
    reluLayer('Name', "reluRes"+name+"_1")
    convolution2dLayer([3 3], 128, 'Stride', 1,'Padding','same', 'Name', "convRes"+name+"_2")
    groupNormalizationLayer('channel-wise','Name',"normRes"+name+"_2")
    additionLayer(2,'Name',"add"+name)];

end

Граммовая матрица

Функция createGramMatrix принимает в качестве входных данных активации одного слоя и возвращает стилистическое представление для каждого изображения в мини-пакете. Входной сигнал представляет собой характерную карту размера [H, W, C, N], где H - высота, W - ширина, C - количество каналов и N - размер мини-партии. Функция выводит массив G размера [C, C, N]. Каждый подчистокG(:,:,k) - матрица Gram, соответствующая k-му изображению в мини-партии. Каждая запись G (i, j, k) матрицы Gram представляет корреляцию между каналами ci и cj, поскольку каждая запись в канале ci умножает запись в соответствующей позиции в канале cj:

G (i, j, k) =1C×H×W∑h=1H∑w=1Wϕk (h, w, ci) (h, w, cj),

где «k» - активации для k-го изображения в мини-партии.

Матрица Gram содержит информацию о том, какие элементы активизируются вместе, но не содержит информации о том, где элементы появляются на изображении. Это происходит потому, что суммирование по высоте и ширине теряет информацию о пространственной структуре. Функция потерь использует эту матрицу в качестве стилистического представления изображения.

function G = createGramMatrix(activations)

[h,w,numChannels] = size(activations,1:3);

features = reshape(activations,h*w,numChannels,[]);
featuresT = permute(features,[2 1 3]);

G = dlmtimes(featuresT,features) / (h*w*numChannels);

end

Ссылки

  1. Джонсон, Джастин, Александр Алахи и Ли Фэй-Фэй. «Потери восприятия для передачи в режиме реального времени и сверхразрешения». Европейская конференция по компьютерному зрению. Спрингер, Чам, 2016.

См. также

| | | | | |

Связанные темы