В этом примере показано, как создать вариационный автокодер (VAE) в MATLAB для создания цифровых изображений. VAE генерирует рукописные цифры в стиле набора данных MNIST.
VAE отличаются от обычных автокодеров тем, что не используют процесс кодирования-декодирования для восстановления входного сигнала. Вместо этого они накладывают распределение вероятности на латентное пространство и изучают распределение так, чтобы распределение выходных сигналов декодера соответствовало распределению наблюдаемых данных. Затем они проводят выборку из этого распределения для создания новых данных.
В этом примере создается сеть VAE, она обучается на наборе данных MNIST и создаются новые изображения, которые очень похожи на изображения в наборе данных.
Загрузите файлы MNIST из http://yann.lecun.com/exdb/mnist/ и загрузите набор данных MNIST в рабочую область [1]. Позвоните в processImagesMNIST и processLabelsMNIST вспомогательные функции, присоединенные к этому примеру, для загрузки данных из файлов в массивы MATLAB.
Поскольку VAE сравнивает восстановленные цифры с входами, а не с категориальными метками, нет необходимости использовать обучающие метки в наборе данных MNIST.
trainImagesFile = 'train-images-idx3-ubyte.gz'; testImagesFile = 't10k-images-idx3-ubyte.gz'; testLabelsFile = 't10k-labels-idx1-ubyte.gz'; XTrain = processImagesMNIST(trainImagesFile);
Read MNIST image data... Number of images in the dataset: 60000 ...
numTrainImages = size(XTrain,4); XTest = processImagesMNIST(testImagesFile);
Read MNIST image data... Number of images in the dataset: 10000 ...
YTest = processLabelsMNIST(testLabelsFile);
Read MNIST label data... Number of labels in the dataset: 10000 ...
Автокодеры имеют две части: кодер и декодер. Кодер принимает входной сигнал изображения и выводит сжатое представление (кодирование), которое является вектором размера latentDim, равное 20 в этом примере. Декодер принимает сжатое представление, декодирует его и воссоздает исходное изображение.
Чтобы сделать вычисления более стабильными в числовом отношении, увеличьте диапазон возможных значений с [0,1] до [-inf, 0], заставив сеть учиться на логарифме отклонений. Определение двух векторов размера latent_dim: один - для среднего, а другой - для логарифма ). Затем используйте эти два вектора, чтобы создать распределение для выборки из .
Используйте 2-е скручивания, сопровождаемые полностью связанным слоем, чтобы субдискретизировать от 28 28 1 изображением MNIST к кодированию в скрытом космосе. Затем используйте перемещенные 2-е скручивания, чтобы расширить кодирование 1 на 1 на 20 назад в изображение 28 на 28 на 1.
latentDim = 20;
imageSize = [28 28 1];
encoderLG = layerGraph([
imageInputLayer(imageSize,'Name','input_encoder','Normalization','none')
convolution2dLayer(3, 32, 'Padding','same', 'Stride', 2, 'Name', 'conv1')
reluLayer('Name','relu1')
convolution2dLayer(3, 64, 'Padding','same', 'Stride', 2, 'Name', 'conv2')
reluLayer('Name','relu2')
fullyConnectedLayer(2 * latentDim, 'Name', 'fc_encoder')
]);
decoderLG = layerGraph([
imageInputLayer([1 1 latentDim],'Name','i','Normalization','none')
transposedConv2dLayer(7, 64, 'Cropping', 'same', 'Stride', 7, 'Name', 'transpose1')
reluLayer('Name','relu1')
transposedConv2dLayer(3, 64, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose2')
reluLayer('Name','relu2')
transposedConv2dLayer(3, 32, 'Cropping', 'same', 'Stride', 2, 'Name', 'transpose3')
reluLayer('Name','relu3')
transposedConv2dLayer(3, 1, 'Cropping', 'same', 'Name', 'transpose4')
]);Чтобы обучить обе сети с помощью пользовательского цикла обучения и включить автоматическое дифференцирование, преобразуйте графики слоев в dlnetwork объекты.
encoderNet = dlnetwork(encoderLG); decoderNet = dlnetwork(decoderLG);
Вспомогательная функция ModelGradients принимает кодер и декодер dlnetwork объекты и мини-пакет входных данных Xи возвращает градиенты потерь относительно обучаемых параметров в сетях. Эта вспомогательная функция определена в конце этого примера.
Функция выполняет этот процесс в два этапа: выборка и потеря. Этап дискретизации выполняет выборку среднего вектора и вектора дисперсии для создания окончательного кодирования, которое должно быть передано в сеть декодеров. Однако, поскольку обратное распространение с помощью операции случайной выборки невозможно, необходимо использовать функцию репараметризации. Эта хитрость перемещает операцию случайной выборки во вспомогательную переменную, которая затем сдвигается на среднее, и масштабируется со стандартным отклонением. Идея заключается в том, что выборка из starti2) такая же, как выборка μi+ε⋅σi, где 0,1). На следующем рисунке графически показана эта идея.

Этап потери проходит кодирование, сгенерированное этапом выборки, через сеть декодера и определяет потерю, которая затем используется для вычисления градиентов. Потеря в VAE, также называемая потерей нижней границы доказательства (ELBO), определяется как сумма двух отдельных терминов потерь:
потеря KL.
Потеря восстановления измеряет, насколько близок выход декодера к исходному входу, используя среднеквадратичную ошибку (MSE):
изображение).
Потеря KL, или расхождение Куллбэка-Лейблера, измеряет разницу между двумя распределениями вероятности. Минимизация потерь KL в этом случае означает обеспечение того, чтобы усвоенные средства и отклонения были как можно ближе к средствам целевого (нормального) распределения. Для скрытого размера размера потери KL получают как
мкi2-starti2).
Практический эффект включения термина потерь KL состоит в том, чтобы упаковать кластеры, полученные из-за потерь восстановления, плотно вокруг центра скрытого пространства, образуя непрерывное пространство для выборки.
Обучение на GPU, если он доступен (требуется параллельная вычислительная Toolbox™).
executionEnvironment = "auto";Настройка параметров обучения для сети. При использовании оптимизатора Адама необходимо инициализировать для каждой сети конечные средние скорости градиента и конечные средние скорости градиента-квадратного затухания с пустыми массивами..
numEpochs = 50; miniBatchSize = 512; lr = 1e-3; numIterations = floor(numTrainImages/miniBatchSize); iteration = 0; avgGradientsEncoder = []; avgGradientsSquaredEncoder = []; avgGradientsDecoder = []; avgGradientsSquaredDecoder = [];
Обучение модели с помощью пользовательского цикла обучения.
Для каждой итерации в эпохе:
Получите следующую мини-партию из обучающего набора.
Преобразование мини-пакета в dlarray объект, убедитесь, что указаны метки размеров 'SSCB' (пространственный, пространственный, канальный, пакетный).
Для обучения GPU преобразуйте dlarray в gpuArray объект.
Оцените градиенты модели с помощью dlfeval и modelGradients функции.
Обновите сетевые обучаемые файлы и средние градиенты для обеих сетей, используя adamupdate функция.
В конце каждой эпохи пропустите изображения тестового набора через автокодер и отобразите потери и время обучения для этой эпохи.
for epoch = 1:numEpochs tic; for i = 1:numIterations iteration = iteration + 1; idx = (i-1)*miniBatchSize+1:i*miniBatchSize; XBatch = XTrain(:,:,:,idx); XBatch = dlarray(single(XBatch), 'SSCB'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" XBatch = gpuArray(XBatch); end [infGrad, genGrad] = dlfeval(... @modelGradients, encoderNet, decoderNet, XBatch); [decoderNet.Learnables, avgGradientsDecoder, avgGradientsSquaredDecoder] = ... adamupdate(decoderNet.Learnables, ... genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder, iteration, lr); [encoderNet.Learnables, avgGradientsEncoder, avgGradientsSquaredEncoder] = ... adamupdate(encoderNet.Learnables, ... infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder, iteration, lr); end elapsedTime = toc; [z, zMean, zLogvar] = sampling(encoderNet, XTest); xPred = sigmoid(forward(decoderNet, z)); elbo = ELBOloss(XTest, xPred, zMean, zLogvar); disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+... ". Time taken for epoch = "+ elapsedTime + "s") end
Epoch : 1 Test ELBO loss = 28.0145. Time taken for epoch = 28.0573s Epoch : 2 Test ELBO loss = 24.8995. Time taken for epoch = 8.797s Epoch : 3 Test ELBO loss = 23.2756. Time taken for epoch = 8.8824s Epoch : 4 Test ELBO loss = 21.151. Time taken for epoch = 8.5979s Epoch : 5 Test ELBO loss = 20.5335. Time taken for epoch = 8.8472s Epoch : 6 Test ELBO loss = 20.232. Time taken for epoch = 8.5068s Epoch : 7 Test ELBO loss = 19.9988. Time taken for epoch = 8.4356s Epoch : 8 Test ELBO loss = 19.8955. Time taken for epoch = 8.4015s Epoch : 9 Test ELBO loss = 19.7991. Time taken for epoch = 8.8089s Epoch : 10 Test ELBO loss = 19.6773. Time taken for epoch = 8.4269s Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s
Для визуализации и интерпретации результатов используйте вспомогательные функции визуализации. Эти вспомогательные функции определяются в конце этого примера.
VisualizeReconstruction функция показывает случайно выбранную цифру из каждого класса, сопровождаемую его реконструкцией после прохождения через автокодер.
VisualizeLatentSpace функция принимает средние и дисперсионные кодировки (каждая из размерности 20), сгенерированные после прохождения тестовых изображений через сеть кодеров, и выполняет анализ основных компонентов (PCA) на матрице, содержащей кодировки для каждого из изображений. Затем можно визуализировать скрытое пространство, определенное средствами, и отклонения в двух измерениях, характеризуемых двумя первыми главными компонентами.
Generate функция инициализирует новые кодировки, дискретизированные из нормального распределения, и выводит изображения, сгенерированные, когда эти кодировки проходят через сеть декодеров.
visualizeReconstruction(XTest, YTest, encoderNet, decoderNet)

visualizeLatentSpace(XTest, YTest, encoderNet)

generate(decoderNet, latentDim)

Вариационные автокодеры являются только одной из многих доступных моделей, используемых для выполнения генеративных задач. Они хорошо работают с наборами данных, в которых изображения малы и имеют четко определенные функции (например, MNIST). Для более сложных наборов данных с большими изображениями генеративные состязательные сети (GAN) имеют тенденцию работать лучше и генерировать изображения с меньшим шумом. Пример реализации GAN для создания образов 64 на 64 RGB см. в разделе Генеративная состязательная сеть (GAN).
LeCun, Y., C. Cortes и C. J. C. Burges. «База данных MNIST рукописных цифр». http://yann.lecun.com/exdb/mnist/.
modelGradients функция принимает кодер и декодер dlnetwork объекты и мини-пакет входных данных Xи возвращает градиенты потерь относительно обучаемых параметров в сетях. Функция выполняет три операции:
Получение кодировок путем вызова sampling функция на мини-пакете изображений, который проходит через сеть кодера.
Получение потери путем передачи кодировок через сеть декодера и вызова ELBOloss функция.
Вычислите градиенты потерь относительно обучаемых параметров обеих сетей, вызвав dlgradient функция.
function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x) [z, zMean, zLogvar] = sampling(encoderNet, x); xPred = sigmoid(forward(decoderNet, z)); loss = ELBOloss(x, xPred, zMean, zLogvar); [genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ... encoderNet.Learnables); end
sampling функция получает кодировки из входных изображений. Первоначально он пропускает мини-пакет изображений через сеть кодера и расщепляет выходной сигнал размера (2*latentDim)*miniBatchSize в матрицу средств и матрицу дисперсий, каждая из которых имеет размер latentDim*batchSize. Затем он использует эти матрицы для реализации трика репараметризации и вычисления кодирования. Наконец, он преобразует эту кодировку в dlarray в формате SSCB.
function [zSampled, zMean, zLogvar] = sampling(encoderNet, x) compressed = forward(encoderNet, x); d = size(compressed,1)/2; zMean = compressed(1:d,:); zLogvar = compressed(1+d:end,:); sz = size(zMean); epsilon = randn(sz); sigma = exp(.5 * zLogvar); z = epsilon .* sigma + zMean; z = reshape(z, [1,1,sz]); zSampled = dlarray(z, 'SSCB'); end
ELBOloss функция принимает кодировки средств и дисперсии, возвращенные sampling и использует их для вычисления потерь ELBO.
function elbo = ELBOloss(x, xPred, zMean, zLogvar) squares = 0.5*(xPred-x).^2; reconstructionLoss = sum(squares, [1,2,3]); KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1); elbo = mean(reconstructionLoss + KL); end
VisualizeReconstruction функция случайным образом выбирает два изображения для каждой цифры набора данных MNIST, пропускает их через VAE и строит график реконструкции бок о бок с исходным входом. Обратите внимание, что для построения графика информации, содержащейся внутри dlarray объект, необходимо сначала извлечь с помощью extractdata и gather функции.
function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet) f = figure; figure(f) title("Example ground truth image vs. reconstructed image") for i = 1:2 for c=0:9 idx = iRandomIdxOfClass(YTest,c); X = XTest(:,:,:,idx); [z, ~, ~] = sampling(encoderNet, X); XPred = sigmoid(forward(decoderNet, z)); X = gather(extractdata(X)); XPred = gather(extractdata(XPred)); comparison = [X, ones(size(X,1),1), XPred]; subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]), end end end function idx = iRandomIdxOfClass(T,c) idx = T == categorical(c); idx = find(idx); idx = idx(randi(numel(idx),1)); end
VisualizeLatentSpace функция визуализирует латентное пространство, определенное матрицами среднего и дисперсии, которые формируют выходной сигнал сети кодера, и локализует кластеры, сформированные представлениями латентного пространства каждой цифры.
Функция начинается с извлечения среднего значения и матриц дисперсии из dlarray объекты. Поскольку транспонирование матрицы с измерениями канала/партии (C и B) невозможно, функция вызывает stripdims перед переносом матриц. Затем он выполняет анализ главных компонентов (PCA) для обеих матриц. Для визуализации скрытого пространства в двух измерениях функция сохраняет первые два главных компонента и строит их график относительно друг друга. Наконец, функция окрашивает классы цифр так, чтобы можно было наблюдать кластеры.
function visualizeLatentSpace(XTest, YTest, encoderNet) [~, zMean, zLogvar] = sampling(encoderNet, XTest); zMean = stripdims(zMean)'; zMean = gather(extractdata(zMean)); zLogvar = stripdims(zLogvar)'; zLogvar = gather(extractdata(zLogvar)); [~,scoreMean] = pca(zMean); [~,scoreLogvar] = pca(zLogvar); c = parula(10); f1 = figure; figure(f1) title("Latent space") ah = subplot(1,2,1); scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; axis equal xlabel("Z_m_u(2)") ylabel("Z_m_u(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); ah = subplot(1,2,2); scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:)); ah.YDir = 'reverse'; xlabel("Z_v_a_r(2)") ylabel("Z_v_a_r(1)") cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9); axis equal end
generate функция проверяет генеративные возможности VAE. Инициализирует dlarray объект, содержащий 25 случайно сгенерированных кодировок, пропускает их через сеть декодеров и строит графики выходных сигналов.
function generate(decoderNet, latentDim) randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB'); generatedImage = sigmoid(predict(decoderNet, randomNoise)); generatedImage = extractdata(generatedImage); f3 = figure; figure(f3) imshow(imtile(generatedImage, "ThumbnailSize", [100,100])) title("Generated samples of digits") drawnow end
adamupdate | dlarray | dlfeval | dlgradient | dlnetwork | layerGraph | sigmoid