Этот пример показывает, как обучить генеративную состязательную сеть Вассерштейна с градиентным штрафом (WGAN-GP) для генерации изображений.
Генеративная состязательная сеть (GAN) является типом нейронной сети для глубокого обучения, которая может генерировать данные с такими же характеристиками, как входные реальные данные.
GAN состоит из двух сетей, которые обучаются вместе:
Генератор - Учитывая вектор случайных значений (латентные входы) в качестве входных данных, эта сеть генерирует данные с той же структурой, что и обучающие данные.
Дискриминатор - Учитывая пакеты данных, содержащих наблюдения как от обучающих данных, так и от сгенерированных данных от генератора, эта сеть пытается классифицировать наблюдения как «реальные» или «сгенерированные».

Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обеих:
Обучите генератор генерировать данные, которые «дурачат» дискриминатор.
Обучите дискриминатор различать реальные и сгенерированные данные.
Чтобы оптимизировать эффективность генератора, максимизируйте потерю дискриминатора, когда даны сгенерированные данные. То есть цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как «действительные». Чтобы оптимизировать эффективность дискриминатора, минимизируйте потерю дискриминатора, когда заданы пакеты как реальных, так и сгенерированных данных. То есть цель дискриминатора состоит в том, чтобы генератор не «обманул».
В идеале эти стратегии приводят к генератору, который генерирует убедительно реалистичные данные, и дискриминатору, который научился сильным представлениям функций, которые характерны для обучающих данных. Однако [2] утверждает, что расхождения, которые GAN обычно минимизируют, потенциально не непрерывны относительно параметров генератора, что приводит к сложности обучения, и представляет модель Wasserstein GAN (WGAN), которая использует потерю Вассерштейна, чтобы помочь стабилизировать обучение. Модель WGAN может все еще производить плохие выборки или не сходиться, потому что взаимодействия между ограничением веса и функцией стоимости могут привести к исчезновению или взрыванию градиентов. Для решения этих проблем [3] вводит градиентный штраф, который улучшает стабильность путем наказания градиентов с большими нормальными значениями за счет более длительного вычислительного времени. Этот тип модели известен как модель WGAN-GP.
В этом примере показано, как обучить модель WGAN-GP, которая может генерировать изображения с аналогичными характеристиками для набора обучающих данных изображений.
Загрузите и извлеките набор данных Flowers [1].
url = 'http://download.tensorflow.org/example_images/flower_photos.tgz'; downloadFolder = tempdir; filename = fullfile(downloadFolder,'flower_dataset.tgz'); imageFolder = fullfile(downloadFolder,'flower_photos'); if ~exist(imageFolder,'dir') disp('Downloading Flowers data set (218 MB)...') websave(filename,url); untar(filename,downloadFolder) end
Создайте изображение datastore, содержащее фотографии цветов.
datasetFolder = fullfile(imageFolder); imds = imageDatastore(datasetFolder, ... 'IncludeSubfolders',true);
Увеличьте данные, чтобы включить случайное горизонтальное отражение и измените размер изображений, чтобы иметь размер 64 на 64.
augmenter = imageDataAugmenter('RandXReflection',true); augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);
Задайте следующую сеть, которая классифицирует реальные и сгенерированные изображения 64 на 64.

Создайте сеть, которая берет изображения 64 на 64 на 3 и возвращает скалярный счет предсказания, используя серию слоев скручивания с нормализацией партии. и прохудившихся слоев ReLU. Для вывода вероятностей в области значений [0,1] используйте сигмоидный слой.
Для слоев свертки задайте фильтры 5 на 5 с увеличением количества фильтров для каждого слоя. Также задайте шаг 2 и заполнение выхода.
Для утечек слоев ReLU задайте шкалу 0,2.
Для последнего слоя свертки задайте один фильтр 4 на 4.
numFilters = 64;
scale = 0.2;
inputSize = [64 64 3];
filterSize = 5;
layersD = [
imageInputLayer(inputSize,'Normalization','none','Name','in')
convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1')
leakyReluLayer(scale,'Name','lrelu1')
convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2')
layerNormalizationLayer('Name','bn2')
leakyReluLayer(scale,'Name','lrelu2')
convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3')
layerNormalizationLayer('Name','bn3')
leakyReluLayer(scale,'Name','lrelu3')
convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4')
layerNormalizationLayer('Name','bn4')
leakyReluLayer(scale,'Name','lrelu4')
convolution2dLayer(4,1,'Name','conv5')
sigmoidLayer('Name','sigmoid')];
lgraphD = layerGraph(layersD);Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.
dlnetD = dlnetwork(lgraphD);
Определите следующую сетевую архитектуру, которая производит изображения от массивов случайных значений 1 на 1 на 100:

Эта сеть:
Преобразовывает случайные векторы размера 100 к массивам 7 на 7 на 128, используя проект, и измените слой.
Upscales получающиеся массивы к массивам 64 на 64 на 3, используя серию перемещенных слоев скручивания и слоев ReLU.
Определите эту сетевую архитектуру как график слоев и задайте следующие свойства сети.
Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, полосой 2 и обрезкой выхода на каждом ребре.
Для последнего транспонированного слоя свертки задайте три фильтра 5 на 5, соответствующих трем каналам RGB сгенерированных изображений, и выходной размер предыдущего слоя.
В конце сети включают слой танха.
Чтобы проецировать и изменить форму входного сигнала шума, используйте пользовательский слой projectAndReshapeLayer, присоединенный к этому примеру как вспомогательный файл. The projectAndReshape слой увеличивает масштаб входа с помощью операции с полным подключением и изменяет форму выхода на заданный размер.
filterSize = 5;
numFilters = 64;
numLatentInputs = 100;
projectionSize = [4 4 512];
layersG = [
featureInputLayer(numLatentInputs,'Normalization','none','Name','in')
projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj');
transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1')
reluLayer('Name','relu1')
transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2')
reluLayer('Name','relu2')
transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3')
reluLayer('Name','relu3')
transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4')
tanhLayer('Name','tanh')];
lgraphG = layerGraph(layersG);Чтобы обучить сеть с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте график слоев в dlnetwork объект.
dlnetG = dlnetwork(lgraphG);
Создайте функции modelGradientsD и modelGradientsG перечисленные в разделе Model Gradients Function примера, которые вычисляют градиенты потерь дискриминатора и генератора относительно настраиваемых параметров сетей дискриминатора и генератора, соответственно.
Функция modelGradientsD принимает за вход генератор и дискриминатор dlnetG и dlnetDмини-пакет входных данных dlX, массив случайных значений dlZ, и lambda значение, используемое для градиентного штрафа, и возвращает градиенты потерь относительно настраиваемых параметров в дискриминаторе и потерь.
Функция modelGradientsG принимает за вход генератор и дискриминатор dlnetG и dlnetDи массив случайных значений dlZ, и возвращает градиенты потерь относительно настраиваемых параметров в генераторе и потерь.
Чтобы обучить модель WGAN-GP, вы должны обучить дискриминатор для большего количества итераций, чем генератор. Другими словами, для каждой итерации генератора необходимо обучить дискриминатор для нескольких итераций.
Обучите с мини-пакетом размером 64 для 10 000 итераций генератора. Для больших наборов данных вам, возможно, потребуется обучить для дополнительных итераций.
miniBatchSize = 64; numIterationsG = 10000;
Для каждой итерации генератора обучите дискриминатор для 5 итераций.
numIterationsDPerG = 5;
Для потерь WGAN-GP укажите значение лямбды 10. Значение лямбды контролирует величину градиентного штрафа, добавленного к потере дискриминатора.
lambda = 10;
Задайте опции для оптимизации Адама:
Для сети дискриминатора задайте скорость обучения 0,0002.
Для сети генератора задайте скорость обучения 0,001.
Для обеих сетей задайте коэффициент градиентного распада 0 и квадратный коэффициент градиентного распада 0,9.
learnRateD = 2e-4; learnRateG = 1e-3; gradientDecayFactor = 0; squaredGradientDecayFactor = 0.9;
Отображать сгенерированные изображения валидации каждые 20 итераций генератора.
validationFrequency = 20;
Использование minibatchqueue для обработки и управления мини-пакетами изображений. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch (определено в конце этого примера), чтобы переформулировать изображения в области значений [-1,1].
Отменить любые частичные мини-пакеты.
Форматируйте данные изображения с помощью меток размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный).
Обучите на графическом процессоре, если он доступен. Когда 'OutputEnvironment' опция minibatchqueue является 'auto', minibatchqueue преобразует каждый выход в gpuArray при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox) (Parallel Computing Toolbox).
The minibatchqueue объект по умолчанию преобразует данные в dlarray объекты с базовым типом single.
augimds.MiniBatchSize = miniBatchSize; executionEnvironment = "auto"; mbq = minibatchqueue(augimds,... 'MiniBatchSize',miniBatchSize,... 'PartialMiniBatch','discard',... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat','SSCB',... 'OutputEnvironment',executionEnvironment);
Обучите модель с помощью пользовательского цикла обучения. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений, используя удерживаемый массив случайных значений для ввода в генератор, а также график счетов.
Инициализируйте параметры для Adam.
trailingAvgD = []; trailingAvgSqD = []; trailingAvgG = []; trailingAvgSqG = [];
Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью задержанного пакета фиксированных массивов случайных значений, подаваемых в генератор, и постройте графики счетов сети.
Создайте массив удерживаемых случайных значений.
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,'single');Преобразуйте данные в dlarray Объекты и задайте метки размерности 'SSCB' (пространственный, пространственный, канальный, пакетный).
dlZValidation = dlarray(ZValidation,'CB'); Для обучения графический процессор преобразуйте данные в gpuArray объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZValidation = gpuArray(dlZValidation); end
Инициализируйте графики процесса обучения. Создайте рисунок и измените ее размер, чтобы она имела вдвое большую ширину.
f = figure; f.Position(3) = 2*f.Position(3);
Создайте подграфик для сгенерированных изображений и сетевых счетов.
imageAxes = subplot(1,2,1); scoreAxes = subplot(1,2,2);
Инициализируйте анимированные линии для графика потерь.
C = colororder; lineLossD = animatedline(scoreAxes,'Color',C(1,:)); lineLossDUnregularized = animatedline(scoreAxes,'Color',C(2,:)); legend('With Gradient Penanlty','Unregularized') xlabel("Generator Iteration") ylabel("Discriminator Loss") grid on
Обучите модель WGAN-GP путем закольцовывания по мини-пакетам данных.
Для numIterationsDPerG итерации, обучите только дискриминатор. Для каждого мини-пакета:
Оцените градиенты модели дискриминатора с помощью dlfeval и modelGradientsD функция.
Обновите параметры сети дискриминатора, используя adamupdate функция.
После обучения дискриминатора на numIterationsDPerG итерации, обучите генератор на одном мини-пакете.
Оцените градиенты модели генератора с помощью dlfeval и modelGradientsG функция.
Обновите параметры сети генератора с помощью adamupdate функция.
После обновления сети генератора:
Постройте график потерь двух сетей.
После каждого validationFrequency итерации генератора, отображают пакет сгенерированных изображений для фиксированного входа задерживаемого генератора.
Пройдя через набор данных, перетащите мини-очередь пакетов.
Обучение может занять некоторое время, и может потребовать многих итераций для вывода хороших изображений.
iterationG = 0; iterationD = 0; start = tic; % Loop over mini-batches while iterationG < numIterationsG iterationG = iterationG + 1; % Train discriminator only for n = 1:numIterationsDPerG iterationD = iterationD + 1; % Reset and shuffle mini-batch queue when there is no more data. if ~hasdata(mbq) shuffle(mbq); end % Read mini-batch of data. dlX = next(mbq); % Generate latent inputs for the generator network. Convert to % dlarray and specify the dimension labels 'CB' (channel, batch). Z = randn([numLatentInputs size(dlX,4)],'like',dlX); dlZ = dlarray(Z,'CB'); % Evaluate the discriminator model gradients using dlfeval and the % modelGradientsD function listed at the end of the example. [gradientsD, lossD, lossDUnregularized] = dlfeval(@modelGradientsD, dlnetD, dlnetG, dlX, dlZ, lambda); % Update the discriminator network parameters. [dlnetD,trailingAvgD,trailingAvgSqD] = adamupdate(dlnetD, gradientsD, ... trailingAvgD, trailingAvgSqD, iterationD, ... learnRateD, gradientDecayFactor, squaredGradientDecayFactor); end % Generate latent inputs for the generator network. Convert to dlarray % and specify the dimension labels 'CB' (channel, batch). Z = randn([numLatentInputs size(dlX,4)],'like',dlX); dlZ = dlarray(Z,'CB'); % Evaluate the generator model gradients using dlfeval and the % modelGradientsG function listed at the end of the example. gradientsG = dlfeval(@modelGradientsG, dlnetG, dlnetD, dlZ); % Update the generator network parameters. [dlnetG,trailingAvgG,trailingAvgSqG] = adamupdate(dlnetG, gradientsG, ... trailingAvgG, trailingAvgSqG, iterationG, ... learnRateG, gradientDecayFactor, squaredGradientDecayFactor); % Every validationFrequency generator iterations, display batch of % generated images using the held-out generator input if mod(iterationG,validationFrequency) == 0 || iterationG == 1 % Generate images using the held-out generator input. dlXGeneratedValidation = predict(dlnetG,dlZValidation); % Tile and rescale the images in the range [0 1]. I = imtile(extractdata(dlXGeneratedValidation)); I = rescale(I); % Display the images. subplot(1,2,1); image(imageAxes,I) xticklabels([]); yticklabels([]); title("Generated Images"); end % Update the scores plot subplot(1,2,2) lossD = double(gather(extractdata(lossD))); lossDUnregularized = double(gather(extractdata(lossDUnregularized))); addpoints(lineLossD,iterationG,lossD); addpoints(lineLossDUnregularized,iterationG,lossDUnregularized); D = duration(0,0,toc(start),'Format','hh:mm:ss'); title( ... "Iteration: " + iterationG + ", " + ... "Elapsed: " + string(D)) drawnow end

Здесь дискриминатор выучил сильное представление функций, которое идентифицирует реальные изображения среди сгенерированных изображений. В свою очередь, генератор выучил аналогичное сильное представление функций, которое позволяет ему генерировать изображения, подобные обучающим данным.
Чтобы сгенерировать новые изображения, используйте predict функция на генераторе с dlarray объект, содержащий пакет случайных векторов. Чтобы отобразить изображения вместе, используйте imtile функциональность и изменение масштаба изображений с помощью rescale функция.
Создайте dlarray объект, содержащий пакет из 25 случайных векторов для ввода в сеть генератора.
ZNew = randn(numLatentInputs,25,'single'); dlZNew = dlarray(ZNew,'CB');
Чтобы сгенерировать изображения с помощью графический процессор, также преобразуйте данные в gpuArray объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZNew = gpuArray(dlZNew); end
Сгенерируйте новые изображения с помощью predict функция с генератором и входными данными.
dlXGeneratedNew = predict(dlnetG,dlZNew);
Отобразите изображения.
I = imtile(extractdata(dlXGeneratedNew)); I = rescale(I); figure image(I) axis off title("Generated Images")

Функция modelGradientsD принимает за вход генератор и дискриминатор dlnetwork объекты dlnetG и dlnetDмини-пакет входных данных dlX, массив случайных значений dlZ, и lambda значение, используемое для градиентного штрафа, и возвращает градиенты потерь относительно настраиваемых параметров в дискриминаторе и потерь.
Заданное изображение , сгенерированное изображение , задайте для некоторых случайных .
Для модели WGAN-GP, учитывая значение лямбды , потеря дискриминатора определяется
где , , и обозначить выход дискриминатора для входов , , и , соответственно, и обозначает градиенты выхода по отношению к . Для мини-пакета данных используйте другое значение для каждого оберсвата и вычислить среднюю потерю.
Градиентный штраф повышает стабильность за счет наказания градиентов с большими нормальными значениями. Значение лямбды контролирует величину градиентного штрафа, добавленного к потере дискриминатора.
function [gradientsD, lossD, lossDUnregularized] = modelGradientsD(dlnetD, dlnetG, dlX, dlZ, lambda) % Calculate the predictions for real data with the discriminator network. dlYPred = forward(dlnetD, dlX); % Calculate the predictions for generated data with the discriminator % network. dlXGenerated = forward(dlnetG,dlZ); dlYPredGenerated = forward(dlnetD, dlXGenerated); % Calculate the loss. lossDUnregularized = mean(dlYPredGenerated - dlYPred); % Calculate and add the gradient penalty. epsilon = rand([1 1 1 size(dlX,4)],'like',dlX); dlXHat = epsilon.*dlX + (1-epsilon).*dlXGenerated; dlYHat = forward(dlnetD, dlXHat); % Calculate gradients. To enable computing higher-order derivatives, set % 'EnableHigherDerivatives' to true. gradientsHat = dlgradient(sum(dlYHat),dlXHat,'EnableHigherDerivatives',true); gradientsHatNorm = sqrt(sum(gradientsHat.^2,1:3) + 1e-10); gradientPenalty = lambda.*mean((gradientsHatNorm - 1).^2); % Penalize loss. lossD = lossDUnregularized + gradientPenalty; % Calculate the gradients of the penalized loss with respect to the % learnable parameters. gradientsD = dlgradient(lossD, dlnetD.Learnables); end
Функция modelGradientsG принимает за вход генератор и дискриминатор dlnetwork объекты dlnetG и dlnetDи массив случайных значений dlZ, и возвращает градиенты потерь относительно настраиваемых параметров в генераторе.
Учитывая сгенерированное изображение , потери для сети генератора даются
где обозначает выход дискриминатора для сгенерированного изображения . Для мини-пакета сгенерированных изображений вычислите среднюю потерю.
function gradientsG = modelGradientsG(dlnetG, dlnetD, dlZ) % Calculate the predictions for generated data with the discriminator % network. dlXGenerated = forward(dlnetG,dlZ); dlYPredGenerated = forward(dlnetD, dlXGenerated); % Calculate the loss. lossG = -mean(dlYPredGenerated); % Calculate the gradients of the loss with respect to the learnable % parameters. gradientsG = dlgradient(lossG, dlnetG.Learnables); end
The preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из входного массива ячеек и соедините в числовой массив.
Перерассчитайте изображения, которые будут находиться в области значений [-1,1].
function X = preprocessMiniBatch(data) % Concatenate mini-batch X = cat(4,data{:}); % Rescale the images in the range [-1 1]. X = rescale(X,-1,1,'InputMin',0,'InputMax',255); end
Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz
Арьовский, Мартин, Сомит Чинтала и Леон Ботту. «Wasserstein GAN». arXiv preprint arXiv:1701.07875 (2017).
Гюльраджани, Ишаан, Фарук Ахмед, Мартин Арьовский, Винсент Дюмулен и Аарон К. Курвиль. «Улучшенная подготовка Wasserstein GANs». В усовершенствованиях в системах нейронной обработки информации, стр. 5767-5777. 2017.
adamupdate | dlarray | dlfeval | dlgradient | dlnetwork | forward | predict