Обучите Порождающую соперничающую сеть (GAN)

В этом примере показано, как обучить порождающую соперничающую сеть (GAN) генерировать изображения.

Порождающая соперничающая сеть (GAN) является типом нейронной сети для глубокого обучения, которая может сгенерировать данные с подобными характеристиками как входные обучающие данные.

ГАНЬ состоит из двух сетей, которые обучаются вместе:

  1. Генератор - Учитывая векторные или случайные значения, как введено, эта сеть генерирует данные с той же структурой как обучающие данные.

  2. Различитель - Данный пакеты данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "действительные" или "сгенерированные".

Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать производительность обоих:

  • Обучите генератор генерировать данные, которые "дурачат" различитель.

  • Обучите различитель различать действительные и сгенерированные данные.

Чтобы максимизировать производительность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные".

Чтобы максимизировать производительность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Таким образом, цель различителя не состоит в том, чтобы "дурачить" генератор.

Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных.

Загрузите обучающие данные

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flower Dataset (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте datastore изображений, содержащий фотографии подсолнечников только.

datasetFolder = fullfile(imageFolder,"sunflowers");

imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

Увеличьте данные, чтобы включать случайное горизонтальное зеркальное отражение, масштабирование, и изменить размер изображений, чтобы иметь размер 64 64.

augmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandScale',[1 2]);
augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);

Задайте сеть генератора

Задайте сеть, которая генерирует изображения от 1 1 100 массивами случайных значений. Создайте сеть, которая увеличивает масштаб 1 1 100 массивами к 64 64 3 массивами с помощью серии транспонированных слоев свертки с пакетной нормализацией и слоев ReLU.

  • Для транспонированных слоев свертки задайте фильтры 4 на 4 с уменьшающимся количеством фильтров для каждого слоя.

  • Для второго транспонированного слоя свертки вперед, задайте шаг 2 и обрезать выход на один пиксель на каждом ребре.

  • Поскольку финал транспонировал слой свертки, задайте 3 фильтра, которые соответствуют каналам RGB сгенерированных изображений.

  • В конце сети включайте tanh слой.

filterSize = [4 4];
numFilters = 64;
numLatentInputs = 100;

layersGenerator = [
    imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','in')
    transposedConv2dLayer(filterSize,8*numFilters,'Name','tconv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(filterSize,4*numFilters,'Stride',2,'Cropping',1,'Name','tconv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping',1,'Name','tconv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping',1,'Name','tconv4')
    batchNormalizationLayer('Name','bn4')
    reluLayer('Name','relu4')
    transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping',1,'Name','tconv5')
    tanhLayer('Name','tanh')];

lgraphGenerator = layerGraph(layersGenerator);

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоя в dlnetwork объект.

dlnetGenerator = dlnetwork(lgraphGenerator)
dlnetGenerator = 
  dlnetwork with properties:

         Layers: [15×1 nnet.cnn.layer.Layer]
    Connections: [14×2 table]
     Learnables: [18×3 table]
          State: [8×3 table]

Задайте сеть различителя

Задайте сеть, которая классифицирует действительный, и сгенерированный 64 64 отображает.

Создайте сеть, которая берет 64 64 3 изображениями и вводами и выводами скалярный счет прогноза с помощью серии слоев свертки с пакетной нормализацией и текучих слоев ReLU.

  • Для слоев свертки задайте фильтры 4 на 4 с растущим числом фильтров для каждого слоя.

  • Для второго слоя свертки вперед, задайте шаг 2 и заполнять выход на один пиксель на каждом ребре.

  • Для итогового слоя свертки задайте один фильтр 4 на 4 так, чтобы сетевые выходные параметры скалярный прогноз.

scale = 0.2;

layersDiscriminator = [
    imageInputLayer([64 64 3],'Normalization','none','Name','in')
    convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding',1,'Name','conv1')
    leakyReluLayer(scale,'Name','lrelu1')
    convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding',1,'Name','conv2')
    batchNormalizationLayer('Name','bn2')
    leakyReluLayer(scale,'Name','lrelu2')
    convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding',1,'Name','conv3')
    batchNormalizationLayer('Name','bn3')
    leakyReluLayer(scale,'Name','lrelu3')
    convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding',1,'Name','conv4')
    batchNormalizationLayer('Name','bn4')
    leakyReluLayer(scale,'Name','lrelu4')
    convolution2dLayer(filterSize,1,'Name','conv5')];

lgraphDiscriminator = layerGraph(layersDiscriminator);

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоя в dlnetwork объект.

dlnetDiscriminator = dlnetwork(lgraphDiscriminator)
dlnetDiscriminator = 
  dlnetwork with properties:

         Layers: [13×1 nnet.cnn.layer.Layer]
    Connections: [12×2 table]
     Learnables: [16×3 table]
          State: [6×3 table]

Визуализируйте генератор и сети различителя в графике.

figure
subplot(1,2,1)
plot(lgraphGenerator)
title("Generator")

subplot(1,2,2)
plot(lgraphDiscriminator)
title("Discriminator")

Градиенты модели Define и функции потерь

Создайте функциональный modelGradients, перечисленный в конце примера, который берет генератор и различитель dlnetwork объекты dlnetGenerator и dlnetDiscrimintor, мини-пакет входных данных X, и массив случайных значений Z, и возвращает градиенты потери относительно learnable параметров в сетях и массиве сгенерированных изображений.

Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте отрицательную логарифмическую функцию правдоподобия. Функцией потерь для генератора дают

lossGenerator=-среднее значение(журнал(σ(YˆСгенерированный))),

где σ обозначает сигмоидальную функцию, и YˆGenerated обозначает выход различителя со сгенерированным вводом данных.

Цель различителя не состоит в том, чтобы "дурачить" генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте сумму соответствующих отрицательных логарифмических функций правдоподобия. Выход различителя соответствует вероятностям, вход принадлежит "действительному" классу. Для сгенерированных данных, чтобы использовать вероятности, соответствующие "сгенерированному" классу, используют значения 1-σ(YˆGenerated). Функцией потерь для различителя дают

lossDiscriminator=-среднее значение(журнал(σ(YˆДействительный)))-среднее значение(журнал(1-σ(YˆСгенерированный))),

где YˆReal обозначает выход различителя с действительным вводом данных.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 128 в течение 1 000 эпох. Для больших наборов данных вы не можете должны быть обучаться для как много эпох. Установите размер чтения увеличенного datastore изображений к мини-пакетному размеру.

numEpochs = 1000;
miniBatchSize = 128;
augimds.MiniBatchSize = miniBatchSize;

Задайте опции для оптимизации ADAM:

  • Если различитель учится различать между действительными и сгенерированными изображениями слишком быстро, то генератор может не обучаться. Чтобы лучше сбалансировать приобретение знаний о различителе и генераторе, установите изучить уровень генератора к 0,0002 и изучить уровень различителя к 0,0001.

  • Для каждой сети инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с [].

  • Для обеих сетей используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.

learnRateGenerator = 0.0002;
learnRateDiscriminator = 0.0001;

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.

executionEnvironment = "auto";

Обучите модель

Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого массива случайных значений, чтобы ввести в генератор.

В течение каждой эпохи переставьте datastore и цикл по мини-пакетам данных.

Для каждого мини-пакета:

  • Нормируйте данные так, чтобы пиксели приняли значения в области значений [-1, 1].

  • Преобразуйте данные в dlarray объекты с базовым типом single и укажите, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Сгенерируйте dlarray объект, содержащий массив случайных значений для сети генератора.

  • Для обучения графического процессора преобразуйте данные в gpuArray объекты.

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • После каждых 100 итераций отобразитесь, пакет сгенерированных изображений для фиксированного протянул вход генератора.

Чтобы контролировать процесс обучения, создайте протянутый пакет фиксированных 64 1 1 100 массивами случайных значений, чтобы ввести в генератор. Укажите, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет). Для обучения графического процессора преобразуйте данные в gpuArray.

ZValidation = randn(1,1,numLatentInputs,64,'single');
dlZValidation = dlarray(ZValidation,'SSCB');

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZValidation = gpuArray(dlZValidation);
end

Обучите GAN. Это может занять время, чтобы запуститься.

figure
iteration = 0;
start = tic;

% Loop over epochs.
for i = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Concatenate mini-batch of data and generate latent inputs for the
        % generator network.
        X = cat(4,data{:,1}{:});
        Z = randn(1,1,numLatentInputs,size(X,4),'single');
        
        % Normalize the images
        X = (single(X)/255)*2 - 1;
        
        % Convert mini-batch of data to dlarray specify the dimension labels
        % 'SSCB' (spatial, spatial, channel, batch).
        dlX = dlarray(X, 'SSCB');
        dlZ = dlarray(Z, 'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
            dlZ = gpuArray(dlZ);
        end
        
        % Evaluate the model gradients and the generator state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradientsGenerator, gradientsDiscriminator, stateGenerator] = ...
            dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ);
        dlnetGenerator.State = stateGenerator;
        
        % Update the discriminator network parameters.
        [dlnetDiscriminator.Learnables,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(dlnetDiscriminator.Learnables, gradientsDiscriminator, ...
            trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ...
            learnRateDiscriminator, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Update the generator network parameters.
        [dlnetGenerator.Learnables,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(dlnetGenerator.Learnables, gradientsGenerator, ...
            trailingAvgGenerator, trailingAvgSqGenerator, iteration, ...
            learnRateGenerator, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Every 100 iterations, display batch of generated images using the
        % held-out generator input.
        if mod(iteration,100) == 0 || iteration == 1
            
            % Generate images using the held-out generator input.
            dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation);
            
            % Rescale the images in the range [0 1] and display the images.
            I = imtile(extractdata(dlXGeneratedValidation));
            I = rescale(I);
            image(I)
            
            % Update the title with training progress information.
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            title(...
                "Epoch: " + i + ", " + ...
                "Iteration: " + iteration + ", " + ...
                "Elapsed: " + string(D))
            
            drawnow
        end
    end
end

Здесь, различитель изучил представление сильной черты, которое идентифицирует действительные изображения среди сгенерированных изображений и в свою очередь, генератор изучил представление столь же сильной черты, которое позволяет ему генерировать реалистически выглядящие данные.

Сгенерируйте новые изображения

Чтобы сгенерировать новые изображения, используйте predict функция на генераторе с dlarray объект, содержащий пакет 1 1 100 массивами случайных значений. Чтобы отобразить изображения вместе, используйте imtile функционируйте и повторно масштабируйте изображения с помощью rescale функция.

Создайте dlarray объект, содержащий пакет 16 1 1 100 массивами случайных значений, чтобы ввести в сеть генератора.

ZNew = randn(1,1,numLatentInputs,16,'single');
dlZNew = dlarray(ZNew,'SSCB');

Для вывода графического процессора преобразуйте данные в gpuArray объекты.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end

Сгенерируйте новые изображения с помощью predict функция с генератором и входными данными.

dlXGeneratedNew = predict(dlnetGenerator,dlZNew);

Отобразите изображения.

I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
image(I)
title("Generated Images")

Функция градиентов модели

Функциональный modelGradients берет генератор и различитель dlnetwork объекты dlnetGenerator и dlnetDiscrimintor, мини-пакет входных данных X, и массив случайных значений Z, и возвращает градиенты потери относительно learnable параметров в сетях и массиве сгенерированных изображений.

function [gradientsGenerator, gradientsDiscriminator, stateGenerator] = ...
    modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ)

% Calculate the predictions for real data with the discriminator network.
dlYPred = forward(dlnetDiscriminator, dlX);

% Calculate the predictions for generated data with the discriminator network.
[dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ);
dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);

% Calculate the GAN loss
[lossGenerator, lossDiscriminator] = ganLoss(dlYPred,dlYPredGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);
gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

end

Функция потерь ГАНЯ

Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте отрицательную логарифмическую функцию правдоподобия. Функцией потерь для генератора дают

lossGenerator=-среднее значение(журнал(σ(YˆСгенерированный))),

где σ обозначает сигмоидальную функцию, и YˆGenerated обозначает выход различителя со сгенерированным вводом данных.

Цель различителя не состоит в том, чтобы "дурачить" генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте сумму соответствующих отрицательных логарифмических функций правдоподобия. Выход различителя соответствует вероятностям, вход принадлежит "действительному" классу. Для сгенерированных данных, чтобы использовать вероятности, соответствующие "сгенерированному" классу, используют значения 1-σ(YˆGenerated). Функцией потерь для различителя дают

lossDiscriminator=-среднее значение(журнал(σ(YˆДействительный)))-среднее значение(журнал(1-σ(YˆСгенерированный))),

где YˆReal обозначает выход различителя с действительным вводом данных.

function [lossGenerator, lossDiscriminator] = ganLoss(dlYPred,dlYPredGenerated)

% Calculate losses for the discriminator network.
lossGenerated = -mean(log(1-sigmoid(dlYPredGenerated)));
lossReal = -mean(log(sigmoid(dlYPred)));

% Combine the losses for the discriminator network.
lossDiscriminator = lossReal + lossGenerated;

% Calculate the loss for the generator network.
lossGenerator = -mean(log(sigmoid(dlYPredGenerated)));

end

Ссылки

Смотрите также

| | | | | |

Похожие темы