Обучите Условную порождающую соперничающую сеть (CGAN)

В этом примере показано, как обучить условную порождающую соперничающую сеть, чтобы сгенерировать изображения.

Порождающая соперничающая сеть (GAN) является типом нейронной сети для глубокого обучения, которая может сгенерировать данные с подобными характеристиками как входные обучающие данные.

ГАНЬ состоит из двух сетей, которые обучаются вместе:

  1. Генератор —, Учитывая вектор из случайных значений, как введено, эта сеть генерирует данные с той же структурой как обучающие данные.

  2. Различитель — Данный пакеты данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "real" или "generated".

Условная порождающая соперничающая сеть (CGAN) является типом GAN, который также использует в своих интересах метки во время учебного процесса.

  1. Генератор —, Учитывая метку и случайный массив, как введено, эта сеть генерирует данные с той же структурой как наблюдения обучающих данных, соответствующие той же метке.

  2. Различитель — Данный пакеты маркированных данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "real" или "generated".

Чтобы обучить условный GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обоих:

  • Обучите генератор генерировать данные, которые "дурачат" различитель.

  • Обучите различитель различать действительные и сгенерированные данные.

Чтобы максимизировать эффективность генератора, максимизируйте потерю различителя, когда дали сгенерированные маркированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать маркированные данные, которые различитель классифицирует как "real".

Чтобы максимизировать эффективность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных маркированных данных. Таким образом, цель различителя не состоит в том, чтобы "дурачить" генератор.

Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные, которые соответствуют входным меткам и различителю, который изучил представления сильной черты, которые являются характеристическими для обучающих данных для каждой метки.

Загрузите обучающие данные

Загрузите и извлеките Цветочный набор данных [1].

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    disp("Downloading Flowers data set (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте datastore изображений, содержащий фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder,IncludeSubfolders=true,LabelSource="foldernames");

Просмотрите количество классов.

classes = categories(imds.Labels);
numClasses = numel(classes)
numClasses = 5

Увеличьте данные, чтобы включать случайное горизонтальное зеркальное отражение и изменить размер изображений, чтобы иметь размер 64 64.

augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);

Задайте сеть генератора

Задайте следующую 2D входную сеть, которая генерирует изображения, данные случайные векторы из размера 100 и соответствующие метки.

Эта сеть:

  • Преобразует случайные векторы из размера от 100 до 4 4 1 024 массивами с помощью полносвязного слоя, сопровождаемого изменять операцией.

  • Преобразует категориальные метки во встраивание векторов и изменяет их к массиву 4 на 4.

  • Конкатенации получившихся изображений от двух входных параметров по измерению канала. Выход является 4 4 1 025 массивами.

  • Увеличивает масштаб полученные массивы к 64 64 3 массивами с помощью серии транспонированных слоев свертки со слоев ReLU и нормализацией партии.

Задайте эту сетевую архитектуру как график слоев и задайте следующие сетевые свойства.

  • Для категориальных входных параметров используйте размерность встраивания 50.

  • Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, шага 2, и "same" обрезка выхода.

  • Поскольку финал транспонировал слой свертки, задайте три фильтра 5 на 5, соответствующие трем каналам RGB сгенерированных изображений.

  • В конце сети включайте tanh слой.

Чтобы спроектировать и изменить шумовой вход, используйте полносвязный слой, сопровождаемый изменять операцией specifed как функциональный слой с функцией, данной feature2image функция, присоединенная к этому примеру как вспомогательный файл. Чтобы встроить категориальные метки, используйте пользовательский слой embeddingLayer присоединенный к этому примеру как вспомогательный файл. Чтобы получить доступ к этим вспомогательным файлам, откройте пример как live скрипт.

numLatentInputs = 100;
embeddingDimension = 50;
numFilters = 64;

filterSize = 5;
projectionSize = [4 4 1024];

layersGenerator = [
    featureInputLayer(numLatentInputs)
    fullyConnectedLayer(prod(projectionSize))
    functionLayer(@(X) feature2image(X,projectionSize),Formattable=true)
    concatenationLayer(3,2,Name="cat");
    transposedConv2dLayer(filterSize,4*numFilters)
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
    tanhLayer];

lgraphGenerator = layerGraph(layersGenerator);

layers = [
    featureInputLayer(1)
    embeddingLayer(embeddingDimension,numClasses)
    fullyConnectedLayer(prod(projectionSize(1:2)))
    functionLayer(@(X) feature2image(X,[projectionSize(1:2) 1]),Formattable=true,Name="emb_reshape")];

lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,"emb_reshape","cat/in2");

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetGenerator = dlnetwork(lgraphGenerator)
dlnetGenerator = 
  dlnetwork with properties:

         Layers: [19×1 nnet.cnn.layer.Layer]
    Connections: [18×2 table]
     Learnables: [19×3 table]
          State: [6×3 table]
     InputNames: {'input'  'input_1'}
    OutputNames: {'layer_2'}
    Initialized: 1

Задайте сеть различителя

Задайте следующую 2D входную сеть, которая классифицирует действительный, и сгенерированный 64 64 отображает, учитывая набор изображений и соответствующих меток.

Создайте сеть, которая берет в качестве входа 64 64 1 изображением и соответствующими метками и выводит скалярный счет предсказания с помощью серии слоев свертки с нормализацией партии. и текучих слоев ReLU. Добавьте шум во входные изображения с помощью уволенного.

  • Для слоя уволенного задайте вероятность уволенного 0,75.

  • Для слоев свертки задайте фильтры 5 на 5 с растущим числом фильтров для каждого слоя. Также задайте шаг 2 и дополнение выхода на каждом ребре.

  • Для текучих слоев ReLU задайте шкалу 0,2.

  • Для последнего слоя задайте слой свертки с одним фильтром 4 на 4.

dropoutProb = 0.75;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none")
    dropoutLayer(dropoutProb)
    concatenationLayer(3,2,Name="cat")
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(4,1)];

lgraphDiscriminator = layerGraph(layersDiscriminator);

layers = [
    featureInputLayer(1)
    embeddingLayer(embeddingDimension,numClasses)
    fullyConnectedLayer(prod(inputSize(1:2)))
    functionLayer(@(X) feature2image(X,[inputSize(1:2) 1]),Formattable=true,Name="emb_reshape")];

lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,"emb_reshape","cat/in2");

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetDiscriminator = dlnetwork(lgraphDiscriminator)
dlnetDiscriminator = 
  dlnetwork with properties:

         Layers: [19×1 nnet.cnn.layer.Layer]
    Connections: [18×2 table]
     Learnables: [19×3 table]
          State: [6×3 table]
     InputNames: {'imageinput'  'input'}
    OutputNames: {'conv_5'}
    Initialized: 1

Градиенты модели Define и функции потерь

Создайте функциональный modelGradients, перечисленный в разделе Model Gradients Function примера, который берет в качестве входа генератор и сети различителя, мини-пакет входных данных и массив случайных значений, и возвращает градиенты потери относительно настраиваемых параметров в сетях и массиве сгенерированных изображений.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 128 в течение 500 эпох.

numEpochs = 500;
miniBatchSize = 128;

Задайте опции для оптимизации Адама. Для обеих сетей используйте:

  • Скорость обучения 0,0002

  • Фактор затухания градиента 0,5

  • Градиент в квадрате затухает фактор 0,999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обновитесь процесс обучения строит каждые 100 итераций.

validationFrequency = 100;

Если различитель учится различать между действительными и сгенерированными изображениями слишком быстро, то генератор может не обучаться. Чтобы лучше сбалансировать приобретение знаний о различителе и генераторе, случайным образом инвертируйте метки пропорции действительных изображений. Задайте зеркально отраженный фактор 0,5.

flipFactor = 0.5;

Обучите модель

Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого массива случайных значений, чтобы ввести в генератор и сетевые баллы.

Используйте minibatchqueue обработать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера), чтобы перемасштабировать изображения в области значений [-1,1].

  • Отбросьте любые частичные мини-пакеты меньше чем с 128 наблюдениями.

  • Формат данные изображения с размерностью маркирует "SSCB" (пространственный, пространственный, канал, пакет).

  • Формат данные о метке с размерностью маркирует "BC" (пакет, канал).

  • Обучайтесь на графическом процессоре, если вы доступны. Когда OutputEnvironment опция minibatchqueue "auto", minibatchqueue преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

minibatchqueue объект, по умолчанию, преобразует данные в dlarray объекты с базовым типом single.

augimds.MiniBatchSize = miniBatchSize;
executionEnvironment = "auto";

mbq = minibatchqueue(augimds, ...
    MiniBatchSize=miniBatchSize, ...
    PartialMiniBatch="discard", ...
    MiniBatchFcn=@preprocessData, ...
    MiniBatchFormat=["SSCB" "BC"], ...
    OutputEnvironment=executionEnvironment);    

Инициализируйте параметры для оптимизатора Адама.

velocityDiscriminator = [];
trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Инициализируйте график процесса обучения. Создайте фигуру и измените размер его, чтобы иметь дважды ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте подграфики сгенерированных изображений и графика баллов.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализируйте анимированные линии для графика баллов.

lineScoreGenerator = animatedline(scoreAxes,Color=[0 0.447 0.741]);
lineScoreDiscriminator = animatedline(scoreAxes,Color=[0.85 0.325 0.098]);

Настройте внешний вид графиков.

legend("Generator","Discriminator");
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

Чтобы контролировать процесс обучения, создайте протянутый пакет 25 случайных векторов, и соответствующий набор маркирует 1 through 5 (соответствующий классам), повторился пять раз.

numValidationImagesPerClass = 5;
ZValidation = randn(numLatentInputs,numValidationImagesPerClass*numClasses,"single");

TValidation = single(repmat(1:numClasses,[1 numValidationImagesPerClass]));

Преобразуйте данные в dlarray объекты и указывают, что размерность маркирует "CB" (образуйте канал, пакет).

dlZValidation = dlarray(ZValidation,"CB");
dlTValidation = dlarray(TValidation,"CB");

Для обучения графического процессора преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
    dlTValidation = gpuArray(dlTValidation);
end

Обучите условный GAN. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных.

Для каждого мини-пакета:

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • Постройте множество этих двух сетей.

  • После каждого validationFrequency итерации, отобразитесь, пакет сгенерированных изображений для фиксированного протянул вход генератора.

Обучение может занять время, чтобы запуститься.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Reset and shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX,dlT] = next(mbq);
        
        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels "CB" (channel, batch).
        % If training on a GPU, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,"single");
        dlZ = dlarray(Z,"CB");
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlT, dlZ, flipFactor);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation,dlTValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(dlXGeneratedValidation), ...
                GridSize=[numValidationImagesPerClass numClasses]);
            I = rescale(I);
            
            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end
        
        % Update the scores plot.
        subplot(1,2,2)
        addpoints(lineScoreGenerator,iteration,...
            double(gather(extractdata(scoreGenerator))));
        
        addpoints(lineScoreDiscriminator,iteration,...
            double(gather(extractdata(scoreDiscriminator))));
        
        % Update the title with training progress information.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))
        
        drawnow
    end
end

Здесь, различитель изучил представление сильной черты, которое идентифицирует действительные изображения среди сгенерированных изображений. В свою очередь генератор изучил представление столь же сильной черты, которое позволяет ему генерировать изображения, похожие на обучающие данные. Каждый столбец соответствует единому классу.

Учебный график показывает множество сетей различителя и генератора. Чтобы узнать больше, как интерпретировать сетевые баллы, смотрите Монитор Процесс обучения GAN и Идентифицируйте Общие Типы отказа.

Сгенерируйте новые изображения

Чтобы сгенерировать новые изображения конкретного класса, используйте predict функция на генераторе с dlarray объект, содержащий пакет случайных векторов и массив меток, соответствующих желаемым классам. Преобразуйте данные в dlarray объекты и указывают, что размерность маркирует "CB" (образуйте канал, пакет). Для предсказания графического процессора преобразуйте данные в gpuArray объекты. Чтобы отобразить изображения вместе, используйте imtile функционируйте и перемасштабируйте изображения с помощью rescale функция.

Создайте массив 36 векторов из случайных значений, соответствующих первому классу.

numObservationsNew = 36;
idxClass = 1;
Z = randn(numLatentInputs,numObservationsNew,"single");
T = repmat(single(idxClass),[1 numObservationsNew]);

Преобразуйте данные в dlarray объекты с размерностью маркируют "SSCB" (пространственный, пространственный, каналы, пакет).

dlZ = dlarray(Z,"CB");
dlT = dlarray(T,"CB");

Чтобы сгенерировать изображения с помощью графического процессора, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZ = gpuArray(dlZ);
    dlT = gpuArray(dlT);
end

Сгенерируйте изображения с помощью predict функция с сетью генератора.

dlXGenerated = predict(dlnetGenerator,dlZ,dlT);

Отобразите сгенерированные изображения в графике.

figure
I = imtile(extractdata(dlXGenerated));
I = rescale(I);
imshow(I)
title("Class: " + classes(idxClass))

Здесь, сеть генератора генерирует изображения, обусловленные на заданном классе.

Функция градиентов модели

Функциональный modelGradients берет в качестве входа генератор и различитель dlnetwork объекты dlnetGenerator и dlnetDiscriminator, мини-пакет входных данных dlX, соответствие маркирует dlT, и массив случайных значений dlZ, и возвращает градиенты потери относительно настраиваемых параметров в сетях, состоянии генератора и сетевых баллах.

Если различитель учится различать между действительными и сгенерированными изображениями слишком быстро, то генератор может не обучаться. Чтобы лучше сбалансировать приобретение знаний о различителе и генераторе, случайным образом инвертируйте метки пропорции действительных изображений.

function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlT, dlZ, flipFactor)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX, dlT);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator, dlZ, dlT);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated, dlT);

% Calculate probabilities.
probGenerated = sigmoid(dlYPredGenerated);
probReal = sigmoid(dlYPred);

% Calculate the generator and discriminator scores.
scoreGenerator = mean(probGenerated);
scoreDiscriminator = (mean(probReal) + mean(1-probGenerated)) / 2;

% Flip labels.
numObservations = size(dlYPred,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossGenerator, lossDiscriminator] = ganLoss(probReal, probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,RetainData=true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

Функция потерь ГАНЯ

Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "real". Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте отрицательную логарифмическую функцию правдоподобия.

Учитывая выход Y из различителя:

  • Yˆ=σ(Y) вероятность, что входное изображение принадлежит классу "real".

  • 1-Yˆ вероятность, что входное изображение принадлежит классу "generated".

Отметьте сигмоидальную операцию σ происходит в modelGradients функция. Функцией потерь для генератора дают

lossGenerator=-mean(log(YˆGenerated)),

где YˆGenerated содержит различитель выходные вероятности для сгенерированных изображений.

Цель различителя не состоит в том, чтобы "дурачить" генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте сумму соответствующих отрицательных логарифмических функций правдоподобия. Функцией потерь для различителя дают

lossDiscriminator=-mean(log(YˆReal))-mean(log(1-YˆGenerated)),

где YˆReal содержит различитель выходные вероятности для действительных изображений.

function [lossGenerator, lossDiscriminator] = ganLoss(scoresReal,scoresGenerated)

% Calculate losses for the discriminator network.
lossGenerated = -mean(log(1 - scoresGenerated));
lossReal = -mean(log(scoresReal));

% Combine the losses for the discriminator network.
lossDiscriminator = lossReal + lossGenerated;

% Calculate the loss for the generator network.
lossGenerator = -mean(log(scoresGenerated));

end

Функция предварительной обработки мини-пакета

preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките изображение и пометьте данные из входных массивов ячеек и конкатенируйте их в числовые массивы.

  2. Перемасштабируйте изображения, чтобы быть в области значений [-1,1].

function [X,T] = preprocessData(XCell,TCell)

% Extract image data from cell and concatenate
X = cat(4,XCell{:});

% Extract label data from cell and concatenate
T = cat(1,TCell{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,InputMin=0,InputMax=255);

end

Ссылки

Смотрите также

| | | | | |

Похожие темы