exponenta event banner

Генеративная состязательная сеть поездов (GAN)

В этом примере показано, как обучить генеративную состязательную сеть генерировать изображения.

Генеративная состязательная сеть (GAN) - это тип сети глубокого обучения, которая может генерировать данные с теми же характеристиками, что и входные реальные данные.

GAN состоит из двух сетей, которые взаимодействуют друг с другом:

  1. Генератор - учитывая вектор случайных значений (скрытых входных данных) в качестве входных данных, эта сеть генерирует данные с той же структурой, что и обучающие данные.

  2. Дискриминатор - данные, содержащие наблюдения как из обучающих данных, так и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "real" или "generated".

Для обучения GAN одновременно обучайте обе сети, чтобы максимизировать производительность обеих:

  • Обучите генератор генерировать данные, которые «обманывают» дискриминатор.

  • Обучите дискриминатор различать реальные и сгенерированные данные.

Для оптимизации производительности генератора максимизируйте потери дискриминатора при предоставлении сгенерированных данных. То есть целью генератора является генерирование данных, которые дискриминатор классифицирует как "real".

Для оптимизации производительности дискриминатора минимизируйте потери дискриминатора при предоставлении партий как реальных, так и сгенерированных данных. То есть цель дискриминатора - не быть «обманутым» генератором.

В идеале эти стратегии приводят к созданию генератора, который генерирует убедительно реалистичные данные, и дискриминатора, который усвоил сильные представления признаков, характерные для обучающих данных.

Загрузка данных обучения

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте хранилище данных изображения, содержащее фотографии цветов.

datasetFolder = fullfile(imageFolder);

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true);

Дополните данные, чтобы включить случайное горизонтальное переворачивание и изменить размер изображений, чтобы они имели размер 64 на 64.

augmenter = imageDataAugmenter('RandXReflection',true);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);    

Определение сети генератора

Определите следующую архитектуру сети, которая генерирует изображения из случайных векторов размера 100.

Эта сеть:

  • Преобразовывает случайные векторы размера 100 ко множествам 7 на 7 на 128, используя проект, и измените слой.

  • Масштабирует результирующие массивы до массивов 64 на 3, используя ряд транспонированных слоев свертки с пакетной нормализацией и уровнями ReLU.

Определите эту сетевую архитектуру как график уровня и укажите следующие свойства сети.

  • Для транспонированных слоев свертки укажите фильтры 5 на 5 с уменьшающимся числом фильтров для каждого слоя, шагом 2 и обрезкой выходных данных на каждом ребре.

  • Для конечного транспонированного слоя свертки укажите три фильтра 5 на 5, соответствующих трем каналам RGB генерируемых изображений, и размер вывода предыдущего слоя.

  • В конце сети включите уровень tanh.

Для проецирования и изменения формы входного шума используйте пользовательский слой projectAndReshapeLayer, прилагается к этому примеру как вспомогательный файл. projectAndReshapeLayer объект увеличивает масштаб входа, используя полностью подключенную операцию, и изменяет форму выхода на указанный размер.

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersGenerator = [
    featureInputLayer(numLatentInputs,'Name','in')
    projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj');
    transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
    batchNormalizationLayer('Name','bnorm1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
    batchNormalizationLayer('Name','bnorm2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
    batchNormalizationLayer('Name','bnorm3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
    tanhLayer('Name','tanh')];

lgraphGenerator = layerGraph(layersGenerator);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetGenerator = dlnetwork(lgraphGenerator);

Определение сети дискриминаторов

Определите следующую сеть, которая классифицирует реальные и созданные 64 на 64 изображения.

Создайте сеть, которая берет изображения 64 на 64 на 3 и возвращает скалярный счет прогноза, используя серию слоев скручивания с пакетной нормализацией и прохудившихся слоев ReLU. Добавление шума к входным изображениям с помощью отсева.

  • Для уровня отсева укажите вероятность отсева 0,5.

  • Для слоев свертки укажите фильтры 5 на 5 с увеличением числа фильтров для каждого слоя. Также укажите шаг 2 и заполнение выходных данных.

  • Для протекающих слоев ReLU укажите масштаб 0,2.

  • Для конечного слоя укажите сверточный слой с одним фильтром 4 на 4.

Для вывода вероятностей в диапазоне [0,1] используйте sigmoid функции в modelGradients функция, перечисленная в разделе «Функция градиентов модели» примера.

dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,'Normalization','none','Name','in')
    dropoutLayer(dropoutProb,'Name','dropout')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
    batchNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(4,1,'Name','conv5')];

lgraphDiscriminator = layerGraph(layersDiscriminator);

Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект.

dlnetDiscriminator = dlnetwork(lgraphDiscriminator);

Определение градиентов модели и функций потерь

Создание функции modelGradients, перечисленных в разделе «Функция градиентов модели» примера, который принимает в качестве входных данных сети генератора и дискриминатора, мини-пакет входных данных, массив случайных значений и коэффициент отражения, и возвращает градиенты потерь относительно обучаемых параметров в сетях и баллов двух сетей.

Укажите параметры обучения

Поезд с размером мини-партии 128 на 500 эпох. Для больших наборов данных, возможно, не потребуется тренироваться столько эпох.

numEpochs = 500;
miniBatchSize = 128;

Укажите параметры оптимизации Adam. Для обеих сетей укажите:

  • Коэффициент обучения 0,0002

  • Коэффициент градиентного спада 0,5

  • Квадрат градиентного коэффициента распада 0,999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Если дискриминатор учится слишком быстро различать реальные и сгенерированные изображения, то генератор может не тренироваться. Чтобы лучше сбалансировать обучение дискриминатора и генератора, добавьте шум к реальным данным, случайным образом переворачивая метки.

Укажите flipFactor значение от 0,3 до 30% реальных меток (15% от общего количества меток). Обратите внимание, что это не ухудшает работу генератора, поскольку все сгенерированные изображения по-прежнему помечены правильно.

flipFactor = 0.3;

Отображение созданных изображений проверки каждые 100 итераций.

validationFrequency = 100;

Модель поезда

Использовать minibatchqueue для обработки и управления мини-партиями изображений. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определено в конце этого примера) для масштабирования изображений в диапазоне [-1,1].

  • Удалите все частичные мини-партии с менее чем 128 наблюдениями.

  • Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single.

  • Обучение на GPU, если он доступен. Когда 'OutputEnvironment' вариант minibatchqueue является "auto", minibatchqueue преобразует каждый выход в gpuArray если графический процессор доступен. Для использования графического процессора требуются параллельные вычислительные Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ® с вычислительными возможностями 3.0 или выше.

augimds.MiniBatchSize = miniBatchSize;

executionEnvironment = "auto";

mbq = minibatchqueue(augimds,...
    'MiniBatchSize',miniBatchSize,...
    'PartialMiniBatch','discard',...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat','SSCB',...
    'OutputEnvironment',executionEnvironment);

Обучение модели с помощью пользовательского цикла обучения. Закольцовывать обучающие данные и обновлять сетевые параметры на каждой итерации. Чтобы контролировать ход обучения, отобразите пакет сгенерированных изображений, используя задержанный массив случайных значений для ввода в генератор, а также график оценок.

Инициализируйте параметры для Adam.

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Чтобы контролировать ход обучения, отобразите партию сгенерированных изображений, используя задержанную партию фиксированных случайных векторов, подаваемых в генератор, и постройте график сетевых оценок.

Создайте массив задержанных случайных значений.

numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,'single');

Преобразовать данные в dlarray объекты и указать метки размеров 'CB' (канал, партия).

dlZValidation = dlarray(ZValidation,'CB');

Для обучения GPU преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

Инициализируйте графики хода обучения. Создайте фигуру и измените ее размер, чтобы она имела удвоенную ширину.

f = figure;
f.Position(3) = 2*f.Position(3);

Создайте вложенный график для созданных изображений и оценок сети.

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

Инициализируйте анимированные линии для графика показателей.

lineScoreGenerator = animatedline(scoreAxes,'Color',[0 0.447 0.741]);
lineScoreDiscriminator = animatedline(scoreAxes, 'Color', [0.85 0.325 0.098]);
legend('Generator','Discriminator');
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

Тренируйте ГАН. Для каждой эпохи тасуйте хранилище данных и закольцовывайте мини-пакеты данных.

Для каждой мини-партии:

  • Оценка градиентов модели с помощью dlfeval и modelGradients функция.

  • Обновление параметров сети с помощью adamupdate функция.

  • Постройте графики двух сетей.

  • После каждого validationFrequency итерации, отображение пакета сгенерированных изображений для фиксированного ввода генератора с задержкой.

Обучение может занять некоторое время.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Reset and shuffle datastore.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        dlX = next(mbq);
        
        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the dimension labels 'CB' (channel, batch).
        % If training on a GPU, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,'single');
        dlZ = dlarray(Z,'CB');        
        
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every validationFrequency iterations, display batch of generated images using the
        % held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation);
            
            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(dlXGeneratedValidation));
            I = rescale(I);
            
            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end
        
        % Update the scores plot.
        subplot(1,2,2)
        addpoints(lineScoreGenerator,iteration,...
            double(gather(extractdata(scoreGenerator))));
        
        addpoints(lineScoreDiscriminator,iteration,...
            double(gather(extractdata(scoreDiscriminator))));
        
        % Update the title with training progress information.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))
        
        drawnow
    end
end

Здесь дискриминатор усвоил сильное характерное представление, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор усвоил такое же сильное представление признаков, которое позволяет ему генерировать изображения, подобные обучающим данным.

На обучающем графике показаны баллы сетей генератора и дискриминатора. Дополнительные сведения о том, как интерпретировать оценки сети, см. в разделах Мониторинг хода обучения GAN и Определение общих видов отказов.

Создать новые изображения

Для создания новых изображений используйте predict функция на генераторе с dlarray объект, содержащий партию случайных векторов. Чтобы отобразить изображения вместе, используйте imtile и масштабировать изображения с помощью rescale функция.

Создать dlarray объект, содержащий партию из 25 случайных векторов для ввода в генераторную сеть.

numObservations = 25;
ZNew = randn(numLatentInputs,numObservations,'single');
dlZNew = dlarray(ZNew,'CB');

Чтобы создать изображения с помощью графического процессора, также преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

Создание новых изображений с помощью predict функция с генератором и входными данными.

dlXGeneratedNew = predict(dlnetGenerator,dlZNew);

Отображение изображений.

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

Функция градиентов модели

Функция modelGradients принимает за вход генератор и дискриминатор dlnetwork объекты dlnetGenerator и dlnetDiscriminator, мини-пакет входных данных dlX, массив случайных значений dlZи процент реальных меток, которые нужно перевернуть flipFactorи возвращает градиенты потерь относительно обучаемых параметров в сетях, состояния генератора и оценок двух сетей. Поскольку выходной сигнал дискриминатора не находится в диапазоне [0,1], modelGradients функция применяет сигмоидальную функцию для преобразования ее в вероятности.

function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);

% Convert the discriminator outputs to probabilities.
probGenerated = sigmoid(dlYPredGenerated);
probReal = sigmoid(dlYPred);

% Calculate the score of the discriminator.
scoreDiscriminator = (mean(probReal) + mean(1-probGenerated)) / 2;

% Calculate the score of the generator.
scoreGenerator = mean(probGenerated);

% Randomly flip a fraction of the labels of the real images.
numObservations = size(probReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));

% Flip the labels.
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

Функция потерь GAN и баллы

Целью генератора является генерирование данных, которые дискриминатор классифицирует как "real". Чтобы максимизировать вероятность того, что изображения из генератора классифицируются дискриминатором как реальные, минимизируйте функцию отрицательного логарифмического правдоподобия.

Учитывая выход Y дискриминатора:

  • Yˆ=σ (Y) - вероятность того, что входное изображение принадлежит классу "real".

  • 1-Yˆ - вероятность того, что входное изображение принадлежит классу "generated".

Обратите внимание, что сигмоидальная операция λ происходит в modelGradients функция. Функция потерь для генератора задается

lossGenerator = -mean (log (YˆGenerated)),

где YˆGenerated содержит вероятности вывода дискриминатора для сгенерированных изображений.

Цель дискриминатора - не быть «обманутым» генератором. Чтобы максимизировать вероятность того, что дискриминатор успешно различает действительное и сгенерированное изображения, минимизируйте сумму соответствующих функций отрицательного логарифмического правдоподобия.

Функция потерь для дискриминатора задается

lossDiscriminator = -mean (log (YˆReal)) -mean (log (1-YˆGenerated)),

где YˆReal содержит вероятности вывода дискриминатора для реальных изображений.

Чтобы измерить по шкале от 0 до 1, насколько хорошо генератор и дискриминатор достигают своих соответствующих целей, можно использовать концепцию оценки.

Оценка генератора представляет собой среднее значение вероятностей, соответствующих выходному сигналу дискриминатора для сгенерированных изображений:

scoreGenerator=mean (YˆGenerated).

Оценка дискриминатора представляет собой среднее значение вероятностей, соответствующих выходному сигналу дискриминатора как для реального, так и для сгенерированного изображения:

scoreDiscriminator=12mean (YˆReal) +12mean (1-YˆGenerated).

Оценка обратно пропорциональна потере, но фактически содержит ту же самую информацию.

function [lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated)

% Calculate the loss for the discriminator network.
lossDiscriminator = -mean(log(probReal)) - mean(log(1-probGenerated));

% Calculate the loss for the generator network.
lossGenerator = -mean(log(probGenerated));

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция выполняет предварительную обработку данных с помощью следующих шагов:

  1. Извлеките данные изображения из входящего массива ячеек и объедините их в числовой массив.

  2. Масштабировать изображения в диапазоне [-1,1].

function X = preprocessMiniBatch(data)

% Concatenate mini-batch
X = cat(4,data{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,'InputMin',0,'InputMax',255);

end

Ссылки

  1. Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Рэдфорд, Алек, Люк Мец и Сомит Чинтала. «Обучение неподконтрольному представлению с глубокими сверточными генеративными состязательными сетями». Препринт, представлен 19 ноября 2015 года. http://arxiv.org/abs/1511.06434.

См. также

| | | | | | |

Связанные темы