Этот пример показывает, как обучить генеративную состязательную сеть Вассерштейна с градиентным штрафом (WGAN-GP) для генерации изображений.
Генеративная состязательная сеть (GAN) является типом нейронной сети для глубокого обучения, которая может генерировать данные с такими же характеристиками, как входные реальные данные.
GAN состоит из двух сетей, которые обучаются вместе:
Генератор - Учитывая вектор случайных значений (латентные входы) в качестве входных данных, эта сеть генерирует данные с той же структурой, что и обучающие данные.
Дискриминатор - Учитывая пакеты данных, содержащих наблюдения как от обучающих данных, так и от сгенерированных данных от генератора, эта сеть пытается классифицировать наблюдения как «реальные» или «сгенерированные».
Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обеих:
Обучите генератор генерировать данные, которые «дурачат» дискриминатор.
Обучите дискриминатор различать реальные и сгенерированные данные.
Чтобы оптимизировать эффективность генератора, максимизируйте потерю дискриминатора, когда даны сгенерированные данные. То есть цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как «действительные». Чтобы оптимизировать эффективность дискриминатора, минимизируйте потерю дискриминатора, когда заданы пакеты как реальных, так и сгенерированных данных. То есть цель дискриминатора состоит в том, чтобы генератор не «обманул».
В идеале эти стратегии приводят к генератору, который генерирует убедительно реалистичные данные, и дискриминатору, который научился сильным представлениям функций, которые характерны для обучающих данных. Однако [2] утверждает, что расхождения, которые GAN обычно минимизируют, потенциально не непрерывны относительно параметров генератора, что приводит к сложности обучения, и представляет модель Wasserstein GAN (WGAN), которая использует потерю Вассерштейна, чтобы помочь стабилизировать обучение. Модель WGAN может все еще производить плохие выборки или не сходиться, потому что взаимодействия между ограничением веса и функцией стоимости могут привести к исчезновению или взрыванию градиентов. Для решения этих проблем [3] вводит градиентный штраф, который улучшает стабильность путем наказания градиентов с большими нормальными значениями за счет более длительного вычислительного времени. Этот тип модели известен как модель WGAN-GP.
В этом примере показано, как обучить модель WGAN-GP, которая может генерировать изображения с аналогичными характеристиками для набора обучающих данных изображений.
Загрузите и извлеките набор данных Flowers [1].
url = 'http://download.tensorflow.org/example_images/flower_photos.tgz'; downloadFolder = tempdir; filename = fullfile(downloadFolder,'flower_dataset.tgz'); imageFolder = fullfile(downloadFolder,'flower_photos'); if ~exist(imageFolder,'dir') disp('Downloading Flowers data set (218 MB)...') websave(filename,url); untar(filename,downloadFolder) end
Создайте изображение datastore, содержащее фотографии цветов.
datasetFolder = fullfile(imageFolder); imds = imageDatastore(datasetFolder, ... 'IncludeSubfolders',true);
Увеличьте данные, чтобы включить случайное горизонтальное отражение и измените размер изображений, чтобы иметь размер 64 на 64.
augmenter = imageDataAugmenter('RandXReflection',true); augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);
Задайте следующую сеть, которая классифицирует реальные и сгенерированные изображения 64 на 64.
Создайте сеть, которая берет изображения 64 на 64 на 3 и возвращает скалярный счет предсказания, используя серию слоев скручивания с нормализацией партии. и прохудившихся слоев ReLU. Для вывода вероятностей в области значений [0,1] используйте сигмоидный слой.
Для слоев свертки задайте фильтры 5 на 5 с увеличением количества фильтров для каждого слоя. Также задайте шаг 2 и заполнение выхода.
Для утечек слоев ReLU задайте шкалу 0,2.
Для последнего слоя свертки задайте один фильтр 4 на 4.
numFilters = 64; scale = 0.2; inputSize = [64 64 3]; filterSize = 5; layersD = [ imageInputLayer(inputSize,'Normalization','none','Name','in') convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1') leakyReluLayer(scale,'Name','lrelu1') convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2') layerNormalizationLayer('Name','bn2') leakyReluLayer(scale,'Name','lrelu2') convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3') layerNormalizationLayer('Name','bn3') leakyReluLayer(scale,'Name','lrelu3') convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4') layerNormalizationLayer('Name','bn4') leakyReluLayer(scale,'Name','lrelu4') convolution2dLayer(4,1,'Name','conv5') sigmoidLayer('Name','sigmoid')]; lgraphD = layerGraph(layersD);
Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork
объект.
dlnetD = dlnetwork(lgraphD);
Определите следующую сетевую архитектуру, которая производит изображения от массивов случайных значений 1 на 1 на 100:
Эта сеть:
Преобразовывает случайные векторы размера 100 к массивам 7 на 7 на 128, используя проект, и измените слой.
Upscales получающиеся массивы к массивам 64 на 64 на 3, используя серию перемещенных слоев скручивания и слоев ReLU.
Определите эту сетевую архитектуру как график слоев и задайте следующие свойства сети.
Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, полосой 2 и обрезкой выхода на каждом ребре.
Для последнего транспонированного слоя свертки задайте три фильтра 5 на 5, соответствующих трем каналам RGB сгенерированных изображений, и выходной размер предыдущего слоя.
В конце сети включают слой танха.
Чтобы проецировать и изменить форму входного сигнала шума, используйте пользовательский слой projectAndReshapeLayer
, присоединенный к этому примеру как вспомогательный файл. The projectAndReshape
слой увеличивает масштаб входа с помощью операции с полным подключением и изменяет форму выхода на заданный размер.
filterSize = 5; numFilters = 64; numLatentInputs = 100; projectionSize = [4 4 512]; layersG = [ featureInputLayer(numLatentInputs,'Normalization','none','Name','in') projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj'); transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1') reluLayer('Name','relu1') transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2') reluLayer('Name','relu2') transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3') reluLayer('Name','relu3') transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4') tanhLayer('Name','tanh')]; lgraphG = layerGraph(layersG);
Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork
объект.
dlnetG = dlnetwork(lgraphG);
Создайте функции modelGradientsD
и modelGradientsG
перечисленные в разделе Model Gradients Function примера, которые вычисляют градиенты потерь дискриминатора и генератора относительно настраиваемых параметров сетей дискриминатора и генератора, соответственно.
Функция modelGradientsD
принимает за вход генератор и дискриминатор dlnetG
и dlnetD
мини-пакет входных данных dlX
, массив случайных значений dlZ
, и lambda
значение, используемое для градиентного штрафа, и возвращает градиенты потерь относительно настраиваемых параметров в дискриминаторе и потерь.
Функция modelGradientsG
принимает за вход генератор и дискриминатор dlnetG
и dlnetD
и массив случайных значений dlZ
, и возвращает градиенты потерь относительно настраиваемых параметров в генераторе и потерь.
Чтобы обучить модель WGAN-GP, вы должны обучить дискриминатор для большего количества итераций, чем генератор. Другими словами, для каждой итерации генератора необходимо обучить дискриминатор для нескольких итераций.
Обучите с мини-пакетом размером 64 для 10 000 итераций генератора. Для больших наборов данных вам, возможно, потребуется обучить для дополнительных итераций.
miniBatchSize = 64; numIterationsG = 10000;
Для каждой итерации генератора обучите дискриминатор для 5 итераций.
numIterationsDPerG = 5;
Для потерь WGAN-GP укажите значение лямбды 10. Значение лямбды контролирует величину градиентного штрафа, добавленного к потере дискриминатора.
lambda = 10;
Задайте опции для оптимизации Адама:
Для сети дискриминатора задайте скорость обучения 0,0002.
Для сети генератора задайте скорость обучения 0,001.
Для обеих сетей задайте коэффициент градиентного распада 0 и квадратный коэффициент градиентного распада 0,9.
learnRateD = 2e-4; learnRateG = 1e-3; gradientDecayFactor = 0; squaredGradientDecayFactor = 0.9;
Отображать сгенерированные изображения валидации каждые 20 итераций генератора.
validationFrequency = 20;
Использование minibatchqueue
для обработки и управления мини-пакетами изображений. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch
(определено в конце этого примера), чтобы переформулировать изображения в области значений [-1,1]
.
Отменить любые частичные мини-пакеты.
Форматируйте данные изображения с помощью меток размерностей 'SSCB'
(пространственный, пространственный, канальный, пакетный).
Обучите на графическом процессоре, если он доступен. Когда 'OutputEnvironment'
опция minibatchqueue
является 'auto'
, minibatchqueue
преобразует каждый выход в gpuArray
при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox) (Parallel Computing Toolbox).
The minibatchqueue
объект по умолчанию преобразует данные в dlarray
объекты с базовым типом single
.
augimds.MiniBatchSize = miniBatchSize; executionEnvironment = "auto"; mbq = minibatchqueue(augimds,... 'MiniBatchSize',miniBatchSize,... 'PartialMiniBatch','discard',... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat','SSCB',... 'OutputEnvironment',executionEnvironment);
Обучите модель с помощью пользовательского цикла обучения. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений, используя удерживаемый массив случайных значений для ввода в генератор, а также график счетов.
Инициализируйте параметры для Adam.
trailingAvgD = []; trailingAvgSqD = []; trailingAvgG = []; trailingAvgSqG = [];
Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью задержанного пакета фиксированных массивов случайных значений, подаваемых в генератор, и постройте графики счетов сети.
Создайте массив удерживаемых случайных значений.
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,'single');
Преобразуйте данные в dlarray
Объекты и задайте метки размерности 'SSCB'
(пространственный, пространственный, канальный, пакетный).
dlZValidation = dlarray(ZValidation,'CB');
Для обучения графический процессор преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZValidation = gpuArray(dlZValidation); end
Инициализируйте графики процесса обучения. Создайте рисунок и измените ее размер, чтобы она имела вдвое большую ширину.
f = figure; f.Position(3) = 2*f.Position(3);
Создайте подграфик для сгенерированных изображений и сетевых счетов.
imageAxes = subplot(1,2,1); scoreAxes = subplot(1,2,2);
Инициализируйте анимированные линии для графика потерь.
C = colororder; lineLossD = animatedline(scoreAxes,'Color',C(1,:)); lineLossDUnregularized = animatedline(scoreAxes,'Color',C(2,:)); legend('With Gradient Penanlty','Unregularized') xlabel("Generator Iteration") ylabel("Discriminator Loss") grid on
Обучите модель WGAN-GP путем закольцовывания по мини-пакетам данных.
Для numIterationsDPerG
итерации, обучите только дискриминатор. Для каждого мини-пакета:
Оцените градиенты модели дискриминатора с помощью dlfeval
и modelGradientsD
функция.
Обновите параметры сети дискриминатора, используя adamupdate
функция.
После обучения дискриминатора на numIterationsDPerG
итерации, обучите генератор на одном мини-пакете.
Оцените градиенты модели генератора с помощью dlfeval
и modelGradientsG
функция.
Обновите параметры сети генератора с помощью adamupdate
функция.
После обновления сети генератора:
Постройте график потерь двух сетей.
После каждого validationFrequency
итерации генератора, отображают пакет сгенерированных изображений для фиксированного входа задерживаемого генератора.
Пройдя через набор данных, перетащите мини-очередь пакетов.
Обучение может занять некоторое время, и может потребовать многих итераций для вывода хороших изображений.
iterationG = 0; iterationD = 0; start = tic; % Loop over mini-batches while iterationG < numIterationsG iterationG = iterationG + 1; % Train discriminator only for n = 1:numIterationsDPerG iterationD = iterationD + 1; % Reset and shuffle mini-batch queue when there is no more data. if ~hasdata(mbq) shuffle(mbq); end % Read mini-batch of data. dlX = next(mbq); % Generate latent inputs for the generator network. Convert to % dlarray and specify the dimension labels 'CB' (channel, batch). Z = randn([numLatentInputs size(dlX,4)],'like',dlX); dlZ = dlarray(Z,'CB'); % Evaluate the discriminator model gradients using dlfeval and the % modelGradientsD function listed at the end of the example. [gradientsD, lossD, lossDUnregularized] = dlfeval(@modelGradientsD, dlnetD, dlnetG, dlX, dlZ, lambda); % Update the discriminator network parameters. [dlnetD,trailingAvgD,trailingAvgSqD] = adamupdate(dlnetD, gradientsD, ... trailingAvgD, trailingAvgSqD, iterationD, ... learnRateD, gradientDecayFactor, squaredGradientDecayFactor); end % Generate latent inputs for the generator network. Convert to dlarray % and specify the dimension labels 'CB' (channel, batch). Z = randn([numLatentInputs size(dlX,4)],'like',dlX); dlZ = dlarray(Z,'CB'); % Evaluate the generator model gradients using dlfeval and the % modelGradientsG function listed at the end of the example. gradientsG = dlfeval(@modelGradientsG, dlnetG, dlnetD, dlZ); % Update the generator network parameters. [dlnetG,trailingAvgG,trailingAvgSqG] = adamupdate(dlnetG, gradientsG, ... trailingAvgG, trailingAvgSqG, iterationG, ... learnRateG, gradientDecayFactor, squaredGradientDecayFactor); % Every validationFrequency generator iterations, display batch of % generated images using the held-out generator input if mod(iterationG,validationFrequency) == 0 || iterationG == 1 % Generate images using the held-out generator input. dlXGeneratedValidation = predict(dlnetG,dlZValidation); % Tile and rescale the images in the range [0 1]. I = imtile(extractdata(dlXGeneratedValidation)); I = rescale(I); % Display the images. subplot(1,2,1); image(imageAxes,I) xticklabels([]); yticklabels([]); title("Generated Images"); end % Update the scores plot subplot(1,2,2) lossD = double(gather(extractdata(lossD))); lossDUnregularized = double(gather(extractdata(lossDUnregularized))); addpoints(lineLossD,iterationG,lossD); addpoints(lineLossDUnregularized,iterationG,lossDUnregularized); D = duration(0,0,toc(start),'Format','hh:mm:ss'); title( ... "Iteration: " + iterationG + ", " + ... "Elapsed: " + string(D)) drawnow end
Здесь дискриминатор выучил сильное представление функций, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор выучил аналогичное сильное представление функций, которое позволяет ему генерировать изображения, подобные обучающим данным.
Чтобы сгенерировать новые изображения, используйте predict
функция на генераторе с dlarray
объект, содержащий пакет случайных векторов. Чтобы отобразить изображения вместе, используйте imtile
функциональность и изменение масштаба изображений с помощью rescale
функция.
Создайте dlarray
объект, содержащий пакет из 25 случайных векторов для ввода в сеть генератора.
ZNew = randn(numLatentInputs,25,'single'); dlZNew = dlarray(ZNew,'CB');
Чтобы сгенерировать изображения с помощью графический процессор, также преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZNew = gpuArray(dlZNew); end
Сгенерируйте новые изображения с помощью predict
функция с генератором и входными данными.
dlXGeneratedNew = predict(dlnetG,dlZNew);
Отобразите изображения.
I = imtile(extractdata(dlXGeneratedNew)); I = rescale(I); figure image(I) axis off title("Generated Images")
Функция modelGradientsD
принимает за вход генератор и дискриминатор dlnetwork
объекты dlnetG
и dlnetD
мини-пакет входных данных dlX
, массив случайных значений dlZ
, и lambda
значение, используемое для градиентного штрафа, и возвращает градиенты потерь относительно настраиваемых параметров в дискриминаторе и потерь.
Заданное изображение , сгенерированное изображение , задайте для некоторых случайных .
Для модели WGAN-GP, учитывая значение лямбды , потеря дискриминатора определяется
где , , и обозначить выход дискриминатора для входов , , и , соответственно, и обозначает градиенты выхода по отношению к . Для мини-пакета данных используйте другое значение для каждого оберсвата и вычислить среднюю потерю.
Градиентный штраф повышает стабильность за счет наказания градиентов с большими нормальными значениями. Значение лямбды контролирует величину градиентного штрафа, добавленного к потере дискриминатора.
function [gradientsD, lossD, lossDUnregularized] = modelGradientsD(dlnetD, dlnetG, dlX, dlZ, lambda) % Calculate the predictions for real data with the discriminator network. dlYPred = forward(dlnetD, dlX); % Calculate the predictions for generated data with the discriminator % network. dlXGenerated = forward(dlnetG,dlZ); dlYPredGenerated = forward(dlnetD, dlXGenerated); % Calculate the loss. lossDUnregularized = mean(dlYPredGenerated - dlYPred); % Calculate and add the gradient penalty. epsilon = rand([1 1 1 size(dlX,4)],'like',dlX); dlXHat = epsilon.*dlX + (1-epsilon).*dlXGenerated; dlYHat = forward(dlnetD, dlXHat); % Calculate gradients. To enable computing higher-order derivatives, set % 'EnableHigherDerivatives' to true. gradientsHat = dlgradient(sum(dlYHat),dlXHat,'EnableHigherDerivatives',true); gradientsHatNorm = sqrt(sum(gradientsHat.^2,1:3) + 1e-10); gradientPenalty = lambda.*mean((gradientsHatNorm - 1).^2); % Penalize loss. lossD = lossDUnregularized + gradientPenalty; % Calculate the gradients of the penalized loss with respect to the % learnable parameters. gradientsD = dlgradient(lossD, dlnetD.Learnables); end
Функция modelGradientsG
принимает за вход генератор и дискриминатор dlnetwork
объекты dlnetG
и dlnetD
и массив случайных значений dlZ
, и возвращает градиенты потерь относительно настраиваемых параметров в генераторе.
Учитывая сгенерированное изображение , потери для сети генератора даются
где обозначает выход дискриминатора для сгенерированного изображения . Для мини-пакета сгенерированных изображений вычислите среднюю потерю.
function gradientsG = modelGradientsG(dlnetG, dlnetD, dlZ) % Calculate the predictions for generated data with the discriminator % network. dlXGenerated = forward(dlnetG,dlZ); dlYPredGenerated = forward(dlnetD, dlXGenerated); % Calculate the loss. lossG = -mean(dlYPredGenerated); % Calculate the gradients of the loss with respect to the learnable % parameters. gradientsG = dlgradient(lossG, dlnetG.Learnables); end
The preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из входного массива ячеек и соедините в числовой массив.
Перерассчитайте изображения, которые будут находиться в области значений [-1,1]
.
function X = preprocessMiniBatch(data) % Concatenate mini-batch X = cat(4,data{:}); % Rescale the images in the range [-1 1]. X = rescale(X,-1,1,'InputMin',0,'InputMax',255); end
Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz
Арьовский, Мартин, Сомит Чинтала и Леон Ботту. «Wasserstein GAN». arXiv preprint arXiv:1701.07875 (2017).
Гюльраджани, Ишаан, Фарук Ахмед, Мартин Арьовский, Винсент Дюмулен и Аарон К. Курвиль. «Улучшенная подготовка Wasserstein GANs». В усовершенствованиях в системах нейронной обработки информации, стр. 5767-5777. 2017.
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| forward
| predict