Обучите Порождающую соперничающую сеть (GAN)

В этом примере показано, как обучить порождающую соперничающую сеть (GAN) генерировать изображения.

Порождающая соперничающая сеть (GAN) является типом нейронной сети для глубокого обучения, которая может сгенерировать данные с подобными характеристиками как вход действительные данные.

ГАНЬ состоит из двух сетей, которые обучаются вместе:

  1. Генератор —, Учитывая вектор случайных значений (скрытые входные параметры), как введено, эта сеть генерирует данные с той же структурой как обучающие данные.

  2. Различитель — Данный пакеты данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "действительные" или "сгенерированные".

Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать производительность обоих:

  • Обучите генератор генерировать данные, которые "дурачат" различитель.

  • Обучите различитель различать действительные и сгенерированные данные.

Чтобы оптимизировать производительность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные".

Чтобы оптимизировать производительность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Таким образом, цель различителя не состоит в том, чтобы "дурачить" генератор.

Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных.

Загрузите обучающие данные

Загрузите и извлеките Цветочный набор данных [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте datastore изображений, содержащий фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true);

Увеличьте данные, чтобы включать случайное горизонтальное зеркальное отражение и изменить размер изображений, чтобы иметь размер 64 64.

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);

Задайте сеть генератора

Задайте следующую сетевую архитектуру, которая генерирует изображения от 1 1 100 массивами случайных значений:

Эта сеть:

  • Преобразует 1 1 100 массивами шума к 7 7 128 массивами с помощью проекта, и измените слой.

  • Увеличивает масштаб полученные массивы к 64 64 3 массивами с помощью серии транспонированных слоев свертки со слоев ReLU и нормализацией партии.

Задайте эту сетевую архитектуру как график слоев и задайте следующие сетевые свойства.

  • Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, шага 2 и обрезки выхода на каждом ребре.

  • Поскольку финал транспонировал слой свертки, задайте три фильтра 5 на 5, соответствующие трем каналам RGB сгенерированных изображений и выходному размеру предыдущего слоя.

  • В конце сети включайте tanh слой.

Чтобы спроектировать и изменить шумовой вход, используйте пользовательский слой projectAndReshapeLayer, присоединенный к этому примеру как вспомогательный файл. projectAndReshapeLayer слой увеличивает масштаб вход с помощью полностью связанной операции и изменяет выход к заданному размеру.

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersGenerator = [
    imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','in')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'proj');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    batchNormalizationLayer('Name','bnorm1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    batchNormalizationLayer('Name','bnorm2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    batchNormalizationLayer('Name','bnorm3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphGenerator = layerGraph(layersGenerator);

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetGenerator = dlnetwork(lgraphGenerator);

Задайте сеть различителя

Задайте следующую сеть, которая классифицирует действительный, и сгенерированный 64 64 отображает.

Создайте сеть, которая берет 64 64 3 изображениями и возвращает скалярный счет предсказания с помощью серии слоев свертки с нормализацией партии. и текучих слоев ReLU. Добавьте шум во входные изображения с помощью уволенного.

  • Для слоя уволенного задайте вероятность уволенного 0,5.

  • Для слоев свертки задайте фильтры 5 на 5 с растущим числом фильтров для каждого слоя. Также задайте шаг 2 и дополнение выхода.

  • Для текучих слоев ReLU задайте шкалу 0,2.

  • Для последнего слоя задайте сверточный слой с одним фильтром 4 на 4.

Чтобы вывести вероятности в области значений [0,1], используйте sigmoid функция в Градиентах Модели Function.

dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,'Normalization','none','Name','in')
    dropoutLayer(0.5,'Name','dropout')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    batchNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')];

lgraphDiscriminator = layerGraph(layersDiscriminator);

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetDiscriminator = dlnetwork(lgraphDiscriminator);

Градиенты модели Define, функции потерь и баллы

Создайте функциональный modelGradients, перечисленный в разделе Model Gradients Function примера, который берет в качестве входного генератора и сетей различителя, мини-пакета входных данных, массива случайных значений и зеркально отраженного фактора, и возвращает градиенты потери относительно настраиваемых параметров в сетях и множестве этих двух сетей.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 128 в течение 500 эпох. Также установите размер чтения увеличенного datastore изображений к мини-пакетному размеру. Для больших наборов данных вы не можете должны быть обучаться для как много эпох.

numEpochs = 500;
miniBatchSize = 128;
augimds.MiniBatchSize = miniBatchSize;

Задайте опции для оптимизации Адама. Для обеих сетей задать

  • Скорость обучения 0,0002

  • Фактор затухания градиента 0,5

  • Градиент в квадрате затухает фактор 0,999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.

executionEnvironment = "auto";

IIf, который различитель учится отличать между действительными и сгенерированными изображениями слишком быстро, затем генератор, может не обучаться. Чтобы лучше сбалансировать приобретение знаний о различителе и генераторе, добавьте шум в действительные данные путем случайного зеркального отражения меток.

Задайте, чтобы инвертировать 30% действительных меток. Это означает, что 15% общего количества меток будут инвертированы. Обратите внимание на то, что это не повреждает генератор, когда все сгенерированные изображения все еще помечены правильно.

flipFactor = 0.3;

Отобразитесь сгенерированная валидация отображает каждые 100 итераций.

validationFrequency = 100;

Обучите модель

Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого массива случайных значений, чтобы ввести в генератор, а также график баллов.

Инициализируйте параметры для Адама.

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого пакета фиксированных массивов случайных значений, поданных в генератор, и постройте сетевые баллы.

Создайте массив протянутых случайных значений.

numValidationImages = 25;
ZValidation = randn(1,1,numLatentInputs,numValidationImages,'single');

Преобразуйте данные в dlarray объекты и указывают, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

dlZValidation = dlarray(ZValidation,'SSCB');

Для обучения графического процессора преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

Инициализируйте графики процесса обучения. Создайте фигуру и измените размер его, чтобы иметь дважды ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте подграфик для сгенерированных изображений и newtork баллов.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализируйте анимированные линии для графика баллов.

lineScoreGenerator = animatedline(scoreAxes,'Color',[0 0.447 0.741]);
lineScoreDiscriminator = animatedline(scoreAxes, 'Color', [0.85 0.325 0.098]);
legend('Generator','Discriminator');
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

Обучите GAN. В течение каждой эпохи переставьте datastore и цикл по мини-пакетам данных.

Для каждого мини-пакета:

  • Перемасштабируйте изображения в области значений [-1 1].

  • Преобразуйте данные в dlarray объекты с базовым типом single и укажите, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Сгенерируйте dlarray объект, содержащий массив случайных значений для сети генератора.

  • Для обучения графического процессора преобразуйте данные в gpuArray объекты.

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • Постройте множество этих двух сетей.

  • После каждого validationFrequency итерации, отобразитесь, пакет сгенерированных изображений для фиксированного протянул вход генератора.

Обучение может занять время, чтобы запуститься.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Concatenate mini-batch of data and generate latent inputs for the
        % generator network.
        X = cat(4,data{:,1}{:});
        X = single(X);
        Z = randn(1,1,numLatentInputs,size(X,4),'single');
        
        % Rescale the images in the range [-1 1].
        X = rescale(X,-1,1,'InputMin',0,'InputMax',255);
        
        % Convert mini-batch of data to dlarray and specify the dimension labels
        % 'SSCB' (spatial, spatial, channel, batch).
        dlX = dlarray(X, 'SSCB');
        dlZ = dlarray(Z, 'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(dlXGeneratedValidation));
            I = rescale(I);
            
            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end
        
        % Update the scores plot
        subplot(1,2,2)
        addpoints(lineScoreGenerator,iteration,...
            double(gather(extractdata(scoreGenerator))));
        
        addpoints(lineScoreDiscriminator,iteration,...
            double(gather(extractdata(scoreDiscriminator))));
        
        % Update the title with training progress information.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))
        
        drawnow
    end
end

Здесь, различитель изучил представление сильной черты, которое идентифицирует действительные изображения среди сгенерированных изображений. В свою очередь генератор изучил представление столь же сильной черты, которое позволяет ему генерировать реалистически выглядящие данные.

Учебный график показывает множество сетей различителя и генератора. Чтобы узнать больше, как интерпретировать сетевые баллы, смотрите Монитор Процесс обучения GAN и Идентифицируйте Общие Типы отказа.

Сгенерируйте новые изображения

Чтобы сгенерировать новые изображения, используйте predict функция на генераторе с dlarray объект, содержащий пакет 1 1 100 массивами случайных значений. Чтобы отобразить изображения вместе, используйте imtile функционируйте и перемасштабируйте изображения с помощью rescale функция.

Создайте dlarray объект, содержащий пакет 25 1 1 100 массивами случайных значений, чтобы ввести в сеть генератора.

ZNew = randn(1,1,numLatentInputs,25,'single');
dlZNew = dlarray(ZNew,'SSCB');

Чтобы сгенерировать изображения с помощью графического процессора, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

Сгенерируйте новые изображения с помощью predict функция с генератором и входными данными.

dlXGeneratedNew = predict(dlnetGenerator,dlZNew);

Отобразите изображения.

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

Функция градиентов модели

Функциональный modelGradients берет в качестве входа генератор и различитель dlnetwork объекты dlnetGenerator и dlnetDiscriminator, мини-пакет входных данных dlX, массив случайных значений dlZ и процент действительных меток к flip flipFactor, и возвращает градиенты потери относительно настраиваемых параметров в сетях, состоянии генератора и множестве этих двух сетей. Поскольку различитель выход не находится в области значений [0,1], modelGradients применяет сигмоидальную функцию, чтобы преобразовать его в вероятности.

function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);

% Convert the discriminator outputs to probabilities.
probGenerated = sigmoid(dlYPredGenerated);
probReal = sigmoid(dlYPred);

% Calculate the score of the discriminator.
scoreDiscriminator = ((mean(probReal)+mean(1-probGenerated))/2);

% Calculate the score of the generator.
scoreGenerator = mean(probGenerated);

% Randomly flip a fraction of the labels of the real images.
numObservations = size(probReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));

% Flip the labels
probReal(:,:,:,idx) = 1-probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

Функция потерь ГАНЯ и баллы

Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте отрицательную логарифмическую функцию правдоподобия.

Учитывая выход Y из различителя:

  • Yˆ=σ(Y) вероятность, что входное изображение принадлежит классу "действительный".

  • 1-Yˆ вероятность, что входное изображение принадлежит классу "сгенерированный".

Обратите внимание на то, что сигмоидальная операция σ происходит в modelGradients функция. Функцией потерь для генератора дают

lossGenerator=-mean(log(YˆGenerated)),

где YˆGenerated содержит различитель выходные вероятности для сгенерированных изображений.

Цель различителя не состоит в том, чтобы "дурачить" генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте сумму соответствующих отрицательных логарифмических функций правдоподобия.

Функцией потерь для различителя дают

lossDiscriminator=-mean(log(YˆReal))-mean(log(1-YˆGenerated)),

где YˆReal содержит различитель выходные вероятности для действительных изображений.

Измеряться по шкале от 0 до 1, как хорошо генератор и различитель достигают их соответствующих целей, можно использовать концепцию счета.

Счет генератора является средним значением вероятностей, соответствующих различителю выход для сгенерированных изображений:

scoreGenerator=mean(YˆGenerated).

Счет различителя является средним значением вероятностей, соответствующих различителю выход для обоих действительные и сгенерированные изображения:

scoreDiscriminator=12mean(YˆReal)+12mean(1-YˆGenerated).

Счет обратно пропорционален потере, но эффективно содержит ту же информацию.

function [lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated)

% Calculate the loss for the discriminator network.
lossDiscriminator =  -mean(log(probReal)) -mean(log(1-probGenerated));

% Calculate the loss for the generator network.
lossGenerator = -mean(log(probGenerated));

end

Ссылки

  1. Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Рэдфорд, Алек, Люк Мец и Сумит Чинтэла. "Безнадзорное представление, учащееся с глубокими сверточными порождающими соперничающими сетями". arXiv предварительно распечатывают arXiv:1511.06434 (2015).

Смотрите также

| | | | | |

Похожие темы