В этом примере показано, как создать вариационный автоэнкодер (VAE) в MATLAB, чтобы сгенерировать изображения цифры. VAE генерирует нарисованные от руки цифры в стиле набора данных MNIST.
VAEs отличаются от обычных автоэнкодеров в этом, они не используют декодирующий кодирование процесс, чтобы восстановить вход. Вместо этого они налагают вероятностное распределение на скрытый пробел и изучают распределение так, чтобы распределение выходных параметров от соответствий декодера те из наблюдаемых данных. Затем они производят от этого распределения, чтобы сгенерировать новые данные.
В этом примере вы создаете сеть VAE, обучаете его на наборе данных MNIST и генерируете новые изображения, которые тесно напоминают тех в наборе данных.
Загрузите файлы MNIST с http://yann.lecun.com/exdb/mnist/и загрузите набор данных MNIST в рабочую область [1]. Извлеките и поместите файлы в рабочую директорию, затем вызовите функции обработки, чтобы загрузить данные из файлов в массивы MATLAB.
Поскольку VAE сравнивает восстановленные цифры с входными параметрами а не с категориальными метками, вы не должны использовать учебные метки в наборе данных MNIST.
trainImagesFile = 'train-images.idx3-ubyte'; testImagesFile = 't10k-images.idx3-ubyte'; testLabelsFile = 't10k-labels.idx1-ubyte'; XTrain = processMNISTimages(trainImagesFile);
Read MNIST image data... Number of images in the dataset: 60000 ...
numTrainImages = size(XTrain,4); XTest = processMNISTimages(testImagesFile);
Read MNIST image data... Number of images in the dataset: 10000 ...
YTest = processMNISTlabels(testLabelsFile);
Read MNIST label data... Number of labels in the dataset: 10000 ...
Автоэнкодеры имеют две части: энкодер и декодер. Энкодер берет вход изображений и выводит сжатое представление (кодирование), который является вектором размера latent_dim
, равняйтесь 20 в этом примере. Декодер берет сжатое представление, декодирует его и воссоздает оригинальное изображение.
Чтобы сделать вычисления более численно устойчивыми, увеличьте область значений возможных значений от [0,1] до [-inf, 0], заставив сеть извлечь уроки из логарифма отклонений. Задайте два вектора размера latent_dim
: один для средних значений и один для логарифма отклонений . Затем используйте эти два вектора, чтобы создать распределение к выборке от.
Используйте 2D свертки, сопровождаемые полносвязным слоем, чтобы проредить от 28 28 1 изображением MNIST к кодированию на скрытом пробеле. Затем используйте транспонированные 2D свертки, чтобы увеличить 1 1 20 кодированием назад в 28 28 1 изображением.
latentDim = 20; imageSize = [28 28 1]; encoderLG = layerGraph([ imageInputLayer(imageSize,'Name','input_encoder','Normalization','none') convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1') reluLayer('Name','relu1') convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2') reluLayer('Name','relu2') fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder') ]); decoderLG = layerGraph([ imageInputLayer([1 1 latentDim],'Name','i','Normalization','none') transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1') reluLayer('Name','relu1') transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2') reluLayer('Name','relu2') transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3') reluLayer('Name','relu3') transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4') ]);
Чтобы обучить обе сети с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте графики слоя в dlnetwork
объекты.
encoderNet = dlnetwork(encoderLG); decoderNet = dlnetwork(decoderLG);
Функция помощника modelGradients берет в энкодере и декодере dlnetwork
объекты и мини-пакет входных данных X
, и возвращает градиенты потери относительно learnable параметров в сетях. Эта функция помощника задана в конце этого примера.
Функция выполняет этот процесс на двух шагах: выборка и потеря. Шаг выборки производит среднее значение и векторы отклонения, чтобы создать кодирование финала, которое будет передано сети декодера. Однако, потому что обратная связь посредством случайной операции выборки не возможна, необходимо использовать прием репараметризации. Этот прием перемещает случайную операцию выборки во вспомогательную переменную , который затем смещен средним значением и масштабируемый стандартным отклонением . Идея является той выборкой от совпадает с выборкой от , где . Следующая фигура изображает эту идею графически.
Шаг потерь передает кодирование, сгенерированное шагом выборки через сеть декодера, и определяет потерю, которая затем используется для расчета градиенты. Потеря в VAEs, также названном нижней границей доказательства (ELBO) потеря, задана как сумма двух отдельных условий потерь:
.
Потеря реконструкции измеряется, как близко декодер выход к исходному входу при помощи среднеквадратической ошибки (MSE):
.
Потеря KL или расхождение Kullback–Leibler, измеряет различие между двумя вероятностными распределениями. Минимизация потери KL в этом, случай означает гарантировать, что изученные средние значения и отклонения максимально близки к тем из целевого (нормального) распределения. Для скрытой размерности размера , потеря KL получена как
.
Практический эффект включения термина потерь KL состоит в том, чтобы упаковать кластеры, изученные из-за потери реконструкции плотно вокруг центра скрытого пробела, формируя непрерывный пробел к выборке от.
Обучайтесь на графическом процессоре (требует Parallel Computing Toolbox™). Если у вас нет графического процессора, установите executionEnvironment
к "cpu"
.
executionEnvironment = "auto";
Установите опции обучения для сети. При использовании оптимизатора Адама необходимо инициализировать для каждой сети запаздывающий средний градиент и запаздывающие средние уровни затухания градиентного квадрата с пустым arrays.
numEpochs = 50; miniBatchSize = 512; lr = 1e-3; numIterations = floor(numTrainImages/miniBatchSize); iteration = 0; avgGradientsEncoder = []; avgGradientsSquaredEncoder = []; avgGradientsDecoder = []; avgGradientsSquaredDecoder = [];
Обучите модель с помощью пользовательского учебного цикла.
Для каждой итерации в эпоху:
Получите следующий мини-пакет из набора обучающих данных.
Преобразуйте мини-пакет в dlarray
объект, убеждаясь задавать размерность маркирует 'SSCB'
(пространственный, пространственный, канал, пакет).
Для обучения графического процессора преобразуйте dlarray
к gpuArray
объект.
Оцените градиенты модели с помощью dlfeval
and modelGradients
функции.
Обновите сеть learnables и средние градиенты для обеих сетей, с помощью adamupdate
функция.
В конце каждой эпохи передайте изображения набора тестов через автоэнкодер и отобразите потерю и учебное время в течение той эпохи.
for epoch = 1:numEpochs tic; for i = 1:numIterations iteration = iteration + 1; idx = (i-1)*miniBatchSize+1:i*miniBatchSize; XBatch = XTrain(:,:,:,idx); XBatch = dlarray(single(XBatch), 'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" XBatch = gpuArray(XBatch); end [infGrad, genGrad] = dlfeval(... @modelGradients, encoderNet, decoderNet, XBatch); [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ... adamupdate(decoderNet.Learnables, ... genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr); [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ... adamupdate(encoderNet.Learnables, ... infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr); end elapsedTime = toc; [z, zMean, zLogvar] = sampling(encoderNet, XTest); xPred = sigmoid(forward(decoderNet, z)); elbo = ELBOloss(XTest, xPred, zMean, zLogvar); disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+... ". Time taken for epoch = "+ elapsedTime + "s") end
Epoch : 1 Test ELBO loss = 27.0561. Time taken for epoch = 33.0037s Epoch : 2 Test ELBO loss = 24.414. Time taken for epoch = 32.4167s Epoch : 3 Test ELBO loss = 23.0166. Time taken for epoch = 32.3244s Epoch : 4 Test ELBO loss = 20.9078. Time taken for epoch = 32.1268s Epoch : 5 Test ELBO loss = 20.6519. Time taken for epoch = 32.3451s Epoch : 6 Test ELBO loss = 20.3201. Time taken for epoch = 32.4371s Epoch : 7 Test ELBO loss = 19.9266. Time taken for epoch = 32.4551s Epoch : 8 Test ELBO loss = 19.8448. Time taken for epoch = 32.9919s Epoch : 9 Test ELBO loss = 19.7485. Time taken for epoch = 33.1783s Epoch : 10 Test ELBO loss = 19.6295. Time taken for epoch = 33.1623s Epoch : 11 Test ELBO loss = 19.539. Time taken for epoch = 32.4781s Epoch : 12 Test ELBO loss = 19.4682. Time taken for epoch = 32.5094s Epoch : 13 Test ELBO loss = 19.3577. Time taken for epoch = 32.5996s Epoch : 14 Test ELBO loss = 19.3247. Time taken for epoch = 32.6447s Epoch : 15 Test ELBO loss = 19.3043. Time taken for epoch = 32.2494s Epoch : 16 Test ELBO loss = 19.2948. Time taken for epoch = 32.5408s Epoch : 17 Test ELBO loss = 19.191. Time taken for epoch = 32.8177s Epoch : 18 Test ELBO loss = 19.1075. Time taken for epoch = 32.5982s Epoch : 19 Test ELBO loss = 19.0606. Time taken for epoch = 33.7771s Epoch : 20 Test ELBO loss = 19.0298. Time taken for epoch = 33.6249s Epoch : 21 Test ELBO loss = 19.0534. Time taken for epoch = 33.4906s Epoch : 22 Test ELBO loss = 18.9859. Time taken for epoch = 33.1101s Epoch : 23 Test ELBO loss = 19.0077. Time taken for epoch = 32.7345s Epoch : 24 Test ELBO loss = 18.9963. Time taken for epoch = 33.0067s Epoch : 25 Test ELBO loss = 18.9189. Time taken for epoch = 32.891s Epoch : 26 Test ELBO loss = 18.8925. Time taken for epoch = 33.0905s Epoch : 27 Test ELBO loss = 18.9182. Time taken for epoch = 32.6203s Epoch : 28 Test ELBO loss = 18.8664. Time taken for epoch = 32.4095s Epoch : 29 Test ELBO loss = 18.8512. Time taken for epoch = 32.4317s Epoch : 30 Test ELBO loss = 18.7983. Time taken for epoch = 32.4s Epoch : 31 Test ELBO loss = 18.7971. Time taken for epoch = 32.4902s Epoch : 32 Test ELBO loss = 18.7888. Time taken for epoch = 32.2591s Epoch : 33 Test ELBO loss = 18.7811. Time taken for epoch = 32.4291s Epoch : 34 Test ELBO loss = 18.7804. Time taken for epoch = 32.5968s Epoch : 35 Test ELBO loss = 18.7839. Time taken for epoch = 32.3787s Epoch : 36 Test ELBO loss = 18.7045. Time taken for epoch = 32.6078s Epoch : 37 Test ELBO loss = 18.7783. Time taken for epoch = 32.6429s Epoch : 38 Test ELBO loss = 18.7068. Time taken for epoch = 32.7032s Epoch : 39 Test ELBO loss = 18.6822. Time taken for epoch = 32.3438s Epoch : 40 Test ELBO loss = 18.7155. Time taken for epoch = 32.6521s Epoch : 41 Test ELBO loss = 18.7161. Time taken for epoch = 32.5532s Epoch : 42 Test ELBO loss = 18.6597. Time taken for epoch = 32.6419s Epoch : 43 Test ELBO loss = 18.6657. Time taken for epoch = 32.4558s Epoch : 44 Test ELBO loss = 18.5996. Time taken for epoch = 32.5503s Epoch : 45 Test ELBO loss = 18.6666. Time taken for epoch = 32.5503s Epoch : 46 Test ELBO loss = 18.6449. Time taken for epoch = 32.2981s Epoch : 47 Test ELBO loss = 18.6107. Time taken for epoch = 32.3152s Epoch : 48 Test ELBO loss = 18.6393. Time taken for epoch = 32.7135s Epoch : 49 Test ELBO loss = 18.6351. Time taken for epoch = 32.3859s Epoch : 50 Test ELBO loss = 18.5955. Time taken for epoch = 32.6549s
Чтобы визуализировать и интерпретировать результаты, используйте функции Визуализации помощника. Эти функции помощника заданы в конце этого примера.
VisualizeReconstruction
функция показывает случайным образом выбранную цифру от каждого класса, сопровождаемого его реконструкцией после прохождения через автоэнкодер.
VisualizeLatentSpace
функционируйте берет среднее значение и кодировку отклонения (каждая размерность 20) сгенерированный после передачи тестовых изображений через сеть энкодера, и выполняет анализ главных компонентов (PCA) матрицы, содержащей кодировку для каждого из изображений. Можно затем визуализировать скрытый пробел, заданный средними значениями и отклонениями в этих двух размерностях, охарактеризованных двумя первыми основными компонентами.
Generate
функция инициализирует новую кодировку, произведенную от нормального распределения, и выводит изображения, сгенерированные, когда эта кодировка проходит через сеть декодера.
visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)
visualizeLatentSpace(XTest, YTest, encoderNet)
generate(decoderNet, latentDim)
Вариационные автоэнкодеры являются только одной из многих доступных моделей, используемых, чтобы выполнить порождающие задачи. Они работают хорошо над наборами данных, где изображения малы и ясно задали функции (такие как MNIST). Для наборов более комплексных данных с увеличенными изображениями порождающие соперничающие сети (GANs) имеют тенденцию выполнять лучше и генерировать изображения с меньшим количеством шума. Для примера, показывающего, как реализовать GANs, чтобы сгенерировать 64 64 изображения RGB, смотрите, Обучают Порождающую соперничающую сеть (GAN).
LeCun, Y., К. Кортес и К. Дж. К. Берджес. "База данных MNIST Рукописных Цифр". http://yann.lecun.com/exdb/mnist/.
modelGradients
функционируйте берет энкодер и декодер dlnetwork
объекты и мини-пакет входных данных X
, и возвращает градиенты потери относительно learnable параметров в сетях. Функция выполняет три операции:
Получите кодировку путем вызова sampling
функция на мини-пакете изображений, который проходит через сеть энкодера.
Получите потерю путем передачи кодировки через сеть декодера и вызова ELBOloss
функция.
Вычислите градиенты потери относительно learnable параматерей обеих сетей путем вызова dlgradient
функция.
function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x) [z, zMean, zLogvar] = sampling(encoderNet, x); xPred = sigmoid(forward(decoderNet, z)); loss = ELBOloss(x, xPred, zMean, zLogvar); [genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ... encoderNet.Learnables); end
sampling
функция получает кодировку из входных изображений. Первоначально, это передает мини-пакет изображений через сеть энкодера и разделяет выход размера (2*latentDim)*miniBatchSize
в матрицу средних значений и матрицу отклонений, каждый размер latentDim*batchSize
. Затем это использует эти матрицы, чтобы реализовать прием перепараметризации и вычислить кодирование. Наконец, это преобразует это кодирование в dlarray
объект в формате SSCB.
function [zSampled, zMean, zLogvar] = sampling(encoderNet, x) compressed = forward(encoderNet, x); d = size(compressed,1)/2; zMean = compressed(1:d,:); zLogvar = compressed(1+d:end,:); sz = size(zMean); epsilon = randn(sz); sigma = exp(.5 * zLogvar); z = epsilon .* sigma + zMean; z = reshape(z, [1,1,sz]); zSampled = dlarray(z, 'SSCB'); end
ELBOloss
функционируйте берет кодировку средних значений и дисперсий, возвращенных sampling
функция, и использует их, чтобы вычислить потерю ELBO.
function elbo = ELBOloss(x, xPred, zMean, zLogvar) squares = 0.5*(xPred-x).^2; reconstructionLoss = sum(squares, [1,2,3]); KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1); elbo = mean(reconstructionLoss + KL); end
VisualizeReconstruction
функция случайным образом выбирает два изображения для каждой цифры набора данных MNIST, передает их через VAE и строит реконструкцию бок о бок с исходным входом. Обратите внимание на то, что построить информацию, содержавшую в dlarray
объект, необходимо извлечь его сначала использование extractdata
and
gather
функции.
function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet) f = figure; figure(f) title("Example ground truth image vs. reconstructed image") for i = 1:2 for c=0:9 idx = iRandomIdxOfClass(YTest,c); X = XTest(:,:,:,idx); [z, ~, ~] = sampling(encoderNet, X); XPred = sigmoid(forward(decoderNet, z)); X = gather(extractdata(X)); XPred = gather(extractdata(XPred)); comparison = [X, ones(size(X,1),1), XPred]; subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]), end end end function idx = iRandomIdxOfClass(T,c) idx = T == categorical(c); idx = find(idx); idx = idx(randi(numel(idx),1)); end
VisualizeLatentSpace
функция визуализирует скрытый пробел, заданный средним значением и матрицами отклонения, которые формируют выход сети энкодера, и определяет местоположение кластеров, сформированных скрытыми представлениями пробела каждой цифры.
Функция запускается путем извлечения среднего значения и матриц отклонения от dlarray
объекты. Поскольку перемещение матрицы с размерностями канала/пакета (C и B) не возможно, вызовы функции stripdims
прежде, чем транспонировать матрицы. Затем это выполняет анализ главных компонентов (PCA) обеих матриц. Чтобы визуализировать скрытый пробел в двух измерениях, функция сохраняет первые два основных компонента и строит их друг против друга. Наконец, функция окрашивает классы цифры так, чтобы можно было наблюдать кластеры.
function visualizeLatentSpace(XTest, YTest, encoderNet) [~, zMean, zLogvar] = sampling(encoderNet, XTest); zMean = stripdims(zMean)'; zMean = gather(extractdata(zMean)); zLogvar = stripdims(zLogvar)'; zLogvar = gather(extractdata(zLogvar)); [~,scoreMean] = pca(zMean); [~,scoreLogvar] = pca(zLogvar); c = parula(10); f1 = figure; figure(f1) title("Latent space") ah = subplot(1,2,1); scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; axis equal xlabel("Z_m_u(2)") ylabel("Z_m_u(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); ah = subplot(1,2,2); scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; xlabel("Z_v_a_r(2)") ylabel("Z_v_a_r(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); axis equal end
Generate
функционируйте тестирует порождающие возможности VAE. Это инициализирует dlarray
объект, содержащий 25 случайным образом сгенерированных кодировок, передает их через сеть декодера и строит выходные параметры.
function generate(decoderNet, latentDim) randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB'); generatedImage = sigmoid(predict(decoderNet, randomNoise)); generatedImage = extractdata(generatedImage); f3 = figure; figure(f3) imshow(imtile(generatedImage, "ThumbnailSize", [100,100])) title("Generated samples of digits") drawnow end
MNIST обрабатывающие функции извлекают данные из загруженных файлов IDX в массивы MATLAB.
processMNISTimages
функция выполняет эти операции:
Проверяйте, может ли файл быть открыт правильно.
Получите магическое число путем чтения первых четырех байтов. Магическое число 2051 для данных изображения, и 2049 для данных о метке.
Считайте следующие 3 набора 4 байтов, которые возвращают количество изображений, количество строк и количество столбцов.
Считайте данные изображения.
Измените размерность массива, и подкачивает первые две размерности вследствие того, что данные считывались в упорядоченном по столбцам формате.
Гарантируйте, что пиксельные значения находятся в области значений [0,1] путем деления их всех на 255, и преобразует трехмерный массив в 4-D dlarray
объект.
Закройте файл.
function X = processMNISTimages(filename) [fileID,errmsg] = fopen(filename,'r','b'); if fileID < 0 error(errmsg); end magicNum = fread(fileID,1,'int32',0,'b'); if magicNum == 2051 fprintf('\nRead MNIST image data...\n') end numImages = fread(fileID,1,'int32',0,'b'); fprintf('Number of images in the dataset: %6d ...\n',numImages); numRows = fread(fileID,1,'int32',0,'b'); numCols = fread(fileID,1,'int32',0,'b'); X = fread(fileID,inf,'unsigned char'); X = reshape(X,numCols,numRows,numImages); X = permute(X,[2 1 3]); X = X./255; X = reshape(X, [28,28,1,size(X,3)]); X = dlarray(X, 'SSCB'); fclose(fileID); end
processMNISTlabels
функция действует так же. После открытия файла и чтения магического числа, это читает метки и возвращает категориальный массив, содержащий их значения.
function Y = processMNISTlabels(filename) [fileID,errmsg] = fopen(filename,'r','b'); if fileID < 0 error(errmsg); end magicNum = fread(fileID,1,'int32',0,'b'); if magicNum == 2049 fprintf('\nRead MNIST label data...\n') end numItems = fread(fileID,1,'int32',0,'b'); fprintf('Number of labels in the dataset: %6d ...\n',numItems); Y = fread(fileID,inf,'unsigned char'); Y = categorical(Y); fclose(fileID); end
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlnetwork
| layerGraph
| sigmoid