Обучите генеративную состязательную сеть (GAN)

Этот пример показывает, как обучить генеративную состязательную сеть для генерации изображений.

Генеративная состязательная сеть (GAN) является типом нейронной сети для глубокого обучения, которая может генерировать данные с такими же характеристиками, как входные реальные данные.

GAN состоит из двух сетей, которые обучаются вместе:

  1. Генератор - Учитывая вектор случайных значений (латентные входы) в качестве входных данных, эта сеть генерирует данные с той же структурой, что и обучающие данные.

  2. Дискриминатор - Учитывая пакеты данных, содержащие наблюдения как от обучающих данных, так и от сгенерированных данных от генератора, эта сеть пытается классифицировать наблюдения как "real" или "generated".

Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обеих:

  • Обучите генератор генерировать данные, которые «дурачат» дискриминатор.

  • Обучите дискриминатор различать реальные и сгенерированные данные.

Чтобы оптимизировать эффективность генератора, максимизируйте потерю дискриминатора, когда даны сгенерированные данные. То есть цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как "real".

Чтобы оптимизировать эффективность дискриминатора, минимизируйте потерю дискриминатора, когда заданы пакеты как реальных, так и сгенерированных данных. То есть цель дискриминатора состоит в том, чтобы генератор не «обманул».

В идеале эти стратегии приводят к генератору, который генерирует убедительно реалистичные данные, и дискриминатору, который научился сильным представлениям функций, которые характерны для обучающих данных.

Загрузка обучающих данных

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте изображение datastore, содержащее фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true);

Увеличьте данные, чтобы включить случайное горизонтальное отражение и измените размер изображений, чтобы иметь размер 64 на 64.

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);    

Определите сеть генератора

Задайте следующую сетевую архитектуру, которая генерирует изображения из случайных векторов размера 100.

Эта сеть:

  • Преобразовывает случайные векторы размера 100 к массивам 7 на 7 на 128, используя проект, и измените слой.

  • Upscales получающиеся массивы к массивам 64 на 64 на 3, используя серию перемещенных слоев скручивания с нормализацией партии. и слоев ReLU.

Определите эту сетевую архитектуру как график слоев и задайте следующие свойства сети.

  • Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, полосой 2 и обрезкой выхода на каждом ребре.

  • Для последнего транспонированного слоя свертки задайте три фильтра 5 на 5, соответствующих трем каналам RGB сгенерированных изображений, и выходной размер предыдущего слоя.

  • В конце сети включают слой танха.

Чтобы проецировать и изменить форму входного сигнала шума, используйте пользовательский слой projectAndReshapeLayer, присоединенный к этому примеру как вспомогательный файл. The projectAndReshapeLayer объект увеличивает масштаб входа с помощью операции с полным подключением и изменяет форму выхода на заданный размер.

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersGenerator = [
    featureInputLayer(numLatentInputs,'Name','in')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    batchNormalizationLayer('Name','bnorm1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    batchNormalizationLayer('Name','bnorm2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    batchNormalizationLayer('Name','bnorm3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphGenerator = layerGraph(layersGenerator);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.

dlnetGenerator = dlnetwork(lgraphGenerator);

Определите сеть дискриминатора

Задайте следующую сеть, которая классифицирует реальные и сгенерированные изображения 64 на 64.

Создайте сеть, которая берет изображения 64 на 64 на 3 и возвращает скалярный счет предсказания, используя серию слоев скручивания с нормализацией партии. и прохудившихся слоев ReLU. Добавьте шум к входу изображениям с помощью отсева.

  • Для выпадающего слоя задайте вероятность выпадения 0,5.

  • Для слоев свертки задайте фильтры 5 на 5 с увеличением количества фильтров для каждого слоя. Также задайте шаг 2 и заполнение выхода.

  • Для утечек слоев ReLU задайте шкалу 0,2.

  • Для последнего слоя задайте сверточный слой с одним фильтром 4 на 4.

Чтобы вывести вероятности в области значений [0,1], используйте sigmoid функция в modelGradients функция, перечисленная в разделе Model Gradients Function примера.

dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,'Normalization','none','Name','in')
    dropoutLayer(dropoutProb,'Name','dropout')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    batchNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')];

lgraphDiscriminator = layerGraph(layersDiscriminator);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.

dlnetDiscriminator = dlnetwork(lgraphDiscriminator);

Задайте градиенты модели и функции потерь

Создайте функцию modelGradients, перечисленный в разделе Model Gradients Function примера, который принимает за вход сети генератора и дискриминатора, мини-пакет входных данных, массив случайных значений и коэффициент отражения и возвращает градиенты потерь относительно настраиваемых параметров в сетях и счетов двух сетей.

Настройка опций обучения

Train с мини-партией размером 128 на 500 эпох. Для больших наборов данных вам, возможно, не нужно обучаться на столько эпох.

numEpochs = 500;
miniBatchSize = 128;

Задайте опции для оптимизации Adam. Для обеих сетей укажите:

  • A скорости обучения 0,0002

  • Коэффициент градиентного распада 0,5

  • Квадратный коэффициент распада градиента 0,999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Если дискриминатор учится слишком быстро различать реальные и сгенерированные изображения, то генератор может не обучаться. Чтобы лучше сбалансировать обучение дискриминатора и генератора, добавьте шум к реальным данным, случайным образом развернув метки.

Задайте flipFactor значение 0,3, чтобы развернуть 30% действительных меток (15% от общего количества меток). Обратите внимание, что это не ухудшает работу генератора, так как все сгенерированные изображения все еще помечены правильно.

flipFactor = 0.3;

Отображать сгенерированные изображения валидации каждые 100 итераций.

validationFrequency = 100;

Обучите модель

Использование minibatchqueue для обработки и управления мини-пакетами изображений. Для каждого мини-пакета:

  • Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch (определено в конце этого примера), чтобы переформулировать изображения в области значений [-1,1].

  • Отбрасывайте любые частичные мини-пакеты с менее чем 128 наблюдениями.

  • Форматируйте данные изображения с помощью меток размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию в minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single.

  • Обучите на графическом процессоре, если он доступен. Когда 'OutputEnvironment' опция minibatchqueue является "auto", minibatchqueue преобразует каждый выход в gpuArray при наличии графический процессор. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ® с вычислительными возможностями 3.0 или выше.

augimds.MiniBatchSize = miniBatchSize;

executionEnvironment = "auto";

mbq = minibatchqueue(augimds,...
    'MiniBatchSize',miniBatchSize,...
    'PartialMiniBatch','discard',...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat','SSCB',...
    'OutputEnvironment',executionEnvironment);

Обучите модель с помощью пользовательского цикла обучения. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений, используя удерживаемый массив случайных значений для ввода в генератор, а также график счетов.

Инициализируйте параметры для Adam.

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью задержанного пакета фиксированных случайных векторов, подаваемых в генератор, и постройте графики счетов сети.

Создайте массив удерживаемых случайных значений.

numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,'single');

Преобразуйте данные в dlarray Объекты и задайте метки размерности 'CB' (канал, пакет).

dlZValidation = dlarray(ZValidation,'CB');

Для обучения графический процессор преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

Инициализируйте графики процесса обучения. Создайте рисунок и измените ее размер, чтобы она имела вдвое большую ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте подграфик для сгенерированных изображений и сетевых счетов.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализируйте анимированные линии для графика счетов.

lineScoreGenerator = animatedline(scoreAxes,'Color',[0 0.447 0.741]);
lineScoreDiscriminator = animatedline(scoreAxes, 'Color', [0.85 0.325 0.098]);
legend('Generator','Discriminator');
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

Обучите GAN. Для каждой эпохи перетащите datastore и закольцовывайте мини-пакеты данных.

Для каждого мини-пакета:

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция.

  • Обновляйте параметры сети с помощью adamupdate функция.

  • Постройте график счетов двух сетей.

  • После каждого validationFrequency итерации, отображают пакет сгенерированных изображений для фиксированного входа задержанного генератора.

Обучение может занять некоторое время.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Reset and shuffle datastore.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        dlX = next(mbq);
        
        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels 'CB' (channel, batch).
        % If training on a GPU, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,'single');
        dlZ = dlarray(Z,'CB');        
        
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(dlXGeneratedValidation));
            I = rescale(I);
            
            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end
        
        % Update the scores plot.
        subplot(1,2,2)
        addpoints(lineScoreGenerator,iteration,...
            double(gather(extractdata(scoreGenerator))));
        
        addpoints(lineScoreDiscriminator,iteration,...
            double(gather(extractdata(scoreDiscriminator))));
        
        % Update the title with training progress information.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))
        
        drawnow
    end
end

Здесь дискриминатор выучил сильное представление функций, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор выучил аналогичное сильное представление функций, которое позволяет ему генерировать изображения, подобные обучающим данным.

График обучения показывает счета сетей генератора и дискриминатора. Дополнительные сведения о том, как интерпретировать счета сети, см. в разделах Мониторинг процесса обучения GAN и Идентификация распространенных типах отказа.

Сгенерируйте новые изображения

Чтобы сгенерировать новые изображения, используйте predict функция на генераторе с dlarray объект, содержащий пакет случайных векторов. Чтобы отобразить изображения вместе, используйте imtile функциональность и изменение масштаба изображений с помощью rescale функция.

Создайте dlarray объект, содержащий пакет из 25 случайных векторов для ввода в сеть генератора.

numObservations = 25;
ZNew = randn(numLatentInputs,numObservations,'single');
dlZNew = dlarray(ZNew,'CB');

Чтобы сгенерировать изображения с помощью графический процессор, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

Сгенерируйте новые изображения с помощью predict функция с генератором и входными данными.

dlXGeneratedNew = predict(dlnetGenerator,dlZNew);

Отобразите изображения.

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

Функция градиентов модели

Функция modelGradients принимает за вход генератор и дискриминатор dlnetwork объекты dlnetGenerator и dlnetDiscriminatorмини-пакет входных данных dlX, массив случайных значений dlZ, и процент реальных меток, чтобы развернуть flipFactorи возвращает градиенты потерь относительно настраиваемых параметров в сетях, состоянии генератора и счетах двух сетей. Поскольку выход дискриминатора не находится в области значений [0,1], modelGradients функция применяет сигмоидную функцию, чтобы преобразовать ее в вероятности.

function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);

% Convert the discriminator outputs to probabilities.
probGenerated = sigmoid(dlYPredGenerated);
probReal = sigmoid(dlYPred);

% Calculate the score of the discriminator.
scoreDiscriminator = (mean(probReal) + mean(1-probGenerated)) / 2;

% Calculate the score of the generator.
scoreGenerator = mean(probGenerated);

% Randomly flip a fraction of the labels of the real images.
numObservations = size(probReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));

% Flip the labels.
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

Функция потерь GAN и счетов

Цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как "real". Чтобы максимизировать вероятность того, что изображения с генератора классифицируются дискриминатором как вещественные, минимизируйте отрицательную функцию журнала правдоподобие.

Учитывая выход Y дискриминатора:

  • Yˆ=σ(Y) - вероятность того, что вход изображение принадлежит классу "real".

  • 1-Yˆ - вероятность того, что вход изображение принадлежит классу "generated".

Обратите внимание, что сигмоидная операция σ происходит в modelGradients функция. Функция потерь для генератора задается как

lossGenerator=-mean(log(YˆGenerated)),

где YˆGenerated содержит выходные вероятности дискриминатора для сгенерированных изображений.

Цель дискриминатора - не быть «обманутым» генератором. Чтобы максимизировать вероятность того, что дискриминатор успешно различает действительное и сгенерированное изображения, минимизируйте сумму соответствующих отрицательных функций журнала функций правдоподобия.

Функция потерь для дискриминатора определяется

lossDiscriminator=-mean(log(YˆReal))-mean(log(1-YˆGenerated)),

где YˆReal содержит выходные вероятности дискриминатора для действительных изображений.

Чтобы измерить по шкале от 0 до 1, насколько хорошо генератор и дискриминатор достигают своих соответствующих целей, можно использовать концепцию счета.

Счет генератора является средним значением вероятностей, соответствующих выходу дискриминатора для сгенерированных изображений:

scoreGenerator=mean(YˆGenerated).

Счет дискриминатора является средним значением вероятностей, соответствующих выходу дискриминатора для как вещественных, так и сгенерированных изображений:

scoreDiscriminator=12mean(YˆReal)+12mean(1-YˆGenerated).

Счет обратно пропорциональна потере, но фактически содержит ту же информацию.

function [lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated)

% Calculate the loss for the discriminator network.
lossDiscriminator = -mean(log(probReal)) - mean(log(1-probGenerated));

% Calculate the loss for the generator network.
lossGenerator = -mean(log(probGenerated));

end

Функция мини-пакетной предварительной обработки

The preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные изображения из входящего массива ячеек и соедините в числовой массив.

  2. Перерассчитайте изображения, которые будут находиться в области значений [-1,1].

function X = preprocessMiniBatch(data)

% Concatenate mini-batch
X = cat(4,data{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,'InputMin',0,'InputMax',255);

end

Ссылки

  1. Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Радфорд, Алек, Люк Мец и Сомитх Чинтала. «Неконтролируемое обучение представлению с глубокими сверточными генеративными состязательными сетями». Препринт, представленный 19 ноября 2015 года. http://arxiv.org/abs/1511.06434.

См. также

| | | | | | |

Похожие темы