Обучите вариационный автоэнкодер (VAE) генерировать изображения

В этом примере показано, как создать вариационный автоэнкодер (VAE) в MATLAB, чтобы сгенерировать изображения цифры. VAE генерирует нарисованные от руки цифры в стиле набора данных MNIST.

VAEs отличаются от обычных автоэнкодеров в этом, они не используют декодирующий кодирование процесс, чтобы восстановить вход. Вместо этого они налагают вероятностное распределение на скрытый пробел и изучают распределение так, чтобы распределение выходных параметров от соответствий декодера те из наблюдаемых данных. Затем они производят от этого распределения, чтобы сгенерировать новые данные.

В этом примере вы создаете сеть VAE, обучаете его на наборе данных MNIST и генерируете новые изображения, которые тесно напоминают тех в наборе данных.

Загрузка данных

Загрузите файлы MNIST с http://yann.lecun.com/exdb/mnist/ и загрузите набор данных MNIST в рабочую область [1]. Извлеките и поместите файлы в рабочую директорию, затем вызовите processImagesMNIST и processLabelsMNIST помощник функционирует присоединенный к этому примеру, чтобы загрузить данные из файлов в массивы MATLAB.

Поскольку VAE сравнивает восстановленные цифры с входными параметрами а не с категориальными метками, вы не должны использовать учебные метки в наборе данных MNIST.

trainImagesFile = 'train-images.idx3-ubyte';
testImagesFile = 't10k-images.idx3-ubyte';
testLabelsFile = 't10k-labels.idx1-ubyte';

XTrain = processImagesMNIST(trainImagesFile);
Read MNIST image data...
Number of images in the dataset:  60000 ...
numTrainImages = size(XTrain,4);
XTest = processImagesMNIST(testImagesFile);
Read MNIST image data...
Number of images in the dataset:  10000 ...
YTest = processLabelsMNIST(testLabelsFile);
Read MNIST label data...
Number of labels in the dataset:  10000 ...

Создайте сеть

Автоэнкодеры имеют две части: энкодер и декодер. Энкодер берет вход изображений и выводит сжатое представление (кодирование), который является вектором размера latent_dim, равняйтесь 20 в этом примере. Декодер берет сжатое представление, декодирует его и воссоздает оригинальное изображение.

Чтобы сделать вычисления более численно устойчивыми, увеличьте область значений возможных значений от [0,1] до [-inf, 0], заставив сеть извлечь уроки из логарифма отклонений. Задайте два вектора размера latent_dim: один для средних значений μ и один для логарифма отклонений log(σ2). Затем используйте эти два вектора, чтобы создать распределение к выборке от.

Используйте 2D свертки, сопровождаемые полносвязным слоем, чтобы проредить от 28 28 1 изображением MNIST к кодированию на скрытом пробеле. Затем используйте транспонированные 2D свертки, чтобы увеличить 1 1 20 кодированием назад в 28 28 1 изображением.

latentDim = 20;
imageSize = [28 28 1];

encoderLG = layerGraph([
    imageInputLayer(imageSize,'Name','input_encoder','Normalization','none')
    convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1')
    reluLayer('Name','relu1')
    convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder')
    ]);

decoderLG = layerGraph([
    imageInputLayer([1 1 latentDim],'Name','i','Normalization','none')
    transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4')
    ]);

Чтобы обучить обе сети с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте графики слоев в dlnetwork объекты.

encoderNet = dlnetwork(encoderLG);
decoderNet = dlnetwork(decoderLG);

Функция градиентов модели Define

Функция помощника modelGradients берет в энкодере и декодере dlnetwork объекты и мини-пакет входных данных X, и возвращает градиенты потери относительно настраиваемых параметров в сетях. Эта функция помощника задана в конце этого примера.

Функция выполняет этот процесс на двух шагах: выборка и потеря. Шаг выборки производит среднее значение и векторы отклонения, чтобы создать кодирование финала, которое будет передано сети декодера. Однако, потому что обратная связь посредством случайной операции выборки не возможна, необходимо использовать прием репараметризации. Этот прием перемещает случайную операцию выборки во вспомогательную переменную ε, который затем смещен средним значением μi и масштабируемый стандартным отклонением σi. Идея является той выборкой от N(μi,σi2) совпадает с выборкой от μi+εσi, где εN(0,1). Следующая фигура изображает эту идею графически.

Шаг потерь передает кодирование, сгенерированное шагом выборки через сеть декодера, и определяет потерю, которая затем используется для расчета градиенты. Потеря в VAEs, также названном нижней границей доказательства (ELBO) потеря, задана как сумма двух отдельных условий потерь:

ELBOloss=reconstructionloss+KLloss.

Потеря реконструкции измеряется, как близко декодер выход к исходному входу при помощи среднеквадратической ошибки (MSE):

reconstructionloss=MSE(decoderoutput,originalimage).

Потеря KL или расхождение Kullback–Leibler, измеряет различие между двумя вероятностными распределениями. Минимизация потери KL в этом, случай означает гарантировать, что изученные средние значения и отклонения максимально близки к тем из целевого (нормального) распределения. Для скрытой размерности размера n, потеря KL получена как

KLloss=-0.5i=1n(1+log(σi)-μi2-σi2).

Практический эффект включения термина потерь KL состоит в том, чтобы упаковать кластеры, изученные из-за потери реконструкции плотно вокруг центра скрытого пробела, формируя непрерывный пробел к выборке от.

Задайте опции обучения

Обучайтесь на графическом процессоре (требует Parallel Computing Toolbox™). Если у вас нет графического процессора, установите executionEnvironment к "cpu".

executionEnvironment = "auto";

Установите опции обучения для сети. При использовании оптимизатора Адама необходимо инициализировать для каждой сети запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с пустым arrays.

numEpochs = 50;
miniBatchSize = 512;
lr = 1e-3;
numIterations = floor(numTrainImages/miniBatchSize);
iteration = 0;

avgGradientsEncoder = [];
avgGradientsSquaredEncoder = [];
avgGradientsDecoder = [];
avgGradientsSquaredDecoder = [];

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

Для каждой итерации в эпоху:

  • Получите следующий мини-пакет из набора обучающих данных.

  • Преобразуйте мини-пакет в dlarray объект, убеждаясь задавать размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Для обучения графического процессора преобразуйте dlarray к gpuArray объект.

  • Оцените градиенты модели с помощью dlfeval and modelGradients функции.

  • Обновите сеть learnables и средние градиенты для обеих сетей, с помощью adamupdate функция.

В конце каждой эпохи передайте изображения набора тестов через автоэнкодер и отобразите потерю и учебное время в течение той эпохи.

for epoch = 1:numEpochs
    tic;
    for i = 1:numIterations
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        XBatch = XTrain(:,:,:,idx);
        XBatch = dlarray(single(XBatch), 'SSCB');
        
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            XBatch = gpuArray(XBatch);           
        end 
            
        [infGrad, genGrad] = dlfeval(...
            @modelGradients, encoderNet, decoderNet, XBatch);
        
        [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ...
            adamupdate(decoderNet.Learnables, ...
                genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr);
        [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ...
            adamupdate(encoderNet.Learnables, ...
                infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr);
    end
    elapsedTime = toc;
    
    [z, zMean, zLogvar] = sampling(encoderNet, XTest);
    xPred = sigmoid(forward(decoderNet, z));
    elbo = ELBOloss(XTest, xPred, zMean, zLogvar);
    disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+...
        ". Time taken for epoch = "+ elapsedTime + "s")    
end
Epoch : 1 Test ELBO loss = 27.0561. Time taken for epoch = 33.0037s
Epoch : 2 Test ELBO loss = 24.414. Time taken for epoch = 32.4167s
Epoch : 3 Test ELBO loss = 23.0166. Time taken for epoch = 32.3244s
Epoch : 4 Test ELBO loss = 20.9078. Time taken for epoch = 32.1268s
Epoch : 5 Test ELBO loss = 20.6519. Time taken for epoch = 32.3451s
Epoch : 6 Test ELBO loss = 20.3201. Time taken for epoch = 32.4371s
Epoch : 7 Test ELBO loss = 19.9266. Time taken for epoch = 32.4551s
Epoch : 8 Test ELBO loss = 19.8448. Time taken for epoch = 32.9919s
Epoch : 9 Test ELBO loss = 19.7485. Time taken for epoch = 33.1783s
Epoch : 10 Test ELBO loss = 19.6295. Time taken for epoch = 33.1623s
Epoch : 11 Test ELBO loss = 19.539. Time taken for epoch = 32.4781s
Epoch : 12 Test ELBO loss = 19.4682. Time taken for epoch = 32.5094s
Epoch : 13 Test ELBO loss = 19.3577. Time taken for epoch = 32.5996s
Epoch : 14 Test ELBO loss = 19.3247. Time taken for epoch = 32.6447s
Epoch : 15 Test ELBO loss = 19.3043. Time taken for epoch = 32.2494s
Epoch : 16 Test ELBO loss = 19.2948. Time taken for epoch = 32.5408s
Epoch : 17 Test ELBO loss = 19.191. Time taken for epoch = 32.8177s
Epoch : 18 Test ELBO loss = 19.1075. Time taken for epoch = 32.5982s
Epoch : 19 Test ELBO loss = 19.0606. Time taken for epoch = 33.7771s
Epoch : 20 Test ELBO loss = 19.0298. Time taken for epoch = 33.6249s
Epoch : 21 Test ELBO loss = 19.0534. Time taken for epoch = 33.4906s
Epoch : 22 Test ELBO loss = 18.9859. Time taken for epoch = 33.1101s
Epoch : 23 Test ELBO loss = 19.0077. Time taken for epoch = 32.7345s
Epoch : 24 Test ELBO loss = 18.9963. Time taken for epoch = 33.0067s
Epoch : 25 Test ELBO loss = 18.9189. Time taken for epoch = 32.891s
Epoch : 26 Test ELBO loss = 18.8925. Time taken for epoch = 33.0905s
Epoch : 27 Test ELBO loss = 18.9182. Time taken for epoch = 32.6203s
Epoch : 28 Test ELBO loss = 18.8664. Time taken for epoch = 32.4095s
Epoch : 29 Test ELBO loss = 18.8512. Time taken for epoch = 32.4317s
Epoch : 30 Test ELBO loss = 18.7983. Time taken for epoch = 32.4s
Epoch : 31 Test ELBO loss = 18.7971. Time taken for epoch = 32.4902s
Epoch : 32 Test ELBO loss = 18.7888. Time taken for epoch = 32.2591s
Epoch : 33 Test ELBO loss = 18.7811. Time taken for epoch = 32.4291s
Epoch : 34 Test ELBO loss = 18.7804. Time taken for epoch = 32.5968s
Epoch : 35 Test ELBO loss = 18.7839. Time taken for epoch = 32.3787s
Epoch : 36 Test ELBO loss = 18.7045. Time taken for epoch = 32.6078s
Epoch : 37 Test ELBO loss = 18.7783. Time taken for epoch = 32.6429s
Epoch : 38 Test ELBO loss = 18.7068. Time taken for epoch = 32.7032s
Epoch : 39 Test ELBO loss = 18.6822. Time taken for epoch = 32.3438s
Epoch : 40 Test ELBO loss = 18.7155. Time taken for epoch = 32.6521s
Epoch : 41 Test ELBO loss = 18.7161. Time taken for epoch = 32.5532s
Epoch : 42 Test ELBO loss = 18.6597. Time taken for epoch = 32.6419s
Epoch : 43 Test ELBO loss = 18.6657. Time taken for epoch = 32.4558s
Epoch : 44 Test ELBO loss = 18.5996. Time taken for epoch = 32.5503s
Epoch : 45 Test ELBO loss = 18.6666. Time taken for epoch = 32.5503s
Epoch : 46 Test ELBO loss = 18.6449. Time taken for epoch = 32.2981s
Epoch : 47 Test ELBO loss = 18.6107. Time taken for epoch = 32.3152s
Epoch : 48 Test ELBO loss = 18.6393. Time taken for epoch = 32.7135s
Epoch : 49 Test ELBO loss = 18.6351. Time taken for epoch = 32.3859s
Epoch : 50 Test ELBO loss = 18.5955. Time taken for epoch = 32.6549s

Визуализация результатов

Чтобы визуализировать и интерпретировать результаты, используйте функции Визуализации помощника. Эти функции помощника заданы в конце этого примера.

VisualizeReconstruction функция показывает случайным образом выбранную цифру от каждого класса, сопровождаемого его реконструкцией после прохождения через автоэнкодер.

VisualizeLatentSpace функционируйте берет среднее значение и кодировку отклонения (каждая размерность 20) сгенерированный после передачи тестовых изображений через сеть энкодера, и выполняет анализ главных компонентов (PCA) матрицы, содержащей кодировку для каждого из изображений. Можно затем визуализировать скрытый пробел, заданный средними значениями и отклонениями в этих двух размерностях, охарактеризованных двумя первыми основными компонентами.

Generate функция инициализирует новую кодировку, произведенную от нормального распределения, и выводит изображения, сгенерированные, когда эта кодировка проходит через сеть декодера.

visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)

visualizeLatentSpace(XTest, YTest, encoderNet)

generate(decoderNet, latentDim)

Следующие шаги

Вариационные автоэнкодеры являются только одной из многих доступных моделей, используемых, чтобы выполнить порождающие задачи. Они работают хорошо над наборами данных, где изображения малы и ясно задали функции (такие как MNIST). Для наборов более комплексных данных с увеличенными изображениями порождающие соперничающие сети (GANs) имеют тенденцию выполнять лучше и генерировать изображения с меньшим количеством шума. Для примера, показывающего, как реализовать GANs, чтобы сгенерировать 64 64 изображения RGB, смотрите, Обучают Порождающую соперничающую сеть (GAN).

Ссылки

  1. LeCun, Y., К. Кортес и К. Дж. К. Берджес. "База данных MNIST Рукописных Цифр". http://yann.lecun.com/exdb/mnist/.

Функции помощника

Функция градиентов модели

modelGradients функционируйте берет энкодер и декодер dlnetwork объекты и мини-пакет входных данных X, и возвращает градиенты потери относительно настраиваемых параметров в сетях. Функция выполняет три операции:

  1. Получите кодировку путем вызова sampling функция на мини-пакете изображений, который проходит через сеть энкодера.

  2. Получите потерю путем передачи кодировки через сеть декодера и вызова ELBOloss функция.

  3. Вычислите градиенты потери относительно настраиваемых параметров обеих сетей путем вызова dlgradient функция.

function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x)
[z, zMean, zLogvar] = sampling(encoderNet, x);
xPred = sigmoid(forward(decoderNet, z));
loss = ELBOloss(x, xPred, zMean, zLogvar);
[genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ...
    encoderNet.Learnables);
end

Выборка и функции потерь

sampling функция получает кодировку из входных изображений. Первоначально, это передает мини-пакет изображений через сеть энкодера и разделяет выход размера (2*latentDim)*miniBatchSize в матрицу средних значений и матрицу отклонений, каждый размер latentDim*batchSize. Затем это использует эти матрицы, чтобы реализовать прием репараметризации и вычислить кодирование. Наконец, это преобразует это кодирование в dlarray объект в формате SSCB.

function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);

sz = size(zMean);
epsilon = randn(sz);
sigma = exp(.5 * zLogvar);
z = epsilon .* sigma + zMean;
z = reshape(z, [1,1,sz]);
zSampled = dlarray(z, 'SSCB');
end

ELBOloss функционируйте берет кодировку средних значений и дисперсий, возвращенных sampling функция, и использует их, чтобы вычислить потерю ELBO.

function elbo = ELBOloss(x, xPred, zMean, zLogvar)
squares = 0.5*(xPred-x).^2;
reconstructionLoss  = sum(squares, [1,2,3]);

KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);

elbo = mean(reconstructionLoss + KL);
end

Функции визуализации

VisualizeReconstruction функция случайным образом выбирает два изображения для каждой цифры набора данных MNIST, передает их через VAE и строит реконструкцию бок о бок с исходным входом. Обратите внимание на то, что построить информацию, содержавшую в dlarray объект, необходимо извлечь его сначала использование extractdata and gather функции.

function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet)
f = figure;
figure(f)
title("Example ground truth image vs. reconstructed image")
for i = 1:2
    for c=0:9
        idx = iRandomIdxOfClass(YTest,c);
        X = XTest(:,:,:,idx);

        [z, ~, ~] = sampling(encoderNet, X);
        XPred = sigmoid(forward(decoderNet, z));
        
        X = gather(extractdata(X));
        XPred = gather(extractdata(XPred));

        comparison = [X, ones(size(X,1),1), XPred];
        subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),
    end
end
end

function idx = iRandomIdxOfClass(T,c)
idx = T == categorical(c);
idx = find(idx);
idx = idx(randi(numel(idx),1));
end

VisualizeLatentSpace функция визуализирует скрытый пробел, заданный средним значением и матрицами отклонения, которые формируют выход сети энкодера, и определяет местоположение кластеров, сформированных скрытыми представлениями пробела каждой цифры.

Функция запускается путем извлечения среднего значения и матриц отклонения от dlarray объекты. Поскольку перемещение матрицы с размерностями канала/пакета (C и B) не возможно, вызовы функции stripdims прежде, чем транспонировать матрицы. Затем это выполняет анализ главных компонентов (PCA) обеих матриц. Чтобы визуализировать скрытый пробел в двух измерениях, функция сохраняет первые два основных компонента и строит их друг против друга. Наконец, функция окрашивает классы цифры так, чтобы можно было наблюдать кластеры.

function visualizeLatentSpace(XTest, YTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);

zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));

zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));

[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);

c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")

ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);

ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
cb = colorbar;  cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);
axis equal
end

generate функционируйте тестирует порождающие возможности VAE. Это инициализирует dlarray объект, содержащий 25 случайным образом сгенерированных кодировок, передает их через сеть декодера и строит выходные параметры.

function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);

f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated samples of digits")
drawnow
end

Смотрите также

| | | | | |

Похожие темы