В этом примере показано, как создать вариационный автоэнкодер (VAE) в MATLAB, чтобы сгенерировать изображения цифр. VAE генерирует нарисованные вручную цифры в стиле набора данных MNIST.
VAE отличаются от обычных автоэнкодеров тем, что они не используют процесс декодирования-кодирования для восстановления входа. Вместо этого они накладывают распределение вероятностей на скрытое пространство и изучают распределение так, чтобы распределение выходов от декодера совпадало с распределением наблюдаемых данных. Затем они получают выборку из этого распределения, чтобы сгенерировать новые данные.
В этом примере вы создаете сеть VAE, обучаете ее на наборе данных MNIST и генерируете новые изображения, которые тесно напоминают изображения в наборе данных.
Загрузите файлы MNIST из http://yann.lecun.com/exdb/mnist/ и загрузите набор данных MNIST в рабочую область [1]. Вызовите processImagesMNIST
и processLabelsMNIST
вспомогательные функции, присоединенные к этому примеру, чтобы загрузить данные из файлов в массивы MATLAB.
Поскольку VAE сравнивает восстановленные цифры с входами, а не с категориальными метками, вам не нужно использовать обучающие метки в наборе данных MNIST.
trainImagesFile = 'train-images-idx3-ubyte.gz'; testImagesFile = 't10k-images-idx3-ubyte.gz'; testLabelsFile = 't10k-labels-idx1-ubyte.gz'; XTrain = processImagesMNIST(trainImagesFile);
Read MNIST image data... Number of images in the dataset: 60000 ...
numTrainImages = size(XTrain,4); XTest = processImagesMNIST(testImagesFile);
Read MNIST image data... Number of images in the dataset: 10000 ...
YTest = processLabelsMNIST(testLabelsFile);
Read MNIST label data... Number of labels in the dataset: 10000 ...
Автоэнкодеры имеют две части: энкодер и декодер. Энкодер принимает изображение на вход и выводит сжатое представление (кодирование), которое является вектором размера latentDim
, равное 20 в этом примере. Декодер принимает сжатое представление, декодирует его и воссоздает оригинальное изображение.
Чтобы сделать вычисления более численно стабильными, увеличьте область значений возможных значений с [0,1] до [-inf, 0], заставив сеть учиться на логарифме отклонений. Задайте два вектора размера latent_dim
: один для средств и один для логарифма отклонений . Затем используйте эти два вектора, чтобы создать распределение для выборки из.
Используйте 2-е скручивания, сопровождаемые полносвязным слоем, чтобы субдискретизировать от 28 28 1 изображением MNIST к кодированию в скрытом пространстве. Затем используйте перемещенные 2-е скручивания, чтобы расширить кодирование 1 на 1 на 20 назад в изображение 28 на 28 на 1.
latentDim = 20; imageSize = [28 28 1]; encoderLG = layerGraph([ imageInputLayer(imageSize,'Name','input_encoder','Normalization','none') convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1') reluLayer('Name','relu1') convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2') reluLayer('Name','relu2') fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder') ]); decoderLG = layerGraph([ imageInputLayer([1 1 latentDim],'Name','i','Normalization','none') transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1') reluLayer('Name','relu1') transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2') reluLayer('Name','relu2') transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3') reluLayer('Name','relu3') transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4') ]);
Чтобы обучить обе сети с помощью пользовательского цикла обучения и включить автоматическую дифференциацию, преобразуйте графики слоев в dlnetwork
объекты.
encoderNet = dlnetwork(encoderLG); decoderNet = dlnetwork(decoderLG);
Вспомогательная функция modelGradients принимает в энкодере и декодере dlnetwork
объекты и мини-пакет входных данных X
, и возвращает градиенты потерь относительно настраиваемых параметров в сетях. Эта вспомогательная функция определяется в конце этого примера.
Функция выполняет этот процесс в два этапа: дискретизация и потеря. Этап дискретизации дискретизирует среднее значение и векторы отклонения, чтобы создать окончательное кодирование, которое должно быть передано в сеть декодера. Однако, поскольку обратное распространение через операцию случайной выборки невозможно, необходимо использовать трюк репараметризации. Этот трюк перемещает операцию случайной выборки к вспомогательной переменной , который затем сдвигается средним и масштабируется стандартным отклонением . Идея в том, что выборка из является тем же самым, что и выборка из , где . Следующий рисунок изображает эту идею графически.
Этап потерь пропускает кодирование, сгенерированное шагом дискретизации, через сеть декодера и определяет потерю, которая затем используется для вычисления градиентов. Потеря в VAE, также названная доказательством нижней границы (ELBO), определяется как сумма двух отдельных терминов потерь:
.
Потеря реконструкции измеряет, насколько близок выход декодера к исходному входу, используя среднюю квадратную ошибку (MSE):
.
KL loss, или Kullback-Leibler divergence, измеряет различие между двумя распределениями вероятностей. Минимизация потерь KL в этом случае означает, что выученные средства и отклонения максимально близки к таковым целевого (нормального) распределения. Для скрытой размерности размера , KL потеря получена как
.
Практический эффект включения термина потерь KL состоит в том, чтобы упаковать кластеры, выученные из-за потерь восстановления, плотно вокруг центра скрытого пространства, образуя непрерывное пространство для выборки.
Обучите на графическом процессоре, если он доступен (требуется Parallel Computing Toolbox™).
executionEnvironment = "auto";
Установите опции обучения для сети. При использовании оптимизатора Адама необходимо инициализировать для каждой сети конечный средний градиент и конечный средний коэффициент распада градиента-квадрата с пустыми массивами .
numEpochs = 50; miniBatchSize = 512; lr = 1e-3; numIterations = floor(numTrainImages/miniBatchSize); iteration = 0; avgGradientsEncoder = []; avgGradientsSquaredEncoder = []; avgGradientsDecoder = []; avgGradientsSquaredDecoder = [];
Обучите модель с помощью пользовательского цикла обучения.
Для каждой итерации в эпоху:
Получите следующий мини-пакет из набора обучающих данных.
Преобразуйте мини-пакет в dlarray
объект, следя за тем, чтобы задать метки размерностей 'SSCB'
(пространственный, пространственный, канальный, пакетный).
Для обучения графический процессор преобразуйте dlarray
в gpuArray
объект.
Оцените градиенты модели с помощью dlfeval
и modelGradients
функций.
Обновляйте обучаемые возможности сети и средние градиенты для обеих сетей, используя adamupdate
функция.
В конце каждой эпохи передайте изображения тестового набора через автоэнкодер и отобразите потерю и время обучения для этой эпохи.
for epoch = 1:numEpochs tic; for i = 1:numIterations iteration = iteration + 1; idx = (i-1)*miniBatchSize+1:i*miniBatchSize; XBatch = XTrain(:,:,:,idx); XBatch = dlarray(single(XBatch), 'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" XBatch = gpuArray(XBatch); end [infGrad, genGrad] = dlfeval(... @modelGradients, encoderNet, decoderNet, XBatch); [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ... adamupdate(decoderNet.Learnables, ... genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr); [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ... adamupdate(encoderNet.Learnables, ... infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr); end elapsedTime = toc; [z, zMean, zLogvar] = sampling(encoderNet, XTest); xPred = sigmoid(forward(decoderNet, z)); elbo = ELBOloss(XTest, xPred, zMean, zLogvar); disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+... ". Time taken for epoch = "+ elapsedTime + "s") end
Epoch : 1 Test ELBO loss = 28.0145. Time taken for epoch = 28.0573s Epoch : 2 Test ELBO loss = 24.8995. Time taken for epoch = 8.797s Epoch : 3 Test ELBO loss = 23.2756. Time taken for epoch = 8.8824s Epoch : 4 Test ELBO loss = 21.151. Time taken for epoch = 8.5979s Epoch : 5 Test ELBO loss = 20.5335. Time taken for epoch = 8.8472s Epoch : 6 Test ELBO loss = 20.232. Time taken for epoch = 8.5068s Epoch : 7 Test ELBO loss = 19.9988. Time taken for epoch = 8.4356s Epoch : 8 Test ELBO loss = 19.8955. Time taken for epoch = 8.4015s Epoch : 9 Test ELBO loss = 19.7991. Time taken for epoch = 8.8089s Epoch : 10 Test ELBO loss = 19.6773. Time taken for epoch = 8.4269s Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s
Чтобы визуализировать и интерпретировать результаты, используйте вспомогательные функции визуализации. Эти вспомогательные функции определены в конце этого примера.
The VisualizeReconstruction
функция показывает случайным образом выбранную цифру из каждого класса, сопровождаемую его восстановлением после прохождения через автоэнкодер.
The VisualizeLatentSpace
функция принимает среднее значение и кодировки отклонения (каждый из размерности 20), сгенерированные после передачи тестовых изображений через сеть энкодера, и выполняет анализ основного компонента (PCA) на матрице, содержащей кодировки для каждого из изображений. Затем можно визуализировать латентное пространство, заданное средствами и отклонениями в двух размерностях, характеризующихся двумя первыми главными компонентами.
The Generate
функция инициализирует новые кодирования, дискретизированные из нормального распределения, и выводит изображения, сгенерированные, когда эти кодирования проходят через сеть декодера.
visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)
visualizeLatentSpace(XTest, YTest, encoderNet)
generate(decoderNet, latentDim)
Вариационные автоэнкодеры являются только одной из многих доступных моделей, используемых для выполнения генеративных задач. Они хорошо работают с наборами данных, где изображения являются маленькими и имеют четко определенные функции (такие как MNIST). Для более сложных наборов данных с большими изображениями генеративные состязательные сети (GANs), как правило, выполняют лучше и генерируют изображения с меньшим шумом. Для примера, показывающего, как реализовать GAN для генерации изображений RGB 64 на 64, смотрите Обучите Генеративную Состязательную Сеть (GAN).
LeCun, Y., C. Cortes, and C. J. C. Burges. «База данных MNIST рукописных цифр». http://yann.lecun.com/exdb/mnist/.
The modelGradients
функция принимает энкодер и декодер dlnetwork
объекты и мини-пакет входных данных X
, и возвращает градиенты потерь относительно настраиваемых параметров в сетях. Функция выполняет три операции:
Получите кодировки путем вызова sampling
функция на мини-пакете изображений, который проходит через сеть энкодера.
Получите потерю, передав кодировки через сеть декодера и вызвав ELBOloss
функция.
Вычислите градиенты потерь относительно настраиваемых параметров обеих сетей путем вызова dlgradient
функция.
function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x) [z, zMean, zLogvar] = sampling(encoderNet, x); xPred = sigmoid(forward(decoderNet, z)); loss = ELBOloss(x, xPred, zMean, zLogvar); [genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ... encoderNet.Learnables); end
The sampling
функция получает кодировки из входных изображений. Первоначально он пропускает мини-пакет изображений через сеть энкодеров и разделяет выход размера (2*latentDim)*miniBatchSize
в матрицу средств и матрицу отклонений, каждый из которых имеет размер latentDim*batchSize
. Затем он использует эти матрицы, чтобы реализовать трюк репараметризации и вычислить кодирование. Наконец, это кодирование преобразуется в dlarray
объект в формате SSCB.
function [zSampled, zMean, zLogvar] = sampling(encoderNet, x) compressed = forward(encoderNet, x); d = size(compressed,1)/2; zMean = compressed(1:d,:); zLogvar = compressed(1+d:end,:); sz = size(zMean); epsilon = randn(sz); sigma = exp(.5 * zLogvar); z = epsilon .* sigma + zMean; z = reshape(z, [1,1,sz]); zSampled = dlarray(z, 'SSCB'); end
The ELBOloss
функция принимает кодировки средств и отклонений, возвращаемых sampling
function, и использует их, чтобы вычислить потери ELBO.
function elbo = ELBOloss(x, xPred, zMean, zLogvar) squares = 0.5*(xPred-x).^2; reconstructionLoss = sum(squares, [1,2,3]); KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1); elbo = mean(reconstructionLoss + KL); end
The VisualizeReconstruction
функция случайным образом выбирает два изображения для каждой цифры набора данных MNIST, передает их через VAE и строит график реконструкции один за другим с исходным входом. Обратите внимание, что чтобы построить график информации, содержащейся в dlarray
объект, вам нужно сначала извлечь его с помощью extractdata
и gather
функций.
function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet) f = figure; figure(f) title("Example ground truth image vs. reconstructed image") for i = 1:2 for c=0:9 idx = iRandomIdxOfClass(YTest,c); X = XTest(:,:,:,idx); [z, ~, ~] = sampling(encoderNet, X); XPred = sigmoid(forward(decoderNet, z)); X = gather(extractdata(X)); XPred = gather(extractdata(XPred)); comparison = [X, ones(size(X,1),1), XPred]; subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]), end end end function idx = iRandomIdxOfClass(T,c) idx = T == categorical(c); idx = find(idx); idx = idx(randi(numel(idx),1)); end
The VisualizeLatentSpace
функция визуализирует латентное пространство, заданное матрицами среднего и отклонения, которые формируют выход сети энкодера, и определяет местоположение кластеров, формируемых представлениями латентного пространства каждой цифры.
Функция начинается с извлечения матриц среднего и отклонения из dlarray
объекты. Поскольку транспонирование матрицы с размерностями канала/пакета (C и B) невозможно, функция вызывает stripdims
перед транспонированием матриц. Затем он проводит основной анализ компонента (PCA) на обеих матрицах. Чтобы визуализировать латентное пространство в двух размерностях, функция сохраняет первые два главных компонента и строит их графики друг против друга. Наконец, функция окрашивает классы цифр так, чтобы можно было наблюдать кластеры.
function visualizeLatentSpace(XTest, YTest, encoderNet) [~, zMean, zLogvar] = sampling(encoderNet, XTest); zMean = stripdims(zMean)'; zMean = gather(extractdata(zMean)); zLogvar = stripdims(zLogvar)'; zLogvar = gather(extractdata(zLogvar)); [~,scoreMean] = pca(zMean); [~,scoreLogvar] = pca(zLogvar); c = parula(10); f1 = figure; figure(f1) title("Latent space") ah = subplot(1,2,1); scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; axis equal xlabel("Z_m_u(2)") ylabel("Z_m_u(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); ah = subplot(1,2,2); scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; xlabel("Z_v_a_r(2)") ylabel("Z_v_a_r(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); axis equal end
The generate
функция проверяет генеративные возможности VAE. Он инициализирует dlarray
объект, содержащий 25 случайным образом сгенерированных кодировок, передает их через сеть декодера и строит графики выходов.
function generate(decoderNet, latentDim) randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB'); generatedImage = sigmoid(predict(decoderNet, randomNoise)); generatedImage = extractdata(generatedImage); f3 = figure; figure(f3) imshow(imtile(generatedImage, "ThumbnailSize", [100,100])) title("Generated samples of digits") drawnow end
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| layerGraph
| sigmoid