exponenta event banner

Условная состязательная сеть поездов (CGAN)

В этом примере показано, как обучить условную порождающую состязательную сеть генерировать изображения.

Генеративная состязательная сеть (GAN) - это тип сети глубокого обучения, которая может генерировать данные с теми же характеристиками, что и входные обучающие данные.

GAN состоит из двух сетей, которые взаимодействуют друг с другом:

  1. Генератор - учитывая вектор случайных значений в качестве входных данных, эта сеть генерирует данные с той же структурой, что и обучающие данные.

  2. Дискриминатор - данные, содержащие наблюдения как из обучающих данных, так и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "real" или "generated".

Условная генеративная состязательная сеть (CGAN) - это тип GAN, который также использует преимущества меток в процессе обучения.

  1. Генератор - учитывая метку и случайный массив в качестве входных данных, эта сеть генерирует данные с той же структурой, что и наблюдения обучающих данных, соответствующие той же метке.

  2. Дискриминатор - учитывая партии помеченных данных, содержащих наблюдения как из обучающих данных, так и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "real" или "generated".

Для обучения условной GAN одновременно обучайте обе сети, чтобы максимизировать производительность обеих:

  • Обучите генератор генерировать данные, которые «обманывают» дискриминатор.

  • Обучите дискриминатор различать реальные и сгенерированные данные.

Чтобы максимизировать производительность генератора, максимизируйте потери дискриминатора при получении сгенерированных помеченных данных. То есть целью генератора является генерирование помеченных данных, которые дискриминатор классифицирует как "real".

Чтобы максимизировать производительность дискриминатора, минимизируйте потери дискриминатора при предоставлении партий как реальных, так и сгенерированных помеченных данных. То есть цель дискриминатора - не быть «обманутым» генератором.

В идеале эти стратегии приводят к созданию генератора, который генерирует убедительно реалистичные данные, соответствующие входным меткам, и дискриминатора, который усвоил сильные представления признаков, характерные для обучающих данных для каждой метки.

Загрузка данных обучения

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте хранилище данных изображения, содержащее фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

Просмотр количества классов.

classes = categories(imds.Labels);
numClasses = numel(classes)
numClasses = 5

Дополните данные, чтобы включить случайное горизонтальное переворачивание и изменить размер изображений, чтобы они имели размер 64 на 64.

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);

Определение сети генератора

Определите следующую сеть с двумя входами, которая генерирует изображения, заданные случайными векторами размера 100 и соответствующими метками.

Эта сеть:

  • Преобразовывает случайные векторы размера 100 к 4 на 4 1 024 множествами.

  • Преобразует категориальные метки во встраиваемые векторы и изменяет их форму в массив 4 на 4.

  • Конкатенация полученных изображений с двумя входами по размеру канала. Продукция составляет 4 на 4 1 025 множествами.

  • Масштабирует результирующие массивы до массивов 64 на 3, используя ряд транспонированных слоев свертки с пакетной нормализацией и уровнями ReLU.

Определите эту сетевую архитектуру как график уровня и укажите следующие свойства сети.

  • Для категориальных входных данных используйте размер вложения 50.

  • Для транспонированных слоев свертки укажите фильтры 5 на 5 с уменьшающимся числом фильтров для каждого слоя, шагом 2 и "same" обрезка выходных данных.

  • Для конечного транспонированного слоя свертки укажите три фильтра 5 на 5, соответствующих трем каналам RGB генерируемых изображений.

  • В конце сети включите уровень tanh.

Для проецирования и изменения формы входного шума используйте пользовательский слой projectAndReshapeLayer, прилагается к этому примеру как вспомогательный файл. projectAndReshapeLayer объект увеличивает масштаб входа, используя полностью подключенную операцию, и изменяет форму выхода на указанный размер.

Для ввода меток в сеть используйте featureInputLayer и укажите один элемент. Для встраивания и изменения формы вводимой метки используйте пользовательский слой embedAndReshapeLayer, прилагается к этому примеру как вспомогательный файл. embedAndReshapeLayer объект преобразует категориальную метку в одноканальное изображение заданного размера с помощью встраивания и полностью связанной операции.

numLatentInputs = 100;
embeddingDimension = 50;
numFilters = 64;

filterSize = 5;
projectionSize = [4 4 1024];

layersGenerator = [
    featureInputLayer(numLatentInputs,'Name','noise')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj');
    concatenationLayer(3,2,'Name','cat');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphGenerator = layerGraph(layersGenerator);

layers = [
    featureInputLayer(1,'Name','labels')
    embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'Name','emb')];

lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,'emb','cat/in2');

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetGenerator = dlnetwork(lgraphGenerator)
dlnetGenerator = 
  dlnetwork with properties:

         Layers: [16×1 nnet.cnn.layer.Layer]
    Connections: [15×2 table]
     Learnables: [19×3 table]
          State: [6×3 table]
     InputNames: {'noise'  'labels'}
    OutputNames: {'tanh'}

Определение сети дискриминаторов

Определите следующую сеть с двумя входами, которая классифицирует реальные и сгенерированные изображения 64 на 64 при наличии набора изображений и соответствующих меток.

Создайте сеть, которая берет в качестве входа изображения 64 на 64 на 1 и соответствующие этикетки и производит скалярный счет прогноза, используя серию слоев скручивания с пакетной нормализацией и прохудившихся слоев ReLU. Добавление шума к входным изображениям с помощью отсева.

  • Для уровня отсева укажите вероятность отсева 0,75.

  • Для слоев свертки укажите фильтры 5 на 5 с увеличением числа фильтров для каждого слоя. Также укажите шаг 2 и заполнение выходных данных на каждом ребре.

  • Для протекающих слоев ReLU укажите масштаб 0,2.

  • Для конечного слоя укажите слой свертки с одним фильтром 4 на 4.

dropoutProb = 0.75;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,'Normalization','none','Name','images')
    dropoutLayer(dropoutProb,'Name','dropout')
    concatenationLayer(3,2,'Name','cat')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    batchNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')];

lgraphDiscriminator = layerGraph(layersDiscriminator);

layers = [
    featureInputLayer(1,'Name','labels')
    embedAndReshapeLayer(inputSize(1:2),embeddingDimension,numClasses,'Name','emb')];

lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,'emb','cat/in2');

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetDiscriminator = dlnetwork(lgraphDiscriminator)
dlnetDiscriminator = 
  dlnetwork with properties:

         Layers: [17×1 nnet.cnn.layer.Layer]
    Connections: [16×2 table]
     Learnables: [19×3 table]
          State: [6×3 table]
     InputNames: {'images'  'labels'}
    OutputNames: {'conv5'}

Определение градиентов модели и функций потерь

Создание функции modelGradients, перечисленных в разделе «Функция градиентов модели» примера, который принимает в качестве входных данных генераторные и дискриминаторные сети, мини-пакет входных данных и массив случайных значений, и возвращает градиенты потерь относительно обучаемых параметров в сетях и массив сгенерированных изображений.

Укажите параметры обучения

Поезд с размером мини-партии 128 на 500 эпох.

numEpochs = 500;
miniBatchSize = 128;

Укажите параметры оптимизации Adam. Для обеих сетей используйте:

  • Коэффициент обучения 0,0002

  • Коэффициент градиентного спада 0,5

  • Квадрат градиентного коэффициента распада 0,999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обновите графики хода обучения каждые 100 итераций.

validationFrequency = 100;

Если дискриминатор учится слишком быстро различать реальные и сгенерированные изображения, то генератор может не тренироваться. Чтобы лучше сбалансировать обучение дискриминатора и генератора, случайным образом переверните метки части реальных изображений. Задайте коэффициент отражения 0,5.

flipFactor = 0.5;

Модель поезда

Обучение модели с помощью пользовательского цикла обучения. Закольцовывать обучающие данные и обновлять сетевые параметры на каждой итерации. Чтобы контролировать ход обучения, отобразите пакет сгенерированных изображений, используя задержанный массив случайных значений для ввода в генератор и оценки сети.

Использовать minibatchqueue для обработки и управления мини-партиями изображений во время обучения. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определено в конце этого примера) для масштабирования изображений в диапазоне [-1,1].

  • Удалите все частичные мини-партии с менее чем 128 наблюдениями.

  • Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Форматирование данных метки с метками размеров 'BC' (партия, канал).

  • Обучение на GPU, если он доступен. Когда 'OutputEnvironment' вариант minibatchqueue является 'auto', minibatchqueue преобразует каждый выход в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

minibatchqueue по умолчанию преобразует данные в dlarray объекты с базовым типом single.

augimds.MiniBatchSize = miniBatchSize;
executionEnvironment = "auto";

mbq = minibatchqueue(augimds,...
    'MiniBatchSize',miniBatchSize,...
    'PartialMiniBatch','discard',...
    'MiniBatchFcn', @preprocessData,...
    'MiniBatchFormat',{'SSCB','BC'},...
    'OutputEnvironment',executionEnvironment);    

Инициализируйте параметры оптимизатора Adam.

velocityDiscriminator = [];
trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Инициализируйте график хода обучения. Создайте фигуру и измените ее размер, чтобы она имела удвоенную ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте вложенные графики созданных изображений и графика оценок.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализация анимированных линий для графика показателей.

lineScoreGenerator = animatedline(scoreAxes,'Color',[0 0.447 0.741]);
lineScoreDiscriminator = animatedline(scoreAxes, 'Color', [0.85 0.325 0.098]);

Настройка внешнего вида графиков.

legend('Generator','Discriminator');
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

Чтобы контролировать ход обучения, создайте задержанную партию из 25 случайных векторов и соответствующего набора меток 1-5 (соответствующих классам), повторенных пять раз.

numValidationImagesPerClass = 5;
ZValidation = randn(numLatentInputs,numValidationImagesPerClass*numClasses,'single');

TValidation = single(repmat(1:numClasses,[1 numValidationImagesPerClass]));

Преобразовать данные в dlarray объекты и указать метки размеров 'CB' (канал, партия).

dlZValidation = dlarray(ZValidation,'CB');
dlTValidation = dlarray(TValidation,'CB');

Для обучения GPU преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
    dlTValidation = gpuArray(dlTValidation);
end

Обучение условного GAN. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных.

Для каждой мини-партии:

  • Оценка градиентов модели с помощью dlfeval и modelGradients функция.

  • Обновление параметров сети с помощью adamupdate функция.

  • Постройте графики двух сетей.

  • После каждого validationFrequency итерации, отображение пакета сгенерированных изображений для фиксированного ввода генератора с задержкой.

Обучение может занять некоторое время.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Reset and shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX,dlT] = next(mbq);
        
        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels 'CB' (channel, batch).
        % If training on a GPU, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,'single');
        dlZ = dlarray(Z,'CB');
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlT, dlZ, flipFactor);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation,dlTValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(dlXGeneratedValidation), ...
                'GridSize',[numValidationImagesPerClass numClasses]);
            I = rescale(I);
            
            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end
        
        % Update the scores plot.
        subplot(1,2,2)
        addpoints(lineScoreGenerator,iteration,...
            double(gather(extractdata(scoreGenerator))));
        
        addpoints(lineScoreDiscriminator,iteration,...
            double(gather(extractdata(scoreDiscriminator))));
        
        % Update the title with training progress information.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))
        
        drawnow
    end
end

Здесь дискриминатор усвоил сильное характерное представление, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор усвоил такое же сильное представление признаков, которое позволяет ему генерировать изображения, подобные обучающим данным. Каждый столбец соответствует одному классу.

На обучающем графике показаны баллы сетей генератора и дискриминатора. Дополнительные сведения о том, как интерпретировать оценки сети, см. в разделах Мониторинг хода обучения GAN и Определение общих видов отказов.

Создать новые изображения

Для создания новых изображений определенного класса используйте predict функция на генераторе с dlarray объект, содержащий пакет случайных векторов и массив меток, соответствующих требуемым классам. Преобразовать данные в dlarray объекты и указать метки размеров 'CB' (канал, партия). Для прогнозирования GPU преобразуйте данные в gpuArray объекты. Чтобы отобразить изображения вместе, используйте imtile и масштабировать изображения с помощью rescale функция.

Создайте массив из 36 векторов случайных значений, соответствующих первому классу.

numObservationsNew = 36;
idxClass = 1;
Z = randn(numLatentInputs,numObservationsNew,'single');
T = repmat(single(idxClass),[1 numObservationsNew]);

Преобразовать данные в dlarray объекты с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный).

dlZ = dlarray(Z,'CB');
dlT = dlarray(T,'CB');

Чтобы создать изображения с помощью графического процессора, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZ = gpuArray(dlZ);
    dlT = gpuArray(dlT);
end

Создание изображений с помощью predict работа с генераторной сетью.

dlXGenerated = predict(dlnetGenerator,dlZ,dlT);

Отображение созданных изображений на графике.

figure
I = imtile(extractdata(dlXGenerated));
I = rescale(I);
imshow(I)
title("Class: " + classes(idxClass))

Здесь генераторная сеть генерирует изображения, обусловленные указанным классом.

Функция градиентов модели

Функция modelGradients принимает за вход генератор и дискриминатор dlnetwork объекты dlnetGenerator и dlnetDiscriminator, мини-пакет входных данных dlX, соответствующие метки dlTи массив случайных значений dlZи возвращает градиенты потерь относительно обучаемых параметров в сетях, состояния генератора и баллов сети.

Если дискриминатор учится слишком быстро различать реальные и сгенерированные изображения, то генератор может не тренироваться. Чтобы лучше сбалансировать обучение дискриминатора и генератора, случайным образом переверните метки части реальных изображений.

function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlT, dlZ, flipFactor)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX, dlT);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator, dlZ, dlT);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated, dlT);

% Calculate probabilities.
probGenerated = sigmoid(dlYPredGenerated);
probReal = sigmoid(dlYPred);

% Calculate the generator and discriminator scores.
scoreGenerator = mean(probGenerated);
scoreDiscriminator = (mean(probReal) + mean(1-probGenerated)) / 2;

% Flip labels.
numObservations = size(dlYPred,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossGenerator, lossDiscriminator] = ganLoss(probReal, probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

Функция потери GAN

Целью генератора является генерирование данных, которые дискриминатор классифицирует как "real". Чтобы максимизировать вероятность того, что изображения из генератора классифицируются дискриминатором как реальные, минимизируйте функцию отрицательного логарифмического правдоподобия.

Учитывая выход Y дискриминатора:

  • Yˆ=σ (Y) - вероятность того, что входное изображение принадлежит классу "real".

  • 1-Yˆ - вероятность того, что входное изображение принадлежит классу "generated".

Обратите внимание на то, что сигмоидальная операция λ происходит в modelGradients функция. Функция потерь для генератора задается

lossGenerator = -mean (log (YˆGenerated)),

где YˆGenerated содержит вероятности вывода дискриминатора для сгенерированных изображений.

Цель дискриминатора - не быть «обманутым» генератором. Чтобы максимизировать вероятность того, что дискриминатор успешно различает действительное и сгенерированное изображения, минимизируйте сумму соответствующих функций отрицательного логарифмического правдоподобия. Функция потерь для дискриминатора задается

lossDiscriminator = -mean (log (YˆReal)) -mean (log (1-YˆGenerated)),

где YˆReal содержит вероятности вывода дискриминатора для реальных изображений.

function [lossGenerator, lossDiscriminator] = ganLoss(scoresReal,scoresGenerated)

% Calculate losses for the discriminator network.
lossGenerated = -mean(log(1 - scoresGenerated));
lossReal = -mean(log(scoresReal));

% Combine the losses for the discriminator network.
lossDiscriminator = lossReal + lossGenerated;

% Calculate the loss for the generator network.
lossGenerator = -mean(log(scoresGenerated));

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция выполняет предварительную обработку данных с помощью следующих шагов:

  1. Извлеките данные изображения и метки из массивов входных ячеек и объедините их в числовые массивы.

  2. Масштабировать изображения в диапазоне [-1,1].

function [X,T] = preprocessData(XCell,TCell)

% Extract image data from cell and concatenate
X = cat(4,XCell{:});

% Extract label data from cell and concatenate
T = cat(1,TCell{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,'InputMin',0,'InputMax',255);

end

Ссылки

См. также

| | | | | |

Связанные темы