Обучите условную генеративную состязательную сеть (CGAN)

Этот пример показывает, как обучить условную генеративную состязательную сеть для генерации изображений.

Генеративная состязательная сеть (GAN) является типом нейронной сети для глубокого обучения, которая может генерировать данные с такими же характеристиками, как входные обучающие данные.

GAN состоит из двух сетей, которые обучаются вместе:

  1. Генератор - учитывая вектор случайных значений как вход, эта сеть генерирует данные с той же структурой, что и обучающие данные.

  2. Дискриминатор - Учитывая пакеты данных, содержащие наблюдения как от обучающих данных, так и от сгенерированных данных от генератора, эта сеть пытается классифицировать наблюдения как "real" или "generated".

Условная генеративная состязательная сеть (CGAN) является типом GAN, который также использует преимущества меток в процессе обучения.

  1. Генератор - учитывая метку и случайный массив как вход, эта сеть генерирует данные с той же структурой, что и наблюдения обучающих данных, соответствующие той же метке.

  2. Дискриминатор - Учитывая пакеты маркированных данных, содержащих наблюдения как из обучающих данных, так и из сгенерированных данных генератора, эта сеть пытается классифицировать наблюдения как "real" или "generated".

Чтобы обучить условную GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обеих:

  • Обучите генератор генерировать данные, которые «дурачат» дискриминатор.

  • Обучите дискриминатор различать реальные и сгенерированные данные.

Чтобы максимизировать эффективность генератора, максимизируйте потерю дискриминатора, когда даны сгенерированные маркированные данные. То есть цель генератора состоит в том, чтобы сгенерировать маркированные данные, которые дискриминатор классифицирует как "real".

Чтобы максимизировать эффективность дискриминатора, минимизируйте потерю дискриминатора, когда заданы пакеты как реальных, так и сгенерированных маркированных данных. То есть цель дискриминатора состоит в том, чтобы генератор не «обманул».

В идеале эти стратегии приводят к генератору, который генерирует убедительно реалистичные данные, которые соответствуют входным меткам, и дискриминатору, который выучил сильные представления функций, которые характерны для обучающих данных для каждой метки.

Загрузка обучающих данных

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте изображение datastore, содержащее фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

Просмотрите количество классов.

classes = categories(imds.Labels);
numClasses = numel(classes)
numClasses = 5

Увеличьте данные, чтобы включить случайное горизонтальное отражение и измените размер изображений, чтобы иметь размер 64 на 64.

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);

Определите сеть генератора

Задайте следующую сеть с двумя входами, которая генерирует изображения, заданные случайными векторами размера 100 и соответствующими метками.

Эта сеть:

  • Преобразовывает случайные векторы размера 100 к 4 на 4 1 024 массивами.

  • Преобразует категориальные метки во внедряющие векторы и изменяет их форму в массив 4 на 4.

  • Конкатенирует получившиеся изображения из двух входов по размерности канала. Выход составляет 4 на 4 1 025 массивами.

  • Upscales получающиеся массивы к массивам 64 на 64 на 3, используя серию перемещенных слоев скручивания с нормализацией партии. и слоев ReLU.

Определите эту сетевую архитектуру как график слоев и задайте следующие свойства сети.

  • Для категориальных входов используйте размерность вложения 50.

  • Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, полосой 2 и "same" обрезка выхода.

  • Для последнего транспонированного слоя свертки задайте три фильтра 5 на 5, соответствующих трем каналам RGB сгенерированных изображений.

  • В конце сети включают слой танха.

Чтобы проецировать и изменить форму входного сигнала шума, используйте пользовательский слой projectAndReshapeLayer, присоединенный к этому примеру как вспомогательный файл. The projectAndReshapeLayer объект увеличивает масштаб входа с помощью операции с полным подключением и изменяет форму выхода на заданный размер.

Чтобы ввести метки в сеть, используйте featureInputLayer и задайте одну функцию. Для встраивания и изменения формы входов меток используйте пользовательский слой embedAndReshapeLayer, присоединенный к этому примеру как вспомогательный файл. The embedAndReshapeLayer преобразует категориальную метку в одноканальное изображение заданного размера с помощью встраивания и полносвязной операции.

numLatentInputs = 100;
embeddingDimension = 50;
numFilters = 64;

filterSize = 5;
projectionSize = [4 4 1024];

layersGenerator = [
    featureInputLayer(numLatentInputs,'Name','noise')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj');
    concatenationLayer(3,2,'Name','cat');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphGenerator = layerGraph(layersGenerator);

layers = [
    featureInputLayer(1,'Name','labels')
    embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'Name','emb')];

lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,'emb','cat/in2');

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.

dlnetGenerator = dlnetwork(lgraphGenerator)
dlnetGenerator = 
  dlnetwork with properties:

         Layers: [16×1 nnet.cnn.layer.Layer]
    Connections: [15×2 table]
     Learnables: [19×3 table]
          State: [6×3 table]
     InputNames: {'noise'  'labels'}
    OutputNames: {'tanh'}

Определите сеть дискриминатора

Задайте следующую сеть с двумя входами, которая классифицирует реальные и сгенерированные изображения 64 на 64 с учетом набора изображений и соответствующих меток.

Создайте сеть, которая берет в качестве входа изображения 64 на 64 на 1 и соответствующие метки и производит скалярный счет предсказания, используя серию слоев скручивания с нормализацией партии. и прохудившихся слоев ReLU. Добавьте шум к входу изображениям с помощью отсева.

  • Для выпадающего слоя задайте вероятность выпадения 0,75.

  • Для слоев свертки задайте фильтры 5 на 5 с увеличением количества фильтров для каждого слоя. Также задайте шаг 2 и заполнение выхода на каждом ребре.

  • Для утечек слоев ReLU задайте шкалу 0,2.

  • Для последнего слоя задайте слой свертки с одним фильтром 4 на 4.

dropoutProb = 0.75;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,'Normalization','none','Name','images')
    dropoutLayer(dropoutProb,'Name','dropout')
    concatenationLayer(3,2,'Name','cat')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    batchNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')];

lgraphDiscriminator = layerGraph(layersDiscriminator);

layers = [
    featureInputLayer(1,'Name','labels')
    embedAndReshapeLayer(inputSize(1:2),embeddingDimension,numClasses,'Name','emb')];

lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,'emb','cat/in2');

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.

dlnetDiscriminator = dlnetwork(lgraphDiscriminator)
dlnetDiscriminator = 
  dlnetwork with properties:

         Layers: [17×1 nnet.cnn.layer.Layer]
    Connections: [16×2 table]
     Learnables: [19×3 table]
          State: [6×3 table]
     InputNames: {'images'  'labels'}
    OutputNames: {'conv5'}

Задайте градиенты модели и функции потерь

Создайте функцию modelGradients, перечисленный в разделе Model Gradients Function примера, который принимает за вход сети генератора и дискриминатора, мини-пакет входных данных и массив случайных значений и возвращает градиенты потерь относительно настраиваемых параметров в сетях и массива сгенерированных изображений.

Настройка опций обучения

Train с мини-партией размером 128 на 500 эпох.

numEpochs = 500;
miniBatchSize = 128;

Задайте опции для оптимизации Adam. Для обеих сетей используйте:

  • A скорости обучения 0,0002

  • Коэффициент градиентного распада 0,5

  • Квадратный коэффициент распада градиента 0,999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обновляйте графики процесса обучения каждые 100 итераций.

validationFrequency = 100;

Если дискриминатор учится слишком быстро различать реальные и сгенерированные изображения, то генератор может не обучаться. Чтобы лучше сбалансировать обучение дискриминатора и генератора, случайным образом разверните метки части реальных изображений. Задайте коэффициент отражения 0,5.

flipFactor = 0.5;

Обучите модель

Обучите модель с помощью пользовательского цикла обучения. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью удерживаемого массива случайных значений для ввода в генератор и сетевые счета.

Использование minibatchqueue обрабатывать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch (определено в конце этого примера), чтобы переформулировать изображения в области значений [-1,1].

  • Отбрасывайте любые частичные мини-пакеты с менее чем 128 наблюдениями.

  • Форматируйте данные изображения с помощью меток размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Форматируйте данные метки с помощью меток размерности 'BC' (пакет, канал).

  • Обучите на графическом процессоре, если он доступен. Когда 'OutputEnvironment' опция minibatchqueue является 'auto', minibatchqueue преобразует каждый выход в gpuArray при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

The minibatchqueue объект по умолчанию преобразует данные в dlarray объекты с базовым типом single.

augimds.MiniBatchSize = miniBatchSize;
executionEnvironment = "auto";

mbq = minibatchqueue(augimds,...
    'MiniBatchSize',miniBatchSize,...
    'PartialMiniBatch','discard',...
    'MiniBatchFcn', @preprocessData,...
    'MiniBatchFormat',{'SSCB','BC'},...
    'OutputEnvironment',executionEnvironment);    

Инициализируйте параметры для оптимизатора Адама.

velocityDiscriminator = [];
trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Инициализируйте график процесса обучения. Создайте рисунок и измените ее размер, чтобы она имела вдвое большую ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте подграфики сгенерированных изображений и графика счетов.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализируйте анимированные линии для графика счетов.

lineScoreGenerator = animatedline(scoreAxes,'Color',[0 0.447 0.741]);
lineScoreDiscriminator = animatedline(scoreAxes, 'Color', [0.85 0.325 0.098]);

Настройка внешнего вида графиков.

legend('Generator','Discriminator');
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

Чтобы контролировать процесс обучения, создайте задержанный пакет из 25 случайных векторов и соответствующий набор меток с 1 по 5 (соответствующий классам), повторенных пять раз.

numValidationImagesPerClass = 5;
ZValidation = randn(numLatentInputs,numValidationImagesPerClass*numClasses,'single');

TValidation = single(repmat(1:numClasses,[1 numValidationImagesPerClass]));

Преобразуйте данные в dlarray Объекты и задайте метки размерности 'CB' (канал, пакет).

dlZValidation = dlarray(ZValidation,'CB');
dlTValidation = dlarray(TValidation,'CB');

Для обучения графический процессор преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
    dlTValidation = gpuArray(dlTValidation);
end

Обучите условную GAN. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных.

Для каждого мини-пакета:

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция.

  • Обновляйте параметры сети с помощью adamupdate функция.

  • Постройте график счетов двух сетей.

  • После каждого validationFrequency итерации, отображают пакет сгенерированных изображений для фиксированного входа задержанного генератора.

Обучение может занять некоторое время.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Reset and shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX,dlT] = next(mbq);
        
        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels 'CB' (channel, batch).
        % If training on a GPU, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,'single');
        dlZ = dlarray(Z,'CB');
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlT, dlZ, flipFactor);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation,dlTValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(dlXGeneratedValidation), ...
                'GridSize',[numValidationImagesPerClass numClasses]);
            I = rescale(I);
            
            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end
        
        % Update the scores plot.
        subplot(1,2,2)
        addpoints(lineScoreGenerator,iteration,...
            double(gather(extractdata(scoreGenerator))));
        
        addpoints(lineScoreDiscriminator,iteration,...
            double(gather(extractdata(scoreDiscriminator))));
        
        % Update the title with training progress information.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))
        
        drawnow
    end
end

Здесь дискриминатор выучил сильное представление функций, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор выучил аналогичное сильное представление функций, которое позволяет ему генерировать изображения, подобные обучающим данным. Каждый столбец соответствует одному классу.

График обучения показывает счета сетей генератора и дискриминатора. Дополнительные сведения о том, как интерпретировать счета сети, см. в разделах Мониторинг процесса обучения GAN и Идентификация распространенных типах отказа.

Сгенерируйте новые изображения

Чтобы сгенерировать новые изображения конкретного класса, используйте predict функция на генераторе с dlarray объект, содержащий пакет случайных векторов и массив меток, соответствующих желаемым классам. Преобразуйте данные в dlarray Объекты и задайте метки размерности 'CB' (канал, пакет). Для предсказания графический процессор преобразуйте данные в gpuArray объекты. Чтобы отобразить изображения вместе, используйте imtile функциональность и изменение масштаба изображений с помощью rescale функция.

Создайте массив из 36 векторов случайных значений, соответствующих первому классу.

numObservationsNew = 36;
idxClass = 1;
Z = randn(numLatentInputs,numObservationsNew,'single');
T = repmat(single(idxClass),[1 numObservationsNew]);

Преобразуйте данные в dlarray объекты с метками размерностей 'SSCB' (пространственный, пространственный, каналы, пакет).

dlZ = dlarray(Z,'CB');
dlT = dlarray(T,'CB');

Чтобы сгенерировать изображения с помощью графический процессор, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZ = gpuArray(dlZ);
    dlT = gpuArray(dlT);
end

Сгенерируйте изображения с помощью predict функция с сетью генератора.

dlXGenerated = predict(dlnetGenerator,dlZ,dlT);

Отобразите сгенерированные изображения на графике.

figure
I = imtile(extractdata(dlXGenerated));
I = rescale(I);
imshow(I)
title("Class: " + classes(idxClass))

Здесь сеть генератора генерирует изображения, обусловленные заданным классом.

Функция градиентов модели

Функция modelGradients принимает за вход генератор и дискриминатор dlnetwork объекты dlnetGenerator и dlnetDiscriminatorмини-пакет входных данных dlX, соответствующие метки dlTи массив случайных значений dlZ, и возвращает градиенты потерь относительно настраиваемых параметров в сетях, состоянии генератора и счетах сети.

Если дискриминатор учится слишком быстро различать реальные и сгенерированные изображения, то генератор может не обучаться. Чтобы лучше сбалансировать обучение дискриминатора и генератора, случайным образом разверните метки части реальных изображений.

function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlT, dlZ, flipFactor)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX, dlT);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator, dlZ, dlT);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated, dlT);

% Calculate probabilities.
probGenerated = sigmoid(dlYPredGenerated);
probReal = sigmoid(dlYPred);

% Calculate the generator and discriminator scores.
scoreGenerator = mean(probGenerated);
scoreDiscriminator = (mean(probReal) + mean(1-probGenerated)) / 2;

% Flip labels.
numObservations = size(dlYPred,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossGenerator, lossDiscriminator] = ganLoss(probReal, probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

Функция потерь GAN

Цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как "real". Чтобы максимизировать вероятность того, что изображения с генератора классифицируются дискриминатором как вещественные, минимизируйте отрицательную функцию журнала правдоподобие.

Учитывая выход Y дискриминатора:

  • Yˆ=σ(Y) - вероятность того, что вход изображение принадлежит классу "real".

  • 1-Yˆ - вероятность того, что вход изображение принадлежит классу "generated".

Обратите внимание на операцию сигмоида σ происходит в modelGradients функция. Функция потерь для генератора задается как

lossGenerator=-mean(log(YˆGenerated)),

где YˆGenerated содержит выходные вероятности дискриминатора для сгенерированных изображений.

Цель дискриминатора - не быть «обманутым» генератором. Чтобы максимизировать вероятность того, что дискриминатор успешно различает действительное и сгенерированное изображения, минимизируйте сумму соответствующих отрицательных функций журнала функций правдоподобия. Функция потерь для дискриминатора определяется

lossDiscriminator=-mean(log(YˆReal))-mean(log(1-YˆGenerated)),

где YˆReal содержит выходные вероятности дискриминатора для действительных изображений.

function [lossGenerator, lossDiscriminator] = ganLoss(scoresReal,scoresGenerated)

% Calculate losses for the discriminator network.
lossGenerated = -mean(log(1 - scoresGenerated));
lossReal = -mean(log(scoresReal));

% Combine the losses for the discriminator network.
lossDiscriminator = lossReal + lossGenerated;

% Calculate the loss for the generator network.
lossGenerator = -mean(log(scoresGenerated));

end

Функция мини-пакетной предварительной обработки

The preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките изображение и пометьте данные из входных массивов ячеек и соедините их в числовые массивы.

  2. Перерассчитайте изображения, которые будут находиться в области значений [-1,1].

function [X,T] = preprocessData(XCell,TCell)

% Extract image data from cell and concatenate
X = cat(4,XCell{:});

% Extract label data from cell and concatenate
T = cat(1,TCell{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,'InputMin',0,'InputMax',255);

end

Ссылки

См. также

| | | | | |

Похожие темы