Этот пример показывает, как обучить условную генеративную состязательную сеть для генерации изображений.
Генеративная состязательная сеть (GAN) является типом нейронной сети для глубокого обучения, которая может генерировать данные с такими же характеристиками, как входные обучающие данные.
GAN состоит из двух сетей, которые обучаются вместе:
Генератор - учитывая вектор случайных значений как вход, эта сеть генерирует данные с той же структурой, что и обучающие данные.
Дискриминатор - Учитывая пакеты данных, содержащие наблюдения как от обучающих данных, так и от сгенерированных данных от генератора, эта сеть пытается классифицировать наблюдения как "real"
или "generated"
.
Условная генеративная состязательная сеть (CGAN) является типом GAN, который также использует преимущества меток в процессе обучения.
Генератор - учитывая метку и случайный массив как вход, эта сеть генерирует данные с той же структурой, что и наблюдения обучающих данных, соответствующие той же метке.
Дискриминатор - Учитывая пакеты маркированных данных, содержащих наблюдения как из обучающих данных, так и из сгенерированных данных генератора, эта сеть пытается классифицировать наблюдения как "real"
или "generated"
.
Чтобы обучить условную GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обеих:
Обучите генератор генерировать данные, которые «дурачат» дискриминатор.
Обучите дискриминатор различать реальные и сгенерированные данные.
Чтобы максимизировать эффективность генератора, максимизируйте потерю дискриминатора, когда даны сгенерированные маркированные данные. То есть цель генератора состоит в том, чтобы сгенерировать маркированные данные, которые дискриминатор классифицирует как "real"
.
Чтобы максимизировать эффективность дискриминатора, минимизируйте потерю дискриминатора, когда заданы пакеты как реальных, так и сгенерированных маркированных данных. То есть цель дискриминатора состоит в том, чтобы генератор не «обманул».
В идеале эти стратегии приводят к генератору, который генерирует убедительно реалистичные данные, которые соответствуют входным меткам, и дискриминатору, который выучил сильные представления функций, которые характерны для обучающих данных для каждой метки.
Загрузите и извлеките набор данных Flowers [1].
url = 'http://download.tensorflow.org/example_images/flower_photos.tgz'; downloadFolder = tempdir; filename = fullfile(downloadFolder,'flower_dataset.tgz'); imageFolder = fullfile(downloadFolder,'flower_photos'); if ~exist(imageFolder,'dir') disp('Downloading Flowers data set (218 MB)...') websave(filename,url); untar(filename,downloadFolder) end
Создайте изображение datastore, содержащее фотографии цветов.
datasetFolder = fullfile(imageFolder); imds = imageDatastore(datasetFolder, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames');
Просмотрите количество классов.
classes = categories(imds.Labels); numClasses = numel(classes)
numClasses = 5
Увеличьте данные, чтобы включить случайное горизонтальное отражение и измените размер изображений, чтобы иметь размер 64 на 64.
augmenter = imageDataAugmenter('RandXReflection',true); augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);
Задайте следующую сеть с двумя входами, которая генерирует изображения, заданные случайными векторами размера 100 и соответствующими метками.
Эта сеть:
Преобразовывает случайные векторы размера 100 к 4 на 4 1 024 массивами.
Преобразует категориальные метки во внедряющие векторы и изменяет их форму в массив 4 на 4.
Конкатенирует получившиеся изображения из двух входов по размерности канала. Выход составляет 4 на 4 1 025 массивами.
Upscales получающиеся массивы к массивам 64 на 64 на 3, используя серию перемещенных слоев скручивания с нормализацией партии. и слоев ReLU.
Определите эту сетевую архитектуру как график слоев и задайте следующие свойства сети.
Для категориальных входов используйте размерность вложения 50.
Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, полосой 2 и "same"
обрезка выхода.
Для последнего транспонированного слоя свертки задайте три фильтра 5 на 5, соответствующих трем каналам RGB сгенерированных изображений.
В конце сети включают слой танха.
Чтобы проецировать и изменить форму входного сигнала шума, используйте пользовательский слой projectAndReshapeLayer
, присоединенный к этому примеру как вспомогательный файл. The projectAndReshapeLayer
объект увеличивает масштаб входа с помощью операции с полным подключением и изменяет форму выхода на заданный размер.
Чтобы ввести метки в сеть, используйте featureInputLayer
и задайте одну функцию. Для встраивания и изменения формы входов меток используйте пользовательский слой embedAndReshapeLayer
, присоединенный к этому примеру как вспомогательный файл. The embedAndReshapeLayer
преобразует категориальную метку в одноканальное изображение заданного размера с помощью встраивания и полносвязной операции.
numLatentInputs = 100; embeddingDimension = 50; numFilters = 64; filterSize = 5; projectionSize = [4 4 1024]; layersGenerator = [ featureInputLayer(numLatentInputs,'Name','noise') projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj'); concatenationLayer(3,2,'Name','cat'); transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1') batchNormalizationLayer('Name','bn1') reluLayer('Name','relu1') transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2') batchNormalizationLayer('Name','bn2') reluLayer('Name','relu2') transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3') batchNormalizationLayer('Name','bn3') reluLayer('Name','relu3') transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4') tanhLayer('Name','tanh')]; lgraphGenerator = layerGraph(layersGenerator); layers = [ featureInputLayer(1,'Name','labels') embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'Name','emb')]; lgraphGenerator = addLayers(lgraphGenerator,layers); lgraphGenerator = connectLayers(lgraphGenerator,'emb','cat/in2');
Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork
объект.
dlnetGenerator = dlnetwork(lgraphGenerator)
dlnetGenerator = dlnetwork with properties: Layers: [16×1 nnet.cnn.layer.Layer] Connections: [15×2 table] Learnables: [19×3 table] State: [6×3 table] InputNames: {'noise' 'labels'} OutputNames: {'tanh'}
Задайте следующую сеть с двумя входами, которая классифицирует реальные и сгенерированные изображения 64 на 64 с учетом набора изображений и соответствующих меток.
Создайте сеть, которая берет в качестве входа изображения 64 на 64 на 1 и соответствующие метки и производит скалярный счет предсказания, используя серию слоев скручивания с нормализацией партии. и прохудившихся слоев ReLU. Добавьте шум к входу изображениям с помощью отсева.
Для выпадающего слоя задайте вероятность выпадения 0,75.
Для слоев свертки задайте фильтры 5 на 5 с увеличением количества фильтров для каждого слоя. Также задайте шаг 2 и заполнение выхода на каждом ребре.
Для утечек слоев ReLU задайте шкалу 0,2.
Для последнего слоя задайте слой свертки с одним фильтром 4 на 4.
dropoutProb = 0.75; numFilters = 64; scale = 0.2; inputSize = [64 64 3]; filterSize = 5; layersDiscriminator = [ imageInputLayer(inputSize,'Normalization','none','Name','images') dropoutLayer(dropoutProb,'Name','dropout') concatenationLayer(3,2,'Name','cat') convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1') leakyReluLayer(scale,'Name','lrelu1') convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2') batchNormalizationLayer('Name','bn2') leakyReluLayer(scale,'Name','lrelu2') convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3') batchNormalizationLayer('Name','bn3') leakyReluLayer(scale,'Name','lrelu3') convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4') batchNormalizationLayer('Name','bn4') leakyReluLayer(scale,'Name','lrelu4') convolution2dLayer(4,1,'Name','conv5')]; lgraphDiscriminator = layerGraph(layersDiscriminator); layers = [ featureInputLayer(1,'Name','labels') embedAndReshapeLayer(inputSize(1:2),embeddingDimension,numClasses,'Name','emb')]; lgraphDiscriminator = addLayers(lgraphDiscriminator,layers); lgraphDiscriminator = connectLayers(lgraphDiscriminator,'emb','cat/in2');
Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork
объект.
dlnetDiscriminator = dlnetwork(lgraphDiscriminator)
dlnetDiscriminator = dlnetwork with properties: Layers: [17×1 nnet.cnn.layer.Layer] Connections: [16×2 table] Learnables: [19×3 table] State: [6×3 table] InputNames: {'images' 'labels'} OutputNames: {'conv5'}
Создайте функцию modelGradients
, перечисленный в разделе Model Gradients Function примера, который принимает за вход сети генератора и дискриминатора, мини-пакет входных данных и массив случайных значений и возвращает градиенты потерь относительно настраиваемых параметров в сетях и массива сгенерированных изображений.
Train с мини-партией размером 128 на 500 эпох.
numEpochs = 500; miniBatchSize = 128;
Задайте опции для оптимизации Adam. Для обеих сетей используйте:
A скорости обучения 0,0002
Коэффициент градиентного распада 0,5
Квадратный коэффициент распада градиента 0,999
learnRate = 0.0002; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999;
Обновляйте графики процесса обучения каждые 100 итераций.
validationFrequency = 100;
Если дискриминатор учится слишком быстро различать реальные и сгенерированные изображения, то генератор может не обучаться. Чтобы лучше сбалансировать обучение дискриминатора и генератора, случайным образом разверните метки части реальных изображений. Задайте коэффициент отражения 0,5.
flipFactor = 0.5;
Обучите модель с помощью пользовательского цикла обучения. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью удерживаемого массива случайных значений для ввода в генератор и сетевые счета.
Использование minibatchqueue
обрабатывать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch
(определено в конце этого примера), чтобы переформулировать изображения в области значений [-1,1]
.
Отбрасывайте любые частичные мини-пакеты с менее чем 128 наблюдениями.
Форматируйте данные изображения с помощью меток размерностей 'SSCB'
(пространственный, пространственный, канальный, пакетный).
Форматируйте данные метки с помощью меток размерности 'BC'
(пакет, канал).
Обучите на графическом процессоре, если он доступен. Когда 'OutputEnvironment'
опция minibatchqueue
является 'auto'
, minibatchqueue
преобразует каждый выход в gpuArray
при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
The minibatchqueue
объект по умолчанию преобразует данные в dlarray
объекты с базовым типом single
.
augimds.MiniBatchSize = miniBatchSize; executionEnvironment = "auto"; mbq = minibatchqueue(augimds,... 'MiniBatchSize',miniBatchSize,... 'PartialMiniBatch','discard',... 'MiniBatchFcn', @preprocessData,... 'MiniBatchFormat',{'SSCB','BC'},... 'OutputEnvironment',executionEnvironment);
Инициализируйте параметры для оптимизатора Адама.
velocityDiscriminator = []; trailingAvgGenerator = []; trailingAvgSqGenerator = []; trailingAvgDiscriminator = []; trailingAvgSqDiscriminator = [];
Инициализируйте график процесса обучения. Создайте рисунок и измените ее размер, чтобы она имела вдвое большую ширину.
f = figure; f.Position(3) = 2*f.Position(3);
Создайте подграфики сгенерированных изображений и графика счетов.
imageAxes = subplot(1,2,1); scoreAxes = subplot(1,2,2);
Инициализируйте анимированные линии для графика счетов.
lineScoreGenerator = animatedline(scoreAxes,'Color',[0 0.447 0.741]); lineScoreDiscriminator = animatedline(scoreAxes, 'Color', [0.85 0.325 0.098]);
Настройка внешнего вида графиков.
legend('Generator','Discriminator'); ylim([0 1]) xlabel("Iteration") ylabel("Score") grid on
Чтобы контролировать процесс обучения, создайте задержанный пакет из 25 случайных векторов и соответствующий набор меток с 1 по 5 (соответствующий классам), повторенных пять раз.
numValidationImagesPerClass = 5;
ZValidation = randn(numLatentInputs,numValidationImagesPerClass*numClasses,'single');
TValidation = single(repmat(1:numClasses,[1 numValidationImagesPerClass]));
Преобразуйте данные в dlarray
Объекты и задайте метки размерности 'CB'
(канал, пакет).
dlZValidation = dlarray(ZValidation,'CB'); dlTValidation = dlarray(TValidation,'CB');
Для обучения графический процессор преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZValidation = gpuArray(dlZValidation); dlTValidation = gpuArray(dlTValidation); end
Обучите условную GAN. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных.
Для каждого мини-пакета:
Оцените градиенты модели с помощью dlfeval
и modelGradients
функция.
Обновляйте параметры сети с помощью adamupdate
функция.
Постройте график счетов двух сетей.
После каждого validationFrequency
итерации, отображают пакет сгенерированных изображений для фиксированного входа задержанного генератора.
Обучение может занять некоторое время.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Reset and shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX,dlT] = next(mbq); % Generate latent inputs for the generator network. Convert to % dlarray and specify the dimension labels 'CB' (channel, batch). % If training on a GPU, then convert latent inputs to gpuArray. Z = randn(numLatentInputs,miniBatchSize,'single'); dlZ = dlarray(Z,'CB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZ = gpuArray(dlZ); end % Evaluate the model gradients and the generator state using % dlfeval and the modelGradients function listed at the end of the % example. [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ... dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlT, dlZ, flipFactor); dlnetGenerator.State = stateGenerator; % Update the discriminator network parameters. [dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ... adamupdate(dlnetDiscriminator, gradientsDiscriminator, ... trailingAvgDiscriminator, trailingAvgSqDiscriminator, iteration, ... learnRate, gradientDecayFactor, squaredGradientDecayFactor); % Update the generator network parameters. [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = ... adamupdate(dlnetGenerator, gradientsGenerator, ... trailingAvgGenerator, trailingAvgSqGenerator, iteration, ... learnRate, gradientDecayFactor, squaredGradientDecayFactor); % Every validationFrequency iterations, display batch of generated images using the % held-out generator input. if mod(iteration,validationFrequency) == 0 || iteration == 1 % Generate images using the held-out generator input. dlXGeneratedValidation = predict(dlnetGenerator,dlZValidation,dlTValidation); % Tile and rescale the images in the range [0 1]. I = imtile(extractdata(dlXGeneratedValidation), ... 'GridSize',[numValidationImagesPerClass numClasses]); I = rescale(I); % Display the images. subplot(1,2,1); image(imageAxes,I) xticklabels([]); yticklabels([]); title("Generated Images"); end % Update the scores plot. subplot(1,2,2) addpoints(lineScoreGenerator,iteration,... double(gather(extractdata(scoreGenerator)))); addpoints(lineScoreDiscriminator,iteration,... double(gather(extractdata(scoreDiscriminator)))); % Update the title with training progress information. D = duration(0,0,toc(start),'Format','hh:mm:ss'); title(... "Epoch: " + epoch + ", " + ... "Iteration: " + iteration + ", " + ... "Elapsed: " + string(D)) drawnow end end
Здесь дискриминатор выучил сильное представление функций, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор выучил аналогичное сильное представление функций, которое позволяет ему генерировать изображения, подобные обучающим данным. Каждый столбец соответствует одному классу.
График обучения показывает счета сетей генератора и дискриминатора. Дополнительные сведения о том, как интерпретировать счета сети, см. в разделах Мониторинг процесса обучения GAN и Идентификация распространенных типах отказа.
Чтобы сгенерировать новые изображения конкретного класса, используйте predict
функция на генераторе с dlarray
объект, содержащий пакет случайных векторов и массив меток, соответствующих желаемым классам. Преобразуйте данные в dlarray
Объекты и задайте метки размерности 'CB'
(канал, пакет). Для предсказания графический процессор преобразуйте данные в gpuArray
объекты. Чтобы отобразить изображения вместе, используйте imtile
функциональность и изменение масштаба изображений с помощью rescale
функция.
Создайте массив из 36 векторов случайных значений, соответствующих первому классу.
numObservationsNew = 36;
idxClass = 1;
Z = randn(numLatentInputs,numObservationsNew,'single');
T = repmat(single(idxClass),[1 numObservationsNew]);
Преобразуйте данные в dlarray
объекты с метками размерностей 'SSCB'
(пространственный, пространственный, каналы, пакет).
dlZ = dlarray(Z,'CB'); dlT = dlarray(T,'CB');
Чтобы сгенерировать изображения с помощью графический процессор, также преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZ = gpuArray(dlZ); dlT = gpuArray(dlT); end
Сгенерируйте изображения с помощью predict
функция с сетью генератора.
dlXGenerated = predict(dlnetGenerator,dlZ,dlT);
Отобразите сгенерированные изображения на графике.
figure
I = imtile(extractdata(dlXGenerated));
I = rescale(I);
imshow(I)
title("Class: " + classes(idxClass))
Здесь сеть генератора генерирует изображения, обусловленные заданным классом.
Функция modelGradients
принимает за вход генератор и дискриминатор dlnetwork
объекты dlnetGenerator
и dlnetDiscriminator
мини-пакет входных данных dlX
, соответствующие метки dlT
и массив случайных значений dlZ
, и возвращает градиенты потерь относительно настраиваемых параметров в сетях, состоянии генератора и счетах сети.
Если дискриминатор учится слишком быстро различать реальные и сгенерированные изображения, то генератор может не обучаться. Чтобы лучше сбалансировать обучение дискриминатора и генератора, случайным образом разверните метки части реальных изображений.
function [gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] = ... modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlT, dlZ, flipFactor) % Calculate the predictions for real data with the discriminator network. dlYPred = forward(dlnetDiscriminator, dlX, dlT); % Calculate the predictions for generated data with the discriminator network. [dlXGenerated,stateGenerator] = forward(dlnetGenerator, dlZ, dlT); dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated, dlT); % Calculate probabilities. probGenerated = sigmoid(dlYPredGenerated); probReal = sigmoid(dlYPred); % Calculate the generator and discriminator scores. scoreGenerator = mean(probGenerated); scoreDiscriminator = (mean(probReal) + mean(1-probGenerated)) / 2; % Flip labels. numObservations = size(dlYPred,4); idx = randperm(numObservations,floor(flipFactor * numObservations)); probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx); % Calculate the GAN loss. [lossGenerator, lossDiscriminator] = ganLoss(probReal, probGenerated); % For each network, calculate the gradients with respect to the loss. gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true); gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables); end
Цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как "real"
. Чтобы максимизировать вероятность того, что изображения с генератора классифицируются дискриминатором как вещественные, минимизируйте отрицательную функцию журнала правдоподобие.
Учитывая выход дискриминатора:
- вероятность того, что вход изображение принадлежит классу "real"
.
- вероятность того, что вход изображение принадлежит классу "generated"
.
Обратите внимание на операцию сигмоида происходит в modelGradients
функция. Функция потерь для генератора задается как
где содержит выходные вероятности дискриминатора для сгенерированных изображений.
Цель дискриминатора - не быть «обманутым» генератором. Чтобы максимизировать вероятность того, что дискриминатор успешно различает действительное и сгенерированное изображения, минимизируйте сумму соответствующих отрицательных функций журнала функций правдоподобия. Функция потерь для дискриминатора определяется
где содержит выходные вероятности дискриминатора для действительных изображений.
function [lossGenerator, lossDiscriminator] = ganLoss(scoresReal,scoresGenerated) % Calculate losses for the discriminator network. lossGenerated = -mean(log(1 - scoresGenerated)); lossReal = -mean(log(scoresReal)); % Combine the losses for the discriminator network. lossDiscriminator = lossReal + lossGenerated; % Calculate the loss for the generator network. lossGenerator = -mean(log(scoresGenerated)); end
The preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките изображение и пометьте данные из входных массивов ячеек и соедините их в числовые массивы.
Перерассчитайте изображения, которые будут находиться в области значений [-1,1]
.
function [X,T] = preprocessData(XCell,TCell) % Extract image data from cell and concatenate X = cat(4,XCell{:}); % Extract label data from cell and concatenate T = cat(1,TCell{:}); % Rescale the images in the range [-1 1]. X = rescale(X,-1,1,'InputMin',0,'InputMax',255); end
Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| forward
| predict