От начала до конца глубокое речевое разделение

Этот пример демонстрирует сквозную нейронную сеть для глубокого обучения для не зависящего от диктора речевого разделения.

Введение

Речевое разделение является оспариванием и критической речью, обрабатывающей задачу. Много речевых разделительных методов на основе глубокого обучения были недавно предложены, большинство которых использует преобразования частоты времени смеси аудио временного интервала (См., что Исходное Разделение Приема Использует Нейронные сети для глубокого обучения (Audio Toolbox) для реализации такой системы глубокого обучения).

Решения на основе методов частоты времени страдают от двух основных недостатков:

  • Преобразование представлений частоты времени назад временному интервалу требует оценки фазы, которая вводит ошибки и приводит к несовершенной реконструкции.

  • Относительно длинные окна требуются, чтобы давать к представлениям частоты высокого разрешения, который приводит к высокой вычислительной сложности и недопустимой задержке для сценариев в реальном времени.

В этом примере вы исследуете речевую разделительную сеть глубокого обучения (на основе [1]), который действует непосредственно на звуковой сигнал и обходит проблемы, являющиеся результатом преобразований частоты времени.

Отдельная Речь с помощью Предварительно обученной сети

Загрузите предварительно обученную сеть

Перед обучением нейронная сеть для глубокого обучения с нуля, вы будете использовать предварительно обученную версию сети, чтобы разделить два динамика от сигнала смеси в качестве примера.

Во-первых, загрузите звуковые файлы в качестве примера и предварительно обученная сеть.

url = 'http://ssd.mathworks.com/supportfiles/audio/speechSeparation.zip';
downloadNetFolder = tempdir;
netFolder = fullfile(downloadNetFolder,'speechSeparation');

if ~exist(netFolder,'dir')
    disp('Downloading pretrained network and audio files ...')
    unzip(url,downloadNetFolder)
end

Подготовьте тестовый сигнал

Загрузите два звуковых сигнала, соответствующие двум различным динамикам. Оба сигнала производятся на уровне 8 кГц.

Fs = 8000;
s1 = audioread(fullfile(netFolder,'speaker1.wav'));
s2 = audioread(fullfile(netFolder,'speaker2.wav'));

Нормируйте сигналы.

s1 = s1/max(abs(s1));
s2 = s2/max(abs(s2));

Слушайте несколько секунд каждого сигнала.

T = 5;
sound(s1(1:T*Fs))
pause(T)
sound(s2(1:T*Fs))
pause(T)

Объедините два сигнала в сигнал смеси.

mix = s1+s2;
mix = mix/max(abs(mix));

Слушайте первые несколько секунд сигнала смеси.

sound(mix(1:T*Fs))
pause(T)

Отдельные динамики

Загрузите параметры предварительно обученной речевой разделительной сети.

load(fullfile(netFolder,'paramsBest.mat'),'learnables','states')

Разделите эти два динамика в сигналах смеси путем вызова separateSpeakers функция.

[z1,z2] = separateSpeakers(mix,learnables,states,false);

Слушайте первые несколько секунд первого предполагаемого речевого сигнала.

sound(z1(1:T*Fs))
pause(T)

Слушайте второй предполагаемый сигнал.

sound(z2(1:T*Fs))
pause(T)

Чтобы проиллюстрировать эффект речевого разделения, постройте предполагаемые и исходные разделенные сигналы наряду с сигналом смеси.

s1 = s1(1:length(z1));
s2 = s2(1:length(z2));
mix = mix(1:length(s1));

t  = (0:length(s1)-1)/Fs;

figure;
subplot(311)
plot(t,s1)
hold on
plot(t,z1)
grid on
legend('Speaker 1 - Actual','Speaker 1 - Estimated')
subplot(312)
plot(t,s2)
hold on
plot(t,z2)
grid on
legend('Speaker 2 - Actual','Speaker 2 - Estimated')
subplot(313)
plot(t,mix)
grid on
legend('Mixture')
xlabel('Time (s)')

Сравните с нейронной сетью для глубокого обучения преобразования частоты времени

Затем вы сравниваете эффективность сети к сети, разработанной в Исходном Разделении Приема Используя Нейронные сети для глубокого обучения (Audio Toolbox) пример. Эта речевая разделительная сеть основана на традиционных представлениях частоты времени аудио смеси (использующий кратковременное преобразование Фурье, STFT, и обратное кратковременное преобразование Фурье, ISTFT).

Загрузите предварительно обученную сеть.

url = 'http://ssd.mathworks.com/supportfiles/audio/CocktailPartySourceSeparation.zip';

downloadNetFolder = tempdir;
cocktailNetFolder = fullfile(downloadNetFolder,'CocktailPartySourceSeparation');

if ~exist(cocktailNetFolder,'dir')
    disp('Downloading pretrained network and audio files (5 files - 24.5 MB) ...')
    unzip(url,downloadNetFolder)
end

Функциональный separateSpeakersTimeFrequency инкапсулирует шаги, требуемые разделить речь с помощью этого network. Функция выполняет следующие шаги:

  • Вычислите величину STFT входной смеси временного интервала.

  • Вычислите мягкую маску частоты времени путем передачи STFT сети.

  • Вычислите STFT разделенных сигналов путем умножения смеси STFT на маску.

  • Восстановите разделенные сигналы временного интервала с помощью ISTFT. Фаза смеси STFT используется.

Отошлите к Исходному Разделению Приема Используя Нейронные сети для глубокого обучения (Audio Toolbox) пример для получения дополнительной информации об этой сети.

Разделите эти два динамика.

[y1,y2] = separateSpeakersTimeFrequency(mix,cocktailNetFolder);

Слушайте первый разделенный сигнал.

sound(y1(1:Fs*T))
pause(T)

Слушайте второй разделенный сигнал.

sound(y2(1:Fs*T))
pause(T)

Оцените Производительность сети с помощью SI-SNR

Вы сравните эти две сети с помощью инвариантного к масштабу отношения источника к шуму (SI-SNR) объективная мера [1].

Вычислите SISNR для первого динамика со сквозной сетью.

Во-первых, нормируйте фактические и предполагаемые сигналы.

s10 = s1 - mean(s1);
z10 = z1 - mean(z1);

Вычислите компонент "сигнала" ОСШ.

t = sum(s10.*z10) .* z10 ./ (sum(z10.^2)+eps);

Вычислите "шумовой" компонент ОСШ.

n = s1 - t;

Теперь вычислите SI-SNR (в дБ).

v1 = 20*log((sqrt(sum(t.^2))+eps)./sqrt((sum(n.^2))+eps))/log(10);
fprintf('End-to-end network - Speaker 1 SISNR: %f dB\n',v1)
End-to-end network - Speaker 1 SISNR: 14.316869 dB

Шаги расчета SI-SNR инкапсулируются в функциональном SISNR. Используйте функцию, чтобы вычислить SI-SNR второго динамика со сквозной сетью.

v2 = SISNR(z2,s2);
fprintf('End-to-end network - Speaker 2 SISNR: %f dB\n',v2)
End-to-end network - Speaker 2 SISNR: 13.706421 dB

Затем вычислите SI-SNR для каждого динамика для основанной на STFT сети.

w1 = SISNR(y1,s1(1:length(y1)));
w2 = SISNR(y2,s2(1:length(y2)));
fprintf('STFT network - Speaker 1 SISNR: %f dB\n',w1)
STFT network - Speaker 1 SISNR: 7.003789 dB
fprintf('STFT network - Speaker 2 SISNR: %f dB\n',w2)
STFT network - Speaker 2 SISNR: 7.382209 dB

Обучение речевой разделительной сети

Исследуйте сетевую архитектуру

Сеть основана [1] и состоит из трех этапов: Кодирование, оценка маски или разделение и декодирование.

  • Энкодер преобразовывает входные сигналы смеси временного интервала в промежуточное представление с помощью сверточных слоев.

  • Средство оценки маски вычисляет одну маску на динамик. Промежуточное представление каждого динамика получено путем умножения выхода энкодера его соответствующей маской. Средство оценки маски состоит из 32 блоков сверточных и слоев нормализации со связями пропуска между блоками.

  • Декодер преобразовывает промежуточные представления разделенным речевым сигналам временного интервала с помощью, транспонировал сверточные слои.

Операция сети инкапсулируется в separateSpeakers.

Опционально уменьшайте размер набора данных

Чтобы обучить сеть с набором данных в целом и достигнуть максимально возможной точности, установите reduceDataset ко лжи. Чтобы запустить этот пример быстро, установите reduceDataset к истине. Это запустит остальную часть примера только на горстке файлов.

reduceDataset = true;

Загрузите обучающий набор данных

Вы используете подмножество Набора данных LibriSpeech [2], чтобы обучить сеть. Набор данных LibriSpeech является большим корпусом английской речи чтения, произведенной на уровне 16 кГц. Данные выведены из аудиокниг, считанных из проекта LibriVox.

Загрузите набор данных LibriSpeech. Если reduceDataset верно, этот steo пропущен.

downloadDatasetFolder = tempdir;
datasetFolder = fullfile(downloadDatasetFolder,"LibriSpeech","train-clean-360");
if ~reduceDataset    
    filename = "train-clean-360.tar.gz";
    url = "http://www.openSLR.org/resources/12/" + filename;
    if ~isfolder(datasetFolder)
        gunzip(url,downloadDatasetFolder);
        unzippedFile = fullfile(downloadDatasetFolder,filename);
        untar(unzippedFile{1}(1:end-3),downloadDatasetFolder);
    end
end

Предварительно обработайте набор данных

Набор данных LibriSpeech состоит из большого количества звуковых файлов с одним динамиком. Это не содержит сигналы смеси, где 2 или больше человека говорят одновременно.

Вы обработаете исходный набор данных, чтобы создать новый набор данных, который подходит для того, чтобы обучить речевую разделительную сеть.

Шаги для создания обучающего набора данных инкапсулируются в createTrainingDataset. Функция создает сигналы смеси, состоявшие из произнесения двух случайных динамиков. Функция возвращает три аудио хранилища данных:

  • mixDatastore точки к файлам смеси (где два докладчика говорят одновременно).

  • speaker1Datastore точки к файлам, содержащим изолированную речь первого динамика в смеси.

  • speaker2Datastore точки к файлам, содержащим изолированную речь второго динамика в смеси.

miniBatchSize = 4;
[mixDatastore,speaker1Datastore,speaker2Datastore] = createTrainingDataset(netFolder,datasetFolder,downloadDatasetFolder,reduceDataset,miniBatchSize);

Объедините хранилища данных. Это гарантирует, что файлы остаются в правильном порядке, когда вы переставляете их в начале каждой новой эпохи в учебном цикле.

ds = combine(mixDatastore,speaker1Datastore,speaker2Datastore);

Создайте мини-пакетную очередь из datastore.

mqueue = minibatchqueue(ds,'MiniBatchSize',miniBatchSize,'OutputEnvironment','cpu','OutputAsDlarray',false);

Задайте опции обучения

Задайте параметры обучения.

Обучайтесь в течение 10 эпох.

if reduceDataset
    numEpochs = 1;
else
    numEpochs = 10; %#ok
end

Задайте опции для оптимизации Адама. Установите начальную скорость обучения на 1e-3. Используйте фактор затухания градиента 0,9 и фактор затухания градиента в квадрате 0,999.

learnRate = 1e-3;
averageGrad = [];
averageSqGrad = [];

gradDecay = 0.9;
sqGradDecay = 0.999;

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™.

executionEnvironment = "auto"; % Change to "gpu" to train on a GPU.

duration = 4 * 8000;

Настройте данные о валидации

Вы будете использовать тестовый сигнал, который вы ранее использовали, чтобы протестировать предварительно обученную сеть, чтобы вычислять валидацию SI-SNR периодически во время обучения.

Если графический процессор доступен, переместите сигнал валидации в графический процессор.

mix = dlarray(mix,'SCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    mix = gpuArray(mix);
end

Задайте количество итераций между валидацией расчеты SI-SNR.

numIterPerValidation = 50;

Задайте вектор, чтобы содержать валидацию SI-SNR от каждой итерации.

valSNR = [];

Задайте переменную, чтобы содержать лучшую валидацию SI-SNR.

bestSNR = -Inf;

Задайте переменную, чтобы содержать эпоху, в которую произошел лучший счет валидации.

bestEpoch = 1;

Инициализируйте сеть

Инициализируйте сетевые параметры. learnables структура, содержащая настраиваемые параметры от слоев сети. states структура, содержащая состояния от слоев нормализации.

[learnables,states] = initializeNetworkParams;

Обучите сеть

Выполните учебный цикл. Это может занять много часов, чтобы запуститься.

Обратите внимание на то, что нет никакого априорного способа сопоставить предполагаемые выходные сигналы динамика с ожидаемыми сигналами динамика. Это разрешено при помощи обучения инварианта сочетания уровня Произнесения (uPIT) [1]. Потеря основана на вычислении SI-SNR. uPIT минимизирует потерю по всем сочетаниям между выходными параметрами и целями. Это задано в функциональном uPIT.

Валидация SI-SNR периодически вычисляется. Если SI-SNR является оптимальным значением до настоящего времени, сетевые параметры сохранены в params.mat.

iteration = 0;

% Loop over epochs.
for jj =1:numEpochs

    % Shuffle the data
    shuffle(mqueue);

    while hasdata(mqueue)

        % Compute validation loss/SNR periodically
        if mod(iteration,numIterPerValidation)==0
            
            [z1,z2] = separateSpeakers(mix, learnables,states,false);
            
            l = uPIT(z1,s1,z2,s2);
            valSNR(end+1) = l; %#ok

            if l > bestSNR
                bestSNR = l;
                bestEpoch = jj;
                filename = 'params.mat';
                save(filename,'learnables','states');
            end
        end

        iteration = iteration + 1;

        % Get a new batch of training data
        [x1Batch,x2Batch,mixBatch] = next(mqueue);
        x1Batch = reshape(x1Batch,[duration 1 miniBatchSize]);
        x2Batch = reshape(x2Batch,[duration 1 miniBatchSize]);
        mixBatch = reshape(mixBatch,[duration 1 miniBatchSize]);

        x1Batch = dlarray(x1Batch,'SCB');
        x2Batch = dlarray(x2Batch,'SCB');
        mixBatch = dlarray(mixBatch,'SCB');

        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            x1Batch = gpuArray(x1Batch);
            x2Batch = gpuArray(x2Batch);
            mixBatch = gpuArray(mixBatch);
        end

        % Evaluate the model gradients and loss using dlfeval and the modelGradients function.
        [gradients,states] = dlfeval( @modelGradients,mixBatch,x1Batch,x2Batch,learnables,states,miniBatchSize);

        % Update the network parameters using the ADAM optimizer.
        [learnables,averageGrad,averageSqGrad] = adamupdate(learnables,gradients,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);
        
    end

    % Reduce the learning rate if the validation accuracy did not improve
    % during the epoch
    if bestEpoch ~= jj
        learnRate = learnRate/2;
    end
end

Постройте значения ОСШ валидации.

if ~reduceDataset
    valIterNum = 0:length(valSNR)-1;
    figure
    semilogx(numIterPerValidation*(valIterNum-1),valSNR,'b*-')
    grid on
    xlabel('Iteration #')
    ylabel('Validation SINR (dB)')
    valFig.Visible = 'on';
end

Ссылки

[1] И Ло, Nima Mesgarani, "Conv-tasnet: Превосходное идеальное маскирование величины частоты времени для речевого разделения", 2019 транзакций IEEE/ACM на аудио, речи, и обработке языка, издании 29, выпуске 8, стр 1256-1266.

[2] В. Панаетов, Г. Чен, Д. Пови и С. Худэнпур, "Librispeech: корпус ASR на основе аудиокниг общественного достояния", 2 015 Международных конференций IEEE по вопросам Акустики, Речи и Обработки сигналов (ICASSP), Брисбена, QLD, 2015, стр 5206-5210, doi: 10.1109/ICASSP.2015.7178964

Вспомогательные Функции

function [mixDatastore,speaker1Datastore,speaker2Datastore] = createTrainingDataset(netFolder,datasetFolder,downloadDatasetFolder,reduceDataset,miniBatchSize)
% createTrainingDataset Create training dataset

newDatasetPath = fullfile(downloadDatasetFolder,'speech-sep-dataset');

%Create the new dataset folders if they do not exist already.
processDataset = ~isfolder(newDatasetPath);
if processDataset
    mkdir(newDatasetPath);
    mkdir([newDatasetPath '/sp1']);
    mkdir([newDatasetPath '/sp2']);
    mkdir([newDatasetPath '/mix']);
end

%Create an audioDatastore that points to the LibriSpeech dataset.
if reduceDataset
    ads = audioDatastore([repmat({fullfile(netFolder,'speaker1.wav')},1,4),...
                          repmat({fullfile(netFolder,'speaker2.wav')},1,4)]);
else
    ads = audioDatastore(datasetFolder,'IncludeSubfolders',true);
end

% The LibriSpeech dataset is comprised of signals from different speakers.
% The unique speaker ID is encoded in the audio file names.

% Extract the speaker IDs from the file names.
if reduceDataset
    ads.Labels = categorical([repmat({'1'},1,4),repmat({'2'},1,4)]);
else
    ads.Labels = categorical(extractBetween(ads.Files,fullfile(datasetFolder,filesep),filesep));
end

% You will create mixture signals comprised of utterances of two random speakers.  
% Randomize the IDs of all the speakers.
names = unique(ads.Labels);
names = names(randperm(length(names)));

% In this example, you create training signals based on 400 speakers. You
% generate mixture signals based on combining utterances from 200 pairs of
% speakers. 

% Define the two groups of speakers.
numPairs = min(200,floor(numel(names)/2)); 
n1 = names(1:numPairs);
n2 = names(numPairs+1:2*numPairs);

% Create the new dataset. For each pair of speakers: 
% * Use subset to create two audio datastores, each containing files
%   corresponding to their respective speaker.
% * Adjust the datastores so that they have the same number of files.
% * Combine the two datastores using combine. 
% * Use writeall to preprocess the files of the combined datastore and write
%   the new resulting signals to disk.

% The preprocessing steps performed to create the signals before writing
% them to disk are encapsulated in the function createTrainingFiles. For
% each pair of signals:
% * You downsample the signals from 16 kHz to 8 kHz. 
% * You randomly select 4 seconds from each downsampled signal. 
% * You create the mixture by adding the 2 signal chunks.
% * You adjust the signal power to achieve a randomly selected
%   signal-to-noise value in the range [-5,5] dB.
% * You write the 3 signals (corresponding to the first speaker, the second
%   speaker, and the mixture, respectively) to disk.
parfor index=1:length(n1)
    spkInd1 = n1(index);
    spkInd2 = n2(index);
    spk1ds = subset(ads,ads.Labels==spkInd1);
    spk2ds = subset(ads,ads.Labels==spkInd2);
    L = min(length(spk1ds.Files),length(spk2ds.Files));
    L = floor(L/miniBatchSize) * miniBatchSize;
    spk1ds = subset(spk1ds,1:L);
    spk2ds = subset(spk2ds,1:L);
    pairds = combine(spk1ds,spk2ds);
    writeall(pairds,newDatasetPath,'FolderLayout','flatten','WriteFcn',@(data,writeInfo,outputFmt)createTrainingFiles(data,writeInfo,outputFmt,reduceDataset));
end

% Create audio datastores pointing to the files corresponding to the individual speakers and the mixtures.
mixDatastore = audioDatastore(fullfile(newDatasetPath,'mix'));
speaker1Datastore = audioDatastore(fullfile(newDatasetPath,'sp1'));
speaker2Datastore = audioDatastore(fullfile(newDatasetPath,'sp2'));
end

function mix = createTrainingFiles(data,writeInfo,~,varargin)
% createTrainingFiles - Preprocess the training signals and write them to disk

reduceDataset = varargin{1};

duration = 4*8000;

x1 = data{1};
x2 = data{2};

% Resample from 16 kHz to 8 kHz
if ~reduceDataset
    x1 = resample(x1,1,2);
    x2 = resample(x2,1,2);
end

% Read a chunk from the first speaker signal
if length(x1)<=duration
    x1 = [x1;zeros(duration-length(x1),1)];
else
    startInd = randi([1 length(x1)-duration],1);
    endInd = startInd + duration - 1;
    x1 = x1(startInd:endInd);
end

% Read a chunk from the second speaker signal
if length(x2)<=duration
    x2 = [x2;zeros(duration-length(x2),1)];
else
    startInd = randi([1 length(x2)-duration],1);
    endInd = startInd + duration - 1;
    x2 = x2(startInd:endInd);
end

x1 = x1./max(abs(x1));
x2 = x2./max(abs(x2));

% SNR [-5 5] dB
s = snr(x1,x2);
targetSNR = 10 * (rand - 0.5);
x1b = 10^((targetSNR-s)/20) * x1;
mix = x1b + x2;
mix = mix./max(abs(mix));

if reduceDataset
    [~,n] = fileparts(tempname);
    name = sprintf('%s.wav',n);
else
    [~,s1] = fileparts(writeInfo.ReadInfo{1}.FileName);
    [~,s2] = fileparts(writeInfo.ReadInfo{2}.FileName);
    name = sprintf('%s-%s.wav',s1,s2);
end

audiowrite(sprintf('%s',fullfile(writeInfo.Location,'sp1',name)),x1,8000);
audiowrite(sprintf('%s',fullfile(writeInfo.Location,'sp2',name)),x2,8000);
audiowrite(sprintf('%s',fullfile(writeInfo.Location,'mix',name)),mix,8000);

end

function [grad, states] = modelGradients(mix,x1,x2,learnables,states,miniBatchSize)
% modelGradients Compute the model gradients

[y1,y2,states] = separateSpeakers(mix,learnables,states,true);

m = uPIT(x1,y1,x2,y2);
l = sum(m);
loss = -l./miniBatchSize;

grad = dlgradient(loss,learnables);

end

function m = uPIT(x1,y1,x2,y2)
% uPIT - Compute utterance-level permutation invariant training
v1 = SISNR(y1,x1);
v2 = SISNR(y2,x2);
m1 = mean([v1;v2]);

v1 = SISNR(y2,x1);
v2 = SISNR(y1,x2);
m2 = mean([v1;v2]);

m = max(m1,m2);
end

function z = SISNR(x,y)
% SISNR - Compute SI-SNR
x = x - mean(x);
y = y - mean(y);

t = sum(x.*y) .* y ./ (sum(y.^2)+eps);
n = x - t;

z = 20*log((sqrt(sum(t.^2))+eps)./sqrt((sum(n.^2))+eps))/log(10);

end

function [learnables,states] = initializeNetworkParams
% initializeNetworkParams - Initialize the learnables and states of the
% network
learnables.Conv1W = initializeGlorot(20,1,256);
learnables.Conv1B = dlarray(zeros(256,1,'single'));

learnables.ln_weight = dlarray(ones(1,256,'single'));
learnables.ln_bias = dlarray(zeros(1,256,'single'));

learnables.Conv2W = initializeGlorot(1,256,256);
learnables.Conv2B = dlarray(zeros(256,1,'single'));

for index=1:32
    blk = [];
    blk.Conv1W = initializeGlorot(1,256,512);
    blk.Conv1B = dlarray(zeros(512,1,'single'));
    blk.Prelu1 = dlarray(single(0.25));
    blk.BN1Offset = dlarray(zeros(512,1,'single'));
    blk.BN1Scale = dlarray(ones(512,1,'single'));
    blk.Conv2W = initializeGlorot(3,1,512);
    blk.Conv2W =  reshape(blk.Conv2W,[3 1 1 512]);
    blk.Conv2B = dlarray(zeros(512,1,'single'));
    blk.Prelu2 = dlarray(single(0.25));
    blk.BN2Offset= dlarray(zeros(512,1,'single'));
    blk.BN2Scale= dlarray(ones(512,1,'single'));
    blk.Conv3W = initializeGlorot(1,512,256);
    blk.Conv3B = dlarray(ones(256,1,'single'));

    learnables.Blocks(index) = blk;

    s = [];
    s.BN1Mean= dlarray(zeros(512,1,'single'));
    s.BN1Var= dlarray(ones(512,1,'single'));
    s.BN2Mean = dlarray(zeros(512,1,'single'));
    s.BN2Var = dlarray(ones(512,1,'single'));

    states(index) = s; %#ok
end

learnables.Conv3W = initializeGlorot(1,256,512);
learnables.Conv3B = dlarray(zeros(512,1,'single'));

learnables.TransConv1W = initializeGlorot(20,1,256);
learnables.TransConv1B = dlarray(zeros(1,1, 'single'));

end

function weights = initializeGlorot(filterSize,numChannels,numFilters)
% initializeGlorot - Perform Glorot initialization
sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

Z = 2*rand(sz,'single') - 1;
bound = sqrt(6 / (numIn + numOut));

weights = bound * Z;
weights = dlarray(weights);

end

function [output1, output2, states] = separateSpeakers(input, learnables, states, training)
% separateSpeakers - Separate two speaker signals from a mixture input
if ~isa(input,'dlarray')
    input = dlarray(input,'SCB');
end

weights = learnables.Conv1W;
bias = learnables.Conv1B;
x = dlconv(input, weights,bias, 'Stride', 10);

x = relu(x);
x0 = x;

x = x-mean(x, 2);
x = x./sqrt(mean(x.^2, 2) + 1e-5);
x = x.*learnables.ln_weight + learnables.ln_bias;

weights = learnables.Conv2W;
bias = learnables.Conv2B;
encoderOut = dlconv(x, weights, bias);

for index = 1:32
    [encoderOut,s] = convBlock(encoderOut, index-1,learnables.Blocks(index),states(index),training);
    states(index) = s;
end

weights = learnables.Conv3W;
bias = learnables.Conv3B;
masks = dlconv(encoderOut, weights, bias);
masks = relu(masks);

mask1 = masks(:,1:256,:);
mask2 = masks(:,257:512,:);

out1 = x0 .* mask1;
out2 = x0 .* mask2;

weights = learnables.TransConv1W;
bias = learnables.TransConv1B;
output2 = dltranspconv(out1, weights, bias, 'Stride', 10);
output1 = dltranspconv(out2, weights, bias, 'Stride', 10);

if ~training
    output1 = gather(extractdata(output1));
    output2 = gather(extractdata(output2));

    output1 = output1./max(abs(output1));
    output2 = output2./max(abs(output2));
end

end

function [output,state] = convBlock(input, count,learnables,state,training)

% Conv:
weights = learnables.Conv1W;
bias = learnables.Conv1B;
conv1Out = dlconv(input, weights, bias);

% PRelu:
conv1Out = relu(conv1Out) - learnables.Prelu1.*relu(-conv1Out);

% BatchNormalization:
offset = learnables.BN1Offset;
scale = learnables.BN1Scale;
datasetMean = state.BN1Mean;
datasetVariance = state.BN1Var;
if training
    [batchOut, dsmean, dsvar] = batchnorm(conv1Out, offset, scale, datasetMean, datasetVariance);
    state.BN1Mean = dsmean;
    state.BN1Var = dsvar;
else
    batchOut = batchnorm(conv1Out, offset, scale, datasetMean, datasetVariance);
end

% Conv:
weights = learnables.Conv2W;
bias = learnables.Conv2B;
padding = [1 1] * 2^(mod(count,8));
dilationFactor = 2^(mod(count,8));
convOut = dlconv(batchOut, weights, bias,'DilationFactor', dilationFactor, 'Padding', padding);

% PRelu:
convOut = relu(convOut) - learnables.Prelu2.*relu(-convOut);

% BatchNormalization:
offset = learnables.BN2Offset;
scale = learnables.BN2Scale;
datasetMean = state.BN2Mean;
datasetVariance = state.BN2Var;
if training
    [batchOut, dsmean, dsvar] = batchnorm(convOut, offset, scale, datasetMean, datasetVariance);
    state.BN2Mean = dsmean;
    state.BN2Var = dsvar;
else
    batchOut = batchnorm(convOut, offset, scale, datasetMean, datasetVariance);
end

% Conv:
weights = learnables.Conv3W;
bias = learnables.Conv3B;
output = dlconv(batchOut, weights, bias);

% Skip connection
output = output + input;

end

function [speaker1,speaker2] = separateSpeakersTimeFrequency(mix,pathToNet)
% separateSpeakersTimeFrequency - STFT-based speaker separation function
WindowLength  = 128;
FFTLength     = 128;
OverlapLength = 128-1;
win           = hann(WindowLength,"periodic");

% Downsample to 4 kHz
mix = resample(mix,1,2);

P0 = stft(mix, 'Window', win, 'OverlapLength', OverlapLength,...
    'FFTLength', FFTLength, 'FrequencyRange', 'onesided');
P = log(abs(P0) + eps);
MP = mean(P(:));
SP = std(P(:));
P = (P-MP)/SP;

seqLen = 20;
PSeq  = zeros(1 + FFTLength/2,seqLen,1,0);
seqOverlap = seqLen;

loc = 1;
while loc < size(P,2)-seqLen
    PSeq(:,:,:,end+1) = P(:,loc:loc+seqLen-1); %#ok
    loc = loc + seqOverlap;
end

PSeq  = reshape(PSeq, [1 1 (1 + FFTLength/2) * seqLen size(PSeq,4)]);

s = load(fullfile(pathToNet,"CocktailPartyNet.mat"));
CocktailPartyNet = s.CocktailPartyNet;
estimatedMasks = predict(CocktailPartyNet,PSeq);

estimatedMasks = estimatedMasks.';
estimatedMasks = reshape(estimatedMasks,1 + FFTLength/2,numel(estimatedMasks)/(1 + FFTLength/2));

mask1   = estimatedMasks; 
mask2 = 1 - mask1;

P0 = P0(:,1:size(mask1,2));

P_speaker1 = P0 .* mask1;

speaker1 = istft(P_speaker1, 'Window', win, 'OverlapLength', OverlapLength,...
    'FFTLength', FFTLength, 'ConjugateSymmetric', true,...
    'FrequencyRange', 'onesided');
speaker1 = speaker1 / max(abs(speaker1));

P_speaker2 = P0 .* mask2;

speaker2 = istft(P_speaker2, 'Window', win, 'OverlapLength', OverlapLength,...
    'FFTLength',FFTLength, 'ConjugateSymmetric',true,...
    'FrequencyRange', 'onesided');
speaker2 = speaker2 / max(speaker2);

speaker1 = resample(double(speaker1),2,1);
speaker2 = resample(double(speaker2),2,1);
end