В этом примере показано, как обучить порождающую соперничающую сеть, чтобы сгенерировать изображения.
Порождающая соперничающая сеть (GAN) является типом нейронной сети для глубокого обучения, которая может сгенерировать данные с подобными характеристиками как вход действительные данные.
ГАНЬ состоит из двух сетей, которые обучаются вместе:
Генератор —, Учитывая вектор из случайных значений (скрытые входные параметры), как введено, эта сеть генерирует данные с той же структурой как обучающие данные.
Различитель — Данный пакеты данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "real"
или "generated"
.
Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обоих:
Обучите генератор генерировать данные, которые "дурачат" различитель.
Обучите различитель различать действительные и сгенерированные данные.
Чтобы оптимизировать эффективность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "real"
.
Чтобы оптимизировать эффективность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Таким образом, цель различителя не состоит в том, чтобы "дурачить" генератор.
Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных.
Загрузите и извлеките Цветочный набор данных [1].
url = "http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz"); imageFolder = fullfile(downloadFolder,"flower_photos"); if ~exist(imageFolder,"dir") disp("Downloading Flowers data set (218 MB)...") websave(filename,url); untar(filename,downloadFolder) end
Создайте datastore изображений, содержащий фотографии цветов.
datasetFolder = fullfile(imageFolder); imds = imageDatastore(datasetFolder,IncludeSubfolders=true);
Увеличьте данные, чтобы включать случайное горизонтальное зеркальное отражение и изменить размер изображений, чтобы иметь размер 64 64.
augmenter = imageDataAugmenter(RandXReflection=true); augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
Задайте следующую сетевую архитектуру, которая генерирует изображения от случайных векторов из размера 100.
Эта сеть:
Преобразует случайные векторы из размера от 100 до 7 7 128 массивами с помощью полносвязного слоя, сопровождаемого изменять операцией.
Увеличивает масштаб полученные массивы к 64 64 3 массивами с помощью серии транспонированных слоев свертки со слоев ReLU и нормализацией партии.
Задайте эту сетевую архитектуру как график слоев и задайте следующие сетевые свойства.
Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, шага 2 и обрезки выхода на каждом ребре.
Поскольку финал транспонировал слой свертки, задайте три фильтра 5 на 5, соответствующие трем каналам RGB сгенерированных изображений и выходному размеру предыдущего слоя.
В конце сети включайте tanh слой.
Чтобы спроектировать и изменить шумовой вход, используйте полносвязный слой, сопровождаемый изменять операцией specifed как функциональный слой с функцией, данной feature2image
функция, присоединенная к этому примеру как вспомогательный файл. Чтобы получить доступ к этой функции, откройте этот пример как live скрипт.
filterSize = 5; numFilters = 64; numLatentInputs = 100; projectionSize = [4 4 512]; layersGenerator = [ featureInputLayer(numLatentInputs) fullyConnectedLayer(prod(projectionSize)) functionLayer(@(X) feature2image(X,projectionSize),Formattable=true) transposedConv2dLayer(filterSize,4*numFilters) batchNormalizationLayer reluLayer transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same") batchNormalizationLayer reluLayer transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same") batchNormalizationLayer reluLayer transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same") tanhLayer]; lgraphGenerator = layerGraph(layersGenerator);
Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork
объект.
dlnetGenerator = dlnetwork(lgraphGenerator);
Задайте следующую сеть, которая классифицирует действительный, и сгенерированный 64 64 отображает.
Создайте сеть, которая берет 64 64 3 изображениями и возвращает скалярный счет предсказания с помощью серии слоев свертки с нормализацией партии. и текучих слоев ReLU. Добавьте шум во входные изображения с помощью уволенного.
Для слоя уволенного задайте вероятность уволенного 0,5.
Для слоев свертки задайте фильтры 5 на 5 с растущим числом фильтров для каждого слоя. Также задайте шаг 2 и дополнение выхода.
Для текучих слоев ReLU задайте шкалу 0,2.
Для последнего слоя задайте сверточный слой с одним фильтром 4 на 4.
Чтобы вывести вероятности в области значений [0,1], используйте sigmoid
функция в modelGradients
функция, перечисленная в разделе Model Gradients Function примера.
dropoutProb = 0.5; numFilters = 64; scale = 0.2; inputSize = [64 64 3]; filterSize = 5; layersDiscriminator = [ imageInputLayer(inputSize,Normalization="none") dropoutLayer(dropoutProb) convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same") leakyReluLayer(scale) convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same") batchNormalizationLayer leakyReluLayer(scale) convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same") batchNormalizationLayer leakyReluLayer(scale) convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same") batchNormalizationLayer leakyReluLayer(scale) convolution2dLayer(4,1)]; lgraphDiscriminator = layerGraph(layersDiscriminator);
Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork
объект.
dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
Создайте функциональный modelGradients
, перечисленный в разделе Model Gradients Function примера, который берет в качестве входа генератор и сети различителя, мини-пакет входных данных, массив случайных значений и зеркально отраженный фактор, и возвращает градиенты потери относительно настраиваемых параметров в сетях и множестве этих двух сетей.
Обучайтесь с мини-пакетным размером 128 в течение 500 эпох. Для больших наборов данных вы не можете должны быть обучаться для как много эпох.
numEpochs = 500; miniBatchSize = 128;
Задайте опции для оптимизации Адама. Для обеих сетей задайте:
Скорость обучения 0,0002
Фактор затухания градиента 0,5
Градиент в квадрате затухает фактор 0,999
learnRate = 0.0002; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999;
Если различитель учится различать между действительными и сгенерированными изображениями слишком быстро, то генератор может не обучаться. Чтобы лучше сбалансировать приобретение знаний о различителе и генераторе, добавьте шум в действительные данные путем случайного зеркального отражения меток.
Задайте flipFactor
значение 0,3, чтобы инвертировать 30% действительных меток (15% общих меток). Обратите внимание на то, что это не повреждает генератор, когда все сгенерированные изображения все еще помечены правильно.
flipFactor = 0.3;
Отобразитесь сгенерированная валидация отображает каждые 100 итераций.
validationFrequency = 100;
Используйте minibatchqueue
обработать и управлять мини-пакетами изображений. Для каждого мини-пакета:
Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch
(заданный в конце этого примера), чтобы перемасштабировать изображения в области значений [-1,1]
.
Отбросьте любые частичные мини-пакеты меньше чем с 128 наблюдениями.
Формат данные изображения с размерностью маркирует "SSCB"
(пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
.
Обучайтесь на графическом процессоре, если вы доступны. Когда OutputEnvironment
опция minibatchqueue
"auto"
, minibatchqueue
преобразует каждый выход в gpuArray
если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.
augimds.MiniBatchSize = miniBatchSize; executionEnvironment = "auto"; mbq = minibatchqueue(augimds,... MiniBatchSize=miniBatchSize,... PartialMiniBatch="discard",... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat="SSCB",... OutputEnvironment=executionEnvironment);
Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого массива случайных значений, чтобы ввести к генератору, а также графику баллов.
Инициализируйте параметры для Адама.
trailingAvgGenerator = []; trailingAvgSqGenerator = []; trailingAvgDiscriminator = []; trailingAvgSqDiscriminator = [];
Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого пакета фиксированных случайных векторов, поданных в генератор, и постройте сетевые баллы.
Создайте массив протянутых случайных значений.
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
Преобразуйте данные в dlarray
объекты и указывают, что размерность маркирует "CB"
(образуйте канал, пакет).
dlZValidation = dlarray(ZValidation,"CB");
Для обучения графического процессора преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZValidation = gpuArray(dlZValidation); end
Инициализируйте графики процесса обучения. Создайте фигуру и измените размер его, чтобы иметь дважды ширину.
f = figure; f.Position(3) = 2*f.Position(3);
Создайте подграфик для сгенерированных изображений и сетевых баллов.
imageAxes = subplot(1,2,1); scoreAxes = subplot(1,2,2);
Инициализируйте анимированные линии для графика баллов.
lineScoreGenerator = animatedline(scoreAxes,Color=[0 0.447 0.741]); lineScoreDiscriminator = animatedline(scoreAxes,Color=[0.85 0.325 0.098]); legend("Generator","Discriminator"); ylim([0 1]) xlabel("Iteration") ylabel("Score") grid on
Обучите GAN. В течение каждой эпохи переставьте datastore и цикл по мини-пакетам данных.
Для каждого мини-пакета:
Оцените градиенты модели с помощью dlfeval
и modelGradients
функция.
Обновите сетевые параметры с помощью adamupdate
функция.
Постройте множество этих двух сетей.
После каждого validationFrequency
итерации, отобразитесь, пакет сгенерированных изображений для фиксированного протянул вход генератора.
Обучение может занять время, чтобы запуститься.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Reset and shuffle datastore. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. dlX = next(mbq); % Generate latent inputs for the generator network. Convert to % dlarray and specify the dimension labels "CB" (channel, batch). % If training on a GPU, then convert latent inputs to gpuArray. Z = randn(numLatentInputs,miniBatchSize,"single"); dlZ = dlarray(Z,"CB"); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZ = gpuArray(dlZ); end % Evaluate the model gradients and the generator state using % dlfeval and the modelGradients function listed at the end of the % example. [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ... dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor); dlnetGenerator.State = stateGenerator; % Update the discriminator network parameters. [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ... adamupdate(dlnetDiscriminator, gradientsDiscriminator, ... trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ... learnRate, gradientDecayFactor, squaredGradientDecayFactor); % Update the generator network parameters. [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ... adamupdate(dlnetGenerator, gradientsGenerator, ... trailingAvgGenerator, trailingAvgSqGenerator, iteration, ... learnRate, gradientDecayFactor, squaredGradientDecayFactor); % Every validationFrequency iterations, display batch of generated images using the % held-out generator input. if mod(iteration,validationFrequency) == 0 || iteration == 1 % Generate images using the held-out generator input. dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation); % Tile and rescale the images in the range [0 1]. I = imtile(extractdata(dlXGeneratedValidation)); I = rescale(I); % Display the images. subplot(1,2,1); image(imageAxes,I) xticklabels([]); yticklabels([]); title("Generated Images"); end % Update the scores plot. subplot(1,2,2) addpoints(lineScoreGenerator,iteration,... double(gather(extractdata(scoreGenerator)))); addpoints(lineScoreDiscriminator,iteration,... double(gather(extractdata(scoreDiscriminator)))); % Update the title with training progress information. D = duration(0,0,toc(start),Format="hh:mm:ss"); title(... "Epoch: " + epoch + ", " + ... "Iteration: " + iteration + ", " + ... "Elapsed: " + string(D)) drawnow end end
Здесь, различитель изучил представление сильной черты, которое идентифицирует действительные изображения среди сгенерированных изображений. В свою очередь генератор изучил представление столь же сильной черты, которое позволяет ему генерировать изображения, похожие на обучающие данные.
Учебный график показывает множество сетей различителя и генератора. Чтобы узнать больше, как интерпретировать сетевые баллы, смотрите Монитор Процесс обучения GAN и Идентифицируйте Общие Типы отказа.
Чтобы сгенерировать новые изображения, используйте predict
функция на генераторе с dlarray
объект, содержащий пакет случайных векторов. Чтобы отобразить изображения вместе, используйте imtile
функционируйте и перемасштабируйте изображения с помощью rescale
функция.
Создайте dlarray
объект, содержащий пакет 25 случайных векторов, чтобы ввести к сети генератора.
numObservations = 25; ZNew = randn(numLatentInputs,numObservations,"single"); dlZNew = dlarray(ZNew,"CB");
Чтобы сгенерировать изображения с помощью графического процессора, также преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZNew = gpuArray(dlZNew); end
Сгенерируйте новые изображения с помощью predict
функция с генератором и входными данными.
dlXGeneratedNew = predict(dlnetGenerator,dlZNew);
Отобразите изображения.
I = imtile(extractdata(dlXGeneratedNew)); I = rescale(I); figure image(I) axis off title("Generated Images")
Функциональный modelGradients
берет в качестве входа генератор и различитель dlnetwork
объекты dlnetGenerator
и dlnetDiscriminator
, мини-пакет входных данных dlX
, массив случайных значений dlZ
, и процент действительных меток, чтобы инвертировать flipFactor
, и возвращает градиенты потери относительно настраиваемых параметров в сетях, состоянии генератора и множестве этих двух сетей. Поскольку различитель выход не находится в области значений [0,1], modelGradients
функция применяет сигмоидальную функцию, чтобы преобразовать его в вероятности.
function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ... modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor) % Calculate the predictions for real data with the discriminator network. dlYPred = forward(dlnetDiscriminator, dlX); % Calculate the predictions for generated data with the discriminator network. [dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ); dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated); % Convert the discriminator outputs to probabilities. probGenerated = sigmoid(dlYPredGenerated); probReal = sigmoid(dlYPred); % Calculate the score of the discriminator. scoreDiscriminator = (mean(probReal) + mean(1-probGenerated)) / 2; % Calculate the score of the generator. scoreGenerator = mean(probGenerated); % Randomly flip a fraction of the labels of the real images. numObservations = size(probReal,4); idx = randperm(numObservations,floor(flipFactor * numObservations)); % Flip the labels. probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx); % Calculate the GAN loss. [lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated); % For each network, calculate the gradients with respect to the loss. gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,RetainData=true); gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables); end
Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "real"
. Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте отрицательную логарифмическую функцию правдоподобия.
Учитывая выход из различителя:
вероятность, что входное изображение принадлежит классу "real"
.
вероятность, что входное изображение принадлежит классу "generated"
.
Обратите внимание на то, что сигмоидальная операция происходит в modelGradients
функция. Функцией потерь для генератора дают
где содержит различитель выходные вероятности для сгенерированных изображений.
Цель различителя не состоит в том, чтобы "дурачить" генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте сумму соответствующих отрицательных логарифмических функций правдоподобия.
Функцией потерь для различителя дают
где содержит различитель выходные вероятности для действительных изображений.
Чтобы измериться по шкале от 0 до 1, как хорошо генератор и различитель достигают их соответствующих целей, можно использовать концепцию счета.
Счет генератора является средним значением вероятностей, соответствующих различителю выход для сгенерированных изображений:
Счет различителя является средним значением вероятностей, соответствующих различителю выход для обоих действительные и сгенерированные изображения:
Счет обратно пропорционален потере, но эффективно содержит ту же информацию.
function [lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated) % Calculate the loss for the discriminator network. lossDiscriminator = -mean(log(probReal)) - mean(log(1-probGenerated)); % Calculate the loss for the generator network. lossGenerator = -mean(log(probGenerated)); end
preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из массива входящей ячейки и конкатенируйте в числовой массив.
Перемасштабируйте изображения, чтобы быть в области значений [-1,1]
.
function X = preprocessMiniBatch(data) % Concatenate mini-batch X = cat(4,data{:}); % Rescale the images in the range [-1 1]. X = rescale(X,-1,1,InputMin=0,InputMax=255); end
Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz
Рэдфорд, Алек, Люк Мец и Сумит Чинтэла. “Безнадзорное Представление, Учащееся с Глубокими Сверточными Порождающими Соперничающими Сетями”. Предварительно распечатайте, представленный 19 ноября 2015. http://arxiv.org/abs/1511.06434.
dlnetwork
| forward
| predict
| dlarray
| dlgradient
| dlfeval
| adamupdate
| minibatchqueue