В этом примере показано, как обучить порождающую соперничающую сеть (GAN) генерировать изображения.
Порождающая соперничающая сеть (GAN) является типом нейронной сети для глубокого обучения, которая может сгенерировать данные с подобными характеристиками как входные обучающие данные.
ГАНЬ состоит из двух сетей, которые обучаются вместе:
Генератор - Учитывая векторные или случайные значения, как введено, эта сеть генерирует данные с той же структурой как обучающие данные.
Различитель - Данный пакеты данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "действительные" или "сгенерированные".
Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать производительность обоих:
Обучите генератор генерировать данные, которые "дурачат" различитель.
Обучите различитель различать действительные и сгенерированные данные.
Чтобы максимизировать производительность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные".
Чтобы максимизировать производительность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Таким образом, цель различителя не состоит в том, чтобы "дурачить" генератор.
Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных.
Загрузите и извлеките набор данных Flowers [1].
url = 'http://download.tensorflow.org/example_images/flower_photos.tgz'; downloadFolder = tempdir; filename = fullfile(downloadFolder,'flower_dataset.tgz'); imageFolder = fullfile(downloadFolder,'flower_photos'); if ~exist(imageFolder,'dir') disp('Downloading Flower Dataset (218 MB)...') websave(filename,url); untar(filename,downloadFolder) end
Создайте datastore изображений, содержащий фотографии подсолнечников только.
datasetFolder = fullfile(imageFolder,"sunflowers"); imds = imageDatastore(datasetFolder, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames');
Увеличьте данные, чтобы включать случайное горизонтальное зеркальное отражение, масштабирование, и изменить размер изображений, чтобы иметь размер 64 64.
augmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandScale',[1 2]); augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);
Задайте сеть, которая генерирует изображения от 1 1 100 массивами случайных значений. Создайте сеть, которая увеличивает масштаб 1 1 100 массивами к 64 64 3 массивами с помощью серии транспонированных слоев свертки с пакетной нормализацией и слоев ReLU.
Для транспонированных слоев свертки задайте фильтры 4 на 4 с уменьшающимся количеством фильтров для каждого слоя.
Для второго транспонированного слоя свертки вперед, задайте шаг 2 и обрезать выход на один пиксель на каждом ребре.
Поскольку финал транспонировал слой свертки, задайте 3 фильтра, которые соответствуют каналам RGB сгенерированных изображений.
В конце сети включайте tanh слой.
filterSize = [4 4]; numFilters = 64; numLatentInputs = 100; layersGenerator = [ imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','in') transposedConv2dLayer(filterSize,8*numFilters,'Name','tconv1') batchNormalizationLayer('Name','bn1') reluLayer('Name','relu1') transposedConv2dLayer(filterSize,4*numFilters,'Stride',2,'Cropping',1,'Name','tconv2') batchNormalizationLayer('Name','bn2') reluLayer('Name','relu2') transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping',1,'Name','tconv3') batchNormalizationLayer('Name','bn3') reluLayer('Name','relu3') transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping',1,'Name','tconv4') batchNormalizationLayer('Name','bn4') reluLayer('Name','relu4') transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping',1,'Name','tconv5') tanhLayer('Name','tanh')]; lgraphGenerator = layerGraph(layersGenerator);
Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоя в dlnetwork
объект.
dlnetGenerator = dlnetwork(lgraphGenerator)
dlnetGenerator = dlnetwork with properties: Layers: [15×1 nnet.cnn.layer.Layer] Connections: [14×2 table] Learnables: [18×3 table] State: [8×3 table]
Задайте сеть, которая классифицирует действительный, и сгенерированный 64 64 отображает.
Создайте сеть, которая берет 64 64 3 изображениями и вводами и выводами скалярный счет прогноза с помощью серии слоев свертки с пакетной нормализацией и текучих слоев ReLU.
Для слоев свертки задайте фильтры 4 на 4 с растущим числом фильтров для каждого слоя.
Для второго слоя свертки вперед, задайте шаг 2 и заполнять выход на один пиксель на каждом ребре.
Для итогового слоя свертки задайте один фильтр 4 на 4 так, чтобы сетевые выходные параметры скалярный прогноз.
scale = 0.2; layersDiscriminator = [ imageInputLayer([64 64 3],'Normalization','none','Name','in') convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding',1,'Name','conv1') leakyReluLayer(scale,'Name','lrelu1') convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding',1,'Name','conv2') batchNormalizationLayer('Name','bn2') leakyReluLayer(scale,'Name','lrelu2') convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding',1,'Name','conv3') batchNormalizationLayer('Name','bn3') leakyReluLayer(scale,'Name','lrelu3') convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding',1,'Name','conv4') batchNormalizationLayer('Name','bn4') leakyReluLayer(scale,'Name','lrelu4') convolution2dLayer(filterSize,1,'Name','conv5')]; lgraphDiscriminator = layerGraph(layersDiscriminator);
Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоя в dlnetwork
объект.
dlnetDiscriminator = dlnetwork(lgraphDiscriminator)
dlnetDiscriminator = dlnetwork with properties: Layers: [13×1 nnet.cnn.layer.Layer] Connections: [12×2 table] Learnables: [16×3 table] State: [6×3 table]
Визуализируйте генератор и сети различителя в графике.
figure subplot(1,2,1) plot(lgraphGenerator) title("Generator") subplot(1,2,2) plot(lgraphDiscriminator) title("Discriminator")
Создайте функциональный modelGradients
, перечисленный в конце примера, который берет генератор и различитель dlnetwork
объекты dlnetGenerator
и dlnetDiscrimintor
, мини-пакет входных данных X
, и массив случайных значений Z
, и возвращает градиенты потери относительно learnable параметров в сетях и массиве сгенерированных изображений.
Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте отрицательную логарифмическую функцию правдоподобия. Функцией потерь для генератора дают
где обозначает сигмоидальную функцию, и обозначает выход различителя со сгенерированным вводом данных.
Цель различителя не состоит в том, чтобы "дурачить" генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте сумму соответствующих отрицательных логарифмических функций правдоподобия. Выход различителя соответствует вероятностям, вход принадлежит "действительному" классу. Для сгенерированных данных, чтобы использовать вероятности, соответствующие "сгенерированному" классу, используют значения . Функцией потерь для различителя дают
где обозначает выход различителя с действительным вводом данных.
Обучайтесь с мини-пакетным размером 128 в течение 1 000 эпох. Для больших наборов данных вы не можете должны быть обучаться для как много эпох. Установите размер чтения увеличенного datastore изображений к мини-пакетному размеру.
numEpochs = 1000; miniBatchSize = 128; augimds.MiniBatchSize = miniBatchSize;
Задайте опции для оптимизации ADAM:
Если различитель учится различать между действительными и сгенерированными изображениями слишком быстро, то генератор может не обучаться. Чтобы лучше сбалансировать приобретение знаний о различителе и генераторе, установите изучить уровень генератора к 0,0002 и изучить уровень различителя к 0,0001.
Для каждой сети инициализируйте запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с []
.
Для обеих сетей используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.
learnRateGenerator = 0.0002; learnRateDiscriminator = 0.0001; trailingAvgGenerator = []; trailingAvgSqGenerator = []; trailingAvgDiscriminator = []; trailingAvgSqDiscriminator = []; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999;
Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.
executionEnvironment = "auto";
Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого массива случайных значений, чтобы ввести в генератор.
В течение каждой эпохи переставьте datastore и цикл по мини-пакетам данных.
Для каждого мини-пакета:
Нормируйте данные так, чтобы пиксели приняли значения в области значений [-1, 1].
Преобразуйте данные в dlarray
объекты с базовым типом single
и укажите, что размерность маркирует 'SSCB'
(пространственный, пространственный, канал, пакет).
Сгенерируйте dlarray
объект, содержащий массив случайных значений для сети генератора.
Для обучения графического процессора преобразуйте данные в gpuArray
объекты.
Оцените градиенты модели с помощью dlfeval
и modelGradients
функция.
Обновите сетевые параметры с помощью adamupdate
функция.
После каждых 100 итераций отобразитесь, пакет сгенерированных изображений для фиксированного протянул вход генератора.
Чтобы контролировать процесс обучения, создайте протянутый пакет фиксированных 64 1 1 100 массивами случайных значений, чтобы ввести в генератор. Укажите, что размерность маркирует 'SSCB'
(пространственный, пространственный, канал, пакет). Для обучения графического процессора преобразуйте данные в gpuArray
.
ZValidation = randn(1,1,numLatentInputs,64,'single'); dlZValidation = dlarray(ZValidation,'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZValidation = gpuArray(dlZValidation); end
Обучите GAN. Это может занять время, чтобы запуститься.
figure iteration = 0; start = tic; % Loop over epochs. for i = 1:numEpochs % Reset and shuffle datastore. reset(augimds); augimds = shuffle(augimds); % Loop over mini-batches. while hasdata(augimds) iteration = iteration + 1; % Read mini-batch of data. data = read(augimds); % Ignore last partial mini-batch of epoch. if size(data,1) < miniBatchSize continue end % Concatenate mini-batch of data and generate latent inputs for the % generator network. X = cat(4,data{:,1}{:}); Z = randn(1,1,numLatentInputs,size(X,4),'single'); % Normalize the images X = (single(X)/255)*2 - 1; % Convert mini-batch of data to dlarray specify the dimension labels % 'SSCB' (spatial, spatial, channel, batch). dlX = dlarray(X, 'SSCB'); dlZ = dlarray(Z, 'SSCB'); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); dlZ = gpuArray(dlZ); end % Evaluate the model gradients and the generator state using % dlfeval and the modelGradients function listed at the end of the % example. [gradientsGenerator, gradientsDiscriminator, stateGenerator] = ... dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ); dlnetGenerator.State = stateGenerator; % Update the discriminator network parameters. [dlnetDiscriminator.Learnables,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ... adamupdate(dlnetDiscriminator.Learnables, gradientsDiscriminator, ... trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ... learnRateDiscriminator, gradientDecayFactor, squaredGradientDecayFactor); % Update the generator network parameters. [dlnetGenerator.Learnables,trailingAvgGenerator,trailingAvgSqGenerator] = ... adamupdate(dlnetGenerator.Learnables, gradientsGenerator, ... trailingAvgGenerator, trailingAvgSqGenerator, iteration, ... learnRateGenerator, gradientDecayFactor, squaredGradientDecayFactor); % Every 100 iterations, display batch of generated images using the % held-out generator input. if mod(iteration,100) == 0 || iteration == 1 % Generate images using the held-out generator input. dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation); % Rescale the images in the range [0 1] and display the images. I = imtile(extractdata(dlXGeneratedValidation)); I = rescale(I); image(I) % Update the title with training progress information. D = duration(0,0,toc(start),'Format','hh:mm:ss'); title(... "Epoch: " + i + ", " + ... "Iteration: " + iteration + ", " + ... "Elapsed: " + string(D)) drawnow end end end
Здесь, различитель изучил представление сильной черты, которое идентифицирует действительные изображения среди сгенерированных изображений и в свою очередь, генератор изучил представление столь же сильной черты, которое позволяет ему генерировать реалистически выглядящие данные.
Чтобы сгенерировать новые изображения, используйте predict
функция на генераторе с dlarray
объект, содержащий пакет 1 1 100 массивами случайных значений. Чтобы отобразить изображения вместе, используйте imtile
функционируйте и повторно масштабируйте изображения с помощью rescale
функция.
Создайте dlarray
объект, содержащий пакет 16 1 1 100 массивами случайных значений, чтобы ввести в сеть генератора.
ZNew = randn(1,1,numLatentInputs,16,'single'); dlZNew = dlarray(ZNew,'SSCB');
Для вывода графического процессора преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZNew = gpuArray(dlZNew); end
Сгенерируйте новые изображения с помощью predict
функция с генератором и входными данными.
dlXGeneratedNew = predict(dlnetGenerator,dlZNew);
Отобразите изображения.
I = imtile(extractdata(dlXGeneratedNew));
I = rescale(I);
image(I)
title("Generated Images")
Функциональный modelGradients
берет генератор и различитель dlnetwork
объекты dlnetGenerator
и dlnetDiscrimintor
, мини-пакет входных данных X
, и массив случайных значений Z
, и возвращает градиенты потери относительно learnable параметров в сетях и массиве сгенерированных изображений.
function [gradientsGenerator, gradientsDiscriminator, stateGenerator] = ... modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ) % Calculate the predictions for real data with the discriminator network. dlYPred = forward(dlnetDiscriminator, dlX); % Calculate the predictions for generated data with the discriminator network. [dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ); dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated); % Calculate the GAN loss [lossGenerator, lossDiscriminator] = ganLoss(dlYPred,dlYPredGenerated); % For each network, calculate the gradients with respect to the loss. gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true); gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables); end
Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте отрицательную логарифмическую функцию правдоподобия. Функцией потерь для генератора дают
где обозначает сигмоидальную функцию, и обозначает выход различителя со сгенерированным вводом данных.
Цель различителя не состоит в том, чтобы "дурачить" генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте сумму соответствующих отрицательных логарифмических функций правдоподобия. Выход различителя соответствует вероятностям, вход принадлежит "действительному" классу. Для сгенерированных данных, чтобы использовать вероятности, соответствующие "сгенерированному" классу, используют значения . Функцией потерь для различителя дают
где обозначает выход различителя с действительным вводом данных.
function [lossGenerator, lossDiscriminator] = ganLoss(dlYPred,dlYPredGenerated) % Calculate losses for the discriminator network. lossGenerated = -mean(log(1-sigmoid(dlYPredGenerated))); lossReal = -mean(log(sigmoid(dlYPred))); % Combine the losses for the discriminator network. lossDiscriminator = lossReal + lossGenerated; % Calculate the loss for the generator network. lossGenerator = -mean(log(sigmoid(dlYPredGenerated))); end
Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| forward
| predict