Обучите вариационный автоэнкодер (VAE) генерировать изображения

В этом примере показано, как создать вариационный автоэнкодер (VAE) в MATLAB, чтобы сгенерировать изображения цифр. VAE генерирует нарисованные вручную цифры в стиле набора данных MNIST.

VAE отличаются от обычных автоэнкодеров тем, что они не используют процесс декодирования-кодирования для восстановления входа. Вместо этого они накладывают распределение вероятностей на скрытое пространство и изучают распределение так, чтобы распределение выходов от декодера совпадало с распределением наблюдаемых данных. Затем они получают выборку из этого распределения, чтобы сгенерировать новые данные.

В этом примере вы создаете сеть VAE, обучаете ее на наборе данных MNIST и генерируете новые изображения, которые тесно напоминают изображения в наборе данных.

Загрузка данных

Загрузите файлы MNIST из http://yann.lecun.com/exdb/mnist/ и загрузите набор данных MNIST в рабочую область [1]. Вызовите processImagesMNIST и processLabelsMNIST вспомогательные функции, присоединенные к этому примеру, чтобы загрузить данные из файлов в массивы MATLAB.

Поскольку VAE сравнивает восстановленные цифры с входами, а не с категориальными метками, вам не нужно использовать обучающие метки в наборе данных MNIST.

trainImagesFile = 'train-images-idx3-ubyte.gz';
testImagesFile = 't10k-images-idx3-ubyte.gz';
testLabelsFile = 't10k-labels-idx1-ubyte.gz';

XTrain = processImagesMNIST(trainImagesFile);
Read MNIST image data...
Number of images in the dataset:  60000 ...
numTrainImages = size(XTrain,4);
XTest = processImagesMNIST(testImagesFile);
Read MNIST image data...
Number of images in the dataset:  10000 ...
YTest = processLabelsMNIST(testLabelsFile);
Read MNIST label data...
Number of labels in the dataset:  10000 ...

Конструкция сети

Автоэнкодеры имеют две части: энкодер и декодер. Энкодер принимает изображение на вход и выводит сжатое представление (кодирование), которое является вектором размера latentDim, равное 20 в этом примере. Декодер принимает сжатое представление, декодирует его и воссоздает оригинальное изображение.

Чтобы сделать вычисления более численно стабильными, увеличьте область значений возможных значений с [0,1] до [-inf, 0], заставив сеть учиться на логарифме отклонений. Задайте два вектора размера latent_dim: один для средств μ и один для логарифма отклонений log(σ2). Затем используйте эти два вектора, чтобы создать распределение для выборки из.

Используйте 2-е скручивания, сопровождаемые полносвязным слоем, чтобы субдискретизировать от 28 28 1 изображением MNIST к кодированию в скрытом пространстве. Затем используйте перемещенные 2-е скручивания, чтобы расширить кодирование 1 на 1 на 20 назад в изображение 28 на 28 на 1.

latentDim = 20;
imageSize = [28 28 1];

encoderLG = layerGraph([
    imageInputLayer(imageSize,'Name','input_encoder','Normalization','none')
    convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1')
    reluLayer('Name','relu1')
    convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder')
    ]);

decoderLG = layerGraph([
    imageInputLayer([1 1 latentDim],'Name','i','Normalization','none')
    transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1')
    reluLayer('Name','relu1')
    transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2')
    reluLayer('Name','relu2')
    transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3')
    reluLayer('Name','relu3')
    transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4')
    ]);

Чтобы обучить обе сети с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте графики слоев в dlnetwork объекты.

encoderNet = dlnetwork(encoderLG);
decoderNet = dlnetwork(decoderLG);

Задайте функцию градиентов модели

Вспомогательная функция modelGradients принимает в энкодере и декодере dlnetwork объекты и мини-пакет входных данных X, и возвращает градиенты потерь относительно настраиваемых параметров в сетях. Эта вспомогательная функция определяется в конце этого примера.

Функция выполняет этот процесс в два этапа: дискретизация и потеря. Этап дискретизации дискретизирует среднее значение и векторы отклонения, чтобы создать окончательное кодирование, которое должно быть передано в сеть декодера. Однако, поскольку обратное распространение через операцию случайной выборки невозможно, необходимо использовать трюк репараметризации. Этот трюк перемещает операцию случайной выборки к вспомогательной переменной ε, который затем сдвигается средним μi и масштабируется стандартным отклонением σi. Идея в том, что выборка из N(μi,σi2) является тем же самым, что и выборка из μi+εσi, где εN(0,1). Следующий рисунок изображает эту идею графически.

Этап потерь пропускает кодирование, сгенерированное шагом дискретизации, через сеть декодера и определяет потерю, которая затем используется для вычисления градиентов. Потеря в VAE, также названная доказательством нижней границы (ELBO), определяется как сумма двух отдельных терминов потерь:

ELBOloss=reconstructionloss+KLloss.

Потеря реконструкции измеряет, насколько близок выход декодера к исходному входу, используя среднюю квадратную ошибку (MSE):

reconstructionloss=MSE(decoderoutput,originalimage).

KL loss, или Kullback-Leibler divergence, измеряет различие между двумя распределениями вероятностей. Минимизация потерь KL в этом случае означает, что выученные средства и отклонения максимально близки к таковым целевого (нормального) распределения. Для скрытой размерности размера n, KL потеря получена как

KLloss=-0.5i=1n(1+log(σi2)-μi2-σi2).

Практический эффект включения термина потерь KL состоит в том, чтобы упаковать кластеры, выученные из-за потерь восстановления, плотно вокруг центра скрытого пространства, образуя непрерывное пространство для выборки.

Настройка опций обучения

Обучите на графическом процессоре, если он доступен (требуется Parallel Computing Toolbox™).

executionEnvironment = "auto";

Установите опции обучения для сети. При использовании оптимизатора Адама необходимо инициализировать для каждой сети конечный средний градиент и конечный средний коэффициент распада градиента-квадрата с пустыми массивами .

numEpochs = 50;
miniBatchSize = 512;
lr = 1e-3;
numIterations = floor(numTrainImages/miniBatchSize);
iteration = 0;

avgGradientsEncoder = [];
avgGradientsSquaredEncoder = [];
avgGradientsDecoder = [];
avgGradientsSquaredDecoder = [];

Обучите модель

Обучите модель с помощью пользовательского цикла обучения.

Для каждой итерации в эпоху:

  • Получите следующий мини-пакет из набора обучающих данных.

  • Преобразуйте мини-пакет в dlarray объект, следя за тем, чтобы задать метки размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Для обучения графический процессор преобразуйте dlarray в gpuArray объект.

  • Оцените градиенты модели с помощью dlfeval и modelGradientsфункций.

  • Обновляйте обучаемые возможности сети и средние градиенты для обеих сетей, используя adamupdate функция.

В конце каждой эпохи передайте изображения тестового набора через автоэнкодер и отобразите потерю и время обучения для этой эпохи.

for epoch = 1:numEpochs
    tic;
    for i = 1:numIterations
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        XBatch = XTrain(:,:,:,idx);
        XBatch = dlarray(single(XBatch), 'SSCB');
        
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            XBatch = gpuArray(XBatch);           
        end 
            
        [infGrad, genGrad] = dlfeval(...
            @modelGradients, encoderNet, decoderNet, XBatch);
        
        [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ...
            adamupdate(decoderNet.Learnables, ...
                genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr);
        [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ...
            adamupdate(encoderNet.Learnables, ...
                infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr);
    end
    elapsedTime = toc;
    
    [z, zMean, zLogvar] = sampling(encoderNet, XTest);
    xPred = sigmoid(forward(decoderNet, z));
    elbo = ELBOloss(XTest, xPred, zMean, zLogvar);
    disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+...
        ". Time taken for epoch = "+ elapsedTime + "s")    
end
Epoch : 1 Test ELBO loss = 28.0145. Time taken for epoch = 28.0573s
Epoch : 2 Test ELBO loss = 24.8995. Time taken for epoch = 8.797s
Epoch : 3 Test ELBO loss = 23.2756. Time taken for epoch = 8.8824s
Epoch : 4 Test ELBO loss = 21.151. Time taken for epoch = 8.5979s
Epoch : 5 Test ELBO loss = 20.5335. Time taken for epoch = 8.8472s
Epoch : 6 Test ELBO loss = 20.232. Time taken for epoch = 8.5068s
Epoch : 7 Test ELBO loss = 19.9988. Time taken for epoch = 8.4356s
Epoch : 8 Test ELBO loss = 19.8955. Time taken for epoch = 8.4015s
Epoch : 9 Test ELBO loss = 19.7991. Time taken for epoch = 8.8089s
Epoch : 10 Test ELBO loss = 19.6773. Time taken for epoch = 8.4269s
Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s
Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s
Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s
Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s
Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s
Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s
Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s
Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s
Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s
Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s
Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s
Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s
Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s
Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s
Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s
Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s
Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s
Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s
Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s
Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s
Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s
Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s
Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s
Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s
Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s
Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s
Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s
Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s
Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s
Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s
Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s
Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s
Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s
Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s
Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s
Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s
Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s
Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s
Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s
Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s

Визуализация результатов

Чтобы визуализировать и интерпретировать результаты, используйте вспомогательные функции визуализации. Эти вспомогательные функции определены в конце этого примера.

The VisualizeReconstruction функция показывает случайным образом выбранную цифру из каждого класса, сопровождаемую его восстановлением после прохождения через автоэнкодер.

The VisualizeLatentSpace функция принимает среднее значение и кодировки отклонения (каждый из размерности 20), сгенерированные после передачи тестовых изображений через сеть энкодера, и выполняет анализ основного компонента (PCA) на матрице, содержащей кодировки для каждого из изображений. Затем можно визуализировать латентное пространство, заданное средствами и отклонениями в двух размерностях, характеризующихся двумя первыми главными компонентами.

The Generate функция инициализирует новые кодирования, дискретизированные из нормального распределения, и выводит изображения, сгенерированные, когда эти кодирования проходят через сеть декодера.

visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)

visualizeLatentSpace(XTest, YTest, encoderNet)

generate(decoderNet, latentDim)

Следующие шаги

Вариационные автоэнкодеры являются только одной из многих доступных моделей, используемых для выполнения генеративных задач. Они хорошо работают с наборами данных, где изображения являются маленькими и имеют четко определенные функции (такие как MNIST). Для более сложных наборов данных с большими изображениями генеративные состязательные сети (GANs), как правило, выполняют лучше и генерируют изображения с меньшим шумом. Для примера, показывающего, как реализовать GAN для генерации изображений RGB 64 на 64, смотрите Обучите Генеративную Состязательную Сеть (GAN).

Ссылки

  1. LeCun, Y., C. Cortes, and C. J. C. Burges. «База данных MNIST рукописных цифр». http://yann.lecun.com/exdb/mnist/.

Вспомогательные функции

Функция градиентов модели

The modelGradients функция принимает энкодер и декодер dlnetwork объекты и мини-пакет входных данных X, и возвращает градиенты потерь относительно настраиваемых параметров в сетях. Функция выполняет три операции:

  1. Получите кодировки путем вызова sampling функция на мини-пакете изображений, который проходит через сеть энкодера.

  2. Получите потерю, передав кодировки через сеть декодера и вызвав ELBOloss функция.

  3. Вычислите градиенты потерь относительно настраиваемых параметров обеих сетей путем вызова dlgradient функция.

function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x)
[z, zMean, zLogvar] = sampling(encoderNet, x);
xPred = sigmoid(forward(decoderNet, z));
loss = ELBOloss(x, xPred, zMean, zLogvar);
[genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ...
    encoderNet.Learnables);
end

Функции дискретизации и потерь

The sampling функция получает кодировки из входных изображений. Первоначально он пропускает мини-пакет изображений через сеть энкодеров и разделяет выход размера (2*latentDim)*miniBatchSize в матрицу средств и матрицу отклонений, каждый из которых имеет размер latentDim*batchSize. Затем он использует эти матрицы, чтобы реализовать трюк репараметризации и вычислить кодирование. Наконец, это кодирование преобразуется в dlarray объект в формате SSCB.

function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);

sz = size(zMean);
epsilon = randn(sz);
sigma = exp(.5 * zLogvar);
z = epsilon .* sigma + zMean;
z = reshape(z, [1,1,sz]);
zSampled = dlarray(z, 'SSCB');
end

The ELBOloss функция принимает кодировки средств и отклонений, возвращаемых sampling function, и использует их, чтобы вычислить потери ELBO.

function elbo = ELBOloss(x, xPred, zMean, zLogvar)
squares = 0.5*(xPred-x).^2;
reconstructionLoss  = sum(squares, [1,2,3]);

KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);

elbo = mean(reconstructionLoss + KL);
end

Функции визуализации

The VisualizeReconstruction функция случайным образом выбирает два изображения для каждой цифры набора данных MNIST, передает их через VAE и строит график реконструкции один за другим с исходным входом. Обратите внимание, что чтобы построить график информации, содержащейся в dlarray объект, вам нужно сначала извлечь его с помощью extractdata и gather функций.

function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet)
f = figure;
figure(f)
title("Example ground truth image vs. reconstructed image")
for i = 1:2
    for c=0:9
        idx = iRandomIdxOfClass(YTest,c);
        X = XTest(:,:,:,idx);

        [z, ~, ~] = sampling(encoderNet, X);
        XPred = sigmoid(forward(decoderNet, z));
        
        X = gather(extractdata(X));
        XPred = gather(extractdata(XPred));

        comparison = [X, ones(size(X,1),1), XPred];
        subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),
    end
end
end

function idx = iRandomIdxOfClass(T,c)
idx = T == categorical(c);
idx = find(idx);
idx = idx(randi(numel(idx),1));
end

The VisualizeLatentSpace функция визуализирует латентное пространство, заданное матрицами среднего и отклонения, которые формируют выход сети энкодера, и определяет местоположение кластеров, формируемых представлениями латентного пространства каждой цифры.

Функция начинается с извлечения матриц среднего и отклонения из dlarray объекты. Поскольку транспонирование матрицы с размерностями канала/пакета (C и B) невозможно, функция вызывает stripdims перед транспонированием матриц. Затем он проводит основной анализ компонента (PCA) на обеих матрицах. Чтобы визуализировать латентное пространство в двух размерностях, функция сохраняет первые два главных компонента и строит их графики друг против друга. Наконец, функция окрашивает классы цифр так, чтобы можно было наблюдать кластеры.

function visualizeLatentSpace(XTest, YTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);

zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));

zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));

[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);

c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")

ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);

ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
cb = colorbar;  cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);
axis equal
end

The generate функция проверяет генеративные возможности VAE. Он инициализирует dlarray объект, содержащий 25 случайным образом сгенерированных кодировок, передает их через сеть декодера и строит графики выходов.

function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);

f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated samples of digits")
drawnow
end

См. также

| | | | | |

Похожие темы