exponenta event banner

Поезд Вассерштейн GAN с градиентным штрафом (WGAN-GP)

В этом примере показано, как обучить генеративную состязательную сеть Вассерштейна с градиентным штрафом (WGAN-GP) для создания изображений.

Генеративная состязательная сеть (GAN) - это тип сети глубокого обучения, которая может генерировать данные с теми же характеристиками, что и входные реальные данные.

GAN состоит из двух сетей, которые взаимодействуют друг с другом:

  1. Генератор - учитывая вектор случайных значений (скрытых входных данных) в качестве входных данных, эта сеть генерирует данные с той же структурой, что и обучающие данные.

  2. Дискриминатор - Данная сеть пытается классифицировать наблюдения как «реальные» или «сгенерированные».

Для обучения GAN одновременно обучайте обе сети, чтобы максимизировать производительность обеих:

  • Обучите генератор генерировать данные, которые «обманывают» дискриминатор.

  • Обучите дискриминатор различать реальные и сгенерированные данные.

Для оптимизации производительности генератора максимизируйте потери дискриминатора при предоставлении сгенерированных данных. То есть целью генератора является генерирование данных, которые дискриминатор классифицирует как «реальные». Для оптимизации производительности дискриминатора минимизируйте потери дискриминатора при предоставлении партий как реальных, так и сгенерированных данных. То есть цель дискриминатора - не быть «обманутым» генератором.

В идеале эти стратегии приводят к созданию генератора, который генерирует убедительно реалистичные данные, и дискриминатора, который усвоил сильные представления признаков, характерные для обучающих данных. Однако [2] утверждает, что расхождения, которые GAN обычно минимизируют, потенциально не являются непрерывными по отношению к параметрам генератора, что приводит к трудностям обучения, и вводит модель Wasserstein GAN (WGAN), которая использует потери Wasserstein, чтобы помочь стабилизировать обучение. Модель WGAN может по-прежнему производить плохие выборки или не сходиться, поскольку взаимодействия между ограничением веса и функцией затрат могут привести к исчезновению или разборке градиентов. Для решения этих проблем [3] вводит штраф за градиент, который улучшает стабильность, наказывая градиенты большими значениями норм за счет более длительного времени вычислений. Этот тип модели известен как модель WGAN-GP.

В этом примере показано, как обучить модель WGAN-GP, которая может генерировать изображения со сходными характеристиками с обучающим набором изображений.

Загрузка данных обучения

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте хранилище данных изображения, содержащее фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true);

Дополните данные, чтобы включить случайное горизонтальное переворачивание и изменить размер изображений, чтобы они имели размер 64 на 64.

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);

Определение сети дискриминаторов

Определите следующую сеть, которая классифицирует реальные и созданные 64 на 64 изображения.

Создайте сеть, которая берет изображения 64 на 64 на 3 и возвращает скалярный счет прогноза, используя серию слоев скручивания с пакетной нормализацией и прохудившихся слоев ReLU. Для вывода вероятностей в диапазоне [0,1] используется сигмоидальный слой.

  • Для слоев свертки укажите фильтры 5 на 5 с увеличением числа фильтров для каждого слоя. Также укажите шаг 2 и заполнение выходных данных.

  • Для протекающих слоев ReLU укажите масштаб 0,2.

  • Для конечного слоя свертки укажите один фильтр 4 на 4.

numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersD = [
    imageInputLayer(inputSize,'Normalization','none','Name','in')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    layerNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    layerNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    layerNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')
    sigmoidLayer('Name','sigmoid')];

lgraphD = layerGraph(layersD);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetD = dlnetwork(lgraphD);

Определение сети генератора

Определите следующую сетевую архитектуру, которая производит изображения от множеств случайных ценностей 1 на 1 на 100:

Эта сеть:

  • Преобразовывает случайные векторы размера 100 ко множествам 7 на 7 на 128, используя проект, и измените слой.

  • Масштабирует результирующие массивы до массивов 64 на 3, используя ряд транспонированных слоев свертки и слоев ReLU.

Определите эту сетевую архитектуру как график уровня и укажите следующие свойства сети.

  • Для транспонированных слоев свертки укажите фильтры 5 на 5 с уменьшающимся числом фильтров для каждого слоя, шагом 2 и обрезкой выходных данных на каждом ребре.

  • Для конечного транспонированного слоя свертки укажите три фильтра 5 на 5, соответствующих трем каналам RGB генерируемых изображений, и размер вывода предыдущего слоя.

  • В конце сети включите уровень tanh.

Для проецирования и изменения формы входного шума используйте пользовательский слой projectAndReshapeLayer, прилагается к этому примеру как вспомогательный файл. projectAndReshape уровень увеличивает масштаб входа, используя полностью подключенную операцию, и изменяет форму выхода до заданного размера.

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersG = [
    featureInputLayer(numLatentInputs,'Normalization','none','Name','in')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphG = layerGraph(layersG);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetG = dlnetwork(lgraphG);

Определение функций градиентов модели

Создание функций modelGradientsD и modelGradientsG перечисленные в разделе «Функция градиентов модели» примера, которые вычисляют градиенты дискриминатора и потери генератора относительно обучаемых параметров дискриминатора и сетей генераторов соответственно.

Функция modelGradientsD принимает за вход генератор и дискриминатор dlnetG и dlnetD, мини-пакет входных данных dlX, массив случайных значений dlZ, и lambda значение, используемое для штрафа градиента, и возвращает градиенты потерь относительно обучаемых параметров в дискриминаторе и потерь.

Функция modelGradientsG принимает за вход генератор и дискриминатор dlnetG и dlnetDи массив случайных значений dlZи возвращает градиенты потерь относительно обучаемых параметров в генераторе и потерь.

Укажите параметры обучения

Чтобы обучить модель WGAN-GP, необходимо обучить дискриминатор большему количеству итераций, чем генератор. Другими словами, для каждой итерации генератора необходимо обучить дискриминатор нескольким итерациям.

Поезд с размером мини-партии 64 на 10 000 итераций генератора. Для больших наборов данных может потребоваться обучение для дополнительных итераций.

miniBatchSize = 64;
numIterationsG = 10000;

Для каждой итерации генератора выполните обучение дискриминатора для 5 итераций.

numIterationsDPerG = 5;

Для потери WGAN-GP укажите лямбда-значение 10. Лямбда-значение управляет величиной штрафа градиента, добавляемого к потере дискриминатора.

lambda = 10;

Укажите параметры оптимизации Adam:

  • Для дискриминаторной сети укажите скорость обучения 0,0002.

  • Для генераторной сети укажите коэффициент обучения 0,001.

  • Для обеих сетей задайте коэффициент градиентного затухания, равный 0, и коэффициент градиентного затухания в квадрате, равный 0,9.

learnRateD = 2e-4;
learnRateG = 1e-3;
gradientDecayFactor = 0;
squaredGradientDecayFactor = 0.9;

Отображение созданных изображений проверки каждые 20 итераций генератора.

validationFrequency = 20;

Модель поезда

Использовать minibatchqueue для обработки и управления мини-партиями изображений. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определено в конце этого примера) для масштабирования изображений в диапазоне [-1,1].

  • Удалите все частичные мини-партии.

  • Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Обучение на GPU, если он доступен. Когда 'OutputEnvironment' вариант minibatchqueue является 'auto', minibatchqueue преобразует каждый выход в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox) (Parallel Computing Toolbox).

minibatchqueue по умолчанию преобразует данные в dlarray объекты с базовым типом single.

augimds.MiniBatchSize = miniBatchSize;
executionEnvironment = "auto";

mbq = minibatchqueue(augimds,...
    'MiniBatchSize',miniBatchSize,...
    'PartialMiniBatch','discard',...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat','SSCB',...
    'OutputEnvironment',executionEnvironment);

Обучение модели с помощью пользовательского цикла обучения. Закольцовывать обучающие данные и обновлять сетевые параметры на каждой итерации. Чтобы контролировать ход обучения, отобразите пакет сгенерированных изображений, используя задержанный массив случайных значений для ввода в генератор, а также график оценок.

Инициализируйте параметры для Adam.

trailingAvgD = [];
trailingAvgSqD = [];
trailingAvgG = [];
trailingAvgSqG = [];

Для контроля хода обучения отображают пакет сгенерированных изображений, используя задержанный пакет фиксированных массивов случайных значений, подаваемых в генератор, и строят сетевую оценку.

Создайте массив задержанных случайных значений.

numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,'single');

Преобразовать данные в dlarray объекты и указать метки размеров 'SSCB' (пространственный, пространственный, канальный, пакетный).

dlZValidation = dlarray(ZValidation,'CB');

Для обучения GPU преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

Инициализируйте графики хода обучения. Создайте фигуру и измените ее размер, чтобы она имела удвоенную ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте вложенный график для созданных изображений и оценок сети.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализируйте анимированные линии для графика потерь.

C = colororder;
lineLossD = animatedline(scoreAxes,'Color',C(1,:));
lineLossDUnregularized = animatedline(scoreAxes,'Color',C(2,:));
legend('With Gradient Penanlty','Unregularized')
xlabel("Generator Iteration")
ylabel("Discriminator Loss")
grid on

Обучение модели WGAN-GP путем закольцовывания мини-пакетов данных.

Для numIterationsDPerG итерации, обучение только дискриминатора. Для каждой мини-партии:

  • Оценка градиентов модели дискриминатора с помощью dlfeval и modelGradientsD функция.

  • Обновите параметры сети дискриминатора с помощью adamupdate функция.

После обучения дискриминатор для numIterationsDPerG итерации, обучение генератора на одной мини-партии.

  • Оценка градиентов модели генератора с помощью dlfeval и modelGradientsG функция.

  • Обновление параметров сети генератора с помощью adamupdate функция.

После обновления сети генератора:

  • Постройте график потерь двух сетей.

  • После каждого validationFrequency генераторные итерации, отображают пакет сгенерированных изображений для фиксированного задержанного входного сигнала генератора.

Пройдя через набор данных, перетасуйте очередь мини-пакета.

Обучение может занять некоторое время и может потребовать много итераций для вывода хороших изображений.

iterationG = 0;
iterationD = 0;
start = tic;

% Loop over mini-batches
while iterationG < numIterationsG
    iterationG = iterationG + 1;

    % Train discriminator only
    for n = 1:numIterationsDPerG
        iterationD = iterationD + 1;

        % Reset and shuffle mini-batch queue when there is no more data.
        if ~hasdata(mbq)
            shuffle(mbq);
        end

        % Read mini-batch of data.
        dlX = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels 'CB' (channel, batch).
        Z = randn([numLatentInputs size(dlX,4)],'like',dlX);
        dlZ = dlarray(Z,'CB');

        % Evaluate the discriminator model gradients using dlfeval and the
        % modelGradientsD function listed at the end of the example.
        [gradientsD, lossD, lossDUnregularized] = dlfeval(@modelGradientsD, dlnetD, dlnetG, dlX, dlZ, lambda);

        % Update the discriminator network parameters.
        [dlnetD,trailingAvgD,trailingAvgSqD] = adamupdate(dlnetD, gradientsD, ...
            trailingAvgD, trailingAvgSqD, iterationD, ...
            learnRateD, gradientDecayFactor, squaredGradientDecayFactor);
    end

    % Generate latent inputs for the generator network. Convert to dlarray
    % and specify the dimension labels 'CB' (channel, batch).
    Z = randn([numLatentInputs size(dlX,4)],'like',dlX);
    dlZ = dlarray(Z,'CB');

    % Evaluate the generator model gradients using dlfeval and the
    % modelGradientsG function listed at the end of the example.
    gradientsG = dlfeval(@modelGradientsG, dlnetG, dlnetD, dlZ);

    % Update the generator network parameters.
    [dlnetG,trailingAvgG,trailingAvgSqG] = adamupdate(dlnetG, gradientsG, ...
        trailingAvgG, trailingAvgSqG, iterationG, ...
        learnRateG, gradientDecayFactor, squaredGradientDecayFactor);

    % Every validationFrequency generator iterations, display batch of
    % generated images using the held-out generator input
    if mod(iterationG,validationFrequency) == 0 || iterationG == 1
        % Generate images using the held-out generator input.
        dlXGeneratedValidation = predict(dlnetG,dlZValidation);

        % Tile and rescale the images in the range [0 1].
        I = imtile(extractdata(dlXGeneratedValidation));
        I = rescale(I);

        % Display the images.
        subplot(1,2,1);
        image(imageAxes,I)
        xticklabels([]);
        yticklabels([]);
        title("Generated Images");
    end

    % Update the scores plot
    subplot(1,2,2)

    lossD = double(gather(extractdata(lossD)));
    lossDUnregularized = double(gather(extractdata(lossDUnregularized)));
    addpoints(lineLossD,iterationG,lossD);
    addpoints(lineLossDUnregularized,iterationG,lossDUnregularized);

    D = duration(0,0,toc(start),'Format','hh:mm:ss');
    title( ...
        "Iteration: " + iterationG + ", " + ...
        "Elapsed: " + string(D))
    drawnow
end

Здесь дискриминатор усвоил сильное характерное представление, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор усвоил такое же сильное представление признаков, которое позволяет ему генерировать изображения, подобные обучающим данным.

Создать новые изображения

Для создания новых изображений используйте predict функция на генераторе с dlarray объект, содержащий партию случайных векторов. Чтобы отобразить изображения вместе, используйте imtile и масштабировать изображения с помощью rescale функция.

Создать dlarray объект, содержащий партию из 25 случайных векторов для ввода в генераторную сеть.

ZNew = randn(numLatentInputs,25,'single');
dlZNew = dlarray(ZNew,'CB');

Чтобы создать изображения с помощью графического процессора, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

Создание новых изображений с помощью predict функция с генератором и входными данными.

dlXGeneratedNew = predict(dlnetG,dlZNew);

Отображение изображений.

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

Функция градиентов модели дискриминатора

Функция modelGradientsD принимает за вход генератор и дискриминатор dlnetwork объекты dlnetG и dlnetD, мини-пакет входных данных dlX, массив случайных значений dlZ, и lambda значение, используемое для штрафа градиента, и возвращает градиенты потерь относительно обучаемых параметров в дискриминаторе и потерь.

Учитывая изображение X, сформированное изображение X∼, определите Xˆ=ϵX+ (1-ϵ) X∼ для некоторых случайных ϵ∈U (0,1).

Для модели WGAN-GP, учитывая лямбда-значение λ, потеря дискриминатора задается

lossD=Y∼-Y+λ (∇XˆYˆ‖2-1) 2,

где Y, Y∼ и обозначают выход дискриминатора для входов X, X∼ и соответственно, а ∇XˆYˆ обозначает градиенты выходного относительно . Для мини-партии данных используйте различное значение ϵ для каждой оберсации и рассчитайте среднюю потерю.

Градиентный штраф λ (∇XˆYˆ‖2-1) 2 улучшает стабильность, наказывая градиенты большими значениями нормы. Лямбда-значение управляет величиной штрафа градиента, добавляемого к потере дискриминатора.

function [gradientsD, lossD, lossDUnregularized] = modelGradientsD(dlnetD, dlnetG, dlX, dlZ, lambda)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetD, dlX);

% Calculate the predictions for generated data with the discriminator
% network.
dlXGenerated = forward(dlnetG,dlZ);
dlYPredGenerated = forward(dlnetD, dlXGenerated);

% Calculate the loss.
lossDUnregularized = mean(dlYPredGenerated - dlYPred);

% Calculate and add the gradient penalty. 
epsilon = rand([1 1 1 size(dlX,4)],'like',dlX);
dlXHat = epsilon.*dlX + (1-epsilon).*dlXGenerated;
dlYHat = forward(dlnetD, dlXHat);

% Calculate gradients. To enable computing higher-order derivatives, set
% 'EnableHigherDerivatives' to true.
gradientsHat = dlgradient(sum(dlYHat),dlXHat,'EnableHigherDerivatives',true);
gradientsHatNorm = sqrt(sum(gradientsHat.^2,1:3) + 1e-10);
gradientPenalty = lambda.*mean((gradientsHatNorm - 1).^2);

% Penalize loss.
lossD = lossDUnregularized + gradientPenalty;

% Calculate the gradients of the penalized loss with respect to the
% learnable parameters.
gradientsD = dlgradient(lossD, dlnetD.Learnables);

end

Функция градиентов модели генератора

Функция modelGradientsG принимает за вход генератор и дискриминатор dlnetwork объекты dlnetG и dlnetDи массив случайных значений dlZи возвращает градиенты потерь относительно обучаемых параметров в генераторе.

При наличии сформированного X∼ изображения потери для генераторной сети задаются

lossG=-Y∼,

где Y∼ обозначает выход дискриминатора для сгенерированного X∼ изображения. Для мини-партии сгенерированных изображений вычислите средний убыток.

function gradientsG =  modelGradientsG(dlnetG, dlnetD, dlZ)

% Calculate the predictions for generated data with the discriminator
% network.
dlXGenerated = forward(dlnetG,dlZ);
dlYPredGenerated = forward(dlnetD, dlXGenerated);

% Calculate the loss.
lossG = -mean(dlYPredGenerated);

% Calculate the gradients of the loss with respect to the learnable
% parameters.
gradientsG = dlgradient(lossG, dlnetG.Learnables);

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция выполняет предварительную обработку данных с помощью следующих шагов:

  1. Извлеките данные изображения из массива входных ячеек и объедините их в числовой массив.

  2. Масштабировать изображения в диапазоне [-1,1].

function X = preprocessMiniBatch(data)

% Concatenate mini-batch
X = cat(4,data{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,'InputMin',0,'InputMax',255);

end

Ссылки

  1. Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Арьовский, Мартин, Сомит Чинтала и Леон Ботту. «Wasserstein GAN». arXiv препринт arXiv:1701.07875 (2017).

  3. Гулраджани, Ишаан, Фарук Ахмед, Мартин Арджовский, Винсент Дюмулин и Аарон К. Курвилл. «Улучшение подготовки GAN Вассерштайна». В «Достижениях в системах обработки нейронной информации», стр. 5767-5777. 2017.

См. также

| | | | | |

Связанные темы