Обучите генеративную состязательную сеть (GAN) для синтеза звука

В этом примере показано, как обучить и использовать генеративную состязательную сеть (GAN) для генерации звуков.

Введение

В генеративных состязательных сетях генератор и дискриминатор конкурируют друг с другом, улучшая качество генерации.

GAN вызвали значительный интерес в области обработки звука и речи. Приложения включают синтез текста в речь, преобразование голоса и улучшение речи.

Этот пример обучает GAN для неконтролируемого синтеза аудиоволн. GAN в этом примере генерирует звуки друмбеата. Такого же подхода можно придерживаться, чтобы генерировать другие типы звука, включая речь.

Синтезируйте аудио с предварительно обученной GAN

Прежде чем вы обучите GAN с нуля, вы будете использовать предварительно обученный генератор GAN для синтеза ударов барабана.

Загрузите предварительно обученный генератор.

matFileName = 'drumGeneratorWeights.mat';
if ~exist(matFileName,'file')
    websave(matFileName,'https://www.mathworks.com/supportfiles/audio/GanAudioSynthesis/drumGeneratorWeights.mat');
end

Функция synthesizeDrumBeat вызывает предварительно обученную сеть, чтобы синтезировать drumbeat, выбранный с частотой 16 кГц. The synthesizeDrumBeat функция включена в конец этого примера.

Синтезируйте друмбеат и слушайте его.

drum = synthesizeDrumBeat;

fs = 16e3;
sound(drum,fs)

Постройте график синтезированного барабана.

t = (0:length(drum)-1)/fs;
plot(t,drum)
grid on
xlabel('Time (s)')
title('Synthesized Drum Beat')

Можно использовать синтезатор drumbeat с другими аудио эффектов, чтобы создать более сложные приложения. Для примера можно применить реверберацию к синтезированным ударам барабана.

Создайте reverberator и откройте его параметрический тюнер UI. Этот пользовательский интерфейс позволяет вам настраивать reverberator параметры во время выполнения симуляции.

reverb = reverberator('SampleRate',fs);
parameterTuner(reverb);

Создайте объект возможностей для визуализации ударов барабана.

ts = timescope('SampleRate',fs, ...
    'TimeSpanSource','Property', ...
    'TimeSpanOverrunAction','Scroll', ...
    'TimeSpan',10, ...
    'BufferLength',10*256*64, ...
    'ShowGrid',true, ...
    'YLimits',[-1 1]);

В цикле синтезируйте удары барабана и примените реверберацию. Используйте пользовательский интерфейс параметра tuner, чтобы настроить реверберацию. Если вы хотите запустить симуляцию в течение более длительного времени, увеличьте значение loopCount параметр.

loopCount = 20;
for ii = 1:loopCount
    drum = synthesizeDrumBeat;
    drum = reverb(drum);
    ts(drum(:,1));
    soundsc(drum,fs)
    pause(0.5)
end

Обучите GAN

Теперь, когда вы видели предварительно обученный генератор барабана в действии, можно детально изучить процесс обучения.

GAN является типом нейронной сети для глубокого обучения, которая генерирует данные с характеристиками, подобными обучающим данным.

GAN состоит из двух сетей, которые обучают вместе, генератора и дискриминатора:

  • Генератор - учитывая вектор или случайные значения в качестве входных, эта сеть генерирует данные с той же структурой, что и обучающие данные. Задача генератора - обмануть дискриминатора.

  • Дискриминатор - Учитывая пакеты данных, содержащие наблюдения как из обучающих данных, так и из сгенерированных данных, эта сеть пытается классифицировать наблюдения как реальные или сгенерированные.

Чтобы максимизировать эффективность генератора, максимизируйте потерю дискриминатора, когда даны сгенерированные данные. То есть цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как действительные. Чтобы максимизировать эффективность дискриминатора, минимизируйте потерю дискриминатора, когда заданы пакеты как реальных, так и сгенерированных данных. В идеале эти стратегии приводят к генератору, который генерирует убедительно реалистичные данные, и дискриминатору, который научился сильным представлениям функций, которые характерны для обучающих данных.

В этом примере вы обучаете генератор, чтобы создать поддельные представления преобразования Фурье с временной частотой (STFT) ударов барабана. Вы обучаете дискриминатора идентифицировать реальные STFT. Вы создаете настоящие STFT путем вычисления STFT коротких записей реальных ударов барабана.

Загрузка обучающих данных

Обучите GAN с помощью набора данных Drum Sound Effects [1]. Загрузите и извлечите набор данных.

url = 'http://deepyeti.ucsd.edu/cdonahue/wavegan/data/drums.tar.gz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'drums_dataset.tgz');

drumsFolder = fullfile(downloadFolder,'drums');
if ~exist(drumsFolder,'dir')
    disp('Downloading Drum Sound Effects Dataset (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте audioDatastore объект, который указывает на набор данных барабанов.

ads = audioDatastore(drumsFolder,'IncludeSubfolders',true);

Определите сеть генератора

Определите сеть, которая производит STFTs от массивов случайных значений 1 на 1 на 100. Создайте сеть что upscales массивы 1 на 1 на 100 к массивам 128 на 128 на 1, используя полносвязный слой, сопровождаемый изменять слоем и серией перемещенных слоев скручивания со слоями ReLU.

Этот рисунок показывает размерности сигнала, когда он перемещается через генератор. Архитектура генератора определена в таблице 4 [1].

Сеть генератора определяется как modelGenerator, который включен в конец этого примера.

Определите сеть дискриминатора

Определите сеть, которая классифицирует реальные и сгенерированные 128 на 128 STFT.

Создайте сеть, которая принимает 128 на 128 изображений и выводит скалярный счет предсказания с помощью ряда слоев свертки с утечками слоев ReLU, за которыми следует полносвязный слой.

Этот рисунок показывает размерности сигнала, когда он перемещается через дискриминатор. Архитектура дискриминатора определена в таблице 5 [1].

Сеть дискриминатора определяется как modelDiscriminator, который включен в конец этого примера.

Сгенерируйте реальные обучающие данные Drumbeat

Сгенерируйте данные STFT из сигналов drumbeat в datastore.

Определите параметры STFT.

fftLength = 256;
win = hann(fftLength,'periodic');
overlapLength = 128;

Чтобы ускорить обработку, распределите редукцию данных между несколькими работниками с помощью parfor.

Во-первых, определите количество разделов для набора данных. Если у вас нет Parallel Computing Toolbox™, используйте один раздел.

if ~isempty(ver('parallel'))
    pool = gcp;
    numPar = numpartitions(ads,pool);
else
    numPar = 1;
end

Для каждого раздела считывайте из datastore и вычисляйте STFT.

parfor ii = 1:numPar

    subds = partition(ads,numPar,ii);
    STrain = zeros(fftLength/2+1,128,1,numel(subds.Files));

    for idx = 1:numel(subds.Files)

        x = read(subds);

        if length(x) > fftLength*64
            % Lengthen the signal if it is too short
            x = x(1:fftLength*64);
        end

        % Convert from double-precision to single-precision
        x = single(x);

        % Scale the signal
        x = x ./ max(abs(x));

        % Zero-pad to ensure stft returns 128 windows.
        x = [x ; zeros(overlapLength,1,'like',x)];

        S0 = stft(x,'Window',win,'OverlapLength',overlapLength,'Centered',false);

        % Convert from two-sided to one-sided.
        S = S0(1:129,:);
        S = abs(S);
        STrain(:,:,:,idx) = S;
    end
    STrainC{ii} = STrain;
end

Преобразуйте выход в четырехмерный массив с STFT по четвертой размерности.

STrain = cat(4,STrainC{:});

Преобразуйте данные в шкалу журналы, чтобы лучше соответствовать восприятию человека.

STrain = log(STrain + 1e-6);

Нормализуйте обучающие данные, чтобы иметь нулевое среднее и единичное стандартное отклонение.

Вычислите среднее значение STFT и стандартное отклонение каждого интервала частоты.

SMean = mean(STrain,[2 3 4]);
SStd = std(STrain,1,[2 3 4]);

Нормализуйте каждый интервал частоты.

STrain = (STrain-SMean)./SStd;

Вычисленные STFT имеют неограниченные значения. Следуя подходу в [1], ограничьте данные обрезкой спектров до 3 стандартных отклонений и повторным преобразованием до [-1 1].

STrain = STrain/3;
Y = reshape(STrain,numel(STrain),1);
Y(Y<-1) = -1;
Y(Y>1) = 1;
STrain = reshape(Y,size(STrain));

Сбросьте последнюю ячейку частоты, чтобы принудить количество интервалов STFT к степени двойки (что хорошо работает со сверточными слоями).

STrain = STrain(1:end-1,:,:,:);

Транспозиция размерностей при подготовке к подаче на дискриминатор.

STrain = permute(STrain,[2 1 3 4]);

Настройка опций обучения

Train с мини-партией размером 64 на 1000 эпох.

maxEpochs = 1000;
miniBatchSize = 64;

Вычислите количество итераций, необходимых для использования данных.

numIterationsPerEpoch = floor(size(STrain,4)/miniBatchSize);

Задайте опции для оптимизации Adam. Установите скорость обучения генератора и дискриминатора равной 0.0002. Для обеих сетей используйте коэффициент градиентного распада 0,5 и квадратный коэффициент градиентного распада 0,999.

learnRateGenerator = 0.0002;
learnRateDiscriminator = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обучите на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™.

executionEnvironment = "auto";

Инициализируйте веса генератора и дискриминатора. The initializeGeneratorWeights и initializeDiscriminatorWeights функции возвращают случайные веса, полученные с помощью равномерной инициализации Глорота. Функции включены в конец этого примера.

generatorParameters = initializeGeneratorWeights;
discriminatorParameters = initializeDiscriminatorWeights;

Обучите модель

Обучите модель с помощью пользовательского цикла обучения. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации.

Для каждой эпохи перетасуйте обучающие данные и закольцовывайте мини-пакеты данных.

Для каждого мини-пакета:

  • Сгенерируйте dlarray объект, содержащий массив случайных значений для сети генератора.

  • Для обучения графический процессор преобразуйте данные в gpuArray (Parallel Computing Toolbox) объект.

  • Оцените градиенты модели с помощью dlfeval (Deep Learning Toolbox) и вспомогательные функции, modelDiscriminatorGradients и modelGeneratorGradients.

  • Обновляйте параметры сети с помощью adamupdate (Deep Learning Toolbox) функция.

Инициализируйте параметры для Adam.

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Можно задать saveCheckpoints на true сохранение обновленных весов и состояний в файл MAT каждые десять эпох. Затем можно использовать этот файл MAT для возобновления обучения, если оно прервано. Для целей этого примера задайте saveCheckpoints на false.

saveCheckpoints = false;

Задайте длину входа генератора.

numLatentInputs = 100;

Обучите GAN. Запуск может занять несколько часов.

iteration = 0;

for epoch = 1:maxEpochs

    % Shuffle the data.
    idx = randperm(size(STrain,4));
    STrain = STrain(:,:,:,idx);

    % Loop over mini-batches.
    for index = 1:numIterationsPerEpoch

        iteration = iteration + 1;

        % Read mini-batch of data.
        dlX = STrain(:,:,:,(index-1)*miniBatchSize+1:index*miniBatchSize);
        dlX = dlarray(dlX,'SSCB');

        % Generate latent inputs for the generator network.
        Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,'single') - 0.5 ) ;
        dlZ = dlarray(Z);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
            dlX = gpuArray(dlX);
        end

        % Evaluate the discriminator gradients using dlfeval and the
        % |modelDiscriminatorGradients| helper function.
        gradientsDiscriminator = ...
            dlfeval(@modelDiscriminatorGradients,discriminatorParameters,generatorParameters,dlX,dlZ);

        % Update the discriminator network parameters.
        [discriminatorParameters,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(discriminatorParameters,gradientsDiscriminator, ...
            trailingAvgDiscriminator,trailingAvgSqDiscriminator,iteration, ...
            learnRateDiscriminator,gradientDecayFactor,squaredGradientDecayFactor);

        % Generate latent inputs for the generator network.
        Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,'single') - 0.5 ) ;
        dlZ = dlarray(Z);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end

        % Evaluate the generator gradients using dlfeval and the
        % |modelGeneratorGradients| helper function.
        gradientsGenerator  = ...
            dlfeval(@modelGeneratorGradients,discriminatorParameters,generatorParameters,dlZ);

        % Update the generator network parameters.
        [generatorParameters,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(generatorParameters,gradientsGenerator, ...
            trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
            learnRateGenerator,gradientDecayFactor,squaredGradientDecayFactor);
    end

    % Every 10 iterations, save a training snapshot to a MAT file.
    if saveCheckpoints && mod(epoch,10)==0
        fprintf('Epoch %d out of %d complete\n',epoch,maxEpochs);
        % Save checkpoint in case training is interrupted.
        save('audiogancheckpoint.mat',...
            'generatorParameters','discriminatorParameters',...
            'trailingAvgDiscriminator','trailingAvgSqDiscriminator',...
            'trailingAvgGenerator','trailingAvgSqGenerator','iteration');
    end
end

Синтезируйте звуки

Теперь, когда вы обучили сеть, можно исследовать процесс синтеза более подробно.

Обученный генератор барабана синтезирует кратковременные матрицы преобразования Фурье (STFT) из входных массивов случайных значений. Обратная операция STFT (ISTFT) преобразует частотно-временную STFT в синтезированный аудиосигнал временной области.

Загрузите веса предварительно обученного генератора. Эти веса были получены путем выполнения обучения, выделенного в предыдущем разделе, для 1000 эпох.

load(matFileName,'generatorParameters','SMean','SStd');

Генератор берет векторы случайных значений 1 на 1 на 100 как вход. Сгенерируйте вектор выборки входа.

numLatentInputs = 100;
dlZ = dlarray(2 * ( rand(1,1,numLatentInputs,1,'single') - 0.5 ));

Передайте случайный вектор в генератор, чтобы создать изображение STFT. generatorParameters - структура, содержащая веса предварительно обученного генератора.

dlXGenerated = modelGenerator(dlZ,generatorParameters);

Преобразуйте dlarray STFT в матрицу с одной точностью.

S = dlXGenerated.extractdata;

Транспонируйте STFT, чтобы выровнять его размерности по istft функция.

S = S.';

STFT является матрицей 128 на 128, где первая размерность представляет 128 интервалов частоты, линейно разнесенных от 0 до 8 кГц. Генератор был обучен генерировать односторонний STFT из длины БПФ 256 с опущенным последним интервалом. Повторно введите этот интервал путем вставки строки нулей в STFT.

S = [S ; zeros(1,128)];

Верните шаги нормализации и масштабирования, используемые при создании STFT для обучения.

S = S * 3;
S = (S.*SStd) + SMean;

Преобразуйте STFT из области журнала в линейную область.

S = exp(S);

Преобразуйте STFT из одностороннего в двусторонний.

S = [S; S(end-1:-1:2,:)];

Дополните нулями, чтобы удалить эффекты кромки окна.

S = [zeros(256,100) S zeros(256,100)];

Матрица STFT не содержит никакой информации о фазе. Используйте быструю версию алгоритма Гриффина-Лима с 20 итерациями, чтобы оценить фазу сигнала и произвести аудио выборок.

myAudio = stftmag2sig(S,256, ...
    'FrequencyRange','twosided', ...
    'Window',hann(256,'periodic'), ...
    'OverlapLength',128, ...
    'MaxIterations',20, ...
    'Method','fgla');
myAudio = myAudio./max(abs(myAudio),[],'all');
myAudio = myAudio(128*100:end-128*100);

Послушайте синтезированный барабан.

sound(myAudio,fs)

Постройте график синтезированного барабана.

t = (0:length(myAudio)-1)/fs;
plot(t,myAudio)
grid on
xlabel('Time (s)')
title('Synthesized GAN Sound')

Постройте график STFT синтезированной барабанной трубы.

figure
stft(myAudio,fs,'Window',hann(256,'periodic'),'OverlapLength',128);

Модель функции генератора

The modelGenerator функционируйте upscales массивы 1 на 1 на 100 (dlX) к массивам 128 на 128 на 1 (dlY). parameters - структура, содержащая веса слоев генератора. Архитектура генератора определена в таблице 4 [1].

function dlY = modelGenerator(dlX,parameters)

dlY = fullyconnect(dlX,parameters.FC.Weights,parameters.FC.Bias,'Dataformat','SSCB');

dlY = reshape(dlY,[1024 4 4 size(dlY,2)]);
dlY = permute(dlY,[3 2 1 4]);
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv1.Weights,parameters.Conv1.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = tanh(dlY);
end

Моделируйте функцию дискриминатора

The modelDiscriminator функция принимает 128 на 128 изображений и выводит скалярный счет предсказания. Архитектура дискриминатора определена в таблице 5 [1].

function dlY = modelDiscriminator(dlX,parameters)

dlY = dlconv(dlX,parameters.Conv1.Weights,parameters.Conv1.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = stripdims(dlY);
dlY = permute(dlY,[3 2 1 4]);
dlY = reshape(dlY,4*4*64*16,numel(dlY)/(4*4*64*16));

weights = parameters.FC.Weights;
bias = parameters.FC.Bias;
dlY = fullyconnect(dlY,weights,bias,'Dataformat','CB');

end

Моделируйте функцию градиентов дискриминатора

The modelDiscriminatorGradients функции берут за вход параметры генератора и дискриминатора generatorParameters и discriminatorParametersмини-пакет входных данных dlXи массив случайных значений dlZ, и возвращает градиенты потерь дискриминатора относительно настраиваемых параметров в сетях.

function gradientsDiscriminator = modelDiscriminatorGradients(discriminatorParameters , generatorParameters, dlX, dlZ)

% Calculate the predictions for real data with the discriminator network.
dlYPred = modelDiscriminator(dlX,discriminatorParameters);

% Calculate the predictions for generated data with the discriminator network.
dlXGenerated     = modelGenerator(dlZ,generatorParameters);
dlYPredGenerated = modelDiscriminator(dlarray(dlXGenerated,'SSCB'),discriminatorParameters);

% Calculate the GAN loss
lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsDiscriminator = dlgradient(lossDiscriminator,discriminatorParameters);

end

Функция градиентов генератора модели

The modelGeneratorGradients функция принимает как вход дискриминатор и генератор настраиваемых параметров и массив случайных значений dlZи возвращает градиенты потерь генератора относительно настраиваемых параметров в сетях.

function gradientsGenerator = modelGeneratorGradients(discriminatorParameters, generatorParameters , dlZ)

% Calculate the predictions for generated data with the discriminator network.
dlXGenerated = modelGenerator(dlZ,generatorParameters);
dlYPredGenerated = modelDiscriminator(dlarray(dlXGenerated,'SSCB'),discriminatorParameters);

% Calculate the GAN loss
lossGenerator = ganGeneratorLoss(dlYPredGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, generatorParameters);

end

Функция потерь дискриминатора

Цель дискриминатора - не быть обманутым генератором. Чтобы максимизировать вероятность того, что дискриминатор успешно различает реальные и сгенерированные изображения, минимизируйте функцию потерь дискриминатора. Функция потерь для генератора следует подходу DCGAN, выделенному в [1].

function  lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated)

fake = dlarray(zeros(1,size(dlYPred,2)));
real = dlarray(ones(1,size(dlYPred,2)));

D_loss = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,fake));
D_loss = D_loss + mean(sigmoid_cross_entropy_with_logits(dlYPred,real));
lossDiscriminator  = D_loss / 2;
end

Функция потерь генератора

Цель генератора состоит в том, чтобы сгенерировать данные, которые дискриминатор классифицирует как «действительные». Чтобы максимизировать вероятность того, что изображения с генератора классифицируются дискриминатором как вещественные, минимизируйте функцию потерь генератора. Функция потерь для генератора следует подходу глубокой сверточной генеративной рекламной сети (DCGAN), выделенному в [1].

function lossGenerator = ganGeneratorLoss(dlYPredGenerated)
real = dlarray(ones(1,size(dlYPredGenerated,2)));
lossGenerator = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,real));
end

Инициализатор весов дискриминатора

initializeDiscriminatorWeights инициализирует веса дискриминаторов с помощью алгоритма Glorot.

function discriminatorParameters = initializeDiscriminatorWeights

filterSize = [5 5];
dim = 64;

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]);
bias = zeros(1,1,dim,'single');
discriminatorParameters.Conv1.Weights = dlarray(weights);
discriminatorParameters.Conv1.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]);
bias = zeros(1,1,2*dim,'single');
discriminatorParameters.Conv2.Weights = dlarray(weights);
discriminatorParameters.Conv2.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]);
bias = zeros(1,1,4*dim,'single');
discriminatorParameters.Conv3.Weights = dlarray(weights);
discriminatorParameters.Conv3.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]);
bias = zeros(1,1,8*dim,'single');
discriminatorParameters.Conv4.Weights = dlarray(weights);
discriminatorParameters.Conv4.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]);
bias = zeros(1,1,16*dim,'single');
discriminatorParameters.Conv5.Weights = dlarray(weights);
discriminatorParameters.Conv5.Bias = dlarray(bias);

% fully connected
weights = iGlorotInitialize([1,4 * 4 * dim * 16]);
bias = zeros(1,1,'single');
discriminatorParameters.FC.Weights = dlarray(weights);
discriminatorParameters.FC.Bias = dlarray(bias);
end

Инициализатор весов генератора

initializeGeneratorWeights инициализирует веса генератора с помощью алгоритма Glorot.

function generatorParameters = initializeGeneratorWeights

dim = 64;

% Dense 1
weights = iGlorotInitialize([dim*256,100]);
bias = zeros(dim*256,1,'single');
generatorParameters.FC.Weights = dlarray(weights);
generatorParameters.FC.Bias = dlarray(bias);

filterSize = [5 5];

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]);
bias = zeros(1,1,dim*8,'single');
generatorParameters.Conv1.Weights = dlarray(weights);
generatorParameters.Conv1.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]);
bias = zeros(1,1,dim*4,'single');
generatorParameters.Conv2.Weights = dlarray(weights);
generatorParameters.Conv2.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]);
bias = zeros(1,1,dim*2,'single');
generatorParameters.Conv3.Weights = dlarray(weights);
generatorParameters.Conv3.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]);
bias = zeros(1,1,dim,'single');
generatorParameters.Conv4.Weights = dlarray(weights);
generatorParameters.Conv4.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]);
bias = zeros(1,1,1,'single');
generatorParameters.Conv5.Weights = dlarray(weights);
generatorParameters.Conv5.Bias = dlarray(bias);
end

Синтезируйте Drumbeat

synthesizeDrumBeat использует предварительно обученную сеть для синтеза ударов барабана.

function y = synthesizeDrumBeat

persistent pGeneratorParameters pMean pSTD
if isempty(pGeneratorParameters)
    % If the MAT file does not exist, download it
    filename = 'drumGeneratorWeights.mat';
    load(filename,'SMean','SStd','generatorParameters');
    pMean = SMean;
    pSTD  = SStd;
    pGeneratorParameters = generatorParameters;
end

% Generate random vector
dlZ = dlarray(2 * ( rand(1,1,100,1,'single') - 0.5 ));

% Generate spectrograms
dlXGenerated = modelGenerator(dlZ,pGeneratorParameters);

% Convert from dlarray to single
S = dlXGenerated.extractdata;

S = S.';
% Zero-pad to remove edge effects
S = [S ; zeros(1,128)];

% Reverse steps from training
S = S * 3;
S = (S.*pSTD) + pMean;
S = exp(S);

% Make it two-sided
S = [S ; S(end-1:-1:2,:)];
% Pad with zeros at end and start
S = [zeros(256,100) S zeros(256,100)];

% Reconstruct the signal using a fast Griffin-Lim algorithm.
myAudio = stftmag2sig(gather(S),256, ...
    'FrequencyRange','twosided', ...
    'Window',hann(256,'periodic'), ...
    'OverlapLength',128, ...
    'MaxIterations',20, ...
    'Method','fgla');
myAudio = myAudio./max(abs(myAudio),[],'all');
y = myAudio(128*100:end-128*100);
end

Служебные функции

function out = sigmoid_cross_entropy_with_logits(x,z)
out = max(x, 0) - x .* z + log(1 + exp(-abs(x)));
end

function w = iGlorotInitialize(sz)
if numel(sz) == 2
    numInputs = sz(2);
    numOutputs = sz(1);
else
    numInputs = prod(sz(1:3));
    numOutputs = prod(sz([1 2 4]));
end
multiplier = sqrt(2 / (numInputs + numOutputs));
w = multiplier * sqrt(3) * (2 * rand(sz,'single') - 1);
end

Ссылка

[1] Donahue, C., J. McAuley, and M. Puckette. 2019. «Синтез состязательного звука». ICLR.