Обучите GAN Вассерштейна с градиентным штрафом (WGAN-GP)

Этот пример показывает, как обучить генеративную состязательную сеть Вассерштейна с градиентным штрафом (WGAN-GP) для генерации изображений.

Генеративная состязательная сеть (GAN) является типом нейронной сети для глубокого обучения, которая может генерировать данные с такими же характеристиками, как входные реальные данные.

GAN состоит из двух сетей, которые обучаются вместе:

  1. Генератор - Учитывая вектор случайных значений (латентные входы) в качестве входных данных, эта сеть генерирует данные с той же структурой, что и обучающие данные.

  2. Дискриминатор - Учитывая пакеты данных, содержащих наблюдения как от обучающих данных, так и от сгенерированных данных от генератора, эта сеть пытается классифицировать наблюдения как «реальные» или «сгенерированные».

Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обеих:

  • Обучите генератор генерировать данные, которые «дурачат» дискриминатор.

  • Обучите дискриминатор различать реальные и сгенерированные данные.

Чтобы оптимизировать эффективность генератора, максимизируйте потерю дискриминатора, когда даны сгенерированные данные. То есть цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как «действительные». Чтобы оптимизировать эффективность дискриминатора, минимизируйте потерю дискриминатора, когда заданы пакеты как реальных, так и сгенерированных данных. То есть цель дискриминатора состоит в том, чтобы генератор не «обманул».

В идеале эти стратегии приводят к генератору, который генерирует убедительно реалистичные данные, и дискриминатору, который научился сильным представлениям функций, которые характерны для обучающих данных. Однако [2] утверждает, что расхождения, которые GAN обычно минимизируют, потенциально не непрерывны относительно параметров генератора, что приводит к сложности обучения, и представляет модель Wasserstein GAN (WGAN), которая использует потерю Вассерштейна, чтобы помочь стабилизировать обучение. Модель WGAN может все еще производить плохие выборки или не сходиться, потому что взаимодействия между ограничением веса и функцией стоимости могут привести к исчезновению или взрыванию градиентов. Для решения этих проблем [3] вводит градиентный штраф, который улучшает стабильность путем наказания градиентов с большими нормальными значениями за счет более длительного вычислительного времени. Этот тип модели известен как модель WGAN-GP.

В этом примере показано, как обучить модель WGAN-GP, которая может генерировать изображения с аналогичными характеристиками для набора обучающих данных изображений.

Загрузка обучающих данных

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте изображение datastore, содержащее фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true);

Увеличьте данные, чтобы включить случайное горизонтальное отражение и измените размер изображений, чтобы иметь размер 64 на 64.

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);

Определите сеть дискриминатора

Задайте следующую сеть, которая классифицирует реальные и сгенерированные изображения 64 на 64.

Создайте сеть, которая берет изображения 64 на 64 на 3 и возвращает скалярный счет предсказания, используя серию слоев скручивания с нормализацией партии. и прохудившихся слоев ReLU. Для вывода вероятностей в области значений [0,1] используйте сигмоидный слой.

  • Для слоев свертки задайте фильтры 5 на 5 с увеличением количества фильтров для каждого слоя. Также задайте шаг 2 и заполнение выхода.

  • Для утечек слоев ReLU задайте шкалу 0,2.

  • Для последнего слоя свертки задайте один фильтр 4 на 4.

numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersD = [
    imageInputLayer(inputSize,'Normalization','none','Name','in')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    layerNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    layerNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    layerNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')
    sigmoidLayer('Name','sigmoid')];

lgraphD = layerGraph(layersD);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.

dlnetD = dlnetwork(lgraphD);

Определите сеть генератора

Определите следующую сетевую архитектуру, которая производит изображения от массивов случайных значений 1 на 1 на 100:

Эта сеть:

  • Преобразовывает случайные векторы размера 100 к массивам 7 на 7 на 128, используя проект, и измените слой.

  • Upscales получающиеся массивы к массивам 64 на 64 на 3, используя серию перемещенных слоев скручивания и слоев ReLU.

Определите эту сетевую архитектуру как график слоев и задайте следующие свойства сети.

  • Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, полосой 2 и обрезкой выхода на каждом ребре.

  • Для последнего транспонированного слоя свертки задайте три фильтра 5 на 5, соответствующих трем каналам RGB сгенерированных изображений, и выходной размер предыдущего слоя.

  • В конце сети включают слой танха.

Чтобы проецировать и изменить форму входного сигнала шума, используйте пользовательский слой projectAndReshapeLayer, присоединенный к этому примеру как вспомогательный файл. The projectAndReshape слой увеличивает масштаб входа с помощью операции с полным подключением и изменяет форму выхода на заданный размер.

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersG = [
    featureInputLayer(numLatentInputs,'Normalization','none','Name','in')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphG = layerGraph(layersG);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.

dlnetG = dlnetwork(lgraphG);

Задайте функции градиентов модели

Создайте функции modelGradientsD и modelGradientsG перечисленные в разделе Model Gradients Function примера, которые вычисляют градиенты потерь дискриминатора и генератора относительно настраиваемых параметров сетей дискриминатора и генератора, соответственно.

Функция modelGradientsD принимает за вход генератор и дискриминатор dlnetG и dlnetDмини-пакет входных данных dlX, массив случайных значений dlZ, и lambda значение, используемое для градиентного штрафа, и возвращает градиенты потерь относительно настраиваемых параметров в дискриминаторе и потерь.

Функция modelGradientsG принимает за вход генератор и дискриминатор dlnetG и dlnetDи массив случайных значений dlZ, и возвращает градиенты потерь относительно настраиваемых параметров в генераторе и потерь.

Настройка опций обучения

Чтобы обучить модель WGAN-GP, вы должны обучить дискриминатор для большего количества итераций, чем генератор. Другими словами, для каждой итерации генератора необходимо обучить дискриминатор для нескольких итераций.

Обучите с мини-пакетом размером 64 для 10 000 итераций генератора. Для больших наборов данных вам, возможно, потребуется обучить для дополнительных итераций.

miniBatchSize = 64;
numIterationsG = 10000;

Для каждой итерации генератора обучите дискриминатор для 5 итераций.

numIterationsDPerG = 5;

Для потерь WGAN-GP укажите значение лямбды 10. Значение лямбды контролирует величину градиентного штрафа, добавленного к потере дискриминатора.

lambda = 10;

Задайте опции для оптимизации Адама:

  • Для сети дискриминатора задайте скорость обучения 0,0002.

  • Для сети генератора задайте скорость обучения 0,001.

  • Для обеих сетей задайте коэффициент градиентного распада 0 и квадратный коэффициент градиентного распада 0,9.

learnRateD = 2e-4;
learnRateG = 1e-3;
gradientDecayFactor = 0;
squaredGradientDecayFactor = 0.9;

Отображать сгенерированные изображения валидации каждые 20 итераций генератора.

validationFrequency = 20;

Обучите модель

Использование minibatchqueue для обработки и управления мини-пакетами изображений. Для каждого мини-пакета:

  • Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch (определено в конце этого примера), чтобы переформулировать изображения в области значений [-1,1].

  • Отменить любые частичные мини-пакеты.

  • Форматируйте данные изображения с помощью меток размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Обучите на графическом процессоре, если он доступен. Когда 'OutputEnvironment' опция minibatchqueue является 'auto', minibatchqueue преобразует каждый выход в gpuArray при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox) (Parallel Computing Toolbox).

The minibatchqueue объект по умолчанию преобразует данные в dlarray объекты с базовым типом single.

augimds.MiniBatchSize = miniBatchSize;
executionEnvironment = "auto";

mbq = minibatchqueue(augimds,...
    'MiniBatchSize',miniBatchSize,...
    'PartialMiniBatch','discard',...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat','SSCB',...
    'OutputEnvironment',executionEnvironment);

Обучите модель с помощью пользовательского цикла обучения. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений, используя удерживаемый массив случайных значений для ввода в генератор, а также график счетов.

Инициализируйте параметры для Adam.

trailingAvgD = [];
trailingAvgSqD = [];
trailingAvgG = [];
trailingAvgSqG = [];

Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью задержанного пакета фиксированных массивов случайных значений, подаваемых в генератор, и постройте графики счетов сети.

Создайте массив удерживаемых случайных значений.

numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,'single');

Преобразуйте данные в dlarray Объекты и задайте метки размерности 'SSCB' (пространственный, пространственный, канальный, пакетный).

dlZValidation = dlarray(ZValidation,'CB');

Для обучения графический процессор преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

Инициализируйте графики процесса обучения. Создайте рисунок и измените ее размер, чтобы она имела вдвое большую ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте подграфик для сгенерированных изображений и сетевых счетов.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализируйте анимированные линии для графика потерь.

C = colororder;
lineLossD = animatedline(scoreAxes,'Color',C(1,:));
lineLossDUnregularized = animatedline(scoreAxes,'Color',C(2,:));
legend('With Gradient Penanlty','Unregularized')
xlabel("Generator Iteration")
ylabel("Discriminator Loss")
grid on

Обучите модель WGAN-GP путем закольцовывания по мини-пакетам данных.

Для numIterationsDPerG итерации, обучите только дискриминатор. Для каждого мини-пакета:

  • Оцените градиенты модели дискриминатора с помощью dlfeval и modelGradientsD функция.

  • Обновите параметры сети дискриминатора, используя adamupdate функция.

После обучения дискриминатора на numIterationsDPerG итерации, обучите генератор на одном мини-пакете.

  • Оцените градиенты модели генератора с помощью dlfeval и modelGradientsG функция.

  • Обновите параметры сети генератора с помощью adamupdate функция.

После обновления сети генератора:

  • Постройте график потерь двух сетей.

  • После каждого validationFrequency итерации генератора, отображают пакет сгенерированных изображений для фиксированного входа задерживаемого генератора.

Пройдя через набор данных, перетащите мини-очередь пакетов.

Обучение может занять некоторое время, и может потребовать многих итераций для вывода хороших изображений.

iterationG = 0;
iterationD = 0;
start = tic;

% Loop over mini-batches
while iterationG < numIterationsG
    iterationG = iterationG + 1;

    % Train discriminator only
    for n = 1:numIterationsDPerG
        iterationD = iterationD + 1;

        % Reset and shuffle mini-batch queue when there is no more data.
        if ~hasdata(mbq)
            shuffle(mbq);
        end

        % Read mini-batch of data.
        dlX = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels 'CB' (channel, batch).
        Z = randn([numLatentInputs size(dlX,4)],'like',dlX);
        dlZ = dlarray(Z,'CB');

        % Evaluate the discriminator model gradients using dlfeval and the
        % modelGradientsD function listed at the end of the example.
        [gradientsD, lossD, lossDUnregularized] = dlfeval(@modelGradientsD, dlnetD, dlnetG, dlX, dlZ, lambda);

        % Update the discriminator network parameters.
        [dlnetD,trailingAvgD,trailingAvgSqD] = adamupdate(dlnetD, gradientsD, ...
            trailingAvgD, trailingAvgSqD, iterationD, ...
            learnRateD, gradientDecayFactor, squaredGradientDecayFactor);
    end

    % Generate latent inputs for the generator network. Convert to dlarray
    % and specify the dimension labels 'CB' (channel, batch).
    Z = randn([numLatentInputs size(dlX,4)],'like',dlX);
    dlZ = dlarray(Z,'CB');

    % Evaluate the generator model gradients using dlfeval and the
    % modelGradientsG function listed at the end of the example.
    gradientsG = dlfeval(@modelGradientsG, dlnetG, dlnetD, dlZ);

    % Update the generator network parameters.
    [dlnetG,trailingAvgG,trailingAvgSqG] = adamupdate(dlnetG, gradientsG, ...
        trailingAvgG, trailingAvgSqG, iterationG, ...
        learnRateG, gradientDecayFactor, squaredGradientDecayFactor);

    % Every validationFrequency generator iterations, display batch of
    % generated images using the held-out generator input
    if mod(iterationG,validationFrequency) == 0 || iterationG == 1
        % Generate images using the held-out generator input.
        dlXGeneratedValidation = predict(dlnetG,dlZValidation);

        % Tile and rescale the images in the range [0 1].
        I = imtile(extractdata(dlXGeneratedValidation));
        I = rescale(I);

        % Display the images.
        subplot(1,2,1);
        image(imageAxes,I)
        xticklabels([]);
        yticklabels([]);
        title("Generated Images");
    end

    % Update the scores plot
    subplot(1,2,2)

    lossD = double(gather(extractdata(lossD)));
    lossDUnregularized = double(gather(extractdata(lossDUnregularized)));
    addpoints(lineLossD,iterationG,lossD);
    addpoints(lineLossDUnregularized,iterationG,lossDUnregularized);

    D = duration(0,0,toc(start),'Format','hh:mm:ss');
    title( ...
        "Iteration: " + iterationG + ", " + ...
        "Elapsed: " + string(D))
    drawnow
end

Здесь дискриминатор выучил сильное представление функций, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор выучил аналогичное сильное представление функций, которое позволяет ему генерировать изображения, подобные обучающим данным.

Сгенерируйте новые изображения

Чтобы сгенерировать новые изображения, используйте predict функция на генераторе с dlarray объект, содержащий пакет случайных векторов. Чтобы отобразить изображения вместе, используйте imtile функциональность и изменение масштаба изображений с помощью rescale функция.

Создайте dlarray объект, содержащий пакет из 25 случайных векторов для ввода в сеть генератора.

ZNew = randn(numLatentInputs,25,'single');
dlZNew = dlarray(ZNew,'CB');

Чтобы сгенерировать изображения с помощью графический процессор, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

Сгенерируйте новые изображения с помощью predict функция с генератором и входными данными.

dlXGeneratedNew = predict(dlnetG,dlZNew);

Отобразите изображения.

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

Функция градиентов модели дискриминатора

Функция modelGradientsD принимает за вход генератор и дискриминатор dlnetwork объекты dlnetG и dlnetDмини-пакет входных данных dlX, массив случайных значений dlZ, и lambda значение, используемое для градиентного штрафа, и возвращает градиенты потерь относительно настраиваемых параметров в дискриминаторе и потерь.

Заданное изображение X, сгенерированное изображение X, задайте Xˆ=ϵX+(1-ϵ)X для некоторых случайных ϵU(0,1).

Для модели WGAN-GP, учитывая значение лямбды λ, потеря дискриминатора определяется

lossD=Y-Y+λ(XˆYˆ2-1)2,

где Y, Y, и Yˆ обозначить выход дискриминатора для входов X, X, и Xˆ, соответственно, и XˆYˆ обозначает градиенты выхода Yˆ по отношению к Xˆ. Для мини-пакета данных используйте другое значение ϵ для каждого оберсвата и вычислить среднюю потерю.

Градиентный штраф λ(XˆYˆ2-1)2 повышает стабильность за счет наказания градиентов с большими нормальными значениями. Значение лямбды контролирует величину градиентного штрафа, добавленного к потере дискриминатора.

function [gradientsD, lossD, lossDUnregularized] = modelGradientsD(dlnetD, dlnetG, dlX, dlZ, lambda)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetD, dlX);

% Calculate the predictions for generated data with the discriminator
% network.
dlXGenerated = forward(dlnetG,dlZ);
dlYPredGenerated = forward(dlnetD, dlXGenerated);

% Calculate the loss.
lossDUnregularized = mean(dlYPredGenerated - dlYPred);

% Calculate and add the gradient penalty. 
epsilon = rand([1 1 1 size(dlX,4)],'like',dlX);
dlXHat = epsilon.*dlX + (1-epsilon).*dlXGenerated;
dlYHat = forward(dlnetD, dlXHat);

% Calculate gradients. To enable computing higher-order derivatives, set
% 'EnableHigherDerivatives' to true.
gradientsHat = dlgradient(sum(dlYHat),dlXHat,'EnableHigherDerivatives',true);
gradientsHatNorm = sqrt(sum(gradientsHat.^2,1:3) + 1e-10);
gradientPenalty = lambda.*mean((gradientsHatNorm - 1).^2);

% Penalize loss.
lossD = lossDUnregularized + gradientPenalty;

% Calculate the gradients of the penalized loss with respect to the
% learnable parameters.
gradientsD = dlgradient(lossD, dlnetD.Learnables);

end

Функция градиентов модели генератора

Функция modelGradientsG принимает за вход генератор и дискриминатор dlnetwork объекты dlnetG и dlnetDи массив случайных значений dlZ, и возвращает градиенты потерь относительно настраиваемых параметров в генераторе.

Учитывая сгенерированное изображение X, потери для сети генератора даются

lossG=-Y,

где Y обозначает выход дискриминатора для сгенерированного изображения X. Для мини-пакета сгенерированных изображений вычислите среднюю потерю.

function gradientsG =  modelGradientsG(dlnetG, dlnetD, dlZ)

% Calculate the predictions for generated data with the discriminator
% network.
dlXGenerated = forward(dlnetG,dlZ);
dlYPredGenerated = forward(dlnetD, dlXGenerated);

% Calculate the loss.
lossG = -mean(dlYPredGenerated);

% Calculate the gradients of the loss with respect to the learnable
% parameters.
gradientsG = dlgradient(lossG, dlnetG.Learnables);

end

Функция мини-пакетной предварительной обработки

The preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные изображения из входного массива ячеек и соедините в числовой массив.

  2. Перерассчитайте изображения, которые будут находиться в области значений [-1,1].

function X = preprocessMiniBatch(data)

% Concatenate mini-batch
X = cat(4,data{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,'InputMin',0,'InputMax',255);

end

Ссылки

  1. Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Арьовский, Мартин, Сомит Чинтала и Леон Ботту. «Wasserstein GAN». arXiv preprint arXiv:1701.07875 (2017).

  3. Гюльраджани, Ишаан, Фарук Ахмед, Мартин Арьовский, Винсент Дюмулен и Аарон К. Курвиль. «Улучшенная подготовка Wasserstein GANs». В усовершенствованиях в системах нейронной обработки информации, стр. 5767-5777. 2017.

См. также

| | | | | |

Похожие темы