exponenta event banner

Обучающая генеративная состязательная сеть (GAN) для синтеза звука

В этом примере показано, как обучать и использовать генеративную состязательную сеть (GAN) для генерации звуков.

Введение

В генеративных состязательных сетях генератор и дискриминатор конкурируют друг с другом для улучшения качества генерации.

GAN вызвали значительный интерес в области обработки аудио и речи. Приложения включают синтез текста в речь, преобразование речи и улучшение речи.

Этот пример обучает GAN для неконтролируемого синтеза звуковых сигналов. В этом примере GAN генерирует звуки drumbeat. Такой же подход может применяться для генерации других типов звука, включая речь.

Синтезируйте аудио с помощью предварительно обученного GAN

Прежде чем тренировать GAN с нуля, вы будете использовать предварительно обученный генератор GAN для синтеза барабанных ударов.

Загрузите предварительно обученный генератор.

matFileName = 'drumGeneratorWeights.mat';
if ~exist(matFileName,'file')
    websave(matFileName,'https://www.mathworks.com/supportfiles/audio/GanAudioSynthesis/drumGeneratorWeights.mat');
end

Функция synthesizeDrumBeat вызывает предварительно обученную сеть для синтеза drumbeat, дискретизированного на частоте 16 кГц. synthesizeDrumBeat включена в конце этого примера.

Синтезируйте барабан и слушайте его.

drum = synthesizeDrumBeat;

fs = 16e3;
sound(drum,fs)

Постройте график синтезированного барабана.

t = (0:length(drum)-1)/fs;
plot(t,drum)
grid on
xlabel('Time (s)')
title('Synthesized Drum Beat')

Для создания более сложных приложений можно использовать синтезатор барабанов с другими звуковыми эффектами. Например, можно применить реверберацию к синтезированным ударам барабана.

Создать reverberator (Audio Toolbox) и откройте пользовательский интерфейс тюнера параметров. Этот пользовательский интерфейс позволяет настраивать reverberator параметры по мере выполнения моделирования.

reverb = reverberator('SampleRate',fs);
parameterTuner(reverb);

Создайте объект области времени для визуализации ударов барабана.

ts = timescope('SampleRate',fs, ...
    'TimeSpanSource','Property', ...
    'TimeSpanOverrunAction','Scroll', ...
    'TimeSpan',10, ...
    'BufferLength',10*256*64, ...
    'ShowGrid',true, ...
    'YLimits',[-1 1]);

В петле синтезировать барабанные удары и применить реверберацию. Для настройки реверберации используйте пользовательский интерфейс тюнера параметров. При необходимости выполнения моделирования в течение более длительного времени увеличьте значение loopCount параметр.

loopCount = 20;
for ii = 1:loopCount
    drum = synthesizeDrumBeat;
    drum = reverb(drum);
    ts(drum(:,1));
    soundsc(drum,fs)
    pause(0.5)
end

Обучение GAN

Теперь, когда вы видели предварительно обученный генератор барабанов в действии, вы можете исследовать тренировочный процесс подробно.

GAN - это тип сети глубокого обучения, которая генерирует данные с характеристиками, аналогичными данным обучения.

GAN состоит из двух сетей, которые соединяются друг с другом, генератора и дискриминатора:

  • Генератор - учитывая вектор или случайные значения в качестве входных данных, эта сеть генерирует данные с той же структурой, что и обучающие данные. Это работа генератора обмануть дискриминатора.

  • Дискриминатор - данная сеть пытается классифицировать наблюдения как реальные или сгенерированные.

Чтобы максимизировать производительность генератора, максимизируйте потери дискриминатора при получении сгенерированных данных. То есть целью генератора является генерирование данных, которые дискриминатор классифицирует как реальные. Чтобы максимизировать производительность дискриминатора, минимизируйте потери дискриминатора при предоставлении партий как реальных, так и сгенерированных данных. В идеале эти стратегии приводят к созданию генератора, который генерирует убедительно реалистичные данные, и дискриминатора, который усвоил сильные представления признаков, характерные для обучающих данных.

В этом примере генератор обучается создавать фальшивые представления краткочастотного преобразования Фурье (STFT) барабанных ударов. Вы обучаете дискриминатора идентифицировать реальные STFT. Реальные STFT создаются путем вычисления STFT коротких записей реальных барабанных ударов.

Загрузка данных обучения

Обучение GAN с использованием набора данных Drum Sound Effects [1]. Загрузите и извлеките набор данных.

url = 'http://deepyeti.ucsd.edu/cdonahue/wavegan/data/drums.tar.gz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'drums_dataset.tgz');

drumsFolder = fullfile(downloadFolder,'drums');
if ~exist(drumsFolder,'dir')
    disp('Downloading Drum Sound Effects Dataset (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создание audioDatastore Объект (Audio Toolbox), указывающий на набор данных барабанов.

ads = audioDatastore(drumsFolder,'IncludeSubfolders',true);

Определение сети генератора

Определите сеть, которая производит STFTs от множеств случайных ценностей 1 на 1 на 100. Создайте сеть что upscales множества 1 на 1 на 100 ко множествам 128 на 128 на 1, используя полностью связанный слой, сопровождаемый изменять слоем и серией перемещенных слоев скручивания со слоями ReLU.

На этом рисунке показаны размеры сигнала при его прохождении через генератор. Архитектура генератора определена в таблице 4 из [1].

Сеть генератора определена в modelGenerator, который включен в конце этого примера.

Определение сети дискриминаторов

Определите сеть, которая классифицирует реальные и сгенерированные 128 на 128 STFT.

Создайте сеть, которая принимает изображения 128 на 128 и выводит скалярную оценку прогнозирования, используя ряд слоев свертки с протекающими слоями ReLU, за которыми следует полностью связанный уровень.

На этом рисунке показаны размеры сигнала при его прохождении через дискриминатор. Архитектура дискриминатора определена в таблице 5 из [1].

Сеть дискриминатора определена в modelDiscriminator, который включен в конце этого примера.

Создание реальных данных обучения Drumbeat

Создание данных STFT из сигналов drumbeat в хранилище данных.

Определите параметры STFT.

fftLength = 256;
win = hann(fftLength,'periodic');
overlapLength = 128;

Чтобы ускорить обработку, распределите извлечение элементов между несколькими работниками с помощью parfor.

Сначала определите количество разделов для набора данных. Если у вас нет Toolbox™ Parallel Computing, используйте один раздел.

if ~isempty(ver('parallel'))
    pool = gcp;
    numPar = numpartitions(ads,pool);
else
    numPar = 1;
end

Для каждого раздела считывайте из хранилища данных и вычисляйте STFT.

parfor ii = 1:numPar

    subds = partition(ads,numPar,ii);
    STrain = zeros(fftLength/2+1,128,1,numel(subds.Files));

    for idx = 1:numel(subds.Files)

        x = read(subds);

        if length(x) > fftLength*64
            % Lengthen the signal if it is too short
            x = x(1:fftLength*64);
        end

        % Convert from double-precision to single-precision
        x = single(x);

        % Scale the signal
        x = x ./ max(abs(x));

        % Zero-pad to ensure stft returns 128 windows.
        x = [x ; zeros(overlapLength,1,'like',x)];

        S0 = stft(x,'Window',win,'OverlapLength',overlapLength,'Centered',false);

        % Convert from two-sided to one-sided.
        S = S0(1:129,:);
        S = abs(S);
        STrain(:,:,:,idx) = S;
    end
    STrainC{ii} = STrain;
end

Преобразуйте выходные данные в четырехмерный массив с STFT вдоль четвертого размера.

STrain = cat(4,STrainC{:});

Преобразуйте данные в логарифмическую шкалу, чтобы лучше соответствовать человеческому восприятию.

STrain = log(STrain + 1e-6);

Нормализуйте учебные данные, чтобы они имели нулевое среднее значение и стандартное отклонение единицы измерения.

Вычислите среднее значение STFT и стандартное отклонение каждого частотного элемента.

SMean = mean(STrain,[2 3 4]);
SStd = std(STrain,1,[2 3 4]);

Нормализуйте каждый частотный блок.

STrain = (STrain-SMean)./SStd;

Вычисленные STFT имеют неограниченные значения. Следуя подходу, приведенному в [1], ограничьте данные путем отсечения спектров до 3 стандартных отклонений и масштабирования до [-1 1].

STrain = STrain/3;
Y = reshape(STrain,numel(STrain),1);
Y(Y<-1) = -1;
Y(Y>1) = 1;
STrain = reshape(Y,size(STrain));

Отбросьте последний частотный блок, чтобы принудительно увеличить число ячеек STFT до двух (что хорошо работает со сверточными слоями).

STrain = STrain(1:end-1,:,:,:);

Переставьте размеры при подготовке к подаче на дискриминатор.

STrain = permute(STrain,[2 1 3 4]);

Укажите параметры обучения

Поезд с размером мини-партии 64 на 1000 эпох.

maxEpochs = 1000;
miniBatchSize = 64;

Вычислите число итераций, необходимых для использования данных.

numIterationsPerEpoch = floor(size(STrain,4)/miniBatchSize);

Укажите параметры оптимизации Adam. Установка скорости обучения генератора и дискриминатора в 0.0002. Для обеих сетей используйте коэффициент градиентного затухания 0,5 и коэффициент градиентного затухания 0,999 в квадрате.

learnRateGenerator = 0.0002;
learnRateDiscriminator = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обучение на GPU, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений.

executionEnvironment = "auto";

Инициализация весов генератора и дискриминатора. initializeGeneratorWeights и initializeDiscriminatorWeights функции возвращают случайные веса, полученные с помощью равномерной инициализации Глорота. Функции включены в конце этого примера.

generatorParameters = initializeGeneratorWeights;
discriminatorParameters = initializeDiscriminatorWeights;

Модель поезда

Обучение модели с помощью пользовательского цикла обучения. Закольцовывать обучающие данные и обновлять сетевые параметры на каждой итерации.

Для каждой эпохи перетасуйте обучающие данные и закольцовывайте мини-пакеты данных.

Для каждой мини-партии:

  • Создать dlarray объект, содержащий массив случайных значений для генераторной сети.

  • Для обучения GPU преобразуйте данные в gpuArray (Панель параллельных вычислений).

  • Оценка градиентов модели с помощью dlfeval и вспомогательные функции, modelDiscriminatorGradients и modelGeneratorGradients.

  • Обновление параметров сети с помощью adamupdate функция.

Инициализируйте параметры для Adam.

trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminator = [];
trailingAvgSqDiscriminator = [];

Можно задать saveCheckpoints кому true для сохранения обновленных весов и состояний в MAT-файле каждые десять эпох. Затем этот MAT-файл можно использовать для возобновления обучения, если оно прервано. Для целей этого примера установите saveCheckpoints кому false.

saveCheckpoints = false;

Укажите длину входного сигнала генератора.

numLatentInputs = 100;

Тренируйте ГАН. Запуск может занять несколько часов.

iteration = 0;

for epoch = 1:maxEpochs

    % Shuffle the data.
    idx = randperm(size(STrain,4));
    STrain = STrain(:,:,:,idx);

    % Loop over mini-batches.
    for index = 1:numIterationsPerEpoch

        iteration = iteration + 1;

        % Read mini-batch of data.
        dlX = STrain(:,:,:,(index-1)*miniBatchSize+1:index*miniBatchSize);
        dlX = dlarray(dlX,'SSCB');

        % Generate latent inputs for the generator network.
        Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,'single') - 0.5 ) ;
        dlZ = dlarray(Z);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
            dlX = gpuArray(dlX);
        end

        % Evaluate the discriminator gradients using dlfeval and the
        % |modelDiscriminatorGradients| helper function.
        gradientsDiscriminator = ...
            dlfeval(@modelDiscriminatorGradients,discriminatorParameters,generatorParameters,dlX,dlZ);

        % Update the discriminator network parameters.
        [discriminatorParameters,trailingAvgDiscriminator,trailingAvgSqDiscriminator] = ...
            adamupdate(discriminatorParameters,gradientsDiscriminator, ...
            trailingAvgDiscriminator,trailingAvgSqDiscriminator,iteration, ...
            learnRateDiscriminator,gradientDecayFactor,squaredGradientDecayFactor);

        % Generate latent inputs for the generator network.
        Z = 2 * ( rand(1,1,numLatentInputs,miniBatchSize,'single') - 0.5 ) ;
        dlZ = dlarray(Z);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlZ = gpuArray(dlZ);
        end

        % Evaluate the generator gradients using dlfeval and the
        % |modelGeneratorGradients| helper function.
        gradientsGenerator  = ...
            dlfeval(@modelGeneratorGradients,discriminatorParameters,generatorParameters,dlZ);

        % Update the generator network parameters.
        [generatorParameters,trailingAvgGenerator,trailingAvgSqGenerator] = ...
            adamupdate(generatorParameters,gradientsGenerator, ...
            trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
            learnRateGenerator,gradientDecayFactor,squaredGradientDecayFactor);
    end

    % Every 10 iterations, save a training snapshot to a MAT file.
    if saveCheckpoints && mod(epoch,10)==0
        fprintf('Epoch %d out of %d complete\n',epoch,maxEpochs);
        % Save checkpoint in case training is interrupted.
        save('audiogancheckpoint.mat',...
            'generatorParameters','discriminatorParameters',...
            'trailingAvgDiscriminator','trailingAvgSqDiscriminator',...
            'trailingAvgGenerator','trailingAvgSqGenerator','iteration');
    end
end

Синтезировать звуки

Теперь, когда вы обучили сеть, вы можете исследовать процесс синтеза более подробно.

Обученный генератор барабанов синтезирует матрицы кратковременного преобразования Фурье (STFT) из входных массивов случайных значений. Операция обратного STFT (ISTFT) преобразует временную частоту STFT в синтезированный аудиосигнал временной области.

Загрузите веса предварительно обученного генератора. Эти веса были получены путем выполнения тренировки, выделенной в предыдущем разделе для 1000 эпох.

load(matFileName,'generatorParameters','SMean','SStd');

Генератор берет векторы случайных ценностей 1 на 1 на 100 как вход. Создайте образец входного вектора.

numLatentInputs = 100;
dlZ = dlarray(2 * ( rand(1,1,numLatentInputs,1,'single') - 0.5 ));

Передайте случайный вектор генератору для создания STFT-изображения. generatorParameters - структура, содержащая веса предварительно обученного генератора.

dlXGenerated = modelGenerator(dlZ,generatorParameters);

Преобразование STFT dlarray в матрицу с одной точностью.

S = dlXGenerated.extractdata;

Транспонировать STFT для выравнивания его размеров с istft функция.

S = S.';

STFT представляет собой матрицу 128 на 128, где первый размер представляет 128 частотные ячейки, линейно разнесенные от 0 до 8 кГц. Генератор был обучен генерировать односторонний STFT из длины FFT 256, при этом последний бункер опущен. Снова введите эту ячейку, вставив строку нулей в STFT.

S = [S ; zeros(1,128)];

Отмените шаги нормализации и масштабирования, используемые при создании STFT для обучения.

S = S * 3;
S = (S.*SStd) + SMean;

Преобразование STFT из области регистрации в линейную область.

S = exp(S);

Преобразование STFT из одностороннего в двусторонний.

S = [S; S(end-1:-1:2,:)];

Панель с нулями для удаления краевых эффектов окна.

S = [zeros(256,100) S zeros(256,100)];

Матрица STFT не содержит никакой информации о фазе. Используйте быструю версию алгоритма Гриффина-Лима с 20 итерациями для оценки фазы сигнала и создания аудиоотсчетов.

myAudio = stftmag2sig(S,256, ...
    'FrequencyRange','twosided', ...
    'Window',hann(256,'periodic'), ...
    'OverlapLength',128, ...
    'MaxIterations',20, ...
    'Method','fgla');
myAudio = myAudio./max(abs(myAudio),[],'all');
myAudio = myAudio(128*100:end-128*100);

Слушайте синтезированный барабан.

sound(myAudio,fs)

Постройте график синтезированного барабана.

t = (0:length(myAudio)-1)/fs;
plot(t,myAudio)
grid on
xlabel('Time (s)')
title('Synthesized GAN Sound')

Постройте график STFT синтезированного барабана.

figure
stft(myAudio,fs,'Window',hann(256,'periodic'),'OverlapLength',128);

Функция генератора модели

modelGenerator функционируйте upscales множества 1 на 1 на 100 (dlX) ко множествам 128 на 128 на 1 (dlY). parameters - структура, удерживающая веса слоев генератора. Архитектура генератора определена в таблице 4 из [1].

function dlY = modelGenerator(dlX,parameters)

dlY = fullyconnect(dlX,parameters.FC.Weights,parameters.FC.Bias,'Dataformat','SSCB');

dlY = reshape(dlY,[1024 4 4 size(dlY,2)]);
dlY = permute(dlY,[3 2 1 4]);
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv1.Weights,parameters.Conv1.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = relu(dlY);

dlY = dltranspconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,'Stride' ,2 ,'Cropping','same','DataFormat','SSCB');
dlY = tanh(dlY);
end

Функция дискриминатора модели

modelDiscriminator функция принимает изображения 128 на 128 и выводит оценку скалярного предсказания. Архитектура дискриминатора определена в таблице 5 из [1].

function dlY = modelDiscriminator(dlX,parameters)

dlY = dlconv(dlX,parameters.Conv1.Weights,parameters.Conv1.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv2.Weights,parameters.Conv2.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv3.Weights,parameters.Conv3.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv4.Weights,parameters.Conv4.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = dlconv(dlY,parameters.Conv5.Weights,parameters.Conv5.Bias,'Stride' ,2 ,'Padding','same');
dlY = leakyrelu(dlY,0.2);

dlY = stripdims(dlY);
dlY = permute(dlY,[3 2 1 4]);
dlY = reshape(dlY,4*4*64*16,numel(dlY)/(4*4*64*16));

weights = parameters.FC.Weights;
bias = parameters.FC.Bias;
dlY = fullyconnect(dlY,weights,bias,'Dataformat','CB');

end

Функция градиентов дискриминатора модели

modelDiscriminatorGradients функции принимают за вход параметры генератора и дискриминатора generatorParameters и discriminatorParameters, мини-пакет входных данных dlXи массив случайных значений dlZи возвращает градиенты потерь дискриминатора относительно обучаемых параметров в сетях.

function gradientsDiscriminator = modelDiscriminatorGradients(discriminatorParameters , generatorParameters, dlX, dlZ)

% Calculate the predictions for real data with the discriminator network.
dlYPred = modelDiscriminator(dlX,discriminatorParameters);

% Calculate the predictions for generated data with the discriminator network.
dlXGenerated     = modelGenerator(dlZ,generatorParameters);
dlYPredGenerated = modelDiscriminator(dlarray(dlXGenerated,'SSCB'),discriminatorParameters);

% Calculate the GAN loss
lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsDiscriminator = dlgradient(lossDiscriminator,discriminatorParameters);

end

Функция градиентов генератора модели

modelGeneratorGradients функция принимает в качестве входных данных распознаватель и генератор обучаемых параметров и массив случайных значений dlZи возвращает градиенты потерь генератора относительно обучаемых параметров в сетях.

function gradientsGenerator = modelGeneratorGradients(discriminatorParameters, generatorParameters , dlZ)

% Calculate the predictions for generated data with the discriminator network.
dlXGenerated = modelGenerator(dlZ,generatorParameters);
dlYPredGenerated = modelDiscriminator(dlarray(dlXGenerated,'SSCB'),discriminatorParameters);

% Calculate the GAN loss
lossGenerator = ganGeneratorLoss(dlYPredGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsGenerator = dlgradient(lossGenerator, generatorParameters);

end

Функция потери дискриминатора

Цель дискриминатора - не обманываться генератором. Чтобы максимизировать вероятность того, что дискриминатор успешно различает реальные и сгенерированные изображения, минимизируйте функцию потери дискриминатора. Функция потерь для генератора соответствует подходу DCGAN, выделенному в [1].

function  lossDiscriminator = ganDiscriminatorLoss(dlYPred,dlYPredGenerated)

fake = dlarray(zeros(1,size(dlYPred,2)));
real = dlarray(ones(1,size(dlYPred,2)));

D_loss = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,fake));
D_loss = D_loss + mean(sigmoid_cross_entropy_with_logits(dlYPred,real));
lossDiscriminator  = D_loss / 2;
end

Функция потери генератора

Целью генератора является генерирование данных, которые дискриминатор классифицирует как «реальные». Чтобы максимизировать вероятность того, что изображения из генератора классифицируются дискриминатором как реальные, минимизируйте функцию потерь генератора. Функция потерь для генератора соответствует подходу глубокой сверточной генеративной состязательной сети (DCGAN), выделенному в [1].

function lossGenerator = ganGeneratorLoss(dlYPredGenerated)
real = dlarray(ones(1,size(dlYPredGenerated,2)));
lossGenerator = mean(sigmoid_cross_entropy_with_logits(dlYPredGenerated,real));
end

Инициализатор весов дискриминатора

initializeDiscriminatorWeights инициализирует веса дискриминатора с помощью алгоритма Глорота.

function discriminatorParameters = initializeDiscriminatorWeights

filterSize = [5 5];
dim = 64;

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]);
bias = zeros(1,1,dim,'single');
discriminatorParameters.Conv1.Weights = dlarray(weights);
discriminatorParameters.Conv1.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]);
bias = zeros(1,1,2*dim,'single');
discriminatorParameters.Conv2.Weights = dlarray(weights);
discriminatorParameters.Conv2.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]);
bias = zeros(1,1,4*dim,'single');
discriminatorParameters.Conv3.Weights = dlarray(weights);
discriminatorParameters.Conv3.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]);
bias = zeros(1,1,8*dim,'single');
discriminatorParameters.Conv4.Weights = dlarray(weights);
discriminatorParameters.Conv4.Bias = dlarray(bias);

% Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]);
bias = zeros(1,1,16*dim,'single');
discriminatorParameters.Conv5.Weights = dlarray(weights);
discriminatorParameters.Conv5.Bias = dlarray(bias);

% fully connected
weights = iGlorotInitialize([1,4 * 4 * dim * 16]);
bias = zeros(1,1,'single');
discriminatorParameters.FC.Weights = dlarray(weights);
discriminatorParameters.FC.Bias = dlarray(bias);
end

Инициализатор весов генератора

initializeGeneratorWeights инициализирует веса генератора с помощью алгоритма Глорота.

function generatorParameters = initializeGeneratorWeights

dim = 64;

% Dense 1
weights = iGlorotInitialize([dim*256,100]);
bias = zeros(dim*256,1,'single');
generatorParameters.FC.Weights = dlarray(weights);
generatorParameters.FC.Bias = dlarray(bias);

filterSize = [5 5];

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 8*dim 16*dim]);
bias = zeros(1,1,dim*8,'single');
generatorParameters.Conv1.Weights = dlarray(weights);
generatorParameters.Conv1.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 4*dim 8*dim]);
bias = zeros(1,1,dim*4,'single');
generatorParameters.Conv2.Weights = dlarray(weights);
generatorParameters.Conv2.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 2*dim 4*dim]);
bias = zeros(1,1,dim*2,'single');
generatorParameters.Conv3.Weights = dlarray(weights);
generatorParameters.Conv3.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) dim 2*dim]);
bias = zeros(1,1,dim,'single');
generatorParameters.Conv4.Weights = dlarray(weights);
generatorParameters.Conv4.Bias = dlarray(bias);

% Trans Conv2D
weights = iGlorotInitialize([filterSize(1) filterSize(2) 1 dim]);
bias = zeros(1,1,1,'single');
generatorParameters.Conv5.Weights = dlarray(weights);
generatorParameters.Conv5.Bias = dlarray(bias);
end

Синтезировать барабаны

synthesizeDrumBeat использует предварительно обученную сеть для синтеза барабанных ударов.

function y = synthesizeDrumBeat

persistent pGeneratorParameters pMean pSTD
if isempty(pGeneratorParameters)
    % If the MAT file does not exist, download it
    filename = 'drumGeneratorWeights.mat';
    load(filename,'SMean','SStd','generatorParameters');
    pMean = SMean;
    pSTD  = SStd;
    pGeneratorParameters = generatorParameters;
end

% Generate random vector
dlZ = dlarray(2 * ( rand(1,1,100,1,'single') - 0.5 ));

% Generate spectrograms
dlXGenerated = modelGenerator(dlZ,pGeneratorParameters);

% Convert from dlarray to single
S = dlXGenerated.extractdata;

S = S.';
% Zero-pad to remove edge effects
S = [S ; zeros(1,128)];

% Reverse steps from training
S = S * 3;
S = (S.*pSTD) + pMean;
S = exp(S);

% Make it two-sided
S = [S ; S(end-1:-1:2,:)];
% Pad with zeros at end and start
S = [zeros(256,100) S zeros(256,100)];

% Reconstruct the signal using a fast Griffin-Lim algorithm.
myAudio = stftmag2sig(gather(S),256, ...
    'FrequencyRange','twosided', ...
    'Window',hann(256,'periodic'), ...
    'OverlapLength',128, ...
    'MaxIterations',20, ...
    'Method','fgla');
myAudio = myAudio./max(abs(myAudio),[],'all');
y = myAudio(128*100:end-128*100);
end

Служебные функции

function out = sigmoid_cross_entropy_with_logits(x,z)
out = max(x, 0) - x .* z + log(1 + exp(-abs(x)));
end

function w = iGlorotInitialize(sz)
if numel(sz) == 2
    numInputs = sz(2);
    numOutputs = sz(1);
else
    numInputs = prod(sz(1:3));
    numOutputs = prod(sz([1 2 4]));
end
multiplier = sqrt(2 / (numInputs + numOutputs));
w = multiplier * sqrt(3) * (2 * rand(sz,'single') - 1);
end

Ссылка

[1] Донахью, К., Дж. Маколи и М. Пакетт. 2019. «Состязательный аудиосинтез». ICLR.