Обучите Порождающую соперничающую сеть (GAN) синтезу звука

В этом примере показано, как обучить и использовать порождающую соперничающую сеть (GAN), чтобы сгенерировать звуки.

Введение

В порождающих соперничающих сетях генератор и различитель конкурируют друг против друга, чтобы улучшить качество генерации.

GANs вызвали значительный интерес в области речевой обработки и аудио. Приложения включают синтез текста к речи, речевое преобразование и речевое улучшение.

Этот пример обучает GAN безнадзорному синтезу аудио форм волны. GAN в этом примере генерирует звуки барабанного боя. Тот же подход может сопровождаться, чтобы сгенерировать другие типы звука, включая речь.

Синтезируйте аудио с предварительно обученным GAN

Прежде чем вы обучите GAN с нуля, вы будете использовать предварительно обученный генератор GAN, чтобы синтезировать барабанные ритмы.

Загрузите предварительно обученный генератор.

matFileName = 'drumGeneratorWeights.mat';
if ~exist(matFileName,'file')
    websave(matFileName,'https://www.mathworks.com/supportfiles/audio/GanAudioSynthesis/drumGeneratorWeights.mat');
end

Функциональный synthesizeDrumBeat вызывает предварительно обученную сеть, чтобы синтезировать барабанный бой, произведенный на уровне 16 кГц. synthesizeDrumBeat функция включена в конце этого примера.

Синтезируйте барабанный бой и слушайте его.

drum = synthesizeDrumBeat;

fs = 16e3;
sound(drum,fs)

Постройте синтезируемый барабанный бой.

t = (0:length(drum)-1)/fs;
plot(t,drum)
grid on
xlabel('Time (s)')
title('Synthesized Drum Beat')

Можно использовать синтезатор барабанного боя с другими звуковыми эффектами, чтобы создать более сложные приложения. Например, можно применить реверберацию к синтезируемым барабанным ритмам.

Создайте reverberator Объект (Audio Toolbox) и открытый его тюнер параметра пользовательский интерфейс. Этот пользовательский интерфейс позволяет вам настроить reverberator параметры как симуляция запускаются.

reverb = reverberator('SampleRate',fs);
parameterTuner(reverb);

Создайте объект scope времени визуализировать барабанные ритмы.

ts = timescope('SampleRate',fs, ...
    'TimeSpanSource','Property', ...
    'TimeSpanOverrunAction','Scroll', ...
    'TimeSpan',10, ...
    'BufferLength',10*256*64, ...
    'ShowGrid',true, ...
    'YLimits',[-1 1]);

В цикле синтезируйте барабанные ритмы и примените реверберацию. Используйте тюнер параметра пользовательский интерфейс, чтобы настроить реверберацию. Если вы хотите запустить симуляцию в течение более длительного времени, увеличьте значение loopCount параметр.

loopCount = 20;
for ii = 1:loopCount
    drum = synthesizeDrumBeat;
    drum = reverb(drum);
    ts(drum(:,1));
    soundsc(drum,fs)
    pause(0.5)
end

Обучите GAN

Теперь, когда вы видели предварительно обученный генератор барабанного боя в действии, можно исследовать учебный процесс подробно.

ГАНЬ является типом нейронной сети для глубокого обучения, которая генерирует данные с характеристиками, похожими на обучающие данные.

ГАНЬ состоит из двух сетей, которые обучаются вместе, генератор и различитель:

  • Генератор - Учитывая векторные или случайные значения, как введено, эта сеть генерирует данные с той же структурой как обучающие данные. Это - задание генератора, чтобы одурачить различитель.

  • Различитель - Данный пакеты данных, содержащих наблюдения и от обучающих данных и от сгенерированных данных, эта сеть пытается классифицировать наблюдения как действительные или сгенерированные.

Чтобы максимизировать эффективность генератора, максимизируйте потерю различителя, когда дали сгенерированные данные. Таким образом, цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как действительные. Чтобы максимизировать эффективность различителя, минимизируйте потерю различителя когда данный пакеты и действительных и сгенерированных данных. Идеально, эти стратегии приводят к генератору, который генерирует убедительно реалистические данные и различитель, который изучил представления сильной черты, которые являются характеристическими для обучающих данных.

В этом примере вы обучаете генератор создавать поддельные представления кратковременного преобразования Фурье (STFT) частоты времени барабанных ритмов. Вы обучаете различитель идентифицировать действительный STFTs. Вы создаете действительный STFTs путем вычисления STFT коротких записей действительных барабанных ритмов.

Загрузите обучающие данные

Обучите GAN использование набора данных Drum Sound Effects [1]. Загрузите и извлеките набор данных.

url = 'http://deepyeti.ucsd.edu/cdonahue/wavegan/data/drums.tar.gz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'drums_dataset.tgz');

drumsFolder = fullfile(downloadFolder,'drums');
if ~exist(drumsFolder,'dir')
    disp('Downloading Drum Sound Effects Dataset (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте audioDatastore Объект (Audio Toolbox), который указывает на набор данных барабанов.

ads = audioDatastore(drumsFolder,'IncludeSubfolders',true);

Задайте сеть генератора

Задайте сеть, которая генерирует STFTs от 1 1 100 массивами случайных значений. Создайте сеть, которая увеличивает масштаб 1 1 100 массивами к 128 128 1 массивом с помощью полносвязного слоя, сопровождаемого изменять слоем и серией транспонированных слоев свертки со слоями ReLU.

Этот рисунок показывает размерности сигнала, когда это перемещается через генератор. Архитектура генератора задана в Таблице 4 [1].

Сеть генератора задана в modelGenerator, который включен в конце этого примера.

Задайте сеть различителя

Задайте сеть, которая классифицирует действительный и сгенерированный 128 128 STFTs.

Создайте сеть, которая берет 128 128 изображения и выводит скалярный счет предсказания с помощью серии слоев свертки с текучими слоями ReLU, сопровождаемыми полносвязным слоем.

Этот рисунок показывает размерности сигнала, когда это перемещается через различитель. Архитектура различителя задана в Таблице 5 [1].

Сеть различителя задана в modelDiscriminator, который включен в конце этого примера.

Сгенерируйте действительные обучающие данные барабанного боя

Сгенерируйте данные STFT из сигналов барабанного боя в datastore.

Задайте параметры STFT.

fftLength = 256;
win = hann(fftLength,'periodic');
overlapLength = 128;

Чтобы ускорить обработку, распределите извлечение признаков на нескольких рабочих, использующих parfor.

Во-первых, определите количество разделов для набора данных. Если у вас нет Parallel Computing Toolbox™, используйте один раздел.

if ~isempty(ver('parallel'))
    pool = gcp;
    numPar = numpartitions(ads,pool);
else
    numPar = 1;
end

Для каждого раздела читайте из datastore и вычислите STFT.

parfor ii = 1:numPar

    subds = partition(ads,numPar,ii);
    STrain = zeros(fftLength/2+1,128,1,numel(subds.Files));

    for idx = 1:numel(subds.Files)

        x = read(subds);

        if length(x) > fftLength*64
            % Lengthen the signal if it is too short
            x = x(1:fftLength*64);
        end

        % Convert from double-precision to single-precision
        x = single(x);

        % Scale the signal
        x = x ./ max(abs(x));

        % Zero-pad to ensure stft returns 128 windows.
        x = [x ; zeros(overlapLength,1,'like',x)];

        S0 = stft(x,'Window',win,'OverlapLength',overlapLength,'Centered',false);

        % Convert from two-sided to one-sided.
        S = S0(1:129,:);
        S = abs(S);
        STrain(:,:,:,idx) = S;
    end
    STrainC{ii} = STrain;
end

Преобразуйте выход в четырехмерный массив с STFTs по четвертому измерению.

STrain = cat(4,STrainC{:});

Преобразуйте данные в логарифмическую шкалу, чтобы лучше выровняться с человеческим восприятием.

STrain = log(STrain + 1e-6);

Нормируйте обучающие данные, чтобы иметь нулевое среднее значение и модульное стандартное отклонение.

Вычислите среднее и стандартное отклонение STFT каждого интервала частоты.

SMean = mean(STrain,[2 3 4]);
SStd = std(STrain,1,[2 3 4]);

Нормируйте каждый интервал частоты.

STrain = (STrain-SMean)./SStd;

Вычисленные STFTs имеют неограниченные значения. После подхода в [1], сделайте данные ограниченными путем усечения спектров к 3 стандартным отклонениям и перемасштабирования к [-1 1].

STrain = STrain/3;
Y = reshape(STrain,numel(STrain),1);
Y(Y<-1) = -1;
Y(Y>1) = 1;
STrain = reshape(Y,size(STrain));

Отбросьте последний интервал частоты, чтобы обеспечить количество интервалов STFT к степени двойки (который работает хорошо со сверточными слоями).

STrain = STrain(1:end-1,:,:,:);

Переставьте размерности при подготовке к питанию различителя.

STrain = permute(STrain,[2 1 3 4]);

Задайте опции обучения

Обучайтесь с мини-пакетным размером 64 в течение 1 000 эпох.

maxEpochs = 1000;
miniBatchSize = 64;

Вычислите количество итераций, требуемых использовать данные.

numIterationsPerEpoch = floor(size(STrain,4)/miniBatchSize);

Задайте опции для оптимизации Адама. Установите изучить уровень генератора и различителя к 0.0002. Для обеих сетей используйте фактор затухания градиента 0,5 и фактор затухания градиента в квадрате 0,999.

learnRateGenerator = 0.0002;
learnRateDiscriminator = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™.

executionEnvironment = "auto";

Инициализируйте веса различителя и генератор. initializeGeneratorWeights и initializeDiscriminatorWeights функции возвращают случайные веса, полученные с помощью универсальной инициализации Glorot. Функции включены в конце этого примера.

generatorParameters = initializeGeneratorWeights;
discriminatorParameters = initializeDiscriminatorWeights;

Обучите модель

Обучите модель с помощью пользовательского учебного цикла. Цикл по обучающим данным и обновлению сетевые параметры в каждой итерации.

В течение каждой эпохи переставьте обучающие данные и цикл по мини-пакетам данных.

Для каждого мини-пакета:

  • Сгенерируйте dlarray объект, содержащий массив случайных значений для сети генератора.

  • Для обучения графического процессора преобразуйте данные в gpuArray Объект (Parallel Computing Toolbox).

  • Оцените градиенты модели с помощью dlfeval и функции помощника, modelDiscriminatorGradients и modelGeneratorGradients.

  • Обновите сетевые параметры с помощью adamupdate функция.

Инициализируйте параметры для Адама.

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Можно установить saveCheckpoints к true чтобы сохранить обновленные веса и состояния к MAT регистрируют каждые десять эпох. Можно затем использовать этот файл MAT, чтобы возобновить обучение, если это прервано. В целях этого примера, набор saveCheckpoints к false.

saveCheckpoints = false;

Задайте длину входа генератора.

numLatentInputs = 100;

Обучите GAN. Это может занять несколько часов, чтобы запуститься.

iteration = 0;

for epoch = 1:maxEpochs

    % Shuffle the data.
    idx = randperm(size(STrain,4));
    STrain = STrain(:,:,:,idx);

    % Loop over mini-batches.
    for index = 1:numIterationsPerEpoch

        iteration = iteration + 1;

        % Read mini-batch of data.
        dlX = STrain(:,:,:,(index-1)*miniBatchSize+1:index*miniBatchSize);
        dlX = dlarray(dlX,'SSCB');

        % Generate latent inputs for the generator network.
        Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,'single') - 0.5 ) ;
        dlZ = dlarray(Z);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
            dlX = gpuArray(dlX);
        end

        % Evaluate the discriminator gradients using dlfeval and the
        % |modelDiscriminatorGradients| helper function.
        gradientsDiscriminator = ...
            dlfeval(@modelDiscriminatorGradients,discriminatorParameters,generatorParameters,dlX,dlZ);

        % Update the discriminator network parameters.
        [discriminatorParameters,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(discriminatorParameters,gradientsDiscriminator, ...
            trailingAvgDiscriminator,trailingAvgSqDiscriminator,iteration, ...
            learnRateDiscriminator,gradientDecayFactor,squaredGradientDecayFactor);

        % Generate latent inputs for the generator network.
        Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,'single') - 0.5 ) ;
        dlZ = dlarray(Z);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end

        % Evaluate the generator gradients using dlfeval and the
        % |modelGeneratorGradients| helper function.
        gradientsGenerator  = ...
            dlfeval(@modelGeneratorGradients,discriminatorParameters,generatorParameters,dlZ);

        % Update the generator network parameters.
        [generatorParameters,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(generatorParameters,gradientsGenerator, ...
            trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
            learnRateGenerator,gradientDecayFactor,squaredGradientDecayFactor);
    end

    % Every 10 iterations, save a training snapshot to a MAT file.
    if saveCheckpoints && mod(epoch,10)==0
        fprintf('Epoch %d out of %d complete\n',epoch,maxEpochs);
        % Save checkpoint in case training is interrupted.
        save('audiogancheckpoint.mat',...
            'generatorParameters','discriminatorParameters',...
            'trailingAvgDiscriminator','trailingAvgSqDiscriminator',...
            'trailingAvgGenerator','trailingAvgSqGenerator','iteration');
    end
end

Синтезируйте звуки

Теперь, когда вы обучили сеть, можно исследовать процесс синтеза более подробно.

Обученный генератор барабанного боя синтезирует матрицы кратковременного преобразования Фурье (STFT) от входных массивов случайных значений. Обратная операция STFT (ISTFT) преобразует частоту времени STFT в синтезируемый звуковой сигнал временного интервала.

Загрузите веса предварительно обученного генератора. Эти веса были получены путем выполнения обучения, подсвеченного в предыдущем разделе в течение 1 000 эпох.

load(matFileName,'generatorParameters','SMean','SStd');

Генератор берет 1 1 100 векторами из случайных значений как вход. Сгенерируйте демонстрационный входной вектор.

numLatentInputs = 100;
dlZ = dlarray(2 * ( rand(1,1,numLatentInputs,1,'single') - 0.5 ));

Передайте случайный вектор генератору, чтобы создать изображение STFT. generatorParameters структура, содержащая веса предварительно обученного генератора.

dlXGenerated = modelGenerator(dlZ,generatorParameters);

Преобразуйте dlarray STFT к матрице с одинарной точностью.

S = dlXGenerated.extractdata;

Транспонируйте STFT, чтобы выровнять его размерности с istft функция.

S = S.';

STFT 128 128 матрица, где первая размерность представляет 128 интервалов частоты, линейно расположенных с интервалами от 0 до 8 кГц. Генератор был обучен сгенерировать односторонний STFT от длины БПФ 256 с последним не использованным интервалом. Повторно введите тот интервал путем вставки строки нулей в STFT.

S = [S ; zeros(1,128)];

Вернитесь нормализация и масштабирующиеся шаги, используемые, когда вы сгенерировали STFTs для обучения.

S = S * 3;
S = (S.*SStd) + SMean;

Преобразуйте STFT от логарифмической области до линейной области.

S = exp(S);

Преобразуйте STFT от одностороннего до двухстороннего.

S = [S; S(end-1:-1:2,:)];

Заполните нулями, чтобы удалить краевые эффекты окна.

S = [zeros(256,100) S zeros(256,100)];

Матрица STFT не содержит информации о фазе. Используйте быструю версию алгоритма Гриффина-Лима с 20 итерациями, чтобы оценить фазу сигнала и произвести аудиосэмплы.

myAudio = stftmag2sig(S,256, ...
    'FrequencyRange','twosided', ...
    'Window',hann(256,'periodic'), ...
    'OverlapLength',128, ...
    'MaxIterations',20, ...
    'Method','fgla');
myAudio = myAudio./max(abs(myAudio),[],'all');
myAudio = myAudio(128*100:end-128*100);

Слушайте синтезируемый барабанный бой.

sound(myAudio,fs)

Постройте синтезируемый барабанный бой.

t = (0:length(myAudio)-1)/fs;
plot(t,myAudio)
grid on
xlabel('Time (s)')
title('Synthesized GAN Sound')

Постройте STFT синтезируемого барабанного боя.

figure
stft(myAudio,fs,'Window',hann(256,'periodic'),'OverlapLength',128);

Функция генератора модели

modelGenerator функция увеличивает масштаб 1 1 100 массивами (dlX) к 128 128 1 массивом (dlY). parameters структура, содержащая веса слоев генератора. Архитектура генератора задана в Таблице 4 [1].

function dlY = modelGenerator(dlX,parameters)

dlY = fullyconnect(dlX,parameters.FC.Weights,parameters.FC.Bias,'Dataformat','SSCB');

dlY = reshape(dlY,[1024 4 4 size(dlY,2)]);
dlY = permute(dlY,[3 2 1 4]);
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv1.Weights,parameters.Conv1.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = tanh(dlY);
end

Функция различителя модели

modelDiscriminator функционируйте берет 128 128 изображения и выводит скалярный счет предсказания. Архитектура различителя задана в Таблице 5 [1].

function dlY = modelDiscriminator(dlX,parameters)

dlY = dlconv(dlX,parameters.Conv1.Weights,parameters.Conv1.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = stripdims(dlY);
dlY = permute(dlY,[3 2 1 4]);
dlY = reshape(dlY,4*4*64*16,numel(dlY)/(4*4*64*16));

weights = parameters.FC.Weights;
bias = parameters.FC.Bias;
dlY = fullyconnect(dlY,weights,bias,'Dataformat','CB');

end

Функция градиентов различителя модели

modelDiscriminatorGradients функции берут в качестве входа генератор и параметры различителя generatorParameters и discriminatorParameters, мини-пакет входных данных dlX, и массив случайных значений dlZ, и возвращает градиенты потери различителя относительно настраиваемых параметров в сетях.

function gradientsDiscriminator = modelDiscriminatorGradients(discriminatorParameters , generatorParameters, dlX, dlZ)

% Calculate the predictions for real data with the discriminator network.
dlYPred = modelDiscriminator(dlX,discriminatorParameters);

% Calculate the predictions for generated data with the discriminator network.
dlXGenerated     = modelGenerator(dlZ,generatorParameters);
dlYPredGenerated = modelDiscriminator(dlarray(dlXGenerated,'SSCB'),discriminatorParameters);

% Calculate the GAN loss
lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsDiscriminator = dlgradient(lossDiscriminator,discriminatorParameters);

end

Функция градиентов генератора модели

modelGeneratorGradients функционируйте берет в качестве входа различитель и настраиваемые параметры генератора и массив случайных значений dlZ, и возвращает градиенты потери генератора относительно настраиваемых параметров в сетях.

function gradientsGenerator = modelGeneratorGradients(discriminatorParameters, generatorParameters , dlZ)

% Calculate the predictions for generated data with the discriminator network.
dlXGenerated = modelGenerator(dlZ,generatorParameters);
dlYPredGenerated = modelDiscriminator(dlarray(dlXGenerated,'SSCB'),discriminatorParameters);

% Calculate the GAN loss
lossGenerator = ganGeneratorLoss(dlYPredGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, generatorParameters);

end

Функция потерь различителя

Цель различителя не состоит в том, чтобы дурачить генератор. Чтобы максимизировать вероятность, что различитель успешно различает между действительными и сгенерированными изображениями, минимизируйте функцию потерь различителя. Функция потерь для генератора следует за подходом DCGAN, подсвеченным в [1].

function  lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated)

fake = dlarray(zeros(1,size(dlYPred,2)));
real = dlarray(ones(1,size(dlYPred,2)));

D_loss = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,fake));
D_loss = D_loss + mean(sigmoid_cross_entropy_with_logits(dlYPred,real));
lossDiscriminator  = D_loss / 2;
end

Функция потерь генератора

Цель генератора состоит в том, чтобы сгенерировать данные, которые различитель классифицирует как "действительные". Чтобы максимизировать вероятность, что изображения от генератора классифицируются как действительные различителем, минимизируйте функцию потерь генератора. Функция потерь для генератора следует за подходом глубоко сверточной порождающей adverarial сети (DCGAN), подсвеченным в [1].

function lossGenerator = ganGeneratorLoss(dlYPredGenerated)
real = dlarray(ones(1,size(dlYPredGenerated,2)));
lossGenerator = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,real));
end

Инициализатор весов различителя

initializeDiscriminatorWeights инициализирует веса различителя с помощью алгоритма Glorot.

function discriminatorParameters = initializeDiscriminatorWeights

filterSize = [5 5];
dim = 64;

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]);
bias = zeros(1,1,dim,'single');
discriminatorParameters.Conv1.Weights = dlarray(weights);
discriminatorParameters.Conv1.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]);
bias = zeros(1,1,2*dim,'single');
discriminatorParameters.Conv2.Weights = dlarray(weights);
discriminatorParameters.Conv2.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]);
bias = zeros(1,1,4*dim,'single');
discriminatorParameters.Conv3.Weights = dlarray(weights);
discriminatorParameters.Conv3.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]);
bias = zeros(1,1,8*dim,'single');
discriminatorParameters.Conv4.Weights = dlarray(weights);
discriminatorParameters.Conv4.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]);
bias = zeros(1,1,16*dim,'single');
discriminatorParameters.Conv5.Weights = dlarray(weights);
discriminatorParameters.Conv5.Bias = dlarray(bias);

% fully connected
weights = iGlorotInitialize([1,4 * 4 * dim * 16]);
bias = zeros(1,1,'single');
discriminatorParameters.FC.Weights = dlarray(weights);
discriminatorParameters.FC.Bias = dlarray(bias);
end

Инициализатор весов генератора

initializeGeneratorWeights инициализирует веса генератора с помощью алгоритма Glorot.

function generatorParameters = initializeGeneratorWeights

dim = 64;

% Dense 1
weights = iGlorotInitialize([dim*256,100]);
bias = zeros(dim*256,1,'single');
generatorParameters.FC.Weights = dlarray(weights);
generatorParameters.FC.Bias = dlarray(bias);

filterSize = [5 5];

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]);
bias = zeros(1,1,dim*8,'single');
generatorParameters.Conv1.Weights = dlarray(weights);
generatorParameters.Conv1.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]);
bias = zeros(1,1,dim*4,'single');
generatorParameters.Conv2.Weights = dlarray(weights);
generatorParameters.Conv2.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]);
bias = zeros(1,1,dim*2,'single');
generatorParameters.Conv3.Weights = dlarray(weights);
generatorParameters.Conv3.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]);
bias = zeros(1,1,dim,'single');
generatorParameters.Conv4.Weights = dlarray(weights);
generatorParameters.Conv4.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]);
bias = zeros(1,1,1,'single');
generatorParameters.Conv5.Weights = dlarray(weights);
generatorParameters.Conv5.Bias = dlarray(bias);
end

Синтезируйте барабанный бой

synthesizeDrumBeat использует предварительно обученную сеть, чтобы синтезировать барабанные ритмы.

function y = synthesizeDrumBeat

persistent pGeneratorParameters pMean pSTD
if isempty(pGeneratorParameters)
    % If the MAT file does not exist, download it
    filename = 'drumGeneratorWeights.mat';
    load(filename,'SMean','SStd','generatorParameters');
    pMean = SMean;
    pSTD  = SStd;
    pGeneratorParameters = generatorParameters;
end

% Generate random vector
dlZ = dlarray(2 * ( rand(1,1,100,1,'single') - 0.5 ));

% Generate spectrograms
dlXGenerated = modelGenerator(dlZ,pGeneratorParameters);

% Convert from dlarray to single
S = dlXGenerated.extractdata;

S = S.';
% Zero-pad to remove edge effects
S = [S ; zeros(1,128)];

% Reverse steps from training
S = S * 3;
S = (S.*pSTD) + pMean;
S = exp(S);

% Make it two-sided
S = [S ; S(end-1:-1:2,:)];
% Pad with zeros at end and start
S = [zeros(256,100) S zeros(256,100)];

% Reconstruct the signal using a fast Griffin-Lim algorithm.
myAudio = stftmag2sig(gather(S),256, ...
    'FrequencyRange','twosided', ...
    'Window',hann(256,'periodic'), ...
    'OverlapLength',128, ...
    'MaxIterations',20, ...
    'Method','fgla');
myAudio = myAudio./max(abs(myAudio),[],'all');
y = myAudio(128*100:end-128*100);
end

Служебные функции

function out = sigmoid_cross_entropy_with_logits(x,z)
out = max(x, 0) - x .* z + log(1 + exp(-abs(x)));
end

function w = iGlorotInitialize(sz)
if numel(sz) == 2
    numInputs = sz(2);
    numOutputs = sz(1);
else
    numInputs = prod(sz(1:3));
    numOutputs = prod(sz([1 2 4]));
end
multiplier = sqrt(2 / (numInputs + numOutputs));
w = multiplier * sqrt(3) * (2 * rand(sz,'single') - 1);
end

Ссылка

[1] Donahue, C., Дж. Маколи и М. Пюккетт. 2019. "Соперничающий аудио синтез". ICLR.