Обучите Вассерштейна GAN со штрафом градиента (WGAN-GP)

В этом примере показано, как обучить Вассерштейна порождающая соперничающая сеть со штрафом градиента (WGAN-GP) генерировать изображения.

Порождающая соперничающая сеть (GAN) является типом нейронной сети для глубокого обучения, которая может сгенерировать данные с подобными характеристиками как вход действительные данные.

ГАНЬ состоит из двух сетей, которые обучаются вместе:

  1. Генератор —, Учитывая вектор из случайных значений (скрытые входные параметры), как введено, эта сеть генерирует данные с той же структурой как обучающие данные.

  2. Различитель — Данный пакеты данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "действительные" или "сгенерированные".

Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обоих:

  • Обучите генератор генерировать данные, которые "дурачат" различитель.

  • Обучите различитель различать действительные и сгенерированные данные.

Чтобы оптимизировать эффективность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы оптимизировать эффективность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Таким образом, цель различителя не состоит в том, чтобы "дурачить" генератор.

Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных. Однако [2] утверждает, что расхождения, которые обычно минимизируют GANs, потенциально не непрерывны относительно параметров генератора, ведя к учебной трудности, и вводит модель Wasserstein GAN (WGAN), которая использует утрату Вассерштейна, чтобы помочь стабилизировать обучение. Модель WGAN может все еще произвести плохие выборки или может не сходиться, потому что взаимодействия между ограничением веса и функцией стоимости могут привести к исчезновению или взрыву градиентов. Решать эти проблемы, [3] вводит штраф градиента, который улучшает устойчивость путем наложения штрафа на градиенты с большими значениями нормы за счет более длительного вычислительного времени. Этот тип модели известен как модель WGAN-GP.

В этом примере показано, как обучить модель WGAN-GP, которая может сгенерировать изображения с подобными характеристиками к набору обучающих данных изображений.

Загрузите обучающие данные

Загрузите и извлеките Цветочный набор данных [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте datastore изображений, содержащий фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true);

Увеличьте данные, чтобы включать случайное горизонтальное зеркальное отражение и изменить размер изображений, чтобы иметь размер 64 64.

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);

Задайте сеть различителя

Задайте следующую сеть, которая классифицирует действительный, и сгенерированный 64 64 отображает.

Создайте сеть, которая берет 64 64 3 изображениями и возвращает скалярный счет предсказания с помощью серии слоев свертки с нормализацией партии. и текучих слоев ReLU. Чтобы вывести вероятности в области значений [0,1], используйте сигмоидальный слой.

  • Для слоев свертки задайте фильтры 5 на 5 с растущим числом фильтров для каждого слоя. Также задайте шаг 2 и дополнение выхода.

  • Для текучих слоев ReLU задайте шкалу 0,2.

  • Для итогового слоя свертки задайте один фильтр 4 на 4.

numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersD = [
    imageInputLayer(inputSize,'Normalization','none','Name','in')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    layerNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    layerNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    layerNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')
    sigmoidLayer('Name','sigmoid')];

lgraphD = layerGraph(layersD);

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetD = dlnetwork(lgraphD);

Задайте сеть генератора

Задайте следующую сетевую архитектуру, которая генерирует изображения от 1 1 100 массивами случайных значений:

Эта сеть:

  • Преобразует случайные векторы из размера от 100 до 7 7 128 массивами с помощью проекта, и измените слой.

  • Увеличивает масштаб полученные массивы к 64 64 3 массивами с помощью серии транспонированных слоев свертки и слоев ReLU.

Задайте эту сетевую архитектуру как график слоев и задайте следующие сетевые свойства.

  • Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, шага 2 и обрезки выхода на каждом ребре.

  • Поскольку финал транспонировал слой свертки, задайте три фильтра 5 на 5, соответствующие трем каналам RGB сгенерированных изображений и выходному размеру предыдущего слоя.

  • В конце сети включайте tanh слой.

Чтобы спроектировать и изменить шумовой вход, используйте пользовательский слой projectAndReshapeLayer, присоединенный к этому примеру как вспомогательный файл. projectAndReshape слой увеличивает масштаб вход с помощью полностью связанной операции и изменяет выход к заданному размеру.

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersG = [
    featureInputLayer(numLatentInputs,'Normalization','none','Name','in')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphG = layerGraph(layersG);

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetG = dlnetwork(lgraphG);

Функции градиентов модели Define

Создайте функции modelGradientsD и modelGradientsG перечисленный в разделе Model Gradients Function примера, которые вычисляют градиенты различителя и потери генератора относительно настраиваемых параметров различителя и сетей генератора, соответственно.

Функциональный modelGradientsD берет в качестве входа генератор и различитель dlnetG и dlnetD, мини-пакет входных данных dlX, массив случайных значений dlZ, и lambda значение, используемое для штрафа градиента, и, возвращает градиенты потери относительно настраиваемых параметров в различителе и потери.

Функциональный modelGradientsG берет в качестве входа генератор и различитель dlnetG и dlnetD, и массив случайных значений dlZ, и возвращает градиенты потери относительно настраиваемых параметров в генераторе и потери.

Задайте опции обучения

Чтобы обучить модель WGAN-GP, необходимо обучить различитель большему количеству итераций, чем генератор. Другими словами, для каждой итерации генератора, необходимо обучить различитель нескольким итерациям.

Обучайтесь с мини-пакетным размером 64 для 10 000 итераций генератора. Для больших наборов данных вы можете должны быть обучаться для большего количества итераций.

miniBatchSize = 64;
numIterationsG = 10000;

Для каждой итерации генератора обучите различитель 5 итерациям.

numIterationsDPerG = 5;

За потерю WGAN-GP задайте значение lambda 10. Значение lambda управляет величиной штрафа градиента, добавленного к потере различителя.

lambda = 10;

Задайте опции для оптимизации Адама:

  • Для сети различителя задайте скорость обучения 0,0002.

  • Для сети генератора задайте скорость обучения 0,001.

  • Для обеих сетей задайте фактор затухания градиента 0 и фактор затухания градиента в квадрате 0,9.

learnRateD = 2e-4;
learnRateG = 1e-3;
gradientDecayFactor = 0;
squaredGradientDecayFactor = 0.9;

Отобразитесь сгенерированная валидация отображает каждые 20 итераций генератора.

validationFrequency = 20;

Обучите модель

Используйте minibatchqueue обработать и управлять мини-пакетами изображений. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера), чтобы перемасштабировать изображения в области значений [-1,1].

  • Отбросьте любые частичные мини-пакеты.

  • Формат данные изображения с размерностью маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Обучайтесь на графическом процессоре, если вы доступны. Когда 'OutputEnvironment' опция minibatchqueue 'auto', minibatchqueue преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox) (Parallel Computing Toolbox).

minibatchqueue объект, по умолчанию, преобразует данные в dlarray объекты с базовым типом single.

augimds.MiniBatchSize = miniBatchSize;
executionEnvironment = "auto";

mbq = minibatchqueue(augimds,...
    'MiniBatchSize',miniBatchSize,...
    'PartialMiniBatch','discard',...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat','SSCB',...
    'OutputEnvironment',executionEnvironment);

Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого массива случайных значений, чтобы ввести в генератор, а также график баллов.

Инициализируйте параметры для Адама.

trailingAvgD = [];
trailingAvgSqD = [];
trailingAvgG = [];
trailingAvgSqG = [];

Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого пакета фиксированных массивов случайных значений, поданных в генератор, и постройте сетевые баллы.

Создайте массив протянутых случайных значений.

numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,'single');

Преобразуйте данные в dlarray объекты и указывают, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

dlZValidation = dlarray(ZValidation,'CB');

Для обучения графического процессора преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

Инициализируйте графики процесса обучения. Создайте фигуру и измените размер его, чтобы иметь дважды ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте подграфик для сгенерированных изображений и сетевых баллов.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализируйте анимированные линии для графика потерь.

C = colororder;
lineLossD = animatedline(scoreAxes,'Color',C(1,:));
lineLossDUnregularized = animatedline(scoreAxes,'Color',C(2,:));
legend('With Gradient Penalty','Unregularized')
xlabel("Generator Iteration")
ylabel("Discriminator Loss")
grid on

Обучите модель WGAN-GP цикличным выполнением по мини-пакетам данных.

Для numIterationsDPerG итерации, обучите различитель только. Для каждого мини-пакета:

  • Оцените градиенты модели различителя с помощью dlfeval и modelGradientsD функция.

  • Обновите параметры сети различителя с помощью adamupdate функция.

После обучения различитель для numIterationsDPerG итерации, обучите генератор на одном мини-пакете.

  • Оцените градиенты модели генератора с помощью dlfeval и modelGradientsG функция.

  • Обновите параметры сети генератора с помощью adamupdate функция.

После обновления сети генератора:

  • Постройте потери этих двух сетей.

  • После каждого validationFrequency итерации генератора, отображение пакет сгенерированных изображений для фиксированного протянул вход генератора.

После прохождения через набор данных переставьте мини-пакетную очередь.

Обучение может занять время, чтобы запуститься и может потребовать, чтобы много итераций вывели хорошие изображения.

iterationG = 0;
iterationD = 0;
start = tic;

% Loop over mini-batches
while iterationG < numIterationsG
    iterationG = iterationG + 1;

    % Train discriminator only
    for n = 1:numIterationsDPerG
        iterationD = iterationD + 1;

        % Reset and shuffle mini-batch queue when there is no more data.
        if ~hasdata(mbq)
            shuffle(mbq);
        end

        % Read mini-batch of data.
        dlX = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels 'CB' (channel, batch).
        Z = randn([numLatentInputs size(dlX,4)],'like',dlX);
        dlZ = dlarray(Z,'CB');

        % Evaluate the discriminator model gradients using dlfeval and the
        % modelGradientsD function listed at the end of the example.
        [gradientsD, lossD, lossDUnregularized] = dlfeval(@modelGradientsD, dlnetD, dlnetG, dlX, dlZ, lambda);

        % Update the discriminator network parameters.
        [dlnetD,trailingAvgD,trailingAvgSqD] = adamupdate(dlnetD, gradientsD, ...
            trailingAvgD, trailingAvgSqD, iterationD, ...
            learnRateD, gradientDecayFactor, squaredGradientDecayFactor);
    end

    % Generate latent inputs for the generator network. Convert to dlarray
    % and specify the dimension labels 'CB' (channel, batch).
    Z = randn([numLatentInputs size(dlX,4)],'like',dlX);
    dlZ = dlarray(Z,'CB');

    % Evaluate the generator model gradients using dlfeval and the
    % modelGradientsG function listed at the end of the example.
    gradientsG = dlfeval(@modelGradientsG, dlnetG, dlnetD, dlZ);

    % Update the generator network parameters.
    [dlnetG,trailingAvgG,trailingAvgSqG] = adamupdate(dlnetG, gradientsG, ...
        trailingAvgG, trailingAvgSqG, iterationG, ...
        learnRateG, gradientDecayFactor, squaredGradientDecayFactor);

    % Every validationFrequency generator iterations, display batch of
    % generated images using the held-out generator input
    if mod(iterationG,validationFrequency) == 0 || iterationG == 1
        % Generate images using the held-out generator input.
        dlXGeneratedValidation = predict(dlnetG,dlZValidation);

        % Tile and rescale the images in the range [0 1].
        I = imtile(extractdata(dlXGeneratedValidation));
        I = rescale(I);

        % Display the images.
        subplot(1,2,1);
        image(imageAxes,I)
        xticklabels([]);
        yticklabels([]);
        title("Generated Images");
    end

    % Update the scores plot
    subplot(1,2,2)

    lossD = double(gather(extractdata(lossD)));
    lossDUnregularized = double(gather(extractdata(lossDUnregularized)));
    addpoints(lineLossD,iterationG,lossD);
    addpoints(lineLossDUnregularized,iterationG,lossDUnregularized);

    D = duration(0,0,toc(start),'Format','hh:mm:ss');
    title( ...
        "Iteration: " + iterationG + ", " + ...
        "Elapsed: " + string(D))
    drawnow
end

Здесь, различитель изучил представление сильной черты, которое идентифицирует действительные изображения среди сгенерированных изображений. В свою очередь генератор изучил представление столь же сильной черты, которое позволяет ему генерировать изображения, похожие на обучающие данные.

Сгенерируйте новые изображения

Чтобы сгенерировать новые изображения, используйте predict функция на генераторе с dlarray объект, содержащий пакет случайных векторов. Чтобы отобразить изображения вместе, используйте imtile функционируйте и перемасштабируйте изображения с помощью rescale функция.

Создайте dlarray объект, содержащий пакет 25 случайных векторов, чтобы ввести к сети генератора.

ZNew = randn(numLatentInputs,25,'single');
dlZNew = dlarray(ZNew,'CB');

Чтобы сгенерировать изображения с помощью графического процессора, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

Сгенерируйте новые изображения с помощью predict функция с генератором и входными данными.

dlXGeneratedNew = predict(dlnetG,dlZNew);

Отобразите изображения.

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

Функция градиентов модели различителя

Функциональный modelGradientsD берет в качестве входа генератор и различитель dlnetwork объекты dlnetG и dlnetD, мини-пакет входных данных dlX, массив случайных значений dlZ, и lambda значение, используемое для штрафа градиента, и, возвращает градиенты потери относительно настраиваемых параметров в различителе и потери.

Учитывая изображение X, сгенерированное изображение XDefine Xˆ=ϵX+(1-ϵ)X для некоторых случайных ϵU(0,1).

Для модели WGAN-GP, учитывая значение lambda λ, потерей различителя дают

lossD=Y-Y+λ(XˆYˆ2-1)2,

где Y, Y, и Yˆ обозначьте выход различителя для входных параметров X, X, и Xˆ, соответственно, и XˆYˆ обозначает градиенты выхода Yˆ относительно Xˆ. Для мини-пакета данных используйте различное значение ϵ для каждого obersvation и вычисляют среднюю потерю.

Штраф градиента λ(XˆYˆ2-1)2 улучшает устойчивость путем наложения штрафа на градиенты с большими значениями нормы. Значение lambda управляет величиной штрафа градиента, добавленного к потере различителя.

function [gradientsD, lossD, lossDUnregularized] = modelGradientsD(dlnetD, dlnetG, dlX, dlZ, lambda)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetD, dlX);

% Calculate the predictions for generated data with the discriminator
% network.
dlXGenerated = forward(dlnetG,dlZ);
dlYPredGenerated = forward(dlnetD, dlXGenerated);

% Calculate the loss.
lossDUnregularized = mean(dlYPredGenerated - dlYPred);

% Calculate and add the gradient penalty. 
epsilon = rand([1 1 1 size(dlX,4)],'like',dlX);
dlXHat = epsilon.*dlX + (1-epsilon).*dlXGenerated;
dlYHat = forward(dlnetD, dlXHat);

% Calculate gradients. To enable computing higher-order derivatives, set
% 'EnableHigherDerivatives' to true.
gradientsHat = dlgradient(sum(dlYHat),dlXHat,'EnableHigherDerivatives',true);
gradientsHatNorm = sqrt(sum(gradientsHat.^2,1:3) + 1e-10);
gradientPenalty = lambda.*mean((gradientsHatNorm - 1).^2);

% Penalize loss.
lossD = lossDUnregularized + gradientPenalty;

% Calculate the gradients of the penalized loss with respect to the
% learnable parameters.
gradientsD = dlgradient(lossD, dlnetD.Learnables);

end

Функция градиентов модели генератора

Функциональный modelGradientsG берет в качестве входа генератор и различитель dlnetwork объекты dlnetG и dlnetD, и массив случайных значений dlZ, и возвращает градиенты потери относительно настраиваемых параметров в генераторе.

Учитывая сгенерированное изображение X, потерей для сети генератора дают

lossG=-Y,

где Y обозначает выход различителя для сгенерированного изображения X. Для мини-пакета сгенерированных изображений вычислите среднюю потерю.

function gradientsG =  modelGradientsG(dlnetG, dlnetD, dlZ)

% Calculate the predictions for generated data with the discriminator
% network.
dlXGenerated = forward(dlnetG,dlZ);
dlYPredGenerated = forward(dlnetD, dlXGenerated);

% Calculate the loss.
lossG = -mean(dlYPredGenerated);

% Calculate the gradients of the loss with respect to the learnable
% parameters.
gradientsG = dlgradient(lossG, dlnetG.Learnables);

end

Функция предварительной обработки мини-пакета

preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные изображения из входного массива ячеек и конкатенируйте в числовой массив.

  2. Перемасштабируйте изображения, чтобы быть в области значений [-1,1].

function X = preprocessMiniBatch(data)

% Concatenate mini-batch
X = cat(4,data{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,'InputMin',0,'InputMax',255);

end

Ссылки

  1. Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Arjovsky, Мартин, Сумит Чинтэла и Леон Ботту. "Вассерштейн GAN". arXiv предварительно распечатывают arXiv:1701.07875 (2017).

  3. Gulrajani, Ishaan, Фэрук Ахмед, Мартин Арйовский, Винсент Думулин и Аарон К. Коервилл. "Улучшенное обучение Вассерштейна GANs". В Усовершенствованиях в нейронных системах обработки информации, стр 5767-5777. 2017.

Смотрите также

| | | | | |

Похожие темы