Обучите вариационный автоэнкодер (VAE) генерировать изображения

В этом примере показано, как создать вариационный автоэнкодер (VAE) в MATLAB, чтобы сгенерировать изображения цифры. VAE генерирует нарисованные от руки цифры в стиле набора данных MNIST.

VAEs отличаются от обычных автоэнкодеров в этом, они не используют декодирующий кодирование процесс, чтобы восстановить вход. Вместо этого они налагают вероятностное распределение на скрытый пробел и изучают распределение так, чтобы распределение выходных параметров от соответствий декодера те из наблюдаемых данных. Затем они производят от этого распределения, чтобы сгенерировать новые данные.

В этом примере вы создаете сеть VAE, обучаете его на наборе данных MNIST и генерируете новые изображения, которые тесно напоминают тех в наборе данных.

Загрузка данных

Загрузите файлы MNIST с http://yann.lecun.com/exdb/mnist/ и загрузите набор данных MNIST в рабочую область [1]. Вызовите processImagesMNIST и processLabelsMNIST помощник функционирует присоединенный к этому примеру, чтобы загрузить данные из файлов в массивы MATLAB.

Поскольку VAE сравнивает восстановленные цифры с входными параметрами а не с категориальными метками, вы не должны использовать учебные метки в наборе данных MNIST.

trainImagesFile = 'train-images-idx3-ubyte.gz';
testImagesFile = 't10k-images-idx3-ubyte.gz';
testLabelsFile = 't10k-labels-idx1-ubyte.gz';

XTrain = processImagesMNIST(trainImagesFile);
Read MNIST image data...
Number of images in the dataset:  60000 ...
numTrainImages = size(XTrain,4);
XTest = processImagesMNIST(testImagesFile);
Read MNIST image data...
Number of images in the dataset:  10000 ...
YTest = processLabelsMNIST(testLabelsFile);
Read MNIST label data...
Number of labels in the dataset:  10000 ...

Создайте сеть

Автоэнкодеры имеют две части: энкодер и декодер. Энкодер берет вход изображений и выводит сжатое представление (кодирование), который является вектором из размера latentDim, равняйтесь 20 в этом примере. Декодер берет сжатое представление, декодирует его и воссоздает оригинальное изображение.

Чтобы сделать вычисления более численно устойчивыми, увеличьте область значений возможных значений от [0,1] до [-inf, 0], заставив сеть извлечь уроки из логарифма отклонений. Задайте два вектора из размера latent_dim: один для средних значений μ и один для логарифма отклонений log(σ2). Затем используйте эти два вектора, чтобы создать распределение к выборке от.

Используйте 2D свертки, сопровождаемые полносвязным слоем, чтобы проредить от 28 28 1 изображением MNIST к кодированию на скрытом пробеле. Затем используйте транспонированные 2D свертки, чтобы увеличить 1 1 20 кодированием назад в 28 28 1 изображением.

latentDim = 20;
imageSize = [28 28 1];

encoderLG = layerGraph([
    imageInputLayer(imageSize,'Name','input_encoder','Normalization','none')
    convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1')
    reluLayer('Name','relu1')
    convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder')
    ]);

decoderLG = layerGraph([
    imageInputLayer([1 1 latentDim],'Name','i','Normalization','none')
    transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4')
    ]);

Чтобы обучить обе сети с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте графики слоев в dlnetwork объекты.

encoderNet = dlnetwork(encoderLG);
decoderNet = dlnetwork(decoderLG);

Функция градиентов модели Define

Функция помощника modelGradients берет в энкодере и декодере dlnetwork объекты и мини-пакет входных данных X, и возвращает градиенты потери относительно настраиваемых параметров в сетях. Эта функция помощника задана в конце этого примера.

Функция выполняет этот процесс на двух шагах: выборка и потеря. Шаг выборки производит среднее значение и векторы отклонения, чтобы создать кодирование финала, которое будет передано сети декодера. Однако, потому что обратная связь посредством случайной операции выборки не возможна, необходимо использовать прием репараметризации. Этот прием перемещает случайную операцию выборки во вспомогательную переменную ε, который затем смещен средним значением μi и масштабируемый стандартным отклонением σi. Идея является той выборкой от N(μi,σi2) совпадает с выборкой от μi+εσi, где εN(0,1). Следующая фигура изображает эту идею графически.

Шаг потерь передает кодирование, сгенерированное шагом выборки через сеть декодера, и определяет потерю, которая затем используется для расчета градиенты. Потеря в VAEs, также названном нижней границей доказательства (ELBO) потеря, задана как сумма двух отдельных условий потерь:

ELBOloss=reconstructionloss+KLloss.

Потеря реконструкции измеряется, как близко декодер выход к исходному входу при помощи среднеквадратической ошибки (MSE):

reconstructionloss=MSE(decoderoutput,originalimage).

Потеря KL или расхождение Kullback–Leibler, измеряет различие между двумя вероятностными распределениями. Минимизация потери KL в этом, случай означает гарантировать, что изученные средние значения и отклонения максимально близки к тем из целевого (нормального) распределения. Для скрытой размерности размера n, потеря KL получена как

KLloss=-0.5i=1n(1+log(σi2)-μi2-σi2).

Практический эффект включения термина потерь KL состоит в том, чтобы упаковать кластеры, изученные из-за потери реконструкции плотно вокруг центра скрытого пробела, формируя непрерывный пробел к выборке от.

Задайте опции обучения

Обучайтесь на графическом процессоре, если вы доступны (требует Parallel Computing Toolbox™).

executionEnvironment = "auto";

Установите опции обучения для сети. При использовании оптимизатора Адама необходимо инициализировать для каждой сети запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с пустым arrays.

numEpochs = 50;
miniBatchSize = 512;
lr = 1e-3;
numIterations = floor(numTrainImages/miniBatchSize);
iteration = 0;

avgGradientsEncoder = [];
avgGradientsSquaredEncoder = [];
avgGradientsDecoder = [];
avgGradientsSquaredDecoder = [];

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

Для каждой итерации в эпоху:

  • Получите следующий мини-пакет из набора обучающих данных.

  • Преобразуйте мини-пакет в dlarray объект, убеждаясь задавать размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Для обучения графического процессора преобразуйте dlarray к gpuArray объект.

  • Оцените градиенты модели с помощью dlfeval and modelGradients функции.

  • Обновите сеть learnables и средние градиенты для обеих сетей, с помощью adamupdate функция.

В конце каждой эпохи передайте изображения набора тестов через автоэнкодер и отобразите потерю и учебное время в течение той эпохи.

for epoch = 1:numEpochs
    tic;
    for i = 1:numIterations
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        XBatch = XTrain(:,:,:,idx);
        XBatch = dlarray(single(XBatch), 'SSCB');
        
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            XBatch = gpuArray(XBatch);           
        end 
            
        [infGrad, genGrad] = dlfeval(...
            @modelGradients, encoderNet, decoderNet, XBatch);
        
        [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ...
            adamupdate(decoderNet.Learnables, ...
                genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr);
        [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ...
            adamupdate(encoderNet.Learnables, ...
                infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr);
    end
    elapsedTime = toc;
    
    [z, zMean, zLogvar] = sampling(encoderNet, XTest);
    xPred = sigmoid(forward(decoderNet, z));
    elbo = ELBOloss(XTest, xPred, zMean, zLogvar);
    disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+...
        ". Time taken for epoch = "+ elapsedTime + "s")    
end
Epoch : 1 Test ELBO loss = 28.0145. Time taken for epoch = 28.0573s
Epoch : 2 Test ELBO loss = 24.8995. Time taken for epoch = 8.797s
Epoch : 3 Test ELBO loss = 23.2756. Time taken for epoch = 8.8824s
Epoch : 4 Test ELBO loss = 21.151. Time taken for epoch = 8.5979s
Epoch : 5 Test ELBO loss = 20.5335. Time taken for epoch = 8.8472s
Epoch : 6 Test ELBO loss = 20.232. Time taken for epoch = 8.5068s
Epoch : 7 Test ELBO loss = 19.9988. Time taken for epoch = 8.4356s
Epoch : 8 Test ELBO loss = 19.8955. Time taken for epoch = 8.4015s
Epoch : 9 Test ELBO loss = 19.7991. Time taken for epoch = 8.8089s
Epoch : 10 Test ELBO loss = 19.6773. Time taken for epoch = 8.4269s
Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s
Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s
Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s
Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s
Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s
Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s
Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s
Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s
Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s
Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s
Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s
Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s
Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s
Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s
Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s
Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s
Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s
Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s
Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s
Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s
Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s
Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s
Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s
Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s
Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s
Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s
Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s
Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s
Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s
Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s
Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s
Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s
Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s
Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s
Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s
Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s
Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s
Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s
Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s
Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s

Визуализация результатов

Чтобы визуализировать и интерпретировать результаты, используйте функции Визуализации помощника. Эти функции помощника заданы в конце этого примера.

VisualizeReconstruction функция показывает случайным образом выбранную цифру от каждого класса, сопровождаемого его реконструкцией после прохождения через автоэнкодер.

VisualizeLatentSpace функционируйте берет среднее значение и кодировку отклонения (каждая размерность 20) сгенерированный после передачи тестовых изображений через сеть энкодера, и выполняет анализ главных компонентов (PCA) матрицы, содержащей кодировку для каждого из изображений. Можно затем визуализировать скрытый пробел, заданный средними значениями и отклонениями в этих двух размерностях, охарактеризованных двумя первыми основными компонентами.

Generate функция инициализирует новую кодировку, произведенную от нормального распределения, и выводит изображения, сгенерированные, когда эта кодировка проходит через сеть декодера.

visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)

visualizeLatentSpace(XTest, YTest, encoderNet)

generate(decoderNet, latentDim)

Следующие шаги

Вариационные автоэнкодеры являются только одной из многих доступных моделей, используемых, чтобы выполнить порождающие задачи. Они работают хорошо над наборами данных, где изображения малы и ясно задали функции (такие как MNIST). Для наборов более комплексных данных с увеличенными изображениями порождающие соперничающие сети (GANs) имеют тенденцию выполнять лучше и генерировать изображения с меньшим количеством шума. Для примера, показывающего, как реализовать GANs, чтобы сгенерировать 64 64 изображения RGB, смотрите, Обучают Порождающую соперничающую сеть (GAN).

Ссылки

  1. LeCun, Y., К. Кортес и К. Дж. К. Берджес. "База данных MNIST Рукописных Цифр". http://yann.lecun.com/exdb/mnist/.

Функции помощника

Функция градиентов модели

modelGradients функционируйте берет энкодер и декодер dlnetwork объекты и мини-пакет входных данных X, и возвращает градиенты потери относительно настраиваемых параметров в сетях. Функция выполняет три операции:

  1. Получите кодировку путем вызова sampling функция на мини-пакете изображений, который проходит через сеть энкодера.

  2. Получите потерю путем передачи кодировки через сеть декодера и вызова ELBOloss функция.

  3. Вычислите градиенты потери относительно настраиваемых параметров обеих сетей путем вызова dlgradient функция.

function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x)
[z, zMean, zLogvar] = sampling(encoderNet, x);
xPred = sigmoid(forward(decoderNet, z));
loss = ELBOloss(x, xPred, zMean, zLogvar);
[genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ...
    encoderNet.Learnables);
end

Выборка и функции потерь

sampling функция получает кодировку из входных изображений. Первоначально, это передает мини-пакет изображений через сеть энкодера и разделяет выход размера (2*latentDim)*miniBatchSize в матрицу средних значений и матрицу отклонений, каждый размер latentDim*batchSize. Затем это использует эти матрицы, чтобы реализовать прием репараметризации и вычислить кодирование. Наконец, это преобразует это кодирование в dlarray объект в формате SSCB.

function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);

sz = size(zMean);
epsilon = randn(sz);
sigma = exp(.5 * zLogvar);
z = epsilon .* sigma + zMean;
z = reshape(z, [1,1,sz]);
zSampled = dlarray(z, 'SSCB');
end

ELBOloss функционируйте берет кодировку средних значений и дисперсий, возвращенных sampling функция, и использует их, чтобы вычислить потерю ELBO.

function elbo = ELBOloss(x, xPred, zMean, zLogvar)
squares = 0.5*(xPred-x).^2;
reconstructionLoss  = sum(squares, [1,2,3]);

KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);

elbo = mean(reconstructionLoss + KL);
end

Функции визуализации

VisualizeReconstruction функция случайным образом выбирает два изображения для каждой цифры набора данных MNIST, передает их через VAE и строит реконструкцию бок о бок с исходным входом. Обратите внимание на то, что построить информацию, содержавшую в dlarray объект, необходимо извлечь его сначала использование extractdata и gather функции.

function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet)
f = figure;
figure(f)
title("Example ground truth image vs. reconstructed image")
for i = 1:2
    for c=0:9
        idx = iRandomIdxOfClass(YTest,c);
        X = XTest(:,:,:,idx);

        [z, ~, ~] = sampling(encoderNet, X);
        XPred = sigmoid(forward(decoderNet, z));
        
        X = gather(extractdata(X));
        XPred = gather(extractdata(XPred));

        comparison = [X, ones(size(X,1),1), XPred];
        subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),
    end
end
end

function idx = iRandomIdxOfClass(T,c)
idx = T == categorical(c);
idx = find(idx);
idx = idx(randi(numel(idx),1));
end

VisualizeLatentSpace функция визуализирует скрытый пробел, заданный средним значением и матрицами отклонения, которые формируют выход сети энкодера, и определяет местоположение кластеров, сформированных скрытыми представлениями пробела каждой цифры.

Функция запускается путем извлечения среднего значения и матриц отклонения от dlarray объекты. Поскольку перемещение матрицы с размерностями канала/пакета (C и B) не возможно, вызовы функции stripdims прежде, чем транспонировать матрицы. Затем это выполняет анализ главных компонентов (PCA) обеих матриц. Чтобы визуализировать скрытый пробел в двух измерениях, функция сохраняет первые два основных компонента и строит их друг против друга. Наконец, функция окрашивает классы цифры так, чтобы можно было наблюдать кластеры.

function visualizeLatentSpace(XTest, YTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);

zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));

zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));

[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);

c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")

ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);

ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
cb = colorbar;  cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);
axis equal
end

generate функционируйте тестирует порождающие возможности VAE. Это инициализирует dlarray объект, содержащий 25 случайным образом сгенерированных кодировок, передает их через сеть декодера и строит выходные параметры.

function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);

f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated samples of digits")
drawnow
end

Смотрите также

| | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте