В этом примере показано, как обучить Вассерштейна порождающая соперничающая сеть со штрафом градиента (WGAN-GP) генерировать изображения.
Порождающая соперничающая сеть (GAN) является типом нейронной сети для глубокого обучения, которая может сгенерировать данные с подобными характеристиками как вход действительные данные.
ГАНЬ состоит из двух сетей, которые обучаются вместе:
Генератор —, Учитывая вектор из случайных значений (скрытые входные параметры), как введено, эта сеть генерирует данные с той же структурой как обучающие данные.
Различитель — Данный пакеты данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "действительные" или "сгенерированные".
Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обоих:
Обучите генератор генерировать данные, которые "дурачат" различитель.
Обучите различитель различать действительные и сгенерированные данные.
Чтобы оптимизировать эффективность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы оптимизировать эффективность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Таким образом, цель различителя не состоит в том, чтобы "дурачить" генератор.
Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных. Однако [2] утверждает, что расхождения, которые обычно минимизируют GANs, потенциально не непрерывны относительно параметров генератора, ведя к учебной трудности, и вводит модель Wasserstein GAN (WGAN), которая использует утрату Вассерштейна, чтобы помочь стабилизировать обучение. Модель WGAN может все еще произвести плохие выборки или может не сходиться, потому что взаимодействия между ограничением веса и функцией стоимости могут привести к исчезновению или взрыву градиентов. Решать эти проблемы, [3] вводит штраф градиента, который улучшает устойчивость путем наложения штрафа на градиенты с большими значениями нормы за счет более длительного вычислительного времени. Этот тип модели известен как модель WGAN-GP.
В этом примере показано, как обучить модель WGAN-GP, которая может сгенерировать изображения с подобными характеристиками к набору обучающих данных изображений.
Загрузите и извлеките Цветочный набор данных [1].
url = 'http://download.tensorflow.org/example_images/flower_photos.tgz'; downloadFolder = tempdir; filename = fullfile(downloadFolder,'flower_dataset.tgz'); imageFolder = fullfile(downloadFolder,'flower_photos'); if ~exist(imageFolder,'dir') disp('Downloading Flowers data set (218 MB)...') websave(filename,url); untar(filename,downloadFolder) end
Создайте datastore изображений, содержащий фотографии цветов.
datasetFolder = fullfile(imageFolder); imds = imageDatastore(datasetFolder, ... 'IncludeSubfolders',true);
Увеличьте данные, чтобы включать случайное горизонтальное зеркальное отражение и изменить размер изображений, чтобы иметь размер 64 64.
augmenter = imageDataAugmenter('RandXReflection',true); augimds = augmentedImageDatastore([64 64],imds,'DataAugmentation',augmenter);
Задайте следующую сеть, которая классифицирует действительный, и сгенерированный 64 64 отображает.
Создайте сеть, которая берет 64 64 3 изображениями и возвращает скалярный счет предсказания с помощью серии слоев свертки с нормализацией партии. и текучих слоев ReLU. Чтобы вывести вероятности в области значений [0,1], используйте сигмоидальный слой.
Для слоев свертки задайте фильтры 5 на 5 с растущим числом фильтров для каждого слоя. Также задайте шаг 2 и дополнение выхода.
Для текучих слоев ReLU задайте шкалу 0,2.
Для итогового слоя свертки задайте один фильтр 4 на 4.
numFilters = 64; scale = 0.2; inputSize = [64 64 3]; filterSize = 5; layersD = [ imageInputLayer(inputSize,'Normalization','none','Name','in') convolution2dLayer(filterSize,numFilters,'Stride',2,'Padding','same','Name','conv1') leakyReluLayer(scale,'Name','lrelu1') convolution2dLayer(filterSize,2*numFilters,'Stride',2,'Padding','same','Name','conv2') layerNormalizationLayer('Name','bn2') leakyReluLayer(scale,'Name','lrelu2') convolution2dLayer(filterSize,4*numFilters,'Stride',2,'Padding','same','Name','conv3') layerNormalizationLayer('Name','bn3') leakyReluLayer(scale,'Name','lrelu3') convolution2dLayer(filterSize,8*numFilters,'Stride',2,'Padding','same','Name','conv4') layerNormalizationLayer('Name','bn4') leakyReluLayer(scale,'Name','lrelu4') convolution2dLayer(4,1,'Name','conv5') sigmoidLayer('Name','sigmoid')]; lgraphD = layerGraph(layersD);
Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork
объект.
dlnetD = dlnetwork(lgraphD);
Задайте следующую сетевую архитектуру, которая генерирует изображения от 1 1 100 массивами случайных значений:
Эта сеть:
Преобразует случайные векторы из размера от 100 до 7 7 128 массивами с помощью проекта, и измените слой.
Увеличивает масштаб полученные массивы к 64 64 3 массивами с помощью серии транспонированных слоев свертки и слоев ReLU.
Задайте эту сетевую архитектуру как график слоев и задайте следующие сетевые свойства.
Для транспонированных слоев свертки задайте фильтры 5 на 5 с уменьшающимся количеством фильтров для каждого слоя, шага 2 и обрезки выхода на каждом ребре.
Поскольку финал транспонировал слой свертки, задайте три фильтра 5 на 5, соответствующие трем каналам RGB сгенерированных изображений и выходному размеру предыдущего слоя.
В конце сети включайте tanh слой.
Чтобы спроектировать и изменить шумовой вход, используйте пользовательский слой projectAndReshapeLayer
, присоединенный к этому примеру как вспомогательный файл. projectAndReshape
слой увеличивает масштаб вход с помощью полностью связанной операции и изменяет выход к заданному размеру.
filterSize = 5; numFilters = 64; numLatentInputs = 100; projectionSize = [4 4 512]; layersG = [ featureInputLayer(numLatentInputs,'Normalization','none','Name','in') projectAndReshapeLayer(projectionSize,numLatentInputs,'Name','proj'); transposedConv2dLayer(filterSize,4*numFilters,'Name','tconv1') reluLayer('Name','relu1') transposedConv2dLayer(filterSize,2*numFilters,'Stride',2,'Cropping','same','Name','tconv2') reluLayer('Name','relu2') transposedConv2dLayer(filterSize,numFilters,'Stride',2,'Cropping','same','Name','tconv3') reluLayer('Name','relu3') transposedConv2dLayer(filterSize,3,'Stride',2,'Cropping','same','Name','tconv4') tanhLayer('Name','tanh')]; lgraphG = layerGraph(layersG);
Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork
объект.
dlnetG = dlnetwork(lgraphG);
Создайте функции modelGradientsD
и modelGradientsG
перечисленный в разделе Model Gradients Function примера, которые вычисляют градиенты различителя и потери генератора относительно настраиваемых параметров различителя и сетей генератора, соответственно.
Функциональный modelGradientsD
берет в качестве входа генератор и различитель dlnetG
и dlnetD
, мини-пакет входных данных dlX
, массив случайных значений dlZ
, и lambda
значение, используемое для штрафа градиента, и, возвращает градиенты потери относительно настраиваемых параметров в различителе и потери.
Функциональный modelGradientsG
берет в качестве входа генератор и различитель dlnetG
и dlnetD
, и массив случайных значений dlZ
, и возвращает градиенты потери относительно настраиваемых параметров в генераторе и потери.
Чтобы обучить модель WGAN-GP, необходимо обучить различитель большему количеству итераций, чем генератор. Другими словами, для каждой итерации генератора, необходимо обучить различитель нескольким итерациям.
Обучайтесь с мини-пакетным размером 64 для 10 000 итераций генератора. Для больших наборов данных вы можете должны быть обучаться для большего количества итераций.
miniBatchSize = 64; numIterationsG = 10000;
Для каждой итерации генератора обучите различитель 5 итерациям.
numIterationsDPerG = 5;
За потерю WGAN-GP задайте значение lambda 10. Значение lambda управляет величиной штрафа градиента, добавленного к потере различителя.
lambda = 10;
Задайте опции для оптимизации Адама:
Для сети различителя задайте скорость обучения 0,0002.
Для сети генератора задайте скорость обучения 0,001.
Для обеих сетей задайте фактор затухания градиента 0 и фактор затухания градиента в квадрате 0,9.
learnRateD = 2e-4; learnRateG = 1e-3; gradientDecayFactor = 0; squaredGradientDecayFactor = 0.9;
Отобразитесь сгенерированная валидация отображает каждые 20 итераций генератора.
validationFrequency = 20;
Используйте minibatchqueue
обработать и управлять мини-пакетами изображений. Для каждого мини-пакета:
Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch
(заданный в конце этого примера), чтобы перемасштабировать изображения в области значений [-1,1]
.
Отбросьте любые частичные мини-пакеты.
Формат данные изображения с размерностью маркирует 'SSCB'
(пространственный, пространственный, канал, пакет).
Обучайтесь на графическом процессоре, если вы доступны. Когда 'OutputEnvironment'
опция minibatchqueue
'auto'
, minibatchqueue
преобразует каждый выход в gpuArray
если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox) (Parallel Computing Toolbox).
minibatchqueue
объект, по умолчанию, преобразует данные в dlarray
объекты с базовым типом single
.
augimds.MiniBatchSize = miniBatchSize; executionEnvironment = "auto"; mbq = minibatchqueue(augimds,... 'MiniBatchSize',miniBatchSize,... 'PartialMiniBatch','discard',... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat','SSCB',... 'OutputEnvironment',executionEnvironment);
Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации. Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого массива случайных значений, чтобы ввести в генератор, а также график баллов.
Инициализируйте параметры для Адама.
trailingAvgD = []; trailingAvgSqD = []; trailingAvgG = []; trailingAvgSqG = [];
Чтобы контролировать процесс обучения, отобразите пакет сгенерированных изображений с помощью протянутого пакета фиксированных массивов случайных значений, поданных в генератор, и постройте сетевые баллы.
Создайте массив протянутых случайных значений.
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,'single');
Преобразуйте данные в dlarray
объекты и указывают, что размерность маркирует 'SSCB'
(пространственный, пространственный, канал, пакет).
dlZValidation = dlarray(ZValidation,'CB');
Для обучения графического процессора преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZValidation = gpuArray(dlZValidation); end
Инициализируйте графики процесса обучения. Создайте фигуру и измените размер его, чтобы иметь дважды ширину.
f = figure; f.Position(3) = 2*f.Position(3);
Создайте подграфик для сгенерированных изображений и сетевых баллов.
imageAxes = subplot(1,2,1); scoreAxes = subplot(1,2,2);
Инициализируйте анимированные линии для графика потерь.
C = colororder; lineLossD = animatedline(scoreAxes,'Color',C(1,:)); lineLossDUnregularized = animatedline(scoreAxes,'Color',C(2,:)); legend('With Gradient Penanlty','Unregularized') xlabel("Generator Iteration") ylabel("Discriminator Loss") grid on
Обучите модель WGAN-GP цикличным выполнением по мини-пакетам данных.
Для numIterationsDPerG
итерации, обучите различитель только. Для каждого мини-пакета:
Оцените градиенты модели различителя с помощью dlfeval
и modelGradientsD
функция.
Обновите параметры сети различителя с помощью adamupdate
функция.
После обучения различитель для numIterationsDPerG
итерации, обучите генератор на одном мини-пакете.
Оцените градиенты модели генератора с помощью dlfeval
и modelGradientsG
функция.
Обновите параметры сети генератора с помощью adamupdate
функция.
После обновления сети генератора:
Постройте потери этих двух сетей.
После каждого validationFrequency
итерации генератора, отображение пакет сгенерированных изображений для фиксированного протянул вход генератора.
После прохождения через набор данных переставьте мини-пакетную очередь.
Обучение может занять время, чтобы запуститься и может потребовать, чтобы много итераций вывели хорошие изображения.
iterationG = 0; iterationD = 0; start = tic; % Loop over mini-batches while iterationG < numIterationsG iterationG = iterationG + 1; % Train discriminator only for n = 1:numIterationsDPerG iterationD = iterationD + 1; % Reset and shuffle mini-batch queue when there is no more data. if ~hasdata(mbq) shuffle(mbq); end % Read mini-batch of data. dlX = next(mbq); % Generate latent inputs for the generator network. Convert to % dlarray and specify the dimension labels 'CB' (channel, batch). Z = randn([numLatentInputs size(dlX,4)],'like',dlX); dlZ = dlarray(Z,'CB'); % Evaluate the discriminator model gradients using dlfeval and the % modelGradientsD function listed at the end of the example. [gradientsD, lossD, lossDUnregularized] = dlfeval(@modelGradientsD, dlnetD, dlnetG, dlX, dlZ, lambda); % Update the discriminator network parameters. [dlnetD,trailingAvgD,trailingAvgSqD] = adamupdate(dlnetD, gradientsD, ... trailingAvgD, trailingAvgSqD, iterationD, ... learnRateD, gradientDecayFactor, squaredGradientDecayFactor); end % Generate latent inputs for the generator network. Convert to dlarray % and specify the dimension labels 'CB' (channel, batch). Z = randn([numLatentInputs size(dlX,4)],'like',dlX); dlZ = dlarray(Z,'CB'); % Evaluate the generator model gradients using dlfeval and the % modelGradientsG function listed at the end of the example. gradientsG = dlfeval(@modelGradientsG, dlnetG, dlnetD, dlZ); % Update the generator network parameters. [dlnetG,trailingAvgG,trailingAvgSqG] = adamupdate(dlnetG, gradientsG, ... trailingAvgG, trailingAvgSqG, iterationG, ... learnRateG, gradientDecayFactor, squaredGradientDecayFactor); % Every validationFrequency generator iterations, display batch of % generated images using the held-out generator input if mod(iterationG,validationFrequency) == 0 || iterationG == 1 % Generate images using the held-out generator input. dlXGeneratedValidation = predict(dlnetG,dlZValidation); % Tile and rescale the images in the range [0 1]. I = imtile(extractdata(dlXGeneratedValidation)); I = rescale(I); % Display the images. subplot(1,2,1); image(imageAxes,I) xticklabels([]); yticklabels([]); title("Generated Images"); end % Update the scores plot subplot(1,2,2) lossD = double(gather(extractdata(lossD))); lossDUnregularized = double(gather(extractdata(lossDUnregularized))); addpoints(lineLossD,iterationG,lossD); addpoints(lineLossDUnregularized,iterationG,lossDUnregularized); D = duration(0,0,toc(start),'Format','hh:mm:ss'); title( ... "Iteration: " + iterationG + ", " + ... "Elapsed: " + string(D)) drawnow end
Здесь, различитель изучил представление сильной черты, которое идентифицирует действительные изображения среди сгенерированных изображений. В свою очередь генератор изучил представление столь же сильной черты, которое позволяет ему генерировать изображения, похожие на обучающие данные.
Чтобы сгенерировать новые изображения, используйте predict
функция на генераторе с dlarray
объект, содержащий пакет случайных векторов. Чтобы отобразить изображения вместе, используйте imtile
функционируйте и перемасштабируйте изображения с помощью rescale
функция.
Создайте dlarray
объект, содержащий пакет 25 случайных векторов, чтобы ввести к сети генератора.
ZNew = randn(numLatentInputs,25,'single'); dlZNew = dlarray(ZNew,'CB');
Чтобы сгенерировать изображения с помощью графического процессора, также преобразуйте данные в gpuArray
объекты.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZNew = gpuArray(dlZNew); end
Сгенерируйте новые изображения с помощью predict
функция с генератором и входными данными.
dlXGeneratedNew = predict(dlnetG,dlZNew);
Отобразите изображения.
I = imtile(extractdata(dlXGeneratedNew)); I = rescale(I); figure image(I) axis off title("Generated Images")
Функциональный modelGradientsD
берет в качестве входа генератор и различитель dlnetwork
объекты dlnetG
и dlnetD
, мини-пакет входных данных dlX
, массив случайных значений dlZ
, и lambda
значение, используемое для штрафа градиента, и, возвращает градиенты потери относительно настраиваемых параметров в различителе и потери.
Учитывая изображение , сгенерированное изображение Define для некоторых случайных .
Для модели WGAN-GP, учитывая значение lambda , потерей различителя дают
где , , и обозначьте выход различителя для входных параметров , , и , соответственно, и обозначает градиенты выхода относительно . Для мини-пакета данных используйте различное значение для каждого obersvation и вычисляют среднюю потерю.
Штраф градиента улучшает устойчивость путем наложения штрафа на градиенты с большими значениями нормы. Значение lambda управляет величиной штрафа градиента, добавленного к потере различителя.
function [gradientsD, lossD, lossDUnregularized] = modelGradientsD(dlnetD, dlnetG, dlX, dlZ, lambda) % Calculate the predictions for real data with the discriminator network. dlYPred = forward(dlnetD, dlX); % Calculate the predictions for generated data with the discriminator % network. dlXGenerated = forward(dlnetG,dlZ); dlYPredGenerated = forward(dlnetD, dlXGenerated); % Calculate the loss. lossDUnregularized = mean(dlYPredGenerated - dlYPred); % Calculate and add the gradient penalty. epsilon = rand([1 1 1 size(dlX,4)],'like',dlX); dlXHat = epsilon.*dlX + (1-epsilon).*dlXGenerated; dlYHat = forward(dlnetD, dlXHat); % Calculate gradients. To enable computing higher-order derivatives, set % 'EnableHigherDerivatives' to true. gradientsHat = dlgradient(sum(dlYHat),dlXHat,'EnableHigherDerivatives',true); gradientsHatNorm = sqrt(sum(gradientsHat.^2,1:3) + 1e-10); gradientPenalty = lambda.*mean((gradientsHatNorm - 1).^2); % Penalize loss. lossD = lossDUnregularized + gradientPenalty; % Calculate the gradients of the penalized loss with respect to the % learnable parameters. gradientsD = dlgradient(lossD, dlnetD.Learnables); end
Функциональный modelGradientsG
берет в качестве входа генератор и различитель dlnetwork
объекты dlnetG
и dlnetD
, и массив случайных значений dlZ
, и возвращает градиенты потери относительно настраиваемых параметров в генераторе.
Учитывая сгенерированное изображение , потерей для сети генератора дают
где обозначает выход различителя для сгенерированного изображения . Для мини-пакета сгенерированных изображений вычислите среднюю потерю.
function gradientsG = modelGradientsG(dlnetG, dlnetD, dlZ) % Calculate the predictions for generated data with the discriminator % network. dlXGenerated = forward(dlnetG,dlZ); dlYPredGenerated = forward(dlnetD, dlXGenerated); % Calculate the loss. lossG = -mean(dlYPredGenerated); % Calculate the gradients of the loss with respect to the learnable % parameters. gradientsG = dlgradient(lossG, dlnetG.Learnables); end
preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из входного массива ячеек и конкатенируйте в числовой массив.
Перемасштабируйте изображения, чтобы быть в области значений [-1,1]
.
function X = preprocessMiniBatch(data) % Concatenate mini-batch X = cat(4,data{:}); % Rescale the images in the range [-1 1]. X = rescale(X,-1,1,'InputMin',0,'InputMax',255); end
Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz
Arjovsky, Мартин, Сумит Чинтэла и Леон Ботту. "Вассерштейн GAN". arXiv предварительно распечатывают arXiv:1701.07875 (2017).
Gulrajani, Ishaan, Фэрук Ахмед, Мартин Арйовский, Винсент Думулин и Аарон К. Коервилл. "Улучшенное обучение Вассерштейна GANs". В Усовершенствованиях в нейронных системах обработки информации, стр 5767-5777. 2017.
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| forward
| predict