В этом примере показано, как создать вариационный автоэнкодер (VAE) в MATLAB, чтобы сгенерировать изображения цифры. VAE генерирует нарисованные от руки цифры в стиле набора данных MNIST.
VAEs отличаются от обычных автоэнкодеров в этом, они не используют декодирующий кодирование процесс, чтобы восстановить вход. Вместо этого они налагают вероятностное распределение на скрытый пробел и изучают распределение так, чтобы распределение выходных параметров от соответствий декодера те из наблюдаемых данных. Затем они производят от этого распределения, чтобы сгенерировать новые данные.
В этом примере вы создаете сеть VAE, обучаете его на наборе данных MNIST и генерируете новые изображения, которые тесно напоминают тех в наборе данных.
Загрузите файлы MNIST с http://yann.lecun.com/exdb/mnist/ и загрузите набор данных MNIST в рабочую область [1]. Вызовите processImagesMNIST
и processLabelsMNIST
помощник функционирует присоединенный к этому примеру, чтобы загрузить данные из файлов в массивы MATLAB.
Поскольку VAE сравнивает восстановленные цифры с входными параметрами а не с категориальными метками, вы не должны использовать учебные метки в наборе данных MNIST.
trainImagesFile = 'train-images-idx3-ubyte.gz'; testImagesFile = 't10k-images-idx3-ubyte.gz'; testLabelsFile = 't10k-labels-idx1-ubyte.gz'; XTrain = processImagesMNIST(trainImagesFile);
Read MNIST image data... Number of images in the dataset: 60000 ...
numTrainImages = size(XTrain,4); XTest = processImagesMNIST(testImagesFile);
Read MNIST image data... Number of images in the dataset: 10000 ...
YTest = processLabelsMNIST(testLabelsFile);
Read MNIST label data... Number of labels in the dataset: 10000 ...
Автоэнкодеры имеют две части: энкодер и декодер. Энкодер берет вход изображений и выводит сжатое представление (кодирование), который является вектором из размера latentDim
, равняйтесь 20 в этом примере. Декодер берет сжатое представление, декодирует его и воссоздает оригинальное изображение.
Чтобы сделать вычисления более численно устойчивыми, увеличьте область значений возможных значений от [0,1] до [-inf, 0], заставив сеть извлечь уроки из логарифма отклонений. Задайте два вектора из размера latent_dim
: один для средних значений и один для логарифма отклонений . Затем используйте эти два вектора, чтобы создать распределение к выборке от.
Используйте 2D свертки, сопровождаемые полносвязным слоем, чтобы проредить от 28 28 1 изображением MNIST к кодированию на скрытом пробеле. Затем используйте транспонированные 2D свертки, чтобы увеличить 1 1 20 кодированием назад в 28 28 1 изображением.
latentDim = 20; imageSize = [28 28 1]; encoderLG = layerGraph([ imageInputLayer(imageSize,'Name','input_encoder','Normalization','none') convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1') reluLayer('Name','relu1') convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2') reluLayer('Name','relu2') fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder') ]); decoderLG = layerGraph([ imageInputLayer([1 1 latentDim],'Name','i','Normalization','none') transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1') reluLayer('Name','relu1') transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2') reluLayer('Name','relu2') transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3') reluLayer('Name','relu3') transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4') ]);
Чтобы обучить обе сети с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте графики слоев в dlnetwork
объекты.
encoderNet = dlnetwork(encoderLG); decoderNet = dlnetwork(decoderLG);
Функция помощника modelGradients берет в энкодере и декодере dlnetwork
объекты и мини-пакет входных данных X
, и возвращает градиенты потери относительно настраиваемых параметров в сетях. Эта функция помощника задана в конце этого примера.
Функция выполняет этот процесс на двух шагах: выборка и потеря. Шаг выборки производит среднее значение и векторы отклонения, чтобы создать кодирование финала, которое будет передано сети декодера. Однако, потому что обратная связь посредством случайной операции выборки не возможна, необходимо использовать прием репараметризации. Этот прием перемещает случайную операцию выборки во вспомогательную переменную , который затем смещен средним значением и масштабируемый стандартным отклонением . Идея является той выборкой от совпадает с выборкой от , где . Следующая фигура изображает эту идею графически.
Шаг потерь передает кодирование, сгенерированное шагом выборки через сеть декодера, и определяет потерю, которая затем используется для расчета градиенты. Потеря в VAEs, также названном нижней границей доказательства (ELBO) потеря, задана как сумма двух отдельных терминов потерь:
.
Потеря реконструкции измеряется, как близко декодер выход к исходному входу при помощи среднеквадратической ошибки (MSE):
.
Потеря KL или расхождение Kullback–Leibler, измеряет различие между двумя вероятностными распределениями. Минимизация потери KL в этом, случай означает гарантировать, что изученные средние значения и отклонения максимально близки к тем из целевого (нормального) распределения. Для скрытой размерности размера , потеря KL получена как
.
Практический эффект включения термина потерь KL состоит в том, чтобы упаковать кластеры, изученные из-за потери реконструкции плотно вокруг центра скрытого пробела, формируя непрерывный пробел к выборке от.
Обучайтесь на графическом процессоре, если вы доступны (требует Parallel Computing Toolbox™).
executionEnvironment = "auto";
Установите опции обучения для сети. При использовании оптимизатора Адама необходимо инициализировать для каждой сети запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с пустым arrays.
numEpochs = 50; miniBatchSize = 512; lr = 1e-3; numIterations = floor(numTrainImages/miniBatchSize); iteration = 0; avgGradientsEncoder = []; avgGradientsSquaredEncoder = []; avgGradientsDecoder = []; avgGradientsSquaredDecoder = [];
Обучите модель с помощью пользовательского учебного цикла.
Для каждой итерации в эпоху:
Получите следующий мини-пакет из набора обучающих данных.
Преобразуйте мини-пакет в dlarray
объект, убеждаясь задавать размерность маркирует 'SSCB'
(пространственный, пространственный, канал, пакет).
Для обучения графического процессора преобразуйте dlarray
к gpuArray
объект.
Оцените градиенты модели с помощью dlfeval
and modelGradients
функции.
Обновите сеть learnables и средние градиенты для обеих сетей, с помощью adamupdate
функция.
В конце каждой эпохи передайте изображения набора тестов через автоэнкодер и отобразите потерю и учебное время в течение той эпохи.
for epoch = 1:numEpochs tic; for i = 1:numIterations iteration = iteration + 1; idx = (i-1)*miniBatchSize+1:i*miniBatchSize; XBatch = XTrain(:,:,:,idx); XBatch = dlarray(single(XBatch), 'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" XBatch = gpuArray(XBatch); end [infGrad, genGrad] = dlfeval(... @modelGradients, encoderNet, decoderNet, XBatch); [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ... adamupdate(decoderNet.Learnables, ... genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr); [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ... adamupdate(encoderNet.Learnables, ... infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr); end elapsedTime = toc; [z, zMean, zLogvar] = sampling(encoderNet, XTest); xPred = sigmoid(forward(decoderNet, z)); elbo = ELBOloss(XTest, xPred, zMean, zLogvar); disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+... ". Time taken for epoch = "+ elapsedTime + "s") end
Epoch : 1 Test ELBO loss = 28.0145. Time taken for epoch = 28.0573s Epoch : 2 Test ELBO loss = 24.8995. Time taken for epoch = 8.797s Epoch : 3 Test ELBO loss = 23.2756. Time taken for epoch = 8.8824s Epoch : 4 Test ELBO loss = 21.151. Time taken for epoch = 8.5979s Epoch : 5 Test ELBO loss = 20.5335. Time taken for epoch = 8.8472s Epoch : 6 Test ELBO loss = 20.232. Time taken for epoch = 8.5068s Epoch : 7 Test ELBO loss = 19.9988. Time taken for epoch = 8.4356s Epoch : 8 Test ELBO loss = 19.8955. Time taken for epoch = 8.4015s Epoch : 9 Test ELBO loss = 19.7991. Time taken for epoch = 8.8089s Epoch : 10 Test ELBO loss = 19.6773. Time taken for epoch = 8.4269s Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s
Чтобы визуализировать и интерпретировать результаты, используйте функции Визуализации помощника. Эти функции помощника заданы в конце этого примера.
VisualizeReconstruction
функция показывает случайным образом выбранную цифру от каждого класса, сопровождаемого его реконструкцией после прохождения через автоэнкодер.
VisualizeLatentSpace
функционируйте берет среднее значение и кодировку отклонения (каждая размерность 20) сгенерированный после передачи тестовых изображений через сеть энкодера, и выполняет анализ главных компонентов (PCA) матрицы, содержащей кодировку для каждого из изображений. Можно затем визуализировать скрытый пробел, заданный средними значениями и отклонениями в этих двух размерностях, охарактеризованных двумя первыми основными компонентами.
Generate
функция инициализирует новую кодировку, произведенную от нормального распределения, и выводит изображения, сгенерированные, когда эта кодировка проходит через сеть декодера.
visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)
visualizeLatentSpace(XTest, YTest, encoderNet)
generate(decoderNet, latentDim)
Вариационные автоэнкодеры являются только одной из многих доступных моделей, используемых, чтобы выполнить порождающие задачи. Они работают хорошо над наборами данных, где изображения малы и ясно задали функции (такие как MNIST). Для наборов более комплексных данных с увеличенными изображениями порождающие соперничающие сети (GANs) имеют тенденцию выполнять лучше и генерировать изображения с меньшим количеством шума. Для примера, показывающего, как реализовать GANs, чтобы сгенерировать 64 64 изображения RGB, смотрите, Обучают Порождающую соперничающую сеть (GAN).
LeCun, Y., К. Кортес и К. Дж. К. Берджес. "База данных MNIST Рукописных Цифр". http://yann.lecun.com/exdb/mnist/.
modelGradients
функционируйте берет энкодер и декодер dlnetwork
объекты и мини-пакет входных данных X
, и возвращает градиенты потери относительно настраиваемых параметров в сетях. Функция выполняет три операции:
Получите кодировку путем вызова sampling
функция на мини-пакете изображений, который проходит через сеть энкодера.
Получите потерю путем передачи кодировки через сеть декодера и вызова ELBOloss
функция.
Вычислите градиенты потери относительно настраиваемых параметров обеих сетей путем вызова dlgradient
функция.
function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x) [z, zMean, zLogvar] = sampling(encoderNet, x); xPred = sigmoid(forward(decoderNet, z)); loss = ELBOloss(x, xPred, zMean, zLogvar); [genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ... encoderNet.Learnables); end
sampling
функция получает кодировку из входных изображений. Первоначально, это передает мини-пакет изображений через сеть энкодера и разделяет выход размера (2*latentDim)*miniBatchSize
в матрицу средних значений и матрицу отклонений, каждый размер latentDim*batchSize
. Затем это использует эти матрицы, чтобы реализовать прием репараметризации и вычислить кодирование. Наконец, это преобразует это кодирование в dlarray
объект в формате SSCB.
function [zSampled, zMean, zLogvar] = sampling(encoderNet, x) compressed = forward(encoderNet, x); d = size(compressed,1)/2; zMean = compressed(1:d,:); zLogvar = compressed(1+d:end,:); sz = size(zMean); epsilon = randn(sz); sigma = exp(.5 * zLogvar); z = epsilon .* sigma + zMean; z = reshape(z, [1,1,sz]); zSampled = dlarray(z, 'SSCB'); end
ELBOloss
функционируйте берет кодировку средних значений и дисперсий, возвращенных sampling
функция, и использует их, чтобы вычислить потерю ELBO.
function elbo = ELBOloss(x, xPred, zMean, zLogvar) squares = 0.5*(xPred-x).^2; reconstructionLoss = sum(squares, [1,2,3]); KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1); elbo = mean(reconstructionLoss + KL); end
VisualizeReconstruction
функция случайным образом выбирает два изображения для каждой цифры набора данных MNIST, передает их через VAE и строит реконструкцию бок о бок с исходным входом. Обратите внимание на то, что построить информацию, содержавшую в dlarray
объект, необходимо извлечь его сначала использование extractdata
и gather
функции.
function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet) f = figure; figure(f) title("Example ground truth image vs. reconstructed image") for i = 1:2 for c=0:9 idx = iRandomIdxOfClass(YTest,c); X = XTest(:,:,:,idx); [z, ~, ~] = sampling(encoderNet, X); XPred = sigmoid(forward(decoderNet, z)); X = gather(extractdata(X)); XPred = gather(extractdata(XPred)); comparison = [X, ones(size(X,1),1), XPred]; subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]), end end end function idx = iRandomIdxOfClass(T,c) idx = T == categorical(c); idx = find(idx); idx = idx(randi(numel(idx),1)); end
VisualizeLatentSpace
функция визуализирует скрытый пробел, заданный средним значением и матрицами отклонения, которые формируют выход сети энкодера, и определяет местоположение кластеров, сформированных скрытыми представлениями пробела каждой цифры.
Функция запускается путем извлечения среднего значения и матриц отклонения от dlarray
объекты. Поскольку перемещение матрицы с размерностями канала/пакета (C и B) не возможно, вызовы функции stripdims
прежде, чем транспонировать матрицы. Затем это выполняет анализ главных компонентов (PCA) обеих матриц. Чтобы визуализировать скрытый пробел в двух измерениях, функция сохраняет первые два основных компонента и строит их друг против друга. Наконец, функция окрашивает классы цифры так, чтобы можно было наблюдать кластеры.
function visualizeLatentSpace(XTest, YTest, encoderNet) [~, zMean, zLogvar] = sampling(encoderNet, XTest); zMean = stripdims(zMean)'; zMean = gather(extractdata(zMean)); zLogvar = stripdims(zLogvar)'; zLogvar = gather(extractdata(zLogvar)); [~,scoreMean] = pca(zMean); [~,scoreLogvar] = pca(zLogvar); c = parula(10); f1 = figure; figure(f1) title("Latent space") ah = subplot(1,2,1); scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; axis equal xlabel("Z_m_u(2)") ylabel("Z_m_u(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); ah = subplot(1,2,2); scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; xlabel("Z_v_a_r(2)") ylabel("Z_v_a_r(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); axis equal end
generate
функционируйте тестирует порождающие возможности VAE. Это инициализирует dlarray
объект, содержащий 25 случайным образом сгенерированных кодировок, передает их через сеть декодера и строит выходные параметры.
function generate(decoderNet, latentDim) randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB'); generatedImage = sigmoid(predict(decoderNet, randomNoise)); generatedImage = extractdata(generatedImage); f3 = figure; figure(f3) imshow(imtile(generatedImage, "ThumbnailSize", [100,100])) title("Generated samples of digits") drawnow end
dlnetwork
| layerGraph
| dlarray
| adamupdate
| dlfeval
| dlgradient
| sigmoid