В этом примере показано, как обучить полностью сверточную сеть U-Net (FCN) [1] отключать речевые сигналы.
Реверберация происходит, когда речевой сигнал отражается от объектов в пространстве, вызывая нарастание множества отражений и в конечном итоге приводит к ухудшению качества речи. Дереверберация - это процесс уменьшения эффектов реверберации в речевом сигнале.
Перед подробным началом процесса обучения используйте предварительно обученную сеть для удаления речевого сигнала.
Загрузите предварительно обученную сеть. Эта сеть была обучена на 56-спикерских версиях учебных наборов данных. Пример состоит в обучении 28-громкоговорящей версии.
url = 'https://ssd.mathworks.com/supportfiles/audio/dereverbnet.zip'; downloadFolder = tempdir; networkDataFolder = fullfile(downloadFolder,'derevernet'); if ~exist(networkDataFolder,'dir') disp('Downloading pretrained network ...') unzip(url,downloadFolder) end load(fullfile(networkDataFolder,'dereverbNet.mat'))
Прослушивание чистого речевого сигнала, дискретизированного на частоте 16 кГц.
[cleanAudio,fs] = audioread('clean_speech_signal.wav');
sound(cleanAudio,fs)Акустический путь может быть смоделирован с использованием импульсной характеристики помещения. Можно моделировать реверберацию, свернув анехойный сигнал с импульсной характеристикой помещения.
Загрузите и постройте график импульсной характеристики помещения.
[rirAudio,fsR] = audioread('room_impulse_response.wav'); tAxis = (1/fsR)*(0:numel(rirAudio)-1); figure plot(tAxis,rirAudio) xlabel('Time (s)') ylabel('Amplitude') grid on

Сверните чистую речь с импульсной характеристикой комнаты для получения ревербрированной речи. Выровняйте длины и амплитуды реверберных и чистых речевых сигналов.
revAudio = conv(cleanAudio,rirAudio); revAudio = revAudio(1:numel(cleanAudio)); revAudio = revAudio.*(max(abs(cleanAudio))/max(abs(revAudio)));
Прослушать реверберный речевой сигнал.
sound(revAudio,fs)
Вход в предварительно обученную сеть представляет собой логарифмическое кратковременное преобразование Фурье (STFT) реверсивного звука. Сеть прогнозирует логарифмическую величину STFT деинвертированного входного сигнала. Для оценки исходного звукового сигнала во временной области выполняется обратный STFT и принимается фаза реверберативного звука.
Используйте следующие параметры для вычисления STFT.
params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;Использовать stft для вычисления односторонней логарифмической величины STFT. Используйте одну точность при вычислении функций, чтобы лучше использовать память и ускорить обучение. Несмотря на то, что односторонний STFT дает 257 частотные ячейки, следует учитывать только 256 ячейки и игнорировать самую высокую частотную ячейку.
revAudio = single(revAudio); audioSTFT = stft(revAudio,'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength',params.FFTLength,'FrequencyRange','onesided'); Eps = realmin('single'); reverbFeats = log(abs(audioSTFT(1:end-1,:)) + Eps);
Извлеките фазу STFT.
phaseOriginal = angle(audioSTFT(1:end-1,:));
Каждый вход будет иметь размеры 256 на 256 (частотные ячейки по временным шагам). Разбейте логарифмическую величину STFT на сегменты 256 временных шагов.
params.NumSegments = 256;
params.NumFeatures = 256;
totalFrames = size(reverbFeats,2);
chunks = ceil(totalFrames/params.NumSegments);
reverbSTFTSegments = mat2cell(reverbFeats,params.NumFeatures, ...
[params.NumSegments*ones(1,chunks - 1),(totalFrames - (chunks-1)*params.NumSegments)]);
reverbSTFTSegments{chunks} = reverbFeats(:,end-params.NumSegments + 1:end);Масштабируйте сегментированные элементы до диапазона [-1,1]. Сохранение минимального и максимального значений, используемых для масштабирования для восстановления сигнала с пониженным уровнем мощности.
minVals = num2cell(cellfun(@(x)min(x,[],'all'),reverbSTFTSegments)); maxVals = num2cell(cellfun(@(x)max(x,[],'all'),reverbSTFTSegments)); featNorm = cellfun(@(feat,minFeat,maxFeat)2.*(feat - minFeat)./(maxFeat - minFeat) - 1, ... reverbSTFTSegments,minVals,maxVals,'UniformOutput',false);
Измените форму элементов так, чтобы порции располагались вдоль четвертого размера.
featNorm = reshape(cell2mat(featNorm),params.NumFeatures,params.NumSegments,1,chunks);
Предсказать спектры логарифмической величины реверберного речевого сигнала с использованием предварительно обученной сети.
predictedSTFT4D = predict(dereverbNet,featNorm);
Измените форму на 3 размера и масштабируйте прогнозируемые STFT до исходного диапазона с помощью сохраненных пар «минимум-максимум».
predictedSTFT = squeeze(mat2cell(predictedSTFT4D,params.NumFeatures,params.NumSegments,1,ones(1,chunks)))'; featDeNorm = cellfun(@(feat,minFeat,maxFeat) (feat + 1).*(maxFeat-minFeat)./2 + minFeat, ... predictedSTFT,minVals,maxVals,'UniformOutput',false);
Сторнируйте масштабирование журнала.
predictedSTFT = cellfun(@exp,featDeNorm,'UniformOutput',false);Соединяют предсказанные STFT-сегменты 256 на 256 величины для получения спектрограммы величины исходной длины.
predictedSTFTAll = predictedSTFT(1:chunks - 1);
predictedSTFTAll = cat(2,predictedSTFTAll{:});
predictedSTFTAll(:,totalFrames - params.NumSegments + 1:totalFrames) = predictedSTFT{chunks};Перед принятием обратного STFT добавьте нули к прогнозируемому спектру логарифмических величин и фазе вместо бина с самой высокой частотой, который был исключен при подготовке входных признаков.
nCount = size(predictedSTFTAll,3); predictedSTFTAll = cat(1,predictedSTFTAll,zeros(1,totalFrames,nCount)); phase = cat(1,phaseOriginal,zeros(1,totalFrames,nCount));
Используйте обратную функцию STFT для восстановления уменьшенного речевого сигнала во временной области с использованием предсказанной логарифмической величины STFT и фазы реверберного речевого сигнала.
oneSidedSTFT = predictedSTFTAll.*exp(1j*phase); dereverbedAudio = istft(oneSidedSTFT, ... 'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength',params.FFTLength,'ConjugateSymmetric',true, ... 'FrequencyRange','onesided'); dereverbedAudio = dereverbedAudio./max(abs([dereverbedAudio;revAudio])); dereverbedAudio = [dereverbedAudio;zeros(length(revAudio) - numel(dereverbedAudio), 1)];
Прослушать аудиосигнал с пониженной скоростью.
sound(dereverbedAudio,fs)
Постройте график чистых, реверберативных и очищенных речевых сигналов.
t = (1/fs)*(0:numel(cleanAudio)-1); figure subplot(3,1,1) plot(t,cleanAudio) xlabel('Time (s)') grid on subtitle('Clean Speech Signal') subplot(3,1,2) plot(t,revAudio) xlabel('Time (s)') grid on subtitle('Revereberated Speech Signal') subplot(3,1,3) plot(t,dereverbedAudio) xlabel('Time (s)') grid on subtitle('Derevereberated Speech Signal')

Визуализируйте спектрограммы чистых, реверберативных и подавленных речевых сигналов.
figure('Position',[100,100,800,800]) subplot(3,1,1) spectrogram(cleanAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Clean') subplot(3,1,2) spectrogram(revAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Reverberated') subplot(3,1,3) spectrogram(dereverbedAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Predicted (Dereverberated)')

В этом примере для обучения сети используется база данных реверберной речи [2] и соответствующая база данных чистой речи [3].
Загрузите набор чистых речевых данных.
url1 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_trainset_28spk_wav.zip'; url2 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_testset_wav.zip'; downloadFolder = tempdir; cleanDataFolder = fullfile(downloadFolder,'DS_10283_2791'); if ~exist(cleanDataFolder,'dir') disp('Downloading data set (6 GB) ...') unzip(url1,cleanDataFolder) unzip(url2,cleanDataFolder) end
Загрузите реверсированный набор речевых данных.
url3 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_trainset_28spk_wav.zip'; url4 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_testset_wav.zip'; downloadFolder = tempdir; reverbDataFolder = fullfile(downloadFolder,'DS_10283_2031'); if ~exist(reverbDataFolder,'dir') disp('Downloading data set (6 GB) ...') unzip(url3,reverbDataFolder) unzip(url4,reverbDataFolder) end
После загрузки данных выполните предварительную обработку загруженных данных и извлеките функции перед обучением модели DNN:
Синтетически генерировать реверберационные данные с помощью reverberator объект
Разбить каждый речевой сигнал на небольшие сегменты длительностью 2,072с
Отбрасывать сегменты, содержащие значительные тихие области
Извлечение STFT логарифмической величины в качестве предиктора и целевых функций
Масштабирование и изменение формы элементов
Сначала создайте два audioDatastore объекты, указывающие на чистые и реверберные наборы речевых данных.
adsCleanTrain = audioDatastore(fullfile(cleanDataFolder,'clean_trainset_28spk_wav'),'IncludeSubfolders',true); adsReverbTrain = audioDatastore(fullfile(reverbDataFolder,'reverb_trainset_28spk_wav'),'IncludeSubfolders',true);
Величина реверберации в исходных данных относительно мала. Вы будете увеличивать реверберационные речевые данные со значительными эффектами реверберации с помощью reverberator объект.
Создание audioDatastore это указывает на набор данных чистой речи, выделенный для генерации синтетических ревербераторных данных.
adsSyntheticCleanTrain = subset(adsCleanTrain,10e3+1:length(adsCleanTrain.Files)); adsCleanTrain = subset(adsCleanTrain,1:10e3); adsReverbTrain = subset(adsReverbTrain,1:10e3);
Повторная выборка от 48 кГц до 16 кГц.
adsSyntheticCleanTrain = transform(adsSyntheticCleanTrain,@(x)resample(x,16e3,48e3)); adsCleanTrain = transform(adsCleanTrain,@(x)resample(x,16e3,48e3)); adsReverbTrain = transform(adsReverbTrain,@(x)resample(x,16e3,48e3));
Объедините два хранилища аудиоданных, сохранив соответствие между чистыми и реверберными образцами речи.
adsCombinedTrain = combine(adsCleanTrain,adsReverbTrain);
Функция applyReverb создает reverberator объект обновляет параметры предварительной задержки, коэффициента затухания и влажно-сухой смеси, как указано, а затем применяет реверберацию. Использовать audioDataAugmenter для создания синтетически сгенерированных реверберных данных.
augmenter = audioDataAugmenter('AugmentationMode','independent','NumAugmentations', 1,'ApplyAddNoise',0, ... 'ApplyTimeStretch',0,'ApplyPitchShift',0,'ApplyVolumeControl',0,'ApplyTimeShift',0); algorithmHandle = @(y,preDelay,decayFactor,wetDryMix,samplingRate) ... applyReverb(y,preDelay,decayFactor,wetDryMix,samplingRate); addAugmentationMethod(augmenter,'Reverb',algorithmHandle, ... 'AugmentationParameter',{'PreDelay','DecayFactor','WetDryMix','SamplingRate'}, ... 'ParameterRange',{[0.15,0.25],[0.2,0.5],[0.3,0.45],[16000,16000]}) augmenter.ReverbProbability = 1; disp(augmenter)
audioDataAugmenter with properties:
AugmentationMode: 'independent'
AugmentationParameterSource: 'random'
NumAugmentations: 1
ApplyTimeStretch: 0
ApplyPitchShift: 0
ApplyVolumeControl: 0
ApplyAddNoise: 0
ApplyTimeShift: 0
ApplyReverb: 1
PreDelayRange: [0.1500 0.2500]
DecayFactorRange: [0.2000 0.5000]
WetDryMixRange: [0.3000 0.4500]
SamplingRateRange: [16000 16000]
Создание нового audioDatastore соответствующие синтетически сгенерированным реверберативным данным путем вызова transform для применения увеличения данных.
adsSyntheticReverbTrain = transform(adsSyntheticCleanTrain,@(y)deal(augment(augmenter,y,16e3).Audio{1}));Объединение двух хранилищ аудиоданных.
adsSyntheticCombinedTrain = combine(adsSyntheticCleanTrain,adsSyntheticReverbTrain);
Далее, исходя из размеров входных функций в сеть, сегментируйте звук на порции длительностью 2,072 с с перекрытием 50%.
Наличие слишком большого количества бесшумных сегментов может отрицательно повлиять на обучение модели DNN. Удалите сегменты, которые в основном являются бесшумными (более 50% продолжительности), и исключите их из обучения модели. Не следует полностью удалять молчание, поскольку модель не будет устойчива к тихим областям, и небольшие эффекты реверберации могут быть идентифицированы как молчание. detectSpeech может идентифицировать начальную и конечную точки областей молчания. После этих двух стадий процесс извлечения признаков может быть выполнен, как описано в первом разделе. helperFeatureExtract реализует эти шаги.
Определите параметры извлечения элемента. По настройке reduceDataSet true, для выполнения последующих шагов выбирается небольшое подмножество наборов данных.
reduceDataSet =true; params.fs = 16000; params.WindowdowLength = 512; params.Window = hamming(params.WindowdowLength,"periodic"); params.OverlapLength = round(0.75*params.WindowdowLength); params.FFTLength = params.WindowdowLength; samplesPerMs = params.fs/1000; params.samplesPerImage = (24+256*8)*samplesPerMs; params.shiftImage = params.samplesPerImage/2; params.NumSegments = 256; params.NumFeatures = 256
params = struct with fields:
WindowdowLength: 512
Window: [512×1 double]
OverlapLength: 384
FFTLength: 512
NumSegments: 256
NumFeatures: 256
fs: 16000
samplesPerImage: 33152
shiftImage: 16576
Чтобы ускорить обработку, распределите задачу предварительной обработки и извлечения компонентов между несколькими работниками с помощью parfor.
Определите количество разделов для набора данных. Если у вас нет Toolbox™ Parallel Computing, используйте один раздел.
if ~isempty(ver('parallel')) pool = gcp; numPar = numpartitions(adsCombinedTrain,pool); else numPar = 1; end
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
Для каждого раздела считывайте из хранилища данных, предварительно обрабатывайте аудиосигнал и извлекайте функции.
if reduceDataSet adsCombinedTrain = shuffle(adsCombinedTrain); %#ok adsCombinedTrain = subset(adsCombinedTrain,1:200); adsSyntheticCombinedTrain = shuffle(adsSyntheticCombinedTrain); adsSyntheticCombinedTrain = subset(adsSyntheticCombinedTrain,1:200); end allCleanFeatures = cell(1,numPar); allReverbFeatures = cell(1,numPar); parfor iPartition = 1:numPar combinedPartition = partition(adsCombinedTrain,numPar,iPartition); combinedSyntheticPartition = partition(adsSyntheticCombinedTrain,numPar,iPartition); cPartitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); cSyntheticPartitionSize = numel(combinedSyntheticPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); partitionSize = cPartitionSize + cSyntheticPartitionSize; cleanFeaturesPartition = cell(1,partitionSize); reverbFeaturesPartition = cell(1,partitionSize); for idx = 1:partitionSize if idx <= cPartitionSize audios = read(combinedPartition); else audios = read(combinedSyntheticPartition); end cleanAudio = single(audios(:,1)); reverbAudio = single(audios(:,2)); [featuresClean,featuresReverb] = helperFeatureExtract(cleanAudio,reverbAudio,false,params); cleanFeaturesPartition{idx} = featuresClean; reverbFeaturesPartition{idx} = featuresReverb; end allCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:}); allReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:}); end allCleanFeatures = cat(2,allCleanFeatures{:}); allReverbFeatures = cat(2,allReverbFeatures{:});
Нормализуйте извлеченные элементы в диапазоне [-1,1], а затем измените форму, как описано в первом разделе, с помощью функции FeatureNormalityAndReshape.
trainClean = featureNormalizeAndReshape(allCleanFeatures); trainReverb = featureNormalizeAndReshape(allReverbFeatures);
Теперь, когда вы извлекли элементы STFT логарифмической величины из учебных наборов данных, выполните ту же процедуру, чтобы извлечь элементы из наборов данных проверки. Для целей реконструкции сохраните фазу реверберативных речевых выборок набора данных проверки. Кроме того, сохранить аудиоданные как для чистых, так и для реверберных речевых выборок в наборе проверки достоверности, который будет использоваться в процессе оценки (следующий раздел).
adsCleanVal = audioDatastore(fullfile(cleanDataFolder,'clean_testset_wav'),'IncludeSubfolders',true); adsReverbVal = audioDatastore(fullfile(reverbDataFolder,'reverb_testset_wav'),'IncludeSubfolders',true);
Повторная выборка от 48 кГц до 16 кГц.
adsCleanVal = transform(adsCleanVal,@(x)resample(x,16e3,48e3)); adsReverbVal = transform(adsReverbVal,@(x)resample(x,16e3,48e3)); adsCombinedVal = combine(adsCleanVal,adsReverbVal);
if reduceDataSet adsCombinedVal = shuffle(adsCombinedVal);%#ok adsCombinedVal = subset(adsCombinedVal,1:50); end allValCleanFeatures = cell(1,numPar); allValReverbFeatures = cell(1,numPar); allValReverbPhase = cell(1,numPar); allValCleanAudios = cell(1,numPar); allValReverbAudios = cell(1,numPar); parfor iPartition = 1:numPar combinedPartition = partition(adsCombinedVal,numPar,iPartition); partitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); cleanFeaturesPartition = cell(1,partitionSize); reverbFeaturesPartition = cell(1,partitionSize); reverbPhasePartition = cell(1,partitionSize); cleanAudiosPartition = cell(1,partitionSize); reverbAudiosPartition = cell(1,partitionSize); for idx = 1:partitionSize audios = read(combinedPartition); cleanAudio = single(audios(:,1)); reverbAudio = single(audios(:,2)); [a,b,c,d,e] = helperFeatureExtract(cleanAudio,reverbAudio,true,params); cleanFeaturesPartition{idx} = a; reverbFeaturesPartition{idx} = b; reverbPhasePartition{idx} = c; cleanAudiosPartition{idx} = d; reverbAudiosPartition{idx} = e; end allValCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:}); allValReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:}); allValReverbPhase{iPartition} = cat(2,reverbPhasePartition{:}); allValCleanAudios{iPartition} = cat(2,cleanAudiosPartition{:}); allValReverbAudios{iPartition} = cat(2,reverbAudiosPartition{:}); end allValCleanFeatures = cat(2,allValCleanFeatures{:}); allValReverbFeatures = cat(2,allValReverbFeatures{:}); allValReverbPhase = cat(2,allValReverbPhase{:}); allValCleanAudios = cat(2,allValCleanAudios{:}); allValReverbAudios = cat(2,allValReverbAudios{:}); valClean = featureNormalizeAndReshape(allValCleanFeatures);
Сохраните минимальное и максимальное значения каждого элемента набора проверки ревербератора. Эти значения будут использоваться в процессе реконструкции.
[valReverb,valMinMaxPairs] = featureNormalizeAndReshape(allValReverbFeatures);
Полностью сверточная сетевая архитектура, названная U-Net, была адаптирована для этой задачи дееверберации речи, как предложено в [1]. «U-Net» - это сеть кодера-декодера с пропущенными соединениями. В модели U-Net каждый уровень понижает свой входной сигнал (шаг 2) до тех пор, пока не будет достигнут узкий уровень (путь кодирования). В последующих слоях входной сигнал усиливается каждым уровнем до тех пор, пока выходной сигнал не вернется в исходную форму (путь декодирования). Чтобы минимизировать потерю низкоуровневой информации во время процесса понижающей дискретизации, соединения между зеркально отраженными слоями создаются посредством прямого конкатенации выходов соответствующих уровней (соединения пропуска).
Определите архитектуру сети и верните график уровней с подключениями.
params.WindowdowLength = 512;
params.FFTLength = params.WindowdowLength;
params.NumFeatures = params.FFTLength/2;
params.NumSegments = 256;
filterH = 6;
filterW = 6;
numChannels = 1;
nFilters = [64,128,256,512,512,512,512,512];
inputLayer = imageInputLayer([params.NumFeatures,params.NumSegments,numChannels], ...
'Normalization','none','Name','input');
layers = inputLayer;
% U-Net squeezing path
layers = [layers;
convolution2dLayer([filterH,filterW],nFilters(1),'Stride',2,'Padding','same','Name',"conv"+string(1));
leakyReluLayer(0.2,'Name',"leaky-relu"+string(1))];
for i = 2:8
layers = [layers;
convolution2dLayer([filterH,filterW],nFilters(i),'Stride',2,'Padding','same','Name',"conv"+string(i));
batchNormalizationLayer('Name',"batchnorm"+string(i))];%#ok
if i ~= 8
layers = [layers;leakyReluLayer(0.2,'Name',"leaky-relu"+string(i))];%#ok
else
layers = [layers;reluLayer('Name',"relu"+string(i))];%#ok
end
end
% U-Net expanding path
for i = 7:-1:0
nChannels = numChannels;
if i > 0
nChannels = nFilters(i);
end
layers = [layers;
transposedConv2dLayer([filterH,filterW],nChannels,'Stride',2,'Cropping','same','Name',"deconv"+string(i))];%#ok
if i > 0
layers = [layers; batchNormalizationLayer('Name',"de-batchnorm" +string(i))];%#ok
end
if i > 4
layers = [layers;dropoutLayer(0.5,'Name',"de-dropout"+string(i))];%#ok
end
if i > 0
layers = [layers;
reluLayer('Name',"de-relu"+string(i));
concatenationLayer(3,2,'Name',"concat"+string(i))];%#ok
else
layers = [layers;tanhLayer('Name',"de-tanh"+string(i))];%#ok
end
end
layers = [layers;regressionLayer('Name','output')];
unetLayerGraph = layerGraph(layers);
% Define skip-connections
for i = 1:7
unetLayerGraph = connectLayers(unetLayerGraph,'leaky-relu'+string(i),'concat'+string(i)+'/in2');
endИспользовать analyzeNetwork для просмотра архитектуры модели. Это хороший способ визуализации соединений между слоями.
analyzeNetwork(unetLayerGraph);
В качестве функции потерь будет использоваться среднеквадратичная ошибка (MSE) между спектрами логарифмических величин выборок речи (выходных данных модели) и соответствующей выборок чистой речи (цели). Используйте adam оптимизатор и размер мини-партии 128 для обучения. Разрешите модели тренироваться не более 50 эпох. Если потеря проверки не улучшается в течение 5 последовательных периодов, завершите процесс обучения. Снижайте уровень обучения в 10 раз каждые 15 эпох.
Определите следующие варианты обучения. Измените среду выполнения и укажите, следует ли выполнять фоновую диспетчеризацию в зависимости от доступности оборудования и наличия доступа к Parallel Computing Toolbox™.
initialLearnRate = 8e-4; miniBatchSize = 64; options = trainingOptions("adam", ... "MaxEpochs", 50, ... "InitialLearnRate",initialLearnRate, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Plots","training-progress", ... "Verbose",false, ... "ValidationFrequency",max(1,floor(size(trainReverb,4)/miniBatchSize)), ... "ValidationPatience",5, ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",15, ... "ExecutionEnvironment","gpu", ... "DispatchInBackground",true, ... "ValidationData",{valReverb,valClean});
Обучение сети.
dereverbNet = trainNetwork(trainReverb,trainClean,unetLayerGraph,options);

Спрогнозировать спектры логарифмической величины проверочного набора.
predictedSTFT4D = predict(dereverbNet,valReverb);
Используйте функцию helperReconstreadtAudios для восстановления предсказанной речи. Эта функция выполняет действия, описанные в первом разделе.
params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
params.fs = 16000;
dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,valMinMaxPairs,allValReverbPhase,allValReverbAudios,params);Визуализация логарифмических сигналов STFT чистого, реверберативного и соответствующих речевых сигналов с пониженным уровнем сигнала.
figure('Position',[100,100,1024,1200]) subplot(3,1,1) imagesc(squeeze(allValCleanFeatures{1})) set(gca,'Ydir','normal') subtitle('Clean') xlabel('Time') ylabel('Frequency') colorbar subplot(3,1,2) imagesc(squeeze(allValReverbFeatures{1})) set(gca,'Ydir','normal') subtitle('Reverberated') xlabel('Time') ylabel('Frequency') colorbar subplot(3,1,3) imagesc(squeeze(predictedSTFT4D(:,:,:,1))) set(gca,'Ydir','normal') subtitle('Predicted (Dereverberated)') xlabel('Time') ylabel('Frequency') caxis([-1,1]) colorbar

Для оценки производительности сети используется подмножество объективных показателей, используемых в [1]. Эти метрики вычисляются на сигналах временной области.
Cepstrum distance (CD) - предоставляет оценку логарифмического спектрального расстояния между двумя спектрами (прогнозируемым и чистым). Меньшие значения указывают на лучшее качество.
Логарифмическое отношение правдоподобия (LLR) - линейное прогнозирующее кодирование (LPC) на основе объективного измерения. Меньшие значения указывают на лучшее качество.
Вычисляют эти измерения для реверберативной речи и исказившихся речевых сигналов.
[summaryMeasuresReconstructed,allMeasuresReconstructed] = calculateObjectiveMeasures(dereverbedAudioAll,allValCleanAudios,params.fs); [summaryMeasuresReverb,allMeasuresReverb] = calculateObjectiveMeasures(allValReverbAudios,allValCleanAudios,params.fs); disp(summaryMeasuresReconstructed)
avgCdMean: 3.8386
avgCdMedian: 3.3671
avgLlrMean: 0.9152
avgLlrMedian: 0.8096
disp(summaryMeasuresReverb)
avgCdMean: 4.2591
avgCdMedian: 3.6336
avgLlrMean: 0.9726
avgLlrMedian: 0.8714
Гистограммы иллюстрируют распределение среднего CD, среднего SRMR и среднего LLR реверберативных и дееверберизованных данных.
figure('position',[50,50,1100,1300]) subplot(2,1,1) histogram(allMeasuresReverb.cdMean,10) hold on histogram(allMeasuresReconstructed.cdMean, 10) subtitle('Mean Cepstral Distance Distribution') ylabel('count') xlabel('mean CD') legend('Reverberant (Original)','Dereverberated (Predicted)') subplot(2,1,2) histogram(allMeasuresReverb.llrMean,10) hold on histogram(allMeasuresReconstructed.llrMean,10) subtitle('Mean Log Likelihood Ratio Distribution') ylabel('Count') xlabel('Mean LLR') legend('Reverberant (Original)','Dereverberated (Predicted)')

[1] Эрнст, О., Шазан, С. Э., Ганнот, С., и Голдбергер, Дж. (2018). Дееверберация речи с использованием полностью сверточных сетей. 26-я Европейская конференция по обработке сигналов (EUSIPCO), 390-394.
[2] https://datashare.is.ed.ac.uk/handle/10283/2031
[3] https://datashare.is.ed.ac.uk/handle/10283/2791
[4] https://github.com/MuSAELab/SRMRToolbox
function yOut = applyReverb(y,preDelay,decayFactor,wetDryMix,fs) % This function generates reverberant speech data using the reverberator % object % % inputs: % y - clean speech sample % preDelay, decayFactor, wetDryMix - reverberation parameters % fs - sampling rate of y % % outputs: % yOut - corresponding reveberated speech sample revObj = reverberator('SampleRate',fs, ... 'DecayFactor',decayFactor, ... 'WetDryMix',wetDryMix, ... 'PreDelay',preDelay); yOut = revObj(y); yOut = yOut(1:length(y),1); end
function [featuresClean,featuresReverb,phaseReverb,cleanAudios,reverbAudios] ... = helperFeatureExtract(cleanAudio,reverbAudio,isVal,params) % This function performs the preprocessing and features extraction task on % the audio files used for dereverberation model training and testing. % % inputs: % cleanAudio - the clean audio file (reference) % reverbAudio - corresponding reverberant speech file % isVal - Boolean flag indicating if it is the validation set % params - a structure containing feature extraction parameters % % outputs: % featuresClean - log-magnitude STFT features of clean audio % featuresReverb - log-magnitude STFT features of reverberant audio % phaseReverb - phase of STFT of reverberant audio % cleanAudios - 2.072s-segments of clean audio file used for feature extraction % reverbAudios - 2.072s-segments of corresponding reverberant audio assert(length(cleanAudio) == length(reverbAudio)); nSegments = floor((length(reverbAudio) - (params.samplesPerImage - params.shiftImage))/params.shiftImage); featuresClean = {}; featuresReverb = {}; phaseReverb = {}; cleanAudios = {}; reverbAudios = {}; nGood = 0; nonSilentRegions = detectSpeech(reverbAudio, params.fs); nonSilentRegionIdx = 1; totalRegions = size(nonSilentRegions, 1); for cid = 1:nSegments start = (cid - 1)*params.shiftImage + 1; en = start + params.samplesPerImage - 1; nonSilentSamples = 0; while nonSilentRegionIdx < totalRegions && nonSilentRegions(nonSilentRegionIdx, 2) < start nonSilentRegionIdx = nonSilentRegionIdx + 1; end nonSilentStart = nonSilentRegionIdx; while nonSilentStart <= totalRegions && nonSilentRegions(nonSilentStart, 1) <= en nonSilentDuration = min(en, nonSilentRegions(nonSilentStart,2)) - max(start,nonSilentRegions(nonSilentStart,1)) + 1; nonSilentSamples = nonSilentSamples + nonSilentDuration; nonSilentStart = nonSilentStart + 1; end nonSilentPerc = nonSilentSamples * 100 / (en - start + 1); silent = nonSilentPerc < 50; reverbAudioSegment = reverbAudio(start:en); if ~silent nGood = nGood + 1; cleanAudioSegment = cleanAudio(start:en); assert(length(cleanAudioSegment)==length(reverbAudioSegment), 'Lengths do not match after chunking') % Clean Audio [featsUnit, ~] = featureExtract(cleanAudioSegment, params); featuresClean{nGood} = featsUnit; %#ok % Reverb Audio [featsUnit, phaseUnit] = featureExtract(reverbAudioSegment, params); featuresReverb{nGood} = featsUnit; %#ok if isVal phaseReverb{nGood} = phaseUnit; %#ok reverbAudios{nGood} = reverbAudioSegment;%#ok cleanAudios{nGood} = cleanAudioSegment;%#ok end end end end
function [features, phase, lastFBin] = featureExtract(audio, params) % Function to extract features for a speech file audio = single(audio); audioSTFT = stft(audio,'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength', params.FFTLength, 'FrequencyRange', 'onesided'); phase = single(angle(audioSTFT(1:end-1,:))); features = single(log(abs(audioSTFT(1:end-1,:)) + 10e-30)); lastFBin = audioSTFT(end,:); end
function [featNorm,minMaxPairs] = featureNormalizeAndReshape(feats) % function to normalize features - range [-1, 1] and reshape to 4 % dimensions % % inputs: % feats - 3-dimensional array of extracted features % % outputs: % featNorm - normalized and reshaped features % minMaxPairs - array of original min and max pairs used for normalization nSamples = length(feats); minMaxPairs = zeros(nSamples,2,'single'); featNorm = zeros([size(feats{1}),nSamples],'single'); parfor i = 1:nSamples feat = feats{i}; maxFeat = max(feat,[],'all'); minFeat = min(feat,[],'all'); featNorm(:,:,i) = 2.*(feat - minFeat)./(maxFeat - minFeat) - 1; minMaxPairs(i,:) = [minFeat,maxFeat]; end featNorm = reshape(featNorm,size(featNorm,1),size(featNorm,2),1,size(featNorm,3)); end
function dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,minMaxPairs,reverbPhase,reverbAudios,params) % This function will reconstruct the 2.072s long audios predicted by the % model using the predicted log-magnitude spectrogram and the phase of the % reverberant audio file % % inputs: % predictedSTFT4D - Predicted 4-dimensional STFT log-magnitude features % minMaxPairs - Original minimum/maximum value pairs used in normalization % reverbPhase - Array of phases of STFT of reverberant audio files % reverbAudios - 2.072s-segments of corresponding reverberant audios % params - Structure containing feature extraction parameters predictedSTFT = squeeze(predictedSTFT4D); denormalizedFeatures = zeros(size(predictedSTFT),'single'); for i = 1:size(predictedSTFT,3) feat = predictedSTFT(:,:,i); maxFeat = minMaxPairs(i,2); minFeat = minMaxPairs(i,1); denormalizedFeatures(:,:,i) = (feat + 1).*(maxFeat-minFeat)./2 + minFeat; end predictedSTFT = exp(denormalizedFeatures); nCount = size(predictedSTFT,3); dereverbedAudioAll = cell(1,nCount); nSeg = params.NumSegments; win = params.Window; ovrlp = params.OverlapLength; FFTLength = params.FFTLength; parfor ii = 1:nCount % Append zeros to the highest frequency bin stftUnit = predictedSTFT(:,:,ii); stftUnit = cat(1,stftUnit, zeros(1,nSeg)); phase = reverbPhase{ii}; phase = cat(1,phase,zeros(1,nSeg)); oneSidedSTFT = stftUnit.*exp(1j*phase); dereverbedAudio= istft(oneSidedSTFT, ... 'Window', win,'OverlapLength', ovrlp, ... 'FFTLength',FFTLength,'ConjugateSymmetric',true,... 'FrequencyRange','onesided'); dereverbedAudioAll{ii} = dereverbedAudio./max(max(abs(dereverbedAudio)), max(abs(reverbAudios{ii}))); end end
function [summaryMeasures,allMeasures] = calculateObjectiveMeasures(reconstructedAudios,cleanAudios,fs) % This function computes the objective measures on time-domain signals. % % inputs: % reconstructedAudios - An array of audio files to evaluate. % cleanAudios - An array of reference audio files % fs - Sampling rate of audio files % % outputs: % summaryMeasures - Global means of CD, LLR individual mean and median values % allMeasures - Individual mean and median values nAudios = length(reconstructedAudios); cdMean = zeros(nAudios,1); cdMedian = zeros(nAudios,1); llrMean = zeros(nAudios,1); llrMedian = zeros(nAudios,1); parfor k = 1 : nAudios y = reconstructedAudios{k}; x = cleanAudios{k}; y = y./max(abs(y)); x = x./max(abs(x)); [cdMean(k),cdMedian(k)] = cepstralDistance(x,y,fs); [llrMean(k),llrMedian(k)] = lpcLogLikelihoodRatio(y,x,fs); end summaryMeasures.avgCdMean = mean(cdMean); summaryMeasures.avgCdMedian = mean(cdMedian); summaryMeasures.avgLlrMean = mean(llrMean); summaryMeasures.avgLlrMedian = mean(llrMedian); allMeasures.cdMean = cdMean; allMeasures.llrMean = llrMean; end
function [meanVal, medianVal] = cepstralDistance(x,y,fs) x = x / sqrt(sum(x.^2)); y = y / sqrt(sum(y.^2)); width = round(0.025*fs); shift = round(0.01*fs); nSamples = length(x); nFrames = floor((nSamples - width + shift)/shift); win = window(@hanning,width); winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1); xFrames = x(winIndex).*win; yFrames = y(winIndex).*win; xCeps = cepstralReal(xFrames,width); yCeps = cepstralReal(yFrames,width); dist = (xCeps - yCeps).^2; cepsD = 10 / log(10)*sqrt(2*sum(dist(2:end,:),1) + dist(1,:)); cepsD = max(min(cepsD,10),0); meanVal = mean(cepsD); medianVal = median(cepsD); end
function realC = cepstralReal(x, width) width2p = 2 ^ nextpow2(width); powX = abs(fft(x, width2p)); lowCutoff = max(powX(:)) * 10^-5; powX = max(powX, lowCutoff); realC = real(ifft(log(powX))); order = 24; realC = realC(1 : order + 1, :); realC = realC - mean(realC, 2); end
function [meanLlr,medianLlr] = lpcLogLikelihoodRatio(x,y,fs) order = 12; width = round(0.025*fs); shift = round(0.01*fs); nSamples = length(x); nFrames = floor((nSamples - width + shift)/shift); win = window(@hanning,width); winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1); xFrames = x(winIndex) .* win; yFrames = y(winIndex) .* win; lpcX = realLpc(xFrames, width, order); [lpcY,realY] = realLpc(yFrames, width, order); llr = zeros(nFrames, 1); for n = 1 : nFrames R = toeplitz(realY(1:order+1,n)); num = lpcX(:,n)'*R*lpcX(:,n); den = lpcY(:,n)'*R*lpcY(:,n); llr(n) = log(num/den); end llr = sort(llr); llr = llr(1:ceil(nFrames*0.95)); llr = max(min(llr,2),0); meanLlr = mean(llr); medianLlr = median(llr); end
function [lpcCoeffs, realX] = realLpc(xFrames, width, order) width2p = 2 ^ nextpow2(width); X = fft(xFrames, width2p); Rx = ifft(abs(X).^2); Rx = Rx./width; realX = real(Rx); lpcX = levinson(realX, order); lpcCoeffs = real(lpcX'); end