В этом примере показано, как обучить и использовать порождающую соперничающую сеть (GAN), чтобы сгенерировать звуки.
В порождающих соперничающих сетях генератор и различитель конкурируют друг против друга, чтобы улучшить качество генерации.
GANs вызвали значительный интерес в области речевой обработки и аудио. Приложения включают синтез текста к речи, речевое преобразование и речевое улучшение.
Этот пример обучает GAN безнадзорному синтезу аудио форм волны. GAN в этом примере генерирует звуки барабанного боя. Тот же подход может сопровождаться, чтобы сгенерировать другие типы звука, включая речь.
Прежде чем вы обучите GAN с нуля, вы будете использовать предварительно обученный генератор GAN, чтобы синтезировать барабанные ритмы.
Загрузите предварительно обученный генератор.
matFileName = 'drumGeneratorWeights.mat'; if ~exist(matFileName,'file') websave(matFileName,'https://www.mathworks.com/supportfiles/audio/GanAudioSynthesis/drumGeneratorWeights.mat'); end
Функциональный synthesizeDrumBeat
вызывает предварительно обученную сеть, чтобы синтезировать барабанный бой, произведенный на уровне 16 кГц. synthesizeDrumBeat
функция включена в конце этого примера.
Синтезируйте барабанный бой и слушайте его.
drum = synthesizeDrumBeat; fs = 16e3; sound(drum,fs)
Постройте синтезируемый барабанный бой.
t = (0:length(drum)-1)/fs; plot(t,drum) grid on xlabel('Time (s)') title('Synthesized Drum Beat')
Можно использовать синтезатор барабанного боя с другими звуковыми эффектами, чтобы создать более сложные приложения. Например, можно применить реверберацию к синтезируемым барабанным ритмам.
Создайте reverberator
возразите и откройте его тюнер параметра пользовательский интерфейс. Этот пользовательский интерфейс позволяет вам настроить reverberator
параметры как симуляция запускаются.
reverb = reverberator('SampleRate',fs);
parameterTuner(reverb);
Создайте dsp.TimeScope
объект визуализировать барабанные ритмы.
ts = dsp.TimeScope('SampleRate',fs, ... 'TimeSpanOverrunAction','Scroll', ... 'TimeSpan',10, ... 'BufferLength',10*256*64, ... 'ShowGrid',true, ... 'YLimits',[-1 1]);
В цикле синтезируйте барабанные ритмы и примените реверберацию. Используйте тюнер параметра пользовательский интерфейс, чтобы настроить реверберацию. Если вы хотите запустить симуляцию в течение более длительного времени, увеличьте значение loopCount
параметр.
loopCount = 20; for ii = 1:loopCount drum = synthesizeDrumBeat; drum = reverb(drum); ts(drum(:,1)); soundsc(drum,fs) pause(0.5) end
Теперь, когда вы видели предварительно обученный генератор барабанного боя в действии, можно исследовать учебный процесс подробно.
ГАНЬ является типом нейронной сети для глубокого обучения, которая генерирует данные с характеристиками, похожими на обучающие данные.
ГАНЬ состоит из двух сетей, которые обучаются вместе, генератор и различитель:
Генератор - Учитывая векторные или случайные значения, как введено, эта сеть генерирует данные с той же структурой как обучающие данные. Это - задание генератора, чтобы одурачить различитель.
Различитель - Данный пакеты данных, содержащих наблюдения и от обучающих данных и от сгенерированных данных, эта сеть пытается классифицировать наблюдения как действительные или сгенерированные.
Чтобы максимизировать эффективность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как действительные. Чтобы максимизировать эффективность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных.
В этом примере вы обучаете генератор создавать поддельные представления кратковременного преобразования Фурье (STFT) частоты времени барабанных ритмов. Вы обучаете различитель идентифицировать действительный STFTs. Вы создаете действительный STFTs путем вычисления STFT коротких записей действительных барабанных ритмов.
Обучите GAN использование набора данных Drum Sound Effects [1]. Загрузите и извлеките набор данных.
url = 'http://deepyeti.ucsd.edu/cdonahue/wavegan/data/drums.tar.gz'; downloadFolder = tempdir; filename = fullfile(downloadFolder,'drums_dataset.tgz'); drumsFolder = fullfile(downloadFolder,'drums'); if ~exist(drumsFolder,'dir') disp('Downloading Drum Sound Effects Dataset (218 MB)...') websave(filename,url); untar(filename,downloadFolder) end
Создайте audioDatastore
возразите что точки против набора данных барабанов.
ads = audioDatastore(drumsFolder,'IncludeSubfolders',true);
Задайте сеть, которая генерирует STFTs от 1 1 100 массивами случайных значений. Создайте сеть, которая увеличивает масштаб 1 1 100 массивами к 128 128 1 массивом с помощью полносвязного слоя, сопровождаемого изменять слоем и серией транспонированных слоев свертки со слоями ReLU.
Этот рисунок показывает размерности сигнала, когда это перемещается через генератор. Архитектура генератора задана в Таблице 4 [1].
Сеть генератора задана в modelGenerator
, который включен в конце этого примера.
Задайте сеть, которая классифицирует действительный и сгенерированный 128 128 STFTs.
Создайте сеть, которая берет 128 128 изображения и выводит скалярный счет предсказания с помощью серии слоев свертки с текучими слоями ReLU, сопровождаемыми полносвязным слоем.
Этот рисунок показывает размерности сигнала, когда это перемещается через различитель. Архитектура различителя задана в Таблице 5 [1].
Сеть различителя задана в modelDiscriminator
, который включен в конце этого примера.
Сгенерируйте данные STFT из сигналов барабанного боя в datastore.
Задайте параметры STFT.
fftLength = 256;
win = hann(fftLength,'periodic');
overlapLength = 128;
Чтобы ускорить обработку, распределите извлечение признаков на нескольких рабочих, использующих parfor
.
Во-первых, определите количество разделов для набора данных. Если у вас нет Parallel Computing Toolbox™, используйте один раздел.
if ~isempty(ver('parallel')) pool = gcp; numPar = numpartitions(ads,pool); else numPar = 1; end
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
Для каждого раздела читайте из datastore и вычислите STFT.
parfor ii = 1:numPar subds = partition(ads,numPar,ii); STrain = zeros(fftLength/2+1,128,1,numel(subds.Files)); for idx = 1:numel(subds.Files) x = read(subds); if length(x) > fftLength*64 % Lengthen the signal if it is too short x = x(1:fftLength*64); end % Convert from double-precision to single-precision x = single(x); % Scale the signal x = x ./ max(abs(x)); % Zero-pad to ensure stft returns 128 windows. x = [x ; zeros(overlapLength,1,'like',x)]; S0 = stft(x,'Window',win,'OverlapLength',overlapLength,'Centered',false); % Convert from two-sided to one-sided. S = S0(1:129,:); S = abs(S); STrain(:,:,:,idx) = S; end STrainC{ii} = STrain; end
Преобразуйте выход в четырехмерный массив с STFTs по четвертому измерению.
STrain = cat(4,STrainC{:});
Преобразуйте данные в логарифмическую шкалу, чтобы лучше выровняться с человеческим восприятием.
STrain = log(STrain + 1e-6);
Нормируйте обучающие данные, чтобы иметь нулевое среднее значение и модульное стандартное отклонение.
Вычислите среднее и стандартное отклонение STFT каждого интервала частоты.
SMean = mean(STrain,[2 3 4]); SStd = std(STrain,1,[2 3 4]);
Нормируйте каждый интервал частоты.
STrain = (STrain-SMean)./SStd;
Вычисленные STFTs имеют неограниченные значения. После подхода в [1], сделайте данные ограниченными путем усечения спектров к 3 стандартным отклонениям и перемасштабирования к [-1 1].
STrain = STrain/3; Y = reshape(STrain,numel(STrain),1); Y(Y<-1) = -1; Y(Y>1) = 1; STrain = reshape(Y,size(STrain));
Отбросьте последний интервал частоты, чтобы обеспечить количество интервалов STFT к степени двойки (который работает хорошо со сверточными слоями).
STrain = STrain(1:end-1,:,:,:);
Переставьте размерности при подготовке к питанию различителя.
STrain = permute(STrain,[2 1 3 4]);
Обучайтесь с мини-пакетным размером 64 в течение 1 000 эпох.
maxEpochs = 1000; miniBatchSize = 64;
Вычислите количество итераций, требуемых использовать данные.
numIterationsPerEpoch = floor(size(STrain,4)/miniBatchSize);
Задайте опции для оптимизации Адама. Установите изучить уровень генератора и различителя к 0.0002
. Для обеих сетей используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.
learnRateGenerator = 0.0002; learnRateDiscriminator = 0.0002; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999;
Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™ и CUDA-поддерживающего NVIDIA, графический процессор с вычисляет возможность 3,0 или выше.
executionEnvironment = "auto";
Инициализируйте веса различителя и генератор. initializeGeneratorWeights
и initializeDiscriminatorWeights
функции возвращают случайные веса, полученные с помощью универсальной инициализации Glorot. Функции включены в конце этого примера.
generatorParameters = initializeGeneratorWeights; discriminatorParameters = initializeDiscriminatorWeights;
Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации.
В течение каждой эпохи переставьте обучающие данные и цикл по мини-пакетам данных.
Для каждого мини-пакета:
Сгенерируйте dlarray
объект, содержащий массив случайных значений для сети генератора.
Для обучения графического процессора преобразуйте данные в gpuArray
Объект (Parallel Computing Toolbox).
Оцените градиенты модели с помощью dlfeval
(Deep Learning Toolbox) и функции помощника, modelDiscriminatorGradients
и modelGeneratorGradients
.
Обновите сетевые параметры с помощью adamupdate
(Deep Learning Toolbox) функция.
Инициализируйте параметры для Адама.
trailingAvgGenerator = []; trailingAvgSqGenerator = []; trailingAvgDiscriminator = []; trailingAvgSqDiscriminator = [];
Можно установить saveCheckpoints
к true
чтобы сохранить обновленные веса и состояния к MAT регистрируют каждые десять эпох. Можно затем использовать этот файл MAT, чтобы возобновить обучение, если это прервано. В целях этого примера, набор saveCheckpoints
к false
.
saveCheckpoints = false;
Задайте длину входа генератора.
numLatentInputs = 100;
Обучите GAN. Это может занять несколько часов, чтобы запуститься.
iteration = 0; for epoch = 1:maxEpochs % Shuffle the data. idx = randperm(size(STrain,4)); STrain = STrain(:,:,:,idx); % Loop over mini-batches. for index = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data. dlX = STrain(:,:,:,(index-1)*miniBatchSize+1:index*miniBatchSize); dlX = dlarray(dlX,'SSCB'); % Generate latent inputs for the generator network. Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,'single') - 0.5 ) ; dlZ = dlarray(Z); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZ = gpuArray(dlZ); dlX = gpuArray(dlX); end % Evaluate the discriminator gradients using dlfeval and the % |modelDiscriminatorGradients| helper function. gradientsDiscriminator = ... dlfeval(@modelDiscriminatorGradients,discriminatorParameters,generatorParameters,dlX,dlZ); % Update the discriminator network parameters. [discriminatorParameters,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ... adamupdate(discriminatorParameters,gradientsDiscriminator, ... trailingAvgDiscriminator,trailingAvgSqDiscriminator,iteration, ... learnRateDiscriminator,gradientDecayFactor,squaredGradientDecayFactor); % Generate latent inputs for the generator network. Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,'single') - 0.5 ) ; dlZ = dlarray(Z); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlZ = gpuArray(dlZ); end % Evaluate the generator gradients using dlfeval and the % |modelGeneratorGradients| helper function. gradientsGenerator = ... dlfeval(@modelGeneratorGradients,discriminatorParameters,generatorParameters,dlZ); % Update the generator network parameters. [generatorParameters,trailingAvgGenerator,trailingAvgSqGenerator] = ... adamupdate(generatorParameters,gradientsGenerator, ... trailingAvgGenerator,trailingAvgSqGenerator,iteration, ... learnRateGenerator,gradientDecayFactor,squaredGradientDecayFactor); end % Every 10 iterations, save a training snapshot to a MAT file. if saveCheckpoints && mod(epoch,10)==0 fprintf('Epoch %d out of %d complete\n',epoch,maxEpochs); % Save checkpoint in case training is interrupted. save('audiogancheckpoint.mat',... 'generatorParameters','discriminatorParameters',... 'trailingAvgDiscriminator','trailingAvgSqDiscriminator',... 'trailingAvgGenerator','trailingAvgSqGenerator','iteration'); end end
Теперь, когда вы обучили сеть, можно исследовать процесс синтеза более подробно.
Обученный генератор барабанного боя синтезирует матрицы кратковременного преобразования Фурье (STFT) от входных массивов случайных значений. Обратная операция STFT (ISTFT) преобразует частоту времени STFT в синтезируемый звуковой сигнал временного интервала.
Загрузите веса предварительно обученного генератора. Эти веса были получены путем выполнения обучения, подсвеченного в предыдущем разделе в течение 1 000 эпох.
load(matFileName,'generatorParameters','SMean','SStd');
Генератор берет 1 1 100 векторами из случайных значений как вход. Сгенерируйте демонстрационный входной вектор.
numLatentInputs = 100;
dlZ = dlarray(2 * ( rand(1,1,numLatentInputs,1,'single') - 0.5 ));
Передайте случайный вектор генератору, чтобы создать изображение STFT. generatorParameters
структура, содержащая веса предварительно обученного генератора.
dlXGenerated = modelGenerator(dlZ,generatorParameters);
Преобразуйте dlarray
STFT к матрице с одинарной точностью.
S = dlXGenerated.extractdata;
Транспонируйте STFT, чтобы выровнять его размерности с istft
функция.
S = S.';
STFT 128 128 матрица, где первая размерность представляет 128 интервалов частоты, линейно расположенных с интервалами от 0 до 8 кГц. Генератор был обучен сгенерировать односторонний STFT от длины БПФ 256 с последним не использованным интервалом. Повторно введите тот интервал путем вставки строки нулей в STFT.
S = [S ; zeros(1,128)];
Вернитесь нормализация и масштабирующиеся шаги, используемые, когда вы сгенерировали STFTs для обучения.
S = S * 3; S = (S.*SStd) + SMean;
Преобразуйте STFT от логарифмической области до линейной области.
S = exp(S);
Преобразуйте STFT от одностороннего до двухстороннего.
S = [S; S(end-1:-1:2,:)];
Заполните нулями, чтобы удалить краевые эффекты окна.
S = [zeros(256,100) S zeros(256,100)];
Матрица STFT не содержит информации о фазе. Используйте быструю версию алгоритма Гриффина-Лима с 20 итерациями, чтобы оценить фазу сигнала и произвести аудиосэмплы.
myAudio = stftmag2sig(S,256, ... 'FrequencyRange','twosided', ... 'Window',hann(256,'periodic'), ... 'OverlapLength',128, ... 'MaxIterations',20, ... 'Method','fgla'); myAudio = myAudio./max(abs(myAudio),[],'all'); myAudio = myAudio(128*100:end-128*100);
Слушайте синтезируемый барабанный бой.
sound(myAudio,fs)
Постройте синтезируемый барабанный бой.
t = (0:length(myAudio)-1)/fs; plot(t,myAudio) grid on xlabel('Time (s)') title('Synthesized GAN Sound')
Постройте STFT синтезируемого барабанного боя.
figure stft(myAudio,fs,'Window',hann(256,'periodic'),'OverlapLength',128);
modelGenerator
функция увеличивает масштаб 1 1 100 массивами (dlX) к 128 128 1 массивом (dlY). parameters
структура, содержащая веса слоев генератора. Архитектура генератора задана в Таблице 4 [1].
function dlY = modelGenerator(dlX,parameters) dlY = fullyconnect(dlX,parameters.FC.Weights,parameters.FC.Bias,'Dataformat','SSCB'); dlY = reshape(dlY,[1024 4 4 size(dlY,2)]); dlY = permute(dlY,[3 2 1 4]); dlY = relu(dlY); dlY = dltranspconv(dlY,parameters.Conv1.Weights,parameters.Conv1.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB'); dlY = relu(dlY); dlY = dltranspconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB'); dlY = relu(dlY); dlY = dltranspconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB'); dlY = relu(dlY); dlY = dltranspconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB'); dlY = relu(dlY); dlY = dltranspconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB'); dlY = tanh(dlY); end
modelDiscriminator
функционируйте берет 128 128 изображения и выводит скалярный счет предсказания. Архитектура различителя задана в Таблице 5 [1].
function dlY = modelDiscriminator(dlX,parameters) dlY = dlconv(dlX,parameters.Conv1.Weights,parameters.Conv1.Bias,'Stride' ,2 ,'Padding','same'); dlY = leakyrelu(dlY,0.2); dlY = dlconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,'Stride' ,2 ,'Padding','same'); dlY = leakyrelu(dlY,0.2); dlY = dlconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,'Stride' ,2 ,'Padding','same'); dlY = leakyrelu(dlY,0.2); dlY = dlconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,'Stride' ,2 ,'Padding','same'); dlY = leakyrelu(dlY,0.2); dlY = dlconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,'Stride' ,2 ,'Padding','same'); dlY = leakyrelu(dlY,0.2); dlY = stripdims(dlY); dlY = permute(dlY,[3 2 1 4]); dlY = reshape(dlY,4*4*64*16,numel(dlY)/(4*4*64*16)); weights = parameters.FC.Weights; bias = parameters.FC.Bias; dlY = fullyconnect(dlY,weights,bias,'Dataformat','CB'); end
modelDiscriminatorGradients
функции берут в качестве входа генератор и параметры различителя generatorParameters
и discriminatorParameters
, мини-пакет входных данных dlX
, и массив случайных значений dlZ
, и возвращает градиенты потери различителя относительно настраиваемых параметров в сетях.
function gradientsDiscriminator = modelDiscriminatorGradients(discriminatorParameters , generatorParameters, dlX, dlZ) % Calculate the predictions for real data with the discriminator network. dlYPred = modelDiscriminator(dlX,discriminatorParameters); % Calculate the predictions for generated data with the discriminator network. dlXGenerated = modelGenerator(dlZ,generatorParameters); dlYPredGenerated = modelDiscriminator(dlarray(dlXGenerated,'SSCB'),discriminatorParameters); % Calculate the GAN loss lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated); % For each network, calculate the gradients with respect to the loss. gradientsDiscriminator = dlgradient(lossDiscriminator,discriminatorParameters); end
modelGeneratorGradients
функционируйте берет в качестве входа различитель и настраиваемые параметры генератора и массив случайных значений dlZ
, и возвращает градиенты потери генератора относительно настраиваемых параметров в сетях.
function gradientsGenerator = modelGeneratorGradients(discriminatorParameters, generatorParameters , dlZ) % Calculate the predictions for generated data with the discriminator network. dlXGenerated = modelGenerator(dlZ,generatorParameters); dlYPredGenerated = modelDiscriminator(dlarray(dlXGenerated,'SSCB'),discriminatorParameters); % Calculate the GAN loss lossGenerator = ganGeneratorLoss(dlYPredGenerated); % For each network, calculate the gradients with respect to the loss. gradientsGenerator = dlgradient(lossGenerator, generatorParameters); end
Цель различителя не состоит в том, чтобы дурачить генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте функцию потерь различителя. Функция потерь для генератора следует за подходом DCGAN, подсвеченным в [1].
function lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated) fake = dlarray(zeros(1,size(dlYPred,2))); real = dlarray(ones(1,size(dlYPred,2))); D_loss = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,fake)); D_loss = D_loss + mean(sigmoid_cross_entropy_with_logits(dlYPred,real)); lossDiscriminator = D_loss / 2; end
Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте функцию потерь генератора. Функция потерь для генератора следует за подходом глубоко сверточной порождающей adverarial сети (DCGAN), подсвеченным в [1].
function lossGenerator = ganGeneratorLoss(dlYPredGenerated) real = dlarray(ones(1,size(dlYPredGenerated,2))); lossGenerator = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,real)); end
initializeDiscriminatorWeights
инициализирует веса различителя с помощью алгоритма Glorot.
function discriminatorParameters = initializeDiscriminatorWeights filterSize = [5 5]; dim = 64; % Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]); bias = zeros(1,1,dim,'single'); discriminatorParameters.Conv1.Weights = dlarray(weights); discriminatorParameters.Conv1.Bias = dlarray(bias); % Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]); bias = zeros(1,1,2*dim,'single'); discriminatorParameters.Conv2.Weights = dlarray(weights); discriminatorParameters.Conv2.Bias = dlarray(bias); % Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]); bias = zeros(1,1,4*dim,'single'); discriminatorParameters.Conv3.Weights = dlarray(weights); discriminatorParameters.Conv3.Bias = dlarray(bias); % Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]); bias = zeros(1,1,8*dim,'single'); discriminatorParameters.Conv4.Weights = dlarray(weights); discriminatorParameters.Conv4.Bias = dlarray(bias); % Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]); bias = zeros(1,1,16*dim,'single'); discriminatorParameters.Conv5.Weights = dlarray(weights); discriminatorParameters.Conv5.Bias = dlarray(bias); % fully connected weights = iGlorotInitialize([1,4 * 4 * dim * 16]); bias = zeros(1,1,'single'); discriminatorParameters.FC.Weights = dlarray(weights); discriminatorParameters.FC.Bias = dlarray(bias); end
initializeGeneratorWeights
инициализирует веса генератора с помощью алгоритма Glorot.
function generatorParameters = initializeGeneratorWeights dim = 64; % Dense 1 weights = iGlorotInitialize([dim*256,100]); bias = zeros(dim*256,1,'single'); generatorParameters.FC.Weights = dlarray(weights); generatorParameters.FC.Bias = dlarray(bias); filterSize = [5 5]; % Trans Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]); bias = zeros(1,1,dim*8,'single'); generatorParameters.Conv1.Weights = dlarray(weights); generatorParameters.Conv1.Bias = dlarray(bias); % Trans Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]); bias = zeros(1,1,dim*4,'single'); generatorParameters.Conv2.Weights = dlarray(weights); generatorParameters.Conv2.Bias = dlarray(bias); % Trans Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]); bias = zeros(1,1,dim*2,'single'); generatorParameters.Conv3.Weights = dlarray(weights); generatorParameters.Conv3.Bias = dlarray(bias); % Trans Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]); bias = zeros(1,1,dim,'single'); generatorParameters.Conv4.Weights = dlarray(weights); generatorParameters.Conv4.Bias = dlarray(bias); % Trans Conv2D weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]); bias = zeros(1,1,1,'single'); generatorParameters.Conv5.Weights = dlarray(weights); generatorParameters.Conv5.Bias = dlarray(bias); end
synthesizeDrumBeat
использует предварительно обученную сеть, чтобы синтезировать барабанные ритмы.
function y = synthesizeDrumBeat persistent pGeneratorParameters pMean pSTD if isempty(pGeneratorParameters) % If the MAT file does not exist, download it filename = 'drumGeneratorWeights.mat'; load(filename,'SMean','SStd','generatorParameters'); pMean = SMean; pSTD = SStd; pGeneratorParameters = generatorParameters; end % Generate random vector dlZ = dlarray(2 * ( rand(1,1,100,1,'single') - 0.5 )); % Generate spectrograms dlXGenerated = modelGenerator(dlZ,pGeneratorParameters); % Convert from dlarray to single S = dlXGenerated.extractdata; S = S.'; % Zero-pad to remove edge effects S = [S ; zeros(1,128)]; % Reverse steps from training S = S * 3; S = (S.*pSTD) + pMean; S = exp(S); % Make it two-sided S = [S ; S(end-1:-1:2,:)]; % Pad with zeros at end and start S = [zeros(256,100) S zeros(256,100)]; % Reconstruct the signal using a fast Griffin-Lim algorithm. myAudio = stftmag2sig(gather(S),256, ... 'FrequencyRange','twosided', ... 'Window',hann(256,'periodic'), ... 'OverlapLength',128, ... 'MaxIterations',20, ... 'Method','fgla'); myAudio = myAudio./max(abs(myAudio),[],'all'); y = myAudio(128*100:end-128*100); end
function out = sigmoid_cross_entropy_with_logits(x,z) out = max(x, 0) - x .* z + log(1 + exp(-abs(x))); end function w = iGlorotInitialize(sz) if numel(sz) == 2 numInputs = sz(2); numOutputs = sz(1); else numInputs = prod(sz(1:3)); numOutputs = prod(sz([1 2 4])); end multiplier = sqrt(2 / (numInputs + numOutputs)); w = multiplier * sqrt(3) * (2 * rand(sz,'single') - 1); end
[1] Donahue, C., Дж. Маколи и М. Пюккетт. 2019. "Соперничающий аудио синтез". ICLR.