Дереверберация речи с использованием Нейронных сетей для глубокого обучения

Этот пример показывает, как обучить U-Net полностью сверточную сеть (FCN) [1] для удаления речевых сигналов.

Введение

Реверберация происходит, когда речевой сигнал отражается от объектов в пространстве, заставляя множественные отражения накапливаться и в конечном счете приводит к ухудшению качества речи. Дереверберация является процессом уменьшения эффектов реверберации в речевом сигнале.

Дереверберат речевого сигнала с использованием предварительно обученной сети

Перед углублением в процесс обучения используйте предварительно обученную сеть для удаления речевого сигнала.

Загрузите предварительно обученную сеть. Эта сеть была обучена на 56-динамических версиях обучающих наборов данных. Пример проходит обучение на 28-динамической версии.

url = 'https://ssd.mathworks.com/supportfiles/audio/dereverbnet.zip';
downloadFolder = tempdir;
networkDataFolder = fullfile(downloadFolder,'derevernet');

if ~exist(networkDataFolder,'dir')
    disp('Downloading pretrained network ...')
    unzip(url,downloadFolder)
end
load(fullfile(networkDataFolder,'dereverbNet.mat'))

Прослушайте чистый речевой сигнал, дискретизированный с частотой дискретизации 16 кГц.

[cleanAudio,fs] = audioread('clean_speech_signal.wav');

sound(cleanAudio,fs)

Акустический путь может быть смоделирован с помощью комнатной импульсной характеристики. Можно смоделировать реверберацию путем свертки анехойного сигнала с комнатной импульсной характеристикой.

Загрузка и построение графика импульсной характеристики помещения.

[rirAudio,fsR] = audioread('room_impulse_response.wav');

tAxis = (1/fsR)*(0:numel(rirAudio)-1);

figure
plot(tAxis,rirAudio)
xlabel('Time (s)')
ylabel('Amplitude')
grid on

Сверните чистую речь с комнатной импульсной характеристикой, чтобы получить ревербрированную речь. Выравнивание длин и амплитуд ревербированных и чистых речевых сигналов.

revAudio = conv(cleanAudio,rirAudio);

revAudio = revAudio(1:numel(cleanAudio));
revAudio = revAudio.*(max(abs(cleanAudio))/max(abs(revAudio)));

Слушайте речевой сигнал.

sound(revAudio,fs)

Вход предварительно обученной сети является логарифмическим коротковременным преобразованием Фурье (STFT) реверберативного звука. Сеть предсказывает логарифмическую величину STFT деревертированного входа. Чтобы оценить исходный аудиосигнал временной области, вы выполняете обратный STFT и принимаете фазу реверберативного аудиосигнала.

Используйте следующие параметры для вычисления STFT.

params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;

Использование stft для вычисления односторонней логарифмической величины STFT. Используйте одинарную точность при вычислении функций, чтобы лучше использовать память и ускорить обучение. Несмотря на то, что односторонний STFT дает 257 интервалов частоты, учитывайте только 256 интервалов и игнорируйте самый высокий интервал частоты.

revAudio = single(revAudio);    
audioSTFT = stft(revAudio,'Window',params.Window,'OverlapLength',params.OverlapLength, ...
                'FFTLength',params.FFTLength,'FrequencyRange','onesided'); 
Eps = realmin('single');
reverbFeats = log(abs(audioSTFT(1:end-1,:)) + Eps);

Извлеките фазу STFT.

phaseOriginal = angle(audioSTFT(1:end-1,:));

Каждый вход будет иметь размерности 256 на 256 (частотные интервалы по временным шагам). Разделите логарифмическую величину STFT на сегменты с 256 временными шагами.

params.NumSegments = 256;
params.NumFeatures = 256;
totalFrames = size(reverbFeats,2);
chunks = ceil(totalFrames/params.NumSegments);
reverbSTFTSegments = mat2cell(reverbFeats,params.NumFeatures, ...
    [params.NumSegments*ones(1,chunks - 1),(totalFrames - (chunks-1)*params.NumSegments)]);
reverbSTFTSegments{chunks} = reverbFeats(:,end-params.NumSegments + 1:end);

Масштабируйте сегментированные функции в области значений [-1,1]. Сохраните минимальное и максимальное значения, используемые для масштабирования для восстановления деревертированного сигнала.

minVals = num2cell(cellfun(@(x)min(x,[],'all'),reverbSTFTSegments));
maxVals = num2cell(cellfun(@(x)max(x,[],'all'),reverbSTFTSegments));

featNorm = cellfun(@(feat,minFeat,maxFeat)2.*(feat - minFeat)./(maxFeat - minFeat) - 1, ...
    reverbSTFTSegments,minVals,maxVals,'UniformOutput',false);

Измените форму функций так, чтобы фрагменты находились вдоль четвертой размерности.

featNorm = reshape(cell2mat(featNorm),params.NumFeatures,params.NumSegments,1,chunks);

Спрогнозируйте спектры сигнала с помощью предварительно обученной сети.

predictedSTFT4D = predict(dereverbNet,featNorm);

Измените форму на 3 измерения и масштабируйте предсказанные STFT до исходной области значений с помощью сохраненных пар минимум-максимум.

predictedSTFT = squeeze(mat2cell(predictedSTFT4D,params.NumFeatures,params.NumSegments,1,ones(1,chunks)))';
featDeNorm = cellfun(@(feat,minFeat,maxFeat) (feat + 1).*(maxFeat-minFeat)./2 + minFeat, ...
    predictedSTFT,minVals,maxVals,'UniformOutput',false);

Сторнируйте логарифмическое масштабирование.

predictedSTFT = cellfun(@exp,featDeNorm,'UniformOutput',false);

Конкатенация предсказанные сегменты STFT 256 величин 256, чтобы получить величину спектрограмму исходной длины.

predictedSTFTAll = predictedSTFT(1:chunks - 1);
predictedSTFTAll = cat(2,predictedSTFTAll{:});
predictedSTFTAll(:,totalFrames - params.NumSegments + 1:totalFrames) = predictedSTFT{chunks};

Перед взятием обратного STFT добавьте нули к предсказанному логарифмическому спектру и фазе вместо самой высокой частоты интервала которая была исключена при подготовке входа признаков.

nCount = size(predictedSTFTAll,3);
predictedSTFTAll = cat(1,predictedSTFTAll,zeros(1,totalFrames,nCount));
phase = cat(1,phaseOriginal,zeros(1,totalFrames,nCount));

Используйте обратную функцию STFT для восстановления деревертированного речевого сигнала временной области, используя предсказанную логарифмическую величину STFT и фазу реверантного речевого сигнала.

oneSidedSTFT = predictedSTFTAll.*exp(1j*phase);
dereverbedAudio = istft(oneSidedSTFT, ...
    'Window',params.Window,'OverlapLength',params.OverlapLength, ...
    'FFTLength',params.FFTLength,'ConjugateSymmetric',true, ...
    'FrequencyRange','onesided');

dereverbedAudio = dereverbedAudio./max(abs([dereverbedAudio;revAudio]));
dereverbedAudio = [dereverbedAudio;zeros(length(revAudio) - numel(dereverbedAudio), 1)];

Прослушайте воспроизведенный аудиосигнал.

sound(dereverbedAudio,fs)

Постройте график чистых, реверберативных и дериверированных речевых сигналов.

t = (1/fs)*(0:numel(cleanAudio)-1);

figure

subplot(3,1,1)
plot(t,cleanAudio)
xlabel('Time (s)')
grid on
subtitle('Clean Speech Signal')

subplot(3,1,2)
plot(t,revAudio)
xlabel('Time (s)')
grid on
subtitle('Revereberated Speech Signal')

subplot(3,1,3)
plot(t,dereverbedAudio)
xlabel('Time (s)')
grid on
subtitle('Derevereberated Speech Signal')

Визуализируйте спектрограммы чистых, реверберативных и деревертированных речевых сигналов.

figure('Position',[100,100,800,800])

subplot(3,1,1)
spectrogram(cleanAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');
subtitle('Clean')

subplot(3,1,2)
spectrogram(revAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');  
subtitle('Reverberated')
 
subplot(3,1,3)
spectrogram(dereverbedAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');  
subtitle('Predicted (Dereverberated)')

Загрузите набор данных

Этот пример использует речевую базу данных Reverberant [2] и соответствующую базу данных чистой речи [3], чтобы обучить сеть.

Загрузите чистый набор речевых данных.

url1 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_trainset_28spk_wav.zip';
url2 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_testset_wav.zip';
downloadFolder = tempdir;
cleanDataFolder = fullfile(downloadFolder,'DS_10283_2791');

if ~exist(cleanDataFolder,'dir')
    disp('Downloading data set (6 GB) ...')
    unzip(url1,cleanDataFolder)
    unzip(url2,cleanDataFolder)
end

Загрузите ревербированный набор данных.

url3 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_trainset_28spk_wav.zip';
url4 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_testset_wav.zip';
downloadFolder = tempdir;
reverbDataFolder = fullfile(downloadFolder,'DS_10283_2031');

if ~exist(reverbDataFolder,'dir')
    disp('Downloading data set (6 GB) ...')
    unzip(url3,reverbDataFolder)
    unzip(url4,reverbDataFolder)
end

Предварительная обработка данных и редукция данных

После загрузки данных предварительно обработайте загруженные данные и извлеките функции перед обучением модели DNN:

  1. Синтетически сгенерируйте ревербрантные данные с помощью reverberator объект

  2. Разделите каждый речевой сигнал на небольшие сегменты длительностью 2.072 с

  3. Сбросьте сегменты, которые содержат значительные тихие области

  4. Извлеките STFT с логарифмической амплитудой как предиктор и целевые функции

  5. Масштабирование и изменение формы функций

Во-первых, создайте две audioDatastore объекты, которые указывают на чистые и реверберативные наборы данных.

adsCleanTrain = audioDatastore(fullfile(cleanDataFolder,'clean_trainset_28spk_wav'),'IncludeSubfolders',true);
adsReverbTrain = audioDatastore(fullfile(reverbDataFolder,'reverb_trainset_28spk_wav'),'IncludeSubfolders',true);

Синтетическая генерация речевых данных

Величина реверберации в исходных данных относительно небольшая. Вы увеличите речевые данные со значительными эффектами реверберации, используя reverberator объект.

Создайте audioDatastore который указывает на чистый набор данных, выделенный для синтетической генерации реверберантных данных.

adsSyntheticCleanTrain = subset(adsCleanTrain,10e3+1:length(adsCleanTrain.Files));
adsCleanTrain = subset(adsCleanTrain,1:10e3);
adsReverbTrain = subset(adsReverbTrain,1:10e3);

Дискретизация от 48 кГц до 16 кГц.

adsSyntheticCleanTrain = transform(adsSyntheticCleanTrain,@(x)resample(x,16e3,48e3));
adsCleanTrain = transform(adsCleanTrain,@(x)resample(x,16e3,48e3));
adsReverbTrain = transform(adsReverbTrain,@(x)resample(x,16e3,48e3));

Объедините два хранилища аудиоданных, поддерживая соответствие между чистой и ревербрантной выборками речи.

adsCombinedTrain = combine(adsCleanTrain,adsReverbTrain);

Функция applyReverb создает reverberator объект, обновляет параметры предварительной задержки, коэффициента затухания и влажно-сухого смешивания, как задано, и затем применяет реверберацию. Использование audioDataAugmenter создание синтетически сгенерированных реверберативных данных.

augmenter = audioDataAugmenter('AugmentationMode','independent','NumAugmentations', 1,'ApplyAddNoise',0, ...
    'ApplyTimeStretch',0,'ApplyPitchShift',0,'ApplyVolumeControl',0,'ApplyTimeShift',0);
algorithmHandle = @(y,preDelay,decayFactor,wetDryMix,samplingRate) ...
    applyReverb(y,preDelay,decayFactor,wetDryMix,samplingRate);

addAugmentationMethod(augmenter,'Reverb',algorithmHandle, ...
    'AugmentationParameter',{'PreDelay','DecayFactor','WetDryMix','SamplingRate'}, ...
    'ParameterRange',{[0.15,0.25],[0.2,0.5],[0.3,0.45],[16000,16000]})

augmenter.ReverbProbability = 1;
disp(augmenter)
  audioDataAugmenter with properties:

               AugmentationMode: 'independent'
    AugmentationParameterSource: 'random'
               NumAugmentations: 1
               ApplyTimeStretch: 0
                ApplyPitchShift: 0
             ApplyVolumeControl: 0
                  ApplyAddNoise: 0
                 ApplyTimeShift: 0
                    ApplyReverb: 1
                  PreDelayRange: [0.1500 0.2500]
               DecayFactorRange: [0.2000 0.5000]
                 WetDryMixRange: [0.3000 0.4500]
              SamplingRateRange: [16000 16000]

Создайте новую audioDatastore соответствует синтетически сгенерированным ревербрантным данным по вызову transform для применения увеличения данных.

adsSyntheticReverbTrain = transform(adsSyntheticCleanTrain,@(y)deal(augment(augmenter,y,16e3).Audio{1}));

Объедините два хранилища аудиоданных.

adsSyntheticCombinedTrain = combine(adsSyntheticCleanTrain,adsSyntheticReverbTrain);

Далее, основываясь на размерностях входных функций в сеть, сегментируйте аудио в фрагменты длительностью 2.072 с с перекрытием 50%.

Иметь слишком много молчаливых сегментов может негативно повлиять на обучение модели DNN. Удалите сегменты, которые в основном являются молчаливыми (более 50% от длительности) и исключить из обучения модели. Не удаляйте полностью тишину, потому что модель не будет устойчива к тихим областям, и небольшие эффекты реверберации могут быть определены как тишина. detectSpeech может идентифицировать начальную и конечную точки тихих областей. После этих двух шагов процесс редукции данных может быть выполнен, как объяснено в первом разделе. helperFeatureExtract реализует эти шаги.

Задайте параметры редукции данных. Путем настройки reduceDataSet если true, вы выбираете небольшой подмножество наборов данных, чтобы выполнить последующие шаги.

reduceDataSet = true;
params.fs = 16000;
парамы. WindowdowLength = 512;
params.Window = hamming (params.WindowDowLength,"periodic");
params.OverlapLength = round (0,75 * params.WindowDowLength);
params.FFTLength = params. WindowDowLength;
samplesPerMs = params.fs/1000;
params.samplesPerImage = (24 + 256 * 8) * samplesPerMs;
params.shiftImage = params.samplesPerImage/2;
парамы. NumSegments = 256;
парамы. NumFeatures = 256
params = struct with fields:
    WindowdowLength: 512
             Window: [512×1 double]
      OverlapLength: 384
          FFTLength: 512
        NumSegments: 256
        NumFeatures: 256
                 fs: 16000
    samplesPerImage: 33152
         shiftImage: 16576

Чтобы ускорить обработку, распределите задачу предварительной обработки и редукции данных между несколькими работниками с помощью parfor.

Определите количество разделов для набора данных. Если у вас нет Parallel Computing Toolbox™, используйте один раздел.

if ~isempty(ver('parallel'))
    pool = gcp;
    numPar = numpartitions(adsCombinedTrain,pool);
else
    numPar = 1;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

Для каждого раздела считывайте из datastore, предварительно обрабатывайте аудиосигнал, а затем извлекайте функции.

if reduceDataSet
    adsCombinedTrain = shuffle(adsCombinedTrain); %#ok
    adsCombinedTrain = subset(adsCombinedTrain,1:200);
    
    adsSyntheticCombinedTrain = shuffle(adsSyntheticCombinedTrain);
    adsSyntheticCombinedTrain = subset(adsSyntheticCombinedTrain,1:200);
end

allCleanFeatures = cell(1,numPar);
allReverbFeatures = cell(1,numPar);

parfor iPartition = 1:numPar
    combinedPartition = partition(adsCombinedTrain,numPar,iPartition);
    combinedSyntheticPartition = partition(adsSyntheticCombinedTrain,numPar,iPartition);
        
    cPartitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    cSyntheticPartitionSize = numel(combinedSyntheticPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    partitionSize = cPartitionSize + cSyntheticPartitionSize;
    
    cleanFeaturesPartition = cell(1,partitionSize);    
    reverbFeaturesPartition = cell(1,partitionSize);  
    
    for idx = 1:partitionSize
        if idx <= cPartitionSize
            audios = read(combinedPartition);
        else
            audios = read(combinedSyntheticPartition);
        end
        cleanAudio = single(audios(:,1));
        reverbAudio = single(audios(:,2));
        [featuresClean,featuresReverb] = helperFeatureExtract(cleanAudio,reverbAudio,false,params);
        cleanFeaturesPartition{idx} = featuresClean;
        reverbFeaturesPartition{idx} = featuresReverb;
    end
    allCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:});
    allReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:});
end

allCleanFeatures = cat(2,allCleanFeatures{:});
allReverbFeatures = cat(2,allReverbFeatures{:});

Нормализуйте извлеченные признаки в область значений [-1,1], а затем измените форму, как объяснено в первом разделе, с помощью функции featureNormalizeAndReshape.

trainClean = featureNormalizeAndReshape(allCleanFeatures);
trainReverb = featureNormalizeAndReshape(allReverbFeatures);

Теперь, когда вы извлекли функции STFT логарифмической величины из обучающих наборов данных, выполните ту же процедуру, чтобы извлечь функции из наборов данных валидации. В целях реконструкции сохраните фазу реверсивных выборок речи набора данных валидации. В сложение сохраните аудио данных как для чистых, так и для ревербрантных речевых выборок в наборе валидации, чтобы использовать в процессе оценки (следующий раздел).

adsCleanVal = audioDatastore(fullfile(cleanDataFolder,'clean_testset_wav'),'IncludeSubfolders',true);
adsReverbVal = audioDatastore(fullfile(reverbDataFolder,'reverb_testset_wav'),'IncludeSubfolders',true);

Дискретизация от 48 кГц до 16 кГц.

adsCleanVal = transform(adsCleanVal,@(x)resample(x,16e3,48e3));
adsReverbVal = transform(adsReverbVal,@(x)resample(x,16e3,48e3));

adsCombinedVal = combine(adsCleanVal,adsReverbVal); 
if reduceDataSet
    adsCombinedVal = shuffle(adsCombinedVal);%#ok
    adsCombinedVal = subset(adsCombinedVal,1:50);
end

allValCleanFeatures = cell(1,numPar);
allValReverbFeatures = cell(1,numPar);
allValReverbPhase = cell(1,numPar);
allValCleanAudios = cell(1,numPar);
allValReverbAudios = cell(1,numPar);

parfor iPartition = 1:numPar
    combinedPartition = partition(adsCombinedVal,numPar,iPartition);
    
    partitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    
    cleanFeaturesPartition = cell(1,partitionSize);    
    reverbFeaturesPartition = cell(1,partitionSize);  
    reverbPhasePartition = cell(1,partitionSize); 
    cleanAudiosPartition = cell(1,partitionSize); 
    reverbAudiosPartition = cell(1,partitionSize);

    for idx = 1:partitionSize
        audios = read(combinedPartition);
        
        cleanAudio = single(audios(:,1));
        reverbAudio = single(audios(:,2));
        
        [a,b,c,d,e] = helperFeatureExtract(cleanAudio,reverbAudio,true,params);
        
        cleanFeaturesPartition{idx} = a;
        reverbFeaturesPartition{idx} = b;  
        reverbPhasePartition{idx} = c;
        cleanAudiosPartition{idx} = d;
        reverbAudiosPartition{idx} = e;
    end
    allValCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:});
    allValReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:});
    allValReverbPhase{iPartition} = cat(2,reverbPhasePartition{:});
    allValCleanAudios{iPartition} = cat(2,cleanAudiosPartition{:});
    allValReverbAudios{iPartition} = cat(2,reverbAudiosPartition{:});
end

allValCleanFeatures = cat(2,allValCleanFeatures{:});
allValReverbFeatures = cat(2,allValReverbFeatures{:});
allValReverbPhase = cat(2,allValReverbPhase{:});
allValCleanAudios = cat(2,allValCleanAudios{:});
allValReverbAudios = cat(2,allValReverbAudios{:});

valClean = featureNormalizeAndReshape(allValCleanFeatures);

Сохраните минимальные и максимальные значения каждой функции набора реверберантной валидации. Эти значения будут использоваться в процессе реконструкции.

[valReverb,valMinMaxPairs] = featureNormalizeAndReshape(allValReverbFeatures);

Определите архитектуру нейронной сети

Полностью сверточная сетевая архитектура U-Net была адаптирована для этой задачи речевой дереверберации, как предложено в [1]. «U-Net» является сетью энкодер-декодер с пропущенными соединениями. В модели U-Net каждый слой понижает свой вход (шаг 2), пока не будет достигнут слой узкого места (путь кодирования). В последующих слоях вход усиливается каждым слоем до тех пор, пока выход не вернется к исходной форме (пути декодирования). Чтобы минимизировать потерю низкоуровневой информации в процессе понижающей дискретизации, соединения выполняются между зеркальными слоями путем непосредственной конкатенации выходов соответствующих слоев (пропуск соединений).

Определите сетевую архитектуру и верните график слоев с соединениями.

params.WindowdowLength = 512;
params.FFTLength = params.WindowdowLength;
params.NumFeatures = params.FFTLength/2;
params.NumSegments = 256;
    
filterH = 6;
filterW = 6;
numChannels = 1;
nFilters = [64,128,256,512,512,512,512,512];

inputLayer = imageInputLayer([params.NumFeatures,params.NumSegments,numChannels], ...
    'Normalization','none','Name','input');
layers = inputLayer;

% U-Net squeezing path
layers = [layers;
    convolution2dLayer([filterH,filterW],nFilters(1),'Stride',2,'Padding','same','Name',"conv"+string(1));
    leakyReluLayer(0.2,'Name',"leaky-relu"+string(1))];
        
for i = 2:8
    layers =  [layers;
        convolution2dLayer([filterH,filterW],nFilters(i),'Stride',2,'Padding','same','Name',"conv"+string(i));
        batchNormalizationLayer('Name',"batchnorm"+string(i))];%#ok
    if i ~= 8
        layers = [layers;leakyReluLayer(0.2,'Name',"leaky-relu"+string(i))];%#ok
    else
        layers = [layers;reluLayer('Name',"relu"+string(i))];%#ok
    end
end

% U-Net expanding path
for i = 7:-1:0
    nChannels = numChannels;
    if i > 0
        nChannels = nFilters(i);
    end
    layers = [layers;
        transposedConv2dLayer([filterH,filterW],nChannels,'Stride',2,'Cropping','same','Name',"deconv"+string(i))];%#ok
    if i > 0
        layers = [layers; batchNormalizationLayer('Name',"de-batchnorm" +string(i))];%#ok
    end
    if i > 4
        layers = [layers;dropoutLayer(0.5,'Name',"de-dropout"+string(i))];%#ok
    end
    if i > 0
        layers = [layers;
            reluLayer('Name',"de-relu"+string(i));
            concatenationLayer(3,2,'Name',"concat"+string(i))];%#ok
    else
        layers = [layers;tanhLayer('Name',"de-tanh"+string(i))];%#ok
    end
end

layers = [layers;regressionLayer('Name','output')];

unetLayerGraph = layerGraph(layers); 

% Define skip-connections
for i = 1:7
    unetLayerGraph = connectLayers(unetLayerGraph,'leaky-relu'+string(i),'concat'+string(i)+'/in2');
end

Использование analyzeNetwork для просмотра архитектуры модели. Это хороший способ визуализировать связи между слоями.

analyzeNetwork(unetLayerGraph); 

Обучите сеть

Вы будете использовать среднюю квадратичную невязку (MSE) между спектрами логарифмической амплитуды деревертированной речевой выборки (выход модели) и соответствующей чистой речевой выборкой (цель) в качестве функции потерь. Используйте adam оптимизатор и мини-пакет размером 128 для обучения. Разрешить модели обучаться в течение максимум 50 эпох. Если потеря валидации не улучшается в течение 5 последовательных эпох, завершите процесс обучения. Уменьшите скорость обучения в 10 раз каждые 15 эпох.

Определите опции обучения следующим образом. Измените окружение выполнения и возможность выполнения фоновой диспетчеризации в зависимости от доступности оборудования и наличия доступа к Parallel Computing Toolbox™.

initialLearnRate = 8e-4;
miniBatchSize = 64;

options = trainingOptions("adam", ...
        "MaxEpochs", 50, ...
        "InitialLearnRate",initialLearnRate, ...
        "MiniBatchSize",miniBatchSize, ...
        "Shuffle","every-epoch", ...
        "Plots","training-progress", ...
        "Verbose",false, ...
        "ValidationFrequency",max(1,floor(size(trainReverb,4)/miniBatchSize)), ...
        "ValidationPatience",5, ...
        "LearnRateSchedule","piecewise", ...
        "LearnRateDropFactor",0.1, ... 
        "LearnRateDropPeriod",15, ...
        "ExecutionEnvironment","gpu", ...
        "DispatchInBackground",true, ...
        "ValidationData",{valReverb,valClean});

Обучите сеть.

dereverbNet = trainNetwork(trainReverb,trainClean,unetLayerGraph,options);

Оценка эффективности сети

Предсказание и реконструкция

Спрогнозируйте логарифмические спектры набора валидации.

predictedSTFT4D = predict(dereverbNet,valReverb);

Используйте функцию helperReconstructPredictedAudios, чтобы восстановить предсказанную речь. Эта функция выполняет действия, описанные в первом разделе.

params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
params.fs = 16000;

dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,valMinMaxPairs,allValReverbPhase,allValReverbAudios,params);

Визуализируйте логарифмическую величину STFT чистых, реверберативных и соответствующих деревертированных речевых сигналов.

figure('Position',[100,100,1024,1200])

subplot(3,1,1)
imagesc(squeeze(allValCleanFeatures{1}))    
set(gca,'Ydir','normal')
subtitle('Clean')
xlabel('Time')
ylabel('Frequency')
colorbar

subplot(3,1,2)
imagesc(squeeze(allValReverbFeatures{1}))
set(gca,'Ydir','normal')
subtitle('Reverberated')
xlabel('Time')
ylabel('Frequency')
colorbar

subplot(3,1,3)
imagesc(squeeze(predictedSTFT4D(:,:,:,1)))
set(gca,'Ydir','normal')
subtitle('Predicted (Dereverberated)')
xlabel('Time')
ylabel('Frequency')
caxis([-1,1])
colorbar

Метрики оценки

Для оценки эффективности сети используется подмножество объективных измерений, используемых в [1]. Эти метрики вычисляются на сигналах временной области.

  • Cepstrum distance (CD) - Обеспечивает оценку логарифмического спектрального расстояния между двумя спектрами (предсказанным и чистым). Меньшие значения указывают на лучшее качество.

  • Логарифмический коэффициент правдоподобия (LLR) - линейное прогнозирующее кодирование (LPC) на основе целевого измерения. Меньшие значения указывают на лучшее качество.

Вычислите эти измерения для реверантной речи и деревертированных речевых сигналов.

[summaryMeasuresReconstructed,allMeasuresReconstructed] = calculateObjectiveMeasures(dereverbedAudioAll,allValCleanAudios,params.fs);
[summaryMeasuresReverb,allMeasuresReverb] = calculateObjectiveMeasures(allValReverbAudios,allValCleanAudios,params.fs);
disp(summaryMeasuresReconstructed)
       avgCdMean: 3.8386
     avgCdMedian: 3.3671
      avgLlrMean: 0.9152
    avgLlrMedian: 0.8096
disp(summaryMeasuresReverb)
       avgCdMean: 4.2591
     avgCdMedian: 3.6336
      avgLlrMean: 0.9726
    avgLlrMedian: 0.8714

Гистограммы иллюстрируют распределение среднего CD, среднего SRMR и среднего LLR реверберативных и деревертированных данных.

figure('position',[50,50,1100,1300])

subplot(2,1,1)
histogram(allMeasuresReverb.cdMean,10)
hold on
histogram(allMeasuresReconstructed.cdMean, 10)
subtitle('Mean Cepstral Distance Distribution')
ylabel('count')
xlabel('mean CD')
legend('Reverberant (Original)','Dereverberated (Predicted)')

subplot(2,1,2)
histogram(allMeasuresReverb.llrMean,10)
hold on
histogram(allMeasuresReconstructed.llrMean,10)
subtitle('Mean Log Likelihood Ratio Distribution')
ylabel('Count')
xlabel('Mean LLR')
legend('Reverberant (Original)','Dereverberated (Predicted)')

Ссылки

[1] Ernst, O., Chazan, S.E., Gannot, S., & Goldberger, J. (2018). Речевая дереверберация с использованием полностью сверточных сетей. 2018 26-я Европейская конференция по обработке сигналов (EUSIPCO), 390-394.

[2] https://datashare.is.ed.ac.uk/handle/10283/2031

[3] https://datashare.is.ed.ac.uk/handle/10283/2791

[4] https://github.com/MuSAELab/SRMRToolbox

Вспомогательные функции

Применить реверберацию

function yOut = applyReverb(y,preDelay,decayFactor,wetDryMix,fs)
% This function generates reverberant speech data using the reverberator
% object
%
% inputs: 
% y                                - clean speech sample
% preDelay, decayFactor, wetDryMix - reverberation parameters
% fs                               - sampling rate of y
%
% outputs:
% yOut - corresponding reveberated speech sample

    revObj = reverberator('SampleRate',fs, ...
        'DecayFactor',decayFactor, ...
        'WetDryMix',wetDryMix, ...
        'PreDelay',preDelay);
    yOut = revObj(y);
    yOut = yOut(1:length(y),1);
end

Извлечение пакета функций

function [featuresClean,featuresReverb,phaseReverb,cleanAudios,reverbAudios] ...
    = helperFeatureExtract(cleanAudio,reverbAudio,isVal,params)
% This function performs the preprocessing and features extraction task on 
% the audio files used for dereverberation model training and testing.
%
% inputs:
% cleanAudio  - the clean audio file (reference)
% reverbAudio - corresponding reverberant speech file
% isVal       - Boolean flag indicating if it is the validation set
% params      - a structure containing feature extraction parameters
%
% outputs:
% featuresClean  - log-magnitude STFT features of clean audio
% featuresReverb - log-magnitude STFT features of reverberant audio
% phaseReverb    - phase of STFT of reverberant audio
% cleanAudios    - 2.072s-segments of clean audio file used for feature extraction
% reverbAudios   - 2.072s-segments of corresponding reverberant audio

    assert(length(cleanAudio) == length(reverbAudio));
    nSegments = floor((length(reverbAudio) - (params.samplesPerImage - params.shiftImage))/params.shiftImage);

    featuresClean = {};
    featuresReverb = {};
    phaseReverb = {};
    cleanAudios = {};
    reverbAudios = {};
    nGood = 0;
    nonSilentRegions = detectSpeech(reverbAudio, params.fs);
    nonSilentRegionIdx = 1;
    totalRegions = size(nonSilentRegions, 1);
    for cid = 1:nSegments
        start = (cid - 1)*params.shiftImage + 1;
        en = start + params.samplesPerImage - 1;

        nonSilentSamples = 0;
        while nonSilentRegionIdx < totalRegions && nonSilentRegions(nonSilentRegionIdx, 2) < start
            nonSilentRegionIdx = nonSilentRegionIdx + 1;
        end
        
        nonSilentStart = nonSilentRegionIdx;
        while nonSilentStart <= totalRegions && nonSilentRegions(nonSilentStart, 1) <= en
            nonSilentDuration = min(en, nonSilentRegions(nonSilentStart,2)) - max(start,nonSilentRegions(nonSilentStart,1)) + 1;
            nonSilentSamples = nonSilentSamples + nonSilentDuration; 
            nonSilentStart = nonSilentStart + 1;
        end
        
        nonSilentPerc = nonSilentSamples * 100 / (en - start + 1);
        silent = nonSilentPerc < 50;
        
        reverbAudioSegment = reverbAudio(start:en);
        if ~silent
            nGood = nGood + 1;
            cleanAudioSegment = cleanAudio(start:en);
            assert(length(cleanAudioSegment)==length(reverbAudioSegment), 'Lengths do not match after chunking')
            
            % Clean Audio
            [featsUnit, ~] = featureExtract(cleanAudioSegment, params);
            featuresClean{nGood} = featsUnit; %#ok

            % Reverb Audio
            [featsUnit, phaseUnit] = featureExtract(reverbAudioSegment, params);
            featuresReverb{nGood} = featsUnit; %#ok
            if isVal
                phaseReverb{nGood} = phaseUnit; %#ok
                reverbAudios{nGood} = reverbAudioSegment;%#ok
                cleanAudios{nGood} = cleanAudioSegment;%#ok
            end
        end
    end
end

Извлечение функций

function [features, phase, lastFBin] = featureExtract(audio, params)
% Function to extract features for a speech file
    audio = single(audio);
    
    audioSTFT = stft(audio,'Window',params.Window,'OverlapLength',params.OverlapLength, ...
                    'FFTLength', params.FFTLength, 'FrequencyRange', 'onesided');
    
    phase = single(angle(audioSTFT(1:end-1,:)));     
    features = single(log(abs(audioSTFT(1:end-1,:)) + 10e-30)); 
    lastFBin = audioSTFT(end,:);

end

Нормализация и изменение формы функций

function [featNorm,minMaxPairs] = featureNormalizeAndReshape(feats)
% function to normalize features - range [-1, 1] and reshape to 4
% dimensions
%
% inputs:
% feats - 3-dimensional array of extracted features
%
% outputs:
% featNorm - normalized and reshaped features
% minMaxPairs - array of original min and max pairs used for normalization

    nSamples = length(feats);
    minMaxPairs = zeros(nSamples,2,'single');
    featNorm = zeros([size(feats{1}),nSamples],'single');
    parfor i = 1:nSamples
        feat = feats{i};
        maxFeat = max(feat,[],'all');
        minFeat = min(feat,[],'all');
        featNorm(:,:,i) = 2.*(feat - minFeat)./(maxFeat - minFeat) - 1;
        minMaxPairs(i,:) = [minFeat,maxFeat];
    end
    featNorm = reshape(featNorm,size(featNorm,1),size(featNorm,2),1,size(featNorm,3));
end

Восстановление предсказанного аудио

function dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,minMaxPairs,reverbPhase,reverbAudios,params)
% This function will reconstruct the 2.072s long audios predicted by the 
% model using the predicted log-magnitude spectrogram and the phase of the 
% reverberant audio file
%
% inputs:
% predictedSTFT4D - Predicted 4-dimensional STFT log-magnitude features
% minMaxPairs     - Original minimum/maximum value pairs used in normalization
% reverbPhase     - Array of phases of STFT of reverberant audio files
% reverbAudios    - 2.072s-segments of corresponding reverberant audios
% params          - Structure containing feature extraction parameters

    predictedSTFT = squeeze(predictedSTFT4D);
    denormalizedFeatures = zeros(size(predictedSTFT),'single');
    for i = 1:size(predictedSTFT,3)
        feat = predictedSTFT(:,:,i);
        maxFeat = minMaxPairs(i,2);
        minFeat = minMaxPairs(i,1);
        denormalizedFeatures(:,:,i) = (feat + 1).*(maxFeat-minFeat)./2 + minFeat;        
    end
    
    predictedSTFT = exp(denormalizedFeatures);
    
    nCount = size(predictedSTFT,3);
    dereverbedAudioAll = cell(1,nCount);

    nSeg = params.NumSegments;
    win = params.Window;
    ovrlp = params.OverlapLength;
    FFTLength = params.FFTLength;
    parfor ii = 1:nCount
        % Append zeros to the highest frequency bin
        stftUnit = predictedSTFT(:,:,ii);
        stftUnit = cat(1,stftUnit, zeros(1,nSeg)); 
        phase = reverbPhase{ii};
        phase = cat(1,phase,zeros(1,nSeg));

        oneSidedSTFT = stftUnit.*exp(1j*phase);
        dereverbedAudio= istft(oneSidedSTFT, ...
            'Window', win,'OverlapLength', ovrlp, ...
                                    'FFTLength',FFTLength,'ConjugateSymmetric',true,...
                                    'FrequencyRange','onesided');

        dereverbedAudioAll{ii} = dereverbedAudio./max(max(abs(dereverbedAudio)), max(abs(reverbAudios{ii})));
    end
end

Вычисление целевых измерений

function [summaryMeasures,allMeasures] = calculateObjectiveMeasures(reconstructedAudios,cleanAudios,fs)
% This function computes the objective measures on time-domain signals.
%
% inputs:
% reconstructedAudios - An array of audio files to evaluate.
% cleanAudios - An array of reference audio files
% fs - Sampling rate of audio files
%
% outputs:
% summaryMeasures - Global means of CD, LLR individual mean and median values
% allMeasures - Individual mean and median values

    nAudios = length(reconstructedAudios);
    cdMean = zeros(nAudios,1);
    cdMedian = zeros(nAudios,1);
    llrMean = zeros(nAudios,1);
    llrMedian = zeros(nAudios,1);

    parfor k = 1 : nAudios
      y = reconstructedAudios{k};
      x = cleanAudios{k};

      y = y./max(abs(y));
      x = x./max(abs(x));

      [cdMean(k),cdMedian(k)] = cepstralDistance(x,y,fs);
      [llrMean(k),llrMedian(k)] = lpcLogLikelihoodRatio(y,x,fs);
    end
    
    summaryMeasures.avgCdMean = mean(cdMean);
    summaryMeasures.avgCdMedian = mean(cdMedian);
    summaryMeasures.avgLlrMean = mean(llrMean);
    summaryMeasures.avgLlrMedian = mean(llrMedian);   
    
    allMeasures.cdMean = cdMean;
    allMeasures.llrMean = llrMean;
end

Кепстральное расстояние

function [meanVal, medianVal] = cepstralDistance(x,y,fs)
    x = x / sqrt(sum(x.^2));
    y = y / sqrt(sum(y.^2));

    width = round(0.025*fs);
    shift = round(0.01*fs);

    nSamples = length(x);
    nFrames = floor((nSamples - width + shift)/shift);
    win = window(@hanning,width);

    winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1);

    xFrames = x(winIndex).*win;
    yFrames = y(winIndex).*win;

    xCeps = cepstralReal(xFrames,width);
    yCeps = cepstralReal(yFrames,width);

    dist = (xCeps - yCeps).^2;
    cepsD = 10 / log(10)*sqrt(2*sum(dist(2:end,:),1) + dist(1,:));
    cepsD = max(min(cepsD,10),0);

    meanVal = mean(cepsD);
    medianVal = median(cepsD);
end

Реальный Cepstrum

function realC = cepstralReal(x, width)
    width2p = 2 ^ nextpow2(width);
    powX = abs(fft(x, width2p));

    lowCutoff = max(powX(:)) * 10^-5;
    powX  = max(powX, lowCutoff);

    realC = real(ifft(log(powX)));
    order = 24;
    realC = realC(1 : order + 1, :);
    realC = realC - mean(realC, 2);
end

Коэффициент логарифмической правдоподобности КНД

function [meanLlr,medianLlr] = lpcLogLikelihoodRatio(x,y,fs)
    order = 12;
    width = round(0.025*fs);
    shift = round(0.01*fs);

    nSamples = length(x);
    nFrames  = floor((nSamples - width + shift)/shift);
    win = window(@hanning,width);

    winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1);

    xFrames = x(winIndex) .* win;
    yFrames = y(winIndex) .* win;

    lpcX = realLpc(xFrames, width, order);
    [lpcY,realY] = realLpc(yFrames, width, order);

    llr = zeros(nFrames, 1);
    for n = 1 : nFrames
      R = toeplitz(realY(1:order+1,n));
      num = lpcX(:,n)'*R*lpcX(:,n);
      den = lpcY(:,n)'*R*lpcY(:,n);  
      llr(n) = log(num/den);
    end

    llr = sort(llr);
    llr = llr(1:ceil(nFrames*0.95));
    llr = max(min(llr,2),0);

    meanLlr = mean(llr);
    medianLlr = median(llr);
end

Действительные линейные коэффициенты точности

function [lpcCoeffs, realX] = realLpc(xFrames, width, order)
    width2p = 2 ^ nextpow2(width);
    X = fft(xFrames, width2p);

    Rx = ifft(abs(X).^2);
    Rx = Rx./width; 
    realX = real(Rx);

    lpcX = levinson(realX, order);
    lpcCoeffs = real(lpcX');
end