Используйте Experiment Manager, чтобы обучить порождающие соперничающие сети (GANs)

В этом примере показано, как создать пользовательский учебный эксперимент, чтобы обучить порождающую соперничающую сеть (GAN), которая генерирует изображения цветов. Для пользовательского учебного эксперимента вы явным образом задаете метод обучения, используемый Experiment Manager. В этом примере вы реализуете пользовательский учебный цикл, чтобы обучить GAN, тип нейронной сети для глубокого обучения, которая может сгенерировать данные с подобными характеристиками как вход действительные данные. ГАНЬ состоит из двух сетей, которые обучаются вместе:

  • Генератор —, Учитывая вектор из случайных значений (скрытые входные параметры), как введено, эта сеть генерирует данные с той же структурой как обучающие данные.

  • Различитель — Данный пакеты данных, содержащих наблюдения от обоих обучающие данные и сгенерированные данные из генератора, эта сеть пытается классифицировать наблюдения как "действительные" или "сгенерированные".

Чтобы обучить GAN, обучите обе сети одновременно, чтобы максимизировать эффективность обеих сетей:

  • Обучите генератор генерировать данные, которые "дурачат" различитель. Чтобы оптимизировать эффективность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Другими словами, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные".

  • Обучите различитель различать действительные и сгенерированные данные. Чтобы оптимизировать эффективность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Другими словами, цель различителя не состоит в том, чтобы "дурачить" генератор.

Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных. Для получения дополнительной информации смотрите, Обучают Порождающую соперничающую сеть (GAN).

Открытый эксперимент

Во-первых, откройте пример. Experiment Manager загружает проект с предварительно сконфигурированным экспериментом, который можно смотреть и запустить. Чтобы открыть эксперимент, в панели Браузера Эксперимента, дважды кликают имя эксперимента (ImageGenerationExperiment).

Пользовательские учебные эксперименты состоят из описания, таблицы гиперпараметров и учебной функции. Для получения дополнительной информации смотрите, Конфигурируют Пользовательский Учебный Эксперимент.

Поле Description содержит текстовое описание эксперимента. В данном примере описание:

Train a generative adversarial network (GAN) to generate images of flowers.
Use hyperparameters to specify:
* the probability of the dropout layer in the discriminator network
* the fraction of real labels to flip while training the discriminator network

Раздел Hyperparameters задает стратегию (Exhaustive Sweep) и гиперзначения параметров, чтобы использовать для эксперимента. Когда вы запускаете эксперимент, Experiment Manager обучает сеть с помощью каждой комбинации гиперзначений параметров, заданных в гипертаблице параметров. Этот пример использует два гиперпараметра:

  • dropoutProb устанавливает вероятность слоя уволенного в сети различителя. По умолчанию значения для этого гиперпараметра заданы как [0.25 0.5 0.75].

  • flipFactor устанавливает часть действительных меток инвертировать, когда вы обучаете сеть различителя. Эксперимент использует этот гиперпараметр, чтобы добавить шум в действительные данные и лучший баланс приобретение знаний о различителе и генераторе. В противном случае, если различитель учится различать между действительными и сгенерированными изображениями слишком быстро, то генератор может не обучаться. Значения для этого гиперпараметра заданы как [0.1 0.3 0.5].

Учебная Функция задает обучающие данные, сетевую архитектуру, опции обучения и метод обучения, используемый экспериментом. Вход к учебной функции является структурой с полями от гипертаблицы параметров и experiments.Monitor возразите, что можно использовать, чтобы отследить прогресс обучения, значения записи метрик, используемых обучением, и произвести учебные графики. Учебная функция возвращает структуру, которая содержит обученную сеть генератора, обученную сеть различителя и среду выполнения, используемую для обучения. Experiment Manager сохраняет этот выход, таким образом, можно экспортировать его в рабочее пространство MATLAB, когда обучение завершено. Учебная функция имеет шесть разделов.

  • Инициализируйте Выход, устанавливает начальное значение сетей к пустым массивам, чтобы указать, что обучение не запустилось. Эксперимент устанавливает среду выполнения на "auto", таким образом, это обучает сети на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

output.generator = [];
output.discriminator = [];
output.executionEnvironment = "auto";
  • Обучающие данные загрузки задают обучающие данные для эксперимента как imageDatastore объект. Эксперимент использует Цветочный набор данных, который содержит 3 670 изображений цветов и составляет приблизительно 218 Мбайт. Для получения дополнительной информации об этом наборе данных смотрите Наборы Данных изображения.

monitor.Status = "Loading Data";
url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");
imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    websave(filename,url);
    untar(filename,downloadFolder)
end
datasetFolder = fullfile(imageFolder);
imdsTrain = imageDatastore(datasetFolder, ...
    IncludeSubfolders=true);
augmenter = imageDataAugmenter(RandXReflection=true);
augimdsTrain = augmentedImageDatastore([64 64],imdsTrain, ...
    DataAugmentation=augmenter);
  • Задайте Сеть Генератора, задает архитектуру для сети генератора как график слоев, который генерирует изображения от 1 1 100 массивами случайных значений. Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, учебная функция преобразует график слоев в dlnetwork объект.

monitor.Status = "Creating Generator";
filterSize = 5;
numFilters = 64;
numLatentInputs = 100;
projectionSize = [4 4 512];
layersGenerator = [
    featureInputLayer(numLatentInputs,Name="in")
    projectAndReshapeLayer(projectionSize,numLatentInputs,Name="proj")
    transposedConv2dLayer(filterSize,4*numFilters,Name="tconv1")
    batchNormalizationLayer(Name="bnorm1")
    reluLayer(Name="relu1")
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same",Name="tconv2")
    batchNormalizationLayer(Name="bnorm2")
    reluLayer(Name="relu2")
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same",Name="tconv3")
    batchNormalizationLayer(Name="bnorm3")
    reluLayer(Name="relu3")
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same",Name="tconv4")
    tanhLayer(Name="tanh")];
lgraphGenerator = layerGraph(layersGenerator);
output.generator = dlnetwork(lgraphGenerator);
  • Задайте Сеть Различителя, задает архитектуру для сети различителя как график слоев, который классифицирует действительный и сгенерировал 64 64 3 изображениями. Слой уволенного использует вероятность уволенного, заданную в таблице гиперпараметра. Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, учебная функция преобразует график слоев в dlnetwork объект.

monitor.Status = "Creating Discriminator";
filterSize = 5;
numFilters = 64;
inputSize = [64 64 3];
dropoutProb = params.dropoutProb;
scale = 0.2;
layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none",Name="in")
    dropoutLayer(dropoutProb,Name="dropout")
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same",Name="conv1")
    leakyReluLayer(scale,Name="lrelu1")
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same",Name="conv2")
    batchNormalizationLayer(Name="bn2")
    leakyReluLayer(scale,Name="lrelu2")
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same",Name="conv3")
    batchNormalizationLayer(Name="bn3")
    leakyReluLayer(scale,Name="lrelu3")
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same",Name="conv4")
    batchNormalizationLayer(Name="bn4")
    leakyReluLayer(scale,Name="lrelu4")
    convolution2dLayer(4,1,Name="conv5")];
lgraphDiscriminator = layerGraph(layersDiscriminator);
output.discriminator = dlnetwork(lgraphDiscriminator);
  • Укажите, что Опции обучения задают опции обучения, используемые экспериментом. В этом примере Experiment Manager обучает сети с мини-пакетным размером 128 в течение 50 эпох с помощью начальной скорости обучения 0,0002, фактора затухания градиента 0,5 и фактора затухания градиента в квадрате 0,999.

numEpochs = 50;
miniBatchSize = 128;
initialLearnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];
flipFactor = params.flipFactor;
  • Обучайтесь Модель задает пользовательский учебный цикл, используемый экспериментом. Пользовательский учебный цикл использует minibatchqueue обработать и управлять мини-пакетами изображений. Для каждого мини-пакета, minibatchqueue объект перемасштабирует изображения в области значений [-1,1], отбрасывает любые частичные мини-пакеты меньше чем с 128 наблюдениями и форматирует данные изображения с размерностью, маркирует 'SSCB' (пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. В течение каждой эпохи пользовательский учебный цикл переставляет datastore и циклы по мини-пакетам данных. Если вы обучаетесь на графическом процессоре, данные преобразованы в gpuArray (Parallel Computing Toolbox) объекты. Затем учебная функция оценивает градиенты модели и обновляет различитель и параметры сети генератора. После каждой итерации пользовательского учебного цикла учебная функция сохраняет обучивший нейронные сети и обновляет процесс обучения.

monitor.Metrics = ["scoreGenerator","scoreDiscriminator","scoreCombined"];
monitor.XLabel = "Iteration";
groupSubPlot(monitor,"Combined Score","scoreCombined");
groupSubPlot(monitor,"Generator and Discriminator Scores", ...
    ["scoreGenerator","scoreDiscriminator"]);
monitor.Status = "Training";
augimdsTrain.MiniBatchSize = miniBatchSize;
mbq = minibatchqueue(augimdsTrain,...
    MiniBatchSize=miniBatchSize,...
    PartialMiniBatch="discard",...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat="SSCB",...
    OutputEnvironment=output.executionEnvironment);
iteration = 0;
for epoch = 1:numEpochs
    shuffle(mbq);
    while hasdata(mbq)
        iteration = iteration + 1;
        dlX = next(mbq);
        Z = randn(numLatentInputs,miniBatchSize,"single");
        dlZ = dlarray(Z,"CB");
        if (output.executionEnvironment == "auto" && canUseGPU) || ...
                output.executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        [gradientsGenerator,gradientsDiscriminator,stateGenerator,scoreGenerator,scoreDiscriminator] = ...
            dlfeval(@modelGradients,output.generator,output.discriminator,dlX,dlZ,flipFactor);
        output.generator.State = stateGenerator;
        [output.discriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(output.discriminator,gradientsDiscriminator, ...
            trailingAvgDiscriminator,trailingAvgSqDiscriminator,iteration, ...
            initialLearnRate,gradientDecayFactor,squaredGradientDecayFactor);
        [output.generator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(output.generator,gradientsGenerator, ...
            trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
            initialLearnRate,gradientDecayFactor,squaredGradientDecayFactor);
        scoreGeneratorValue = ...
            double(gather(extractdata(scoreGenerator)));
        scoreDiscriminatorValue = ...
            double(gather(extractdata(scoreDiscriminator)));
        scoreCombinedValue = 1-2*max(abs(scoreDiscriminatorValue-0.5),abs(scoreGeneratorValue-0.5));
        recordMetrics(monitor,iteration, ...
            scoreGenerator=scoreGeneratorValue, ...
            scoreDiscriminator=scoreDiscriminatorValue, ...
            scoreCombined=scoreCombinedValue);
        if monitor.Stop || isnan(scoreGeneratorValue) || isnan(scoreDiscriminatorValue)
            return;
        end
    end
    monitor.Progress = (epoch/numEpochs)*100;
end

Учебный GANs может быть сложной задачей, потому что генератор и сети различителя конкурируют друг против друга во время обучения. Если одна сеть учится слишком быстро, то другая сеть может не учиться. Чтобы помочь вам диагностировать проблемы и контролировать как хорошо, генератор и различитель достигают их соответствующих целей, этот эксперимент отображает пару баллов в учебном графике. Счет генератора scoreGenerator измеряет вероятность, что различитель может правильно отличить сгенерированные изображения. Счет различителя scoreDiscriminator измеряет вероятность, что различитель может правильно отличить все входные изображения, приняв, что количества действительных и сгенерированных изображений, переданных различителю, равны. В идеальном случае оба баллов 0.5. Баллы, которые слишком близки к нулю или можно указать, что одна сеть доминирует над другим. Смотрите Монитор Процесс обучения GAN и Идентифицируйте Общие Типы отказа.

Чтобы помочь вам решить, какое испытание приводит к лучшим результатам, этот эксперимент комбинирует счет генератора и баллы различителя в одно числовое значение, scoreCombined. Эта метрика использует L-∞ норма, чтобы определить, как близко эти две сети из идеального сценария. Это принимает значение того, если и сетевые баллы равняются 0.5, и нуль, если одни из сетевых баллов равняются нулю или один.

Чтобы смотреть учебную функцию, под Учебной Функцией, нажимают Edit. Учебная функция открывается в Редакторе MATLAB®. Кроме того, код для учебной функции появляется в Приложении 1 в конце этого примера.

Запустите эксперимент

Когда вы запускаете эксперимент, Experiment Manager обучает сеть, заданную учебной функцией многократно. Каждое испытание использует различную комбинацию гиперзначений параметров. По умолчанию Experiment Manager запускает одно испытание за один раз. Если у вас есть Parallel Computing Toolbox, можно запустить несколько испытаний одновременно. Для лучших результатов, прежде чем вы запустите свой эксперимент, начинают параллельный пул со стольких же рабочих сколько графические процессоры. Для получения дополнительной информации смотрите Использование Experiment Manager, чтобы Обучить нейронные сети параллельно.

  • Чтобы запустить один суд над экспериментом за один раз, на панели инструментов Experiment Manager, нажимают Run.

  • Чтобы запустить несколько испытаний одновременно, нажмите Use Parallel и затем Запуск. Если нет никакого текущего параллельного пула, Experiment Manager запускает тот с помощью кластерного профиля по умолчанию. Experiment Manager затем выполняет несколько одновременных испытаний, в зависимости от количества параллельных доступных рабочих.

Таблица результатов показывает учебную потерю и точность валидации для каждого испытания.

В то время как эксперимент запускается, нажмите Training Plot, чтобы отобразить учебный график и отследить прогресс каждого испытания.

Оцените результаты

Чтобы найти лучший результат для вашего эксперимента, отсортируйте таблицу результатов с помощью объединенного счета.

  1. Укажите на scoreCombined столбец.

  2. Кликните по треугольному значку.

  3. Выберите Sort в порядке убывания.

Испытание с самым высоким объединенным счетом появляется наверху таблицы результатов.

Оцените качество GAN путем генерации и осмотра изображений, произведенных обученным генератором.

  1. Выберите испытание с самым высоким объединенным счетом.

  2. На панели инструментов Experiment Manager нажмите Export.

  3. В диалоговом окне введите имя переменной рабочей области для экспортируемого учебного выхода. Именем по умолчанию является trainingOutput.

  4. Протестируйте обученную сеть генератора путем вызова generateTestImages функция, которая перечислена в Приложении 3 в конце этого примера. Используйте экспортируемый учебный выход в качестве входа к функции. Например, в командном окне MATLAB, введите:

generateTestImages(trainingOutput)

Функция создает пакет 25 случайных векторов, чтобы ввести к сети генератора и отображает получившиеся изображения.

Используя объединенный счет, чтобы отсортировать ваши результаты не может идентифицировать лучшее испытание во всех случаях. Для лучших результатов повторите этот процесс для каждого испытания с высоким объединенным счетом, визуально проверяя, что генератор производит множество изображений без многих копий. Если изображения имеют мало разнообразия, и некоторые из них почти идентичны, то ваш генератор, вероятно, затронут коллапсом режима. Для получения дополнительной информации смотрите Коллапс Режима.

Чтобы записать наблюдения о результатах вашего эксперимента, добавьте аннотацию.

  1. В таблице результатов щелкните правой кнопкой по scoreCombined ячейке для лучшего испытания.

  2. Выберите Add Annotation.

  3. В панели Аннотаций введите свои наблюдения в текстовое поле.

Для получения дополнительной информации смотрите сортировку, Фильтр, и Аннотируйте Результаты Эксперимента.

Повторно выполните эксперимент

После того, как вы идентифицируете комбинацию гиперпараметров, которая генерирует лучшие изображения, запустите эксперимент во второй раз, чтобы обучить сеть в течение более длительного промежутка времени.

  1. Возвратитесь к панели определения эксперимента.

  2. В гипертаблице параметров введите гиперзначения параметров от своего лучшего испытания. Например, чтобы использовать значения от испытания 3, измените значение dropoutProb к 0.75 и flipFactor к 0.1.

  3. Откройте учебную функцию и задайте более длительное учебное время. Под Задают Опции обучения, изменяют значение numEpochs к 500.

  4. Запустите эксперимент с помощью новых гиперзначений параметров и учебной функции. Experiment Manager запускает одно испытание. Обучение сопровождает в 10 раз дольше, чем предыдущие испытания.

  5. Когда эксперимент закончится, экспортируйте учебный выход и запустите generateTestImages функционируйте, чтобы протестировать новую сеть генератора. Как прежде, визуально проверяйте, что генератор производит множество изображений без многих копий.

Закройте эксперимент

В панели Браузера Эксперимента щелкните правой кнопкой по имени проекта и выберите Close Project. Experiment Manager закрывает все эксперименты и результаты, содержавшиеся в проекте.

Приложение 1: учебная функция

Эта функция задает обучающие данные, сетевую архитектуру, опции обучения и метод обучения, используемый экспериментом.

Входной параметр

  • params структура с полями от гипертаблицы параметров Experiment Manager.

  • monitor experiments.Monitor возразите, что можно использовать, чтобы отследить прогресс обучения, полей информации об обновлении в таблице результатов, значениях записи метрик, используемых обучением, и произвести учебные графики.

Вывод

  • output структура, которая содержит обученную сеть генератора, обученную сеть различителя и среду выполнения, используемую для обучения. Experiment Manager сохраняет этот выход, таким образом, можно экспортировать его в рабочее пространство MATLAB, когда обучение завершено.

function output = ImageGenerationExperiment_training1(params,monitor)

output.generator = [];
output.discriminator = [];
output.executionEnvironment = "auto";

monitor.Status = "Loading Data";

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    websave(filename,url);
    untar(filename,downloadFolder)
end

datasetFolder = fullfile(imageFolder);
imdsTrain = imageDatastore(datasetFolder, ...
    IncludeSubfolders=true);

augmenter = imageDataAugmenter(RandXReflection=true);
augimdsTrain = augmentedImageDatastore([64 64],imdsTrain, ...
    DataAugmentation=augmenter);

monitor.Status = "Creating Generator";

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;
projectionSize = [4 4 512];

layersGenerator = [
    featureInputLayer(numLatentInputs,Name="in")
    projectAndReshapeLayer(projectionSize,numLatentInputs,Name="proj")
    transposedConv2dLayer(filterSize,4*numFilters,Name="tconv1")
    batchNormalizationLayer(Name="bnorm1")
    reluLayer(Name="relu1")
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same",Name="tconv2")
    batchNormalizationLayer(Name="bnorm2")
    reluLayer(Name="relu2")
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same",Name="tconv3")
    batchNormalizationLayer(Name="bnorm3")
    reluLayer(Name="relu3")
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same",Name="tconv4")
    tanhLayer(Name="tanh")];

lgraphGenerator = layerGraph(layersGenerator);
output.generator = dlnetwork(lgraphGenerator);

monitor.Status = "Creating Discriminator";

filterSize = 5;
numFilters = 64;
inputSize = [64 64 3];
dropoutProb = params.dropoutProb;
scale = 0.2;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none",Name="in")
    dropoutLayer(dropoutProb,Name="dropout")
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same",Name="conv1")
    leakyReluLayer(scale,Name="lrelu1")
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same",Name="conv2")
    batchNormalizationLayer(Name="bn2")
    leakyReluLayer(scale,Name="lrelu2")
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same",Name="conv3")
    batchNormalizationLayer(Name="bn3")
    leakyReluLayer(scale,Name="lrelu3")
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same",Name="conv4")
    batchNormalizationLayer(Name="bn4")
    leakyReluLayer(scale,Name="lrelu4")
    convolution2dLayer(4,1,Name="conv5")];

lgraphDiscriminator = layerGraph(layersDiscriminator);
output.discriminator = dlnetwork(lgraphDiscriminator);

numEpochs = 50;
miniBatchSize = 128;
initialLearnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];
flipFactor = params.flipFactor;

monitor.Metrics = ["scoreGenerator","scoreDiscriminator","scoreCombined"];
monitor.XLabel = "Iteration";
groupSubPlot(monitor,"Combined Score","scoreCombined");
groupSubPlot(monitor,"Generator and Discriminator Scores", ...
    ["scoreGenerator","scoreDiscriminator"]);
monitor.Status = "Training";

augimdsTrain.MiniBatchSize = miniBatchSize;
mbq = minibatchqueue(augimdsTrain,...
    MiniBatchSize=miniBatchSize,...
    PartialMiniBatch="discard",...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat="SSCB",...
    OutputEnvironment=output.executionEnvironment);

iteration = 0;
for epoch = 1:numEpochs
    shuffle(mbq);
    while hasdata(mbq)
        iteration = iteration + 1;
        dlX = next(mbq);
        
        Z = randn(numLatentInputs,miniBatchSize,"single");
        dlZ = dlarray(Z,"CB");
        
        if (output.executionEnvironment == "auto" && canUseGPU) || ...
                output.executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end
        
        [gradientsGenerator,gradientsDiscriminator,stateGenerator,scoreGenerator,scoreDiscriminator] = ...
            dlfeval(@modelGradients,output.generator,output.discriminator,dlX,dlZ,flipFactor);
        output.generator.State = stateGenerator;
        
        [output.discriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(output.discriminator,gradientsDiscriminator, ...
            trailingAvgDiscriminator,trailingAvgSqDiscriminator,iteration, ...
            initialLearnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        [output.generator,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(output.generator,gradientsGenerator, ...
            trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
            initialLearnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        scoreGeneratorValue = ...
            double(gather(extractdata(scoreGenerator)));
        scoreDiscriminatorValue = ...
            double(gather(extractdata(scoreDiscriminator)));
        scoreCombinedValue = 1-2*max(abs(scoreDiscriminatorValue-0.5),abs(scoreGeneratorValue-0.5));
        
        recordMetrics(monitor,iteration, ...
            scoreGenerator=scoreGeneratorValue, ...
            scoreDiscriminator=scoreDiscriminatorValue, ...
            scoreCombined=scoreCombinedValue);
        
        if monitor.Stop || isnan(scoreGeneratorValue) || isnan(scoreDiscriminatorValue)
            return;
        end
    end
    monitor.Progress = (epoch/numEpochs)*100;
end
end

Приложение 2: пользовательские учебные функции помощника

Эта функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные изображения из массива входящей ячейки и конкатенируйте в числовой массив.

  2. Перемасштабируйте изображения, чтобы быть в области значений [-1,1].

function X = preprocessMiniBatch(data)
    X = cat(4,data{:});
    X = rescale(X,-1,1,InputMin=0,InputMax=255);
end

Эта функция берет в качестве входа генератор и различитель dlnetwork объекты (dlnetGenerator и dlnetDiscriminator), мини-пакет входных данных (dlX), массив случайных значений (dlZ), и процент действительных меток, чтобы инвертировать (flipFactor), и возвращает градиенты потери относительно настраиваемых параметров в сетях, состоянии генератора и множестве этих двух сетей. Поскольку различитель выход не находится в области значений [0,1], modelGradients применяет sigmoid функционируйте, чтобы преобразовать это значение в вероятность.

function [gradientsGenerator,gradientsDiscriminator,stateGenerator,scoreGenerator,scoreDiscriminator] = ...
    modelGradients(dlnetGenerator,dlnetDiscriminator,dlX,dlZ,flipFactor)
    dlYPred = forward(dlnetDiscriminator,dlX);
    [dlXGenerated,stateGenerator] = forward(dlnetGenerator,dlZ);
    dlYPredGenerated = forward(dlnetDiscriminator,dlXGenerated);
    probGenerated = sigmoid(dlYPredGenerated);
    probReal = sigmoid(dlYPred);
    scoreDiscriminator = ((mean(probReal)+mean(1-probGenerated))/2);
    scoreGenerator = mean(probGenerated);
    numObservations = size(probReal,4);
    idx = randperm(numObservations,floor(flipFactor*numObservations));
    probReal(:,:,:,idx) = 1-probReal(:,:,:,idx);
    [lossGenerator,lossDiscriminator] = GANLoss(probReal,probGenerated);
    gradientsGenerator = dlgradient(lossGenerator, ...
        dlnetGenerator.Learnables,RetainData=true);
    gradientsDiscriminator = dlgradient(lossDiscriminator, ...
        dlnetDiscriminator.Learnables);
end

Эта функция возвращает потерю для сетей генератора и различителя.

function [lossGenerator,lossDiscriminator] = ...
    GANLoss(probReal,probGenerated)
    lossDiscriminator = -mean(log(probReal))-mean(log(1-probGenerated));
    lossGenerator = -mean(log(probGenerated));
end

Приложение 3: сгенерируйте тестовые изображения

Эта функция создает пакет 25 случайных векторов, чтобы ввести к сети генератора и отображает получившиеся изображения. Используйте эту функцию, чтобы визуально проверять, что генератор производит множество изображений без многих копий. Если изображения имеют мало разнообразия, и некоторые из них почти идентичны, то ваш генератор, вероятно, затронут коллапсом режима.

function generateTestImages(trainingOutput)

dlnetGenerator = trainingOutput.generator;
executionEnvironment = trainingOutput.executionEnvironment;

numLatentInputs = 100;
numTestImages = 25;

ZTest = randn(numLatentInputs,numTestImages,"single");
dlZTest = dlarray(ZTest,"CB");

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZTest = gpuArray(dlZTest);
end

dlXGeneratedTest = predict(dlnetGenerator,dlZTest);

I = imtile(extractdata(dlXGeneratedTest));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

end

Смотрите также

Приложения

Объекты

Похожие темы