Этот пример показывает, как обучить U-Net полностью сверточную сеть (FCN) [1] для удаления речевых сигналов.
Реверберация происходит, когда речевой сигнал отражается от объектов в пространстве, заставляя множественные отражения накапливаться и в конечном счете приводит к ухудшению качества речи. Дереверберация является процессом уменьшения эффектов реверберации в речевом сигнале.
Перед углублением в процесс обучения используйте предварительно обученную сеть для удаления речевого сигнала.
Загрузите предварительно обученную сеть. Эта сеть была обучена на 56-динамических версиях обучающих наборов данных. Пример проходит обучение на 28-динамической версии.
url = 'https://ssd.mathworks.com/supportfiles/audio/dereverbnet.zip'; downloadFolder = tempdir; networkDataFolder = fullfile(downloadFolder,'derevernet'); if ~exist(networkDataFolder,'dir') disp('Downloading pretrained network ...') unzip(url,downloadFolder) end load(fullfile(networkDataFolder,'dereverbNet.mat'))
Прослушайте чистый речевой сигнал, дискретизированный с частотой дискретизации 16 кГц.
[cleanAudio,fs] = audioread('clean_speech_signal.wav');
sound(cleanAudio,fs)
Акустический путь может быть смоделирован с помощью комнатной импульсной характеристики. Можно смоделировать реверберацию путем свертки анехойного сигнала с комнатной импульсной характеристикой.
Загрузка и построение графика импульсной характеристики помещения.
[rirAudio,fsR] = audioread('room_impulse_response.wav'); tAxis = (1/fsR)*(0:numel(rirAudio)-1); figure plot(tAxis,rirAudio) xlabel('Time (s)') ylabel('Amplitude') grid on
Сверните чистую речь с комнатной импульсной характеристикой, чтобы получить ревербрированную речь. Выравнивание длин и амплитуд ревербированных и чистых речевых сигналов.
revAudio = conv(cleanAudio,rirAudio); revAudio = revAudio(1:numel(cleanAudio)); revAudio = revAudio.*(max(abs(cleanAudio))/max(abs(revAudio)));
Слушайте речевой сигнал.
sound(revAudio,fs)
Вход предварительно обученной сети является логарифмическим коротковременным преобразованием Фурье (STFT) реверберативного звука. Сеть предсказывает логарифмическую величину STFT деревертированного входа. Чтобы оценить исходный аудиосигнал временной области, вы выполняете обратный STFT и принимаете фазу реверберативного аудиосигнала.
Используйте следующие параметры для вычисления STFT.
params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
Использование stft
для вычисления односторонней логарифмической величины STFT. Используйте одинарную точность при вычислении функций, чтобы лучше использовать память и ускорить обучение. Несмотря на то, что односторонний STFT дает 257 интервалов частоты, учитывайте только 256 интервалов и игнорируйте самый высокий интервал частоты.
revAudio = single(revAudio); audioSTFT = stft(revAudio,'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength',params.FFTLength,'FrequencyRange','onesided'); Eps = realmin('single'); reverbFeats = log(abs(audioSTFT(1:end-1,:)) + Eps);
Извлеките фазу STFT.
phaseOriginal = angle(audioSTFT(1:end-1,:));
Каждый вход будет иметь размерности 256 на 256 (частотные интервалы по временным шагам). Разделите логарифмическую величину STFT на сегменты с 256 временными шагами.
params.NumSegments = 256;
params.NumFeatures = 256;
totalFrames = size(reverbFeats,2);
chunks = ceil(totalFrames/params.NumSegments);
reverbSTFTSegments = mat2cell(reverbFeats,params.NumFeatures, ...
[params.NumSegments*ones(1,chunks - 1),(totalFrames - (chunks-1)*params.NumSegments)]);
reverbSTFTSegments{chunks} = reverbFeats(:,end-params.NumSegments + 1:end);
Масштабируйте сегментированные функции в области значений [-1,1]. Сохраните минимальное и максимальное значения, используемые для масштабирования для восстановления деревертированного сигнала.
minVals = num2cell(cellfun(@(x)min(x,[],'all'),reverbSTFTSegments)); maxVals = num2cell(cellfun(@(x)max(x,[],'all'),reverbSTFTSegments)); featNorm = cellfun(@(feat,minFeat,maxFeat)2.*(feat - minFeat)./(maxFeat - minFeat) - 1, ... reverbSTFTSegments,minVals,maxVals,'UniformOutput',false);
Измените форму функций так, чтобы фрагменты находились вдоль четвертой размерности.
featNorm = reshape(cell2mat(featNorm),params.NumFeatures,params.NumSegments,1,chunks);
Спрогнозируйте спектры сигнала с помощью предварительно обученной сети.
predictedSTFT4D = predict(dereverbNet,featNorm);
Измените форму на 3 измерения и масштабируйте предсказанные STFT до исходной области значений с помощью сохраненных пар минимум-максимум.
predictedSTFT = squeeze(mat2cell(predictedSTFT4D,params.NumFeatures,params.NumSegments,1,ones(1,chunks)))'; featDeNorm = cellfun(@(feat,minFeat,maxFeat) (feat + 1).*(maxFeat-minFeat)./2 + minFeat, ... predictedSTFT,minVals,maxVals,'UniformOutput',false);
Сторнируйте логарифмическое масштабирование.
predictedSTFT = cellfun(@exp,featDeNorm,'UniformOutput',false);
Конкатенация предсказанные сегменты STFT 256 величин 256, чтобы получить величину спектрограмму исходной длины.
predictedSTFTAll = predictedSTFT(1:chunks - 1); predictedSTFTAll = cat(2,predictedSTFTAll{:}); predictedSTFTAll(:,totalFrames - params.NumSegments + 1:totalFrames) = predictedSTFT{chunks};
Перед взятием обратного STFT добавьте нули к предсказанному логарифмическому спектру и фазе вместо самой высокой частоты интервала которая была исключена при подготовке входа признаков.
nCount = size(predictedSTFTAll,3); predictedSTFTAll = cat(1,predictedSTFTAll,zeros(1,totalFrames,nCount)); phase = cat(1,phaseOriginal,zeros(1,totalFrames,nCount));
Используйте обратную функцию STFT для восстановления деревертированного речевого сигнала временной области, используя предсказанную логарифмическую величину STFT и фазу реверантного речевого сигнала.
oneSidedSTFT = predictedSTFTAll.*exp(1j*phase); dereverbedAudio = istft(oneSidedSTFT, ... 'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength',params.FFTLength,'ConjugateSymmetric',true, ... 'FrequencyRange','onesided'); dereverbedAudio = dereverbedAudio./max(abs([dereverbedAudio;revAudio])); dereverbedAudio = [dereverbedAudio;zeros(length(revAudio) - numel(dereverbedAudio), 1)];
Прослушайте воспроизведенный аудиосигнал.
sound(dereverbedAudio,fs)
Постройте график чистых, реверберативных и дериверированных речевых сигналов.
t = (1/fs)*(0:numel(cleanAudio)-1); figure subplot(3,1,1) plot(t,cleanAudio) xlabel('Time (s)') grid on subtitle('Clean Speech Signal') subplot(3,1,2) plot(t,revAudio) xlabel('Time (s)') grid on subtitle('Revereberated Speech Signal') subplot(3,1,3) plot(t,dereverbedAudio) xlabel('Time (s)') grid on subtitle('Derevereberated Speech Signal')
Визуализируйте спектрограммы чистых, реверберативных и деревертированных речевых сигналов.
figure('Position',[100,100,800,800]) subplot(3,1,1) spectrogram(cleanAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Clean') subplot(3,1,2) spectrogram(revAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Reverberated') subplot(3,1,3) spectrogram(dereverbedAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Predicted (Dereverberated)')
Этот пример использует речевую базу данных Reverberant [2] и соответствующую базу данных чистой речи [3], чтобы обучить сеть.
Загрузите чистый набор речевых данных.
url1 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_trainset_28spk_wav.zip'; url2 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_testset_wav.zip'; downloadFolder = tempdir; cleanDataFolder = fullfile(downloadFolder,'DS_10283_2791'); if ~exist(cleanDataFolder,'dir') disp('Downloading data set (6 GB) ...') unzip(url1,cleanDataFolder) unzip(url2,cleanDataFolder) end
Загрузите ревербированный набор данных.
url3 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_trainset_28spk_wav.zip'; url4 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_testset_wav.zip'; downloadFolder = tempdir; reverbDataFolder = fullfile(downloadFolder,'DS_10283_2031'); if ~exist(reverbDataFolder,'dir') disp('Downloading data set (6 GB) ...') unzip(url3,reverbDataFolder) unzip(url4,reverbDataFolder) end
После загрузки данных предварительно обработайте загруженные данные и извлеките функции перед обучением модели DNN:
Синтетически сгенерируйте ревербрантные данные с помощью reverberator
объект
Разделите каждый речевой сигнал на небольшие сегменты длительностью 2.072 с
Сбросьте сегменты, которые содержат значительные тихие области
Извлеките STFT с логарифмической амплитудой как предиктор и целевые функции
Масштабирование и изменение формы функций
Во-первых, создайте две audioDatastore
объекты, которые указывают на чистые и реверберативные наборы данных.
adsCleanTrain = audioDatastore(fullfile(cleanDataFolder,'clean_trainset_28spk_wav'),'IncludeSubfolders',true); adsReverbTrain = audioDatastore(fullfile(reverbDataFolder,'reverb_trainset_28spk_wav'),'IncludeSubfolders',true);
Величина реверберации в исходных данных относительно небольшая. Вы увеличите речевые данные со значительными эффектами реверберации, используя reverberator
объект.
Создайте audioDatastore
который указывает на чистый набор данных, выделенный для синтетической генерации реверберантных данных.
adsSyntheticCleanTrain = subset(adsCleanTrain,10e3+1:length(adsCleanTrain.Files)); adsCleanTrain = subset(adsCleanTrain,1:10e3); adsReverbTrain = subset(adsReverbTrain,1:10e3);
Дискретизация от 48 кГц до 16 кГц.
adsSyntheticCleanTrain = transform(adsSyntheticCleanTrain,@(x)resample(x,16e3,48e3)); adsCleanTrain = transform(adsCleanTrain,@(x)resample(x,16e3,48e3)); adsReverbTrain = transform(adsReverbTrain,@(x)resample(x,16e3,48e3));
Объедините два хранилища аудиоданных, поддерживая соответствие между чистой и ревербрантной выборками речи.
adsCombinedTrain = combine(adsCleanTrain,adsReverbTrain);
Функция applyReverb создает reverberator
объект, обновляет параметры предварительной задержки, коэффициента затухания и влажно-сухого смешивания, как задано, и затем применяет реверберацию. Использование audioDataAugmenter
создание синтетически сгенерированных реверберативных данных.
augmenter = audioDataAugmenter('AugmentationMode','independent','NumAugmentations', 1,'ApplyAddNoise',0, ... 'ApplyTimeStretch',0,'ApplyPitchShift',0,'ApplyVolumeControl',0,'ApplyTimeShift',0); algorithmHandle = @(y,preDelay,decayFactor,wetDryMix,samplingRate) ... applyReverb(y,preDelay,decayFactor,wetDryMix,samplingRate); addAugmentationMethod(augmenter,'Reverb',algorithmHandle, ... 'AugmentationParameter',{'PreDelay','DecayFactor','WetDryMix','SamplingRate'}, ... 'ParameterRange',{[0.15,0.25],[0.2,0.5],[0.3,0.45],[16000,16000]}) augmenter.ReverbProbability = 1; disp(augmenter)
audioDataAugmenter with properties: AugmentationMode: 'independent' AugmentationParameterSource: 'random' NumAugmentations: 1 ApplyTimeStretch: 0 ApplyPitchShift: 0 ApplyVolumeControl: 0 ApplyAddNoise: 0 ApplyTimeShift: 0 ApplyReverb: 1 PreDelayRange: [0.1500 0.2500] DecayFactorRange: [0.2000 0.5000] WetDryMixRange: [0.3000 0.4500] SamplingRateRange: [16000 16000]
Создайте новую audioDatastore
соответствует синтетически сгенерированным ревербрантным данным по вызову transform
для применения увеличения данных.
adsSyntheticReverbTrain = transform(adsSyntheticCleanTrain,@(y)deal(augment(augmenter,y,16e3).Audio{1}));
Объедините два хранилища аудиоданных.
adsSyntheticCombinedTrain = combine(adsSyntheticCleanTrain,adsSyntheticReverbTrain);
Далее, основываясь на размерностях входных функций в сеть, сегментируйте аудио в фрагменты длительностью 2.072 с с перекрытием 50%.
Иметь слишком много молчаливых сегментов может негативно повлиять на обучение модели DNN. Удалите сегменты, которые в основном являются молчаливыми (более 50% от длительности) и исключить из обучения модели. Не удаляйте полностью тишину, потому что модель не будет устойчива к тихим областям, и небольшие эффекты реверберации могут быть определены как тишина. detectSpeech
может идентифицировать начальную и конечную точки тихих областей. После этих двух шагов процесс редукции данных может быть выполнен, как объяснено в первом разделе. helperFeatureExtract реализует эти шаги.
Задайте параметры редукции данных. Путем настройки reduceDataSet
если true, вы выбираете небольшой подмножество наборов данных, чтобы выполнить последующие шаги.
reduceDataSet = true; params.fs = 16000; парамы. WindowdowLength = 512; params.Window = hamming (params.WindowDowLength,"periodic"); params.OverlapLength = round (0,75 * params.WindowDowLength); params.FFTLength = params. WindowDowLength; samplesPerMs = params.fs/1000; params.samplesPerImage = (24 + 256 * 8) * samplesPerMs; params.shiftImage = params.samplesPerImage/2; парамы. NumSegments = 256; парамы. NumFeatures = 256
params = struct with fields:
WindowdowLength: 512
Window: [512×1 double]
OverlapLength: 384
FFTLength: 512
NumSegments: 256
NumFeatures: 256
fs: 16000
samplesPerImage: 33152
shiftImage: 16576
Чтобы ускорить обработку, распределите задачу предварительной обработки и редукции данных между несколькими работниками с помощью parfor
.
Определите количество разделов для набора данных. Если у вас нет Parallel Computing Toolbox™, используйте один раздел.
if ~isempty(ver('parallel')) pool = gcp; numPar = numpartitions(adsCombinedTrain,pool); else numPar = 1; end
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
Для каждого раздела считывайте из datastore, предварительно обрабатывайте аудиосигнал, а затем извлекайте функции.
if reduceDataSet adsCombinedTrain = shuffle(adsCombinedTrain); %#ok adsCombinedTrain = subset(adsCombinedTrain,1:200); adsSyntheticCombinedTrain = shuffle(adsSyntheticCombinedTrain); adsSyntheticCombinedTrain = subset(adsSyntheticCombinedTrain,1:200); end allCleanFeatures = cell(1,numPar); allReverbFeatures = cell(1,numPar); parfor iPartition = 1:numPar combinedPartition = partition(adsCombinedTrain,numPar,iPartition); combinedSyntheticPartition = partition(adsSyntheticCombinedTrain,numPar,iPartition); cPartitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); cSyntheticPartitionSize = numel(combinedSyntheticPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); partitionSize = cPartitionSize + cSyntheticPartitionSize; cleanFeaturesPartition = cell(1,partitionSize); reverbFeaturesPartition = cell(1,partitionSize); for idx = 1:partitionSize if idx <= cPartitionSize audios = read(combinedPartition); else audios = read(combinedSyntheticPartition); end cleanAudio = single(audios(:,1)); reverbAudio = single(audios(:,2)); [featuresClean,featuresReverb] = helperFeatureExtract(cleanAudio,reverbAudio,false,params); cleanFeaturesPartition{idx} = featuresClean; reverbFeaturesPartition{idx} = featuresReverb; end allCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:}); allReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:}); end allCleanFeatures = cat(2,allCleanFeatures{:}); allReverbFeatures = cat(2,allReverbFeatures{:});
Нормализуйте извлеченные признаки в область значений [-1,1], а затем измените форму, как объяснено в первом разделе, с помощью функции featureNormalizeAndReshape.
trainClean = featureNormalizeAndReshape(allCleanFeatures); trainReverb = featureNormalizeAndReshape(allReverbFeatures);
Теперь, когда вы извлекли функции STFT логарифмической величины из обучающих наборов данных, выполните ту же процедуру, чтобы извлечь функции из наборов данных валидации. В целях реконструкции сохраните фазу реверсивных выборок речи набора данных валидации. В сложение сохраните аудио данных как для чистых, так и для ревербрантных речевых выборок в наборе валидации, чтобы использовать в процессе оценки (следующий раздел).
adsCleanVal = audioDatastore(fullfile(cleanDataFolder,'clean_testset_wav'),'IncludeSubfolders',true); adsReverbVal = audioDatastore(fullfile(reverbDataFolder,'reverb_testset_wav'),'IncludeSubfolders',true);
Дискретизация от 48 кГц до 16 кГц.
adsCleanVal = transform(adsCleanVal,@(x)resample(x,16e3,48e3)); adsReverbVal = transform(adsReverbVal,@(x)resample(x,16e3,48e3)); adsCombinedVal = combine(adsCleanVal,adsReverbVal);
if reduceDataSet adsCombinedVal = shuffle(adsCombinedVal);%#ok adsCombinedVal = subset(adsCombinedVal,1:50); end allValCleanFeatures = cell(1,numPar); allValReverbFeatures = cell(1,numPar); allValReverbPhase = cell(1,numPar); allValCleanAudios = cell(1,numPar); allValReverbAudios = cell(1,numPar); parfor iPartition = 1:numPar combinedPartition = partition(adsCombinedVal,numPar,iPartition); partitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); cleanFeaturesPartition = cell(1,partitionSize); reverbFeaturesPartition = cell(1,partitionSize); reverbPhasePartition = cell(1,partitionSize); cleanAudiosPartition = cell(1,partitionSize); reverbAudiosPartition = cell(1,partitionSize); for idx = 1:partitionSize audios = read(combinedPartition); cleanAudio = single(audios(:,1)); reverbAudio = single(audios(:,2)); [a,b,c,d,e] = helperFeatureExtract(cleanAudio,reverbAudio,true,params); cleanFeaturesPartition{idx} = a; reverbFeaturesPartition{idx} = b; reverbPhasePartition{idx} = c; cleanAudiosPartition{idx} = d; reverbAudiosPartition{idx} = e; end allValCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:}); allValReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:}); allValReverbPhase{iPartition} = cat(2,reverbPhasePartition{:}); allValCleanAudios{iPartition} = cat(2,cleanAudiosPartition{:}); allValReverbAudios{iPartition} = cat(2,reverbAudiosPartition{:}); end allValCleanFeatures = cat(2,allValCleanFeatures{:}); allValReverbFeatures = cat(2,allValReverbFeatures{:}); allValReverbPhase = cat(2,allValReverbPhase{:}); allValCleanAudios = cat(2,allValCleanAudios{:}); allValReverbAudios = cat(2,allValReverbAudios{:}); valClean = featureNormalizeAndReshape(allValCleanFeatures);
Сохраните минимальные и максимальные значения каждой функции набора реверберантной валидации. Эти значения будут использоваться в процессе реконструкции.
[valReverb,valMinMaxPairs] = featureNormalizeAndReshape(allValReverbFeatures);
Полностью сверточная сетевая архитектура U-Net была адаптирована для этой задачи речевой дереверберации, как предложено в [1]. «U-Net» является сетью энкодер-декодер с пропущенными соединениями. В модели U-Net каждый слой понижает свой вход (шаг 2), пока не будет достигнут слой узкого места (путь кодирования). В последующих слоях вход усиливается каждым слоем до тех пор, пока выход не вернется к исходной форме (пути декодирования). Чтобы минимизировать потерю низкоуровневой информации в процессе понижающей дискретизации, соединения выполняются между зеркальными слоями путем непосредственной конкатенации выходов соответствующих слоев (пропуск соединений).
Определите сетевую архитектуру и верните график слоев с соединениями.
params.WindowdowLength = 512; params.FFTLength = params.WindowdowLength; params.NumFeatures = params.FFTLength/2; params.NumSegments = 256; filterH = 6; filterW = 6; numChannels = 1; nFilters = [64,128,256,512,512,512,512,512]; inputLayer = imageInputLayer([params.NumFeatures,params.NumSegments,numChannels], ... 'Normalization','none','Name','input'); layers = inputLayer; % U-Net squeezing path layers = [layers; convolution2dLayer([filterH,filterW],nFilters(1),'Stride',2,'Padding','same','Name',"conv"+string(1)); leakyReluLayer(0.2,'Name',"leaky-relu"+string(1))]; for i = 2:8 layers = [layers; convolution2dLayer([filterH,filterW],nFilters(i),'Stride',2,'Padding','same','Name',"conv"+string(i)); batchNormalizationLayer('Name',"batchnorm"+string(i))];%#ok if i ~= 8 layers = [layers;leakyReluLayer(0.2,'Name',"leaky-relu"+string(i))];%#ok else layers = [layers;reluLayer('Name',"relu"+string(i))];%#ok end end % U-Net expanding path for i = 7:-1:0 nChannels = numChannels; if i > 0 nChannels = nFilters(i); end layers = [layers; transposedConv2dLayer([filterH,filterW],nChannels,'Stride',2,'Cropping','same','Name',"deconv"+string(i))];%#ok if i > 0 layers = [layers; batchNormalizationLayer('Name',"de-batchnorm" +string(i))];%#ok end if i > 4 layers = [layers;dropoutLayer(0.5,'Name',"de-dropout"+string(i))];%#ok end if i > 0 layers = [layers; reluLayer('Name',"de-relu"+string(i)); concatenationLayer(3,2,'Name',"concat"+string(i))];%#ok else layers = [layers;tanhLayer('Name',"de-tanh"+string(i))];%#ok end end layers = [layers;regressionLayer('Name','output')]; unetLayerGraph = layerGraph(layers); % Define skip-connections for i = 1:7 unetLayerGraph = connectLayers(unetLayerGraph,'leaky-relu'+string(i),'concat'+string(i)+'/in2'); end
Использование analyzeNetwork
для просмотра архитектуры модели. Это хороший способ визуализировать связи между слоями.
analyzeNetwork(unetLayerGraph);
Вы будете использовать среднюю квадратичную невязку (MSE) между спектрами логарифмической амплитуды деревертированной речевой выборки (выход модели) и соответствующей чистой речевой выборкой (цель) в качестве функции потерь. Используйте adam
оптимизатор и мини-пакет размером 128 для обучения. Разрешить модели обучаться в течение максимум 50 эпох. Если потеря валидации не улучшается в течение 5 последовательных эпох, завершите процесс обучения. Уменьшите скорость обучения в 10 раз каждые 15 эпох.
Определите опции обучения следующим образом. Измените окружение выполнения и возможность выполнения фоновой диспетчеризации в зависимости от доступности оборудования и наличия доступа к Parallel Computing Toolbox™.
initialLearnRate = 8e-4; miniBatchSize = 64; options = trainingOptions("adam", ... "MaxEpochs", 50, ... "InitialLearnRate",initialLearnRate, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Plots","training-progress", ... "Verbose",false, ... "ValidationFrequency",max(1,floor(size(trainReverb,4)/miniBatchSize)), ... "ValidationPatience",5, ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",15, ... "ExecutionEnvironment","gpu", ... "DispatchInBackground",true, ... "ValidationData",{valReverb,valClean});
Обучите сеть.
dereverbNet = trainNetwork(trainReverb,trainClean,unetLayerGraph,options);
Спрогнозируйте логарифмические спектры набора валидации.
predictedSTFT4D = predict(dereverbNet,valReverb);
Используйте функцию helperReconstructPredictedAudios, чтобы восстановить предсказанную речь. Эта функция выполняет действия, описанные в первом разделе.
params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
params.fs = 16000;
dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,valMinMaxPairs,allValReverbPhase,allValReverbAudios,params);
Визуализируйте логарифмическую величину STFT чистых, реверберативных и соответствующих деревертированных речевых сигналов.
figure('Position',[100,100,1024,1200]) subplot(3,1,1) imagesc(squeeze(allValCleanFeatures{1})) set(gca,'Ydir','normal') subtitle('Clean') xlabel('Time') ylabel('Frequency') colorbar subplot(3,1,2) imagesc(squeeze(allValReverbFeatures{1})) set(gca,'Ydir','normal') subtitle('Reverberated') xlabel('Time') ylabel('Frequency') colorbar subplot(3,1,3) imagesc(squeeze(predictedSTFT4D(:,:,:,1))) set(gca,'Ydir','normal') subtitle('Predicted (Dereverberated)') xlabel('Time') ylabel('Frequency') caxis([-1,1]) colorbar
Для оценки эффективности сети используется подмножество объективных измерений, используемых в [1]. Эти метрики вычисляются на сигналах временной области.
Cepstrum distance (CD) - Обеспечивает оценку логарифмического спектрального расстояния между двумя спектрами (предсказанным и чистым). Меньшие значения указывают на лучшее качество.
Логарифмический коэффициент правдоподобия (LLR) - линейное прогнозирующее кодирование (LPC) на основе целевого измерения. Меньшие значения указывают на лучшее качество.
Вычислите эти измерения для реверантной речи и деревертированных речевых сигналов.
[summaryMeasuresReconstructed,allMeasuresReconstructed] = calculateObjectiveMeasures(dereverbedAudioAll,allValCleanAudios,params.fs); [summaryMeasuresReverb,allMeasuresReverb] = calculateObjectiveMeasures(allValReverbAudios,allValCleanAudios,params.fs); disp(summaryMeasuresReconstructed)
avgCdMean: 3.8386 avgCdMedian: 3.3671 avgLlrMean: 0.9152 avgLlrMedian: 0.8096
disp(summaryMeasuresReverb)
avgCdMean: 4.2591 avgCdMedian: 3.6336 avgLlrMean: 0.9726 avgLlrMedian: 0.8714
Гистограммы иллюстрируют распределение среднего CD, среднего SRMR и среднего LLR реверберативных и деревертированных данных.
figure('position',[50,50,1100,1300]) subplot(2,1,1) histogram(allMeasuresReverb.cdMean,10) hold on histogram(allMeasuresReconstructed.cdMean, 10) subtitle('Mean Cepstral Distance Distribution') ylabel('count') xlabel('mean CD') legend('Reverberant (Original)','Dereverberated (Predicted)') subplot(2,1,2) histogram(allMeasuresReverb.llrMean,10) hold on histogram(allMeasuresReconstructed.llrMean,10) subtitle('Mean Log Likelihood Ratio Distribution') ylabel('Count') xlabel('Mean LLR') legend('Reverberant (Original)','Dereverberated (Predicted)')
[1] Ernst, O., Chazan, S.E., Gannot, S., & Goldberger, J. (2018). Речевая дереверберация с использованием полностью сверточных сетей. 2018 26-я Европейская конференция по обработке сигналов (EUSIPCO), 390-394.
[2] https://datashare.is.ed.ac.uk/handle/10283/2031
[3] https://datashare.is.ed.ac.uk/handle/10283/2791
[4] https://github.com/MuSAELab/SRMRToolbox
function yOut = applyReverb(y,preDelay,decayFactor,wetDryMix,fs) % This function generates reverberant speech data using the reverberator % object % % inputs: % y - clean speech sample % preDelay, decayFactor, wetDryMix - reverberation parameters % fs - sampling rate of y % % outputs: % yOut - corresponding reveberated speech sample revObj = reverberator('SampleRate',fs, ... 'DecayFactor',decayFactor, ... 'WetDryMix',wetDryMix, ... 'PreDelay',preDelay); yOut = revObj(y); yOut = yOut(1:length(y),1); end
function [featuresClean,featuresReverb,phaseReverb,cleanAudios,reverbAudios] ... = helperFeatureExtract(cleanAudio,reverbAudio,isVal,params) % This function performs the preprocessing and features extraction task on % the audio files used for dereverberation model training and testing. % % inputs: % cleanAudio - the clean audio file (reference) % reverbAudio - corresponding reverberant speech file % isVal - Boolean flag indicating if it is the validation set % params - a structure containing feature extraction parameters % % outputs: % featuresClean - log-magnitude STFT features of clean audio % featuresReverb - log-magnitude STFT features of reverberant audio % phaseReverb - phase of STFT of reverberant audio % cleanAudios - 2.072s-segments of clean audio file used for feature extraction % reverbAudios - 2.072s-segments of corresponding reverberant audio assert(length(cleanAudio) == length(reverbAudio)); nSegments = floor((length(reverbAudio) - (params.samplesPerImage - params.shiftImage))/params.shiftImage); featuresClean = {}; featuresReverb = {}; phaseReverb = {}; cleanAudios = {}; reverbAudios = {}; nGood = 0; nonSilentRegions = detectSpeech(reverbAudio, params.fs); nonSilentRegionIdx = 1; totalRegions = size(nonSilentRegions, 1); for cid = 1:nSegments start = (cid - 1)*params.shiftImage + 1; en = start + params.samplesPerImage - 1; nonSilentSamples = 0; while nonSilentRegionIdx < totalRegions && nonSilentRegions(nonSilentRegionIdx, 2) < start nonSilentRegionIdx = nonSilentRegionIdx + 1; end nonSilentStart = nonSilentRegionIdx; while nonSilentStart <= totalRegions && nonSilentRegions(nonSilentStart, 1) <= en nonSilentDuration = min(en, nonSilentRegions(nonSilentStart,2)) - max(start,nonSilentRegions(nonSilentStart,1)) + 1; nonSilentSamples = nonSilentSamples + nonSilentDuration; nonSilentStart = nonSilentStart + 1; end nonSilentPerc = nonSilentSamples * 100 / (en - start + 1); silent = nonSilentPerc < 50; reverbAudioSegment = reverbAudio(start:en); if ~silent nGood = nGood + 1; cleanAudioSegment = cleanAudio(start:en); assert(length(cleanAudioSegment)==length(reverbAudioSegment), 'Lengths do not match after chunking') % Clean Audio [featsUnit, ~] = featureExtract(cleanAudioSegment, params); featuresClean{nGood} = featsUnit; %#ok % Reverb Audio [featsUnit, phaseUnit] = featureExtract(reverbAudioSegment, params); featuresReverb{nGood} = featsUnit; %#ok if isVal phaseReverb{nGood} = phaseUnit; %#ok reverbAudios{nGood} = reverbAudioSegment;%#ok cleanAudios{nGood} = cleanAudioSegment;%#ok end end end end
function [features, phase, lastFBin] = featureExtract(audio, params) % Function to extract features for a speech file audio = single(audio); audioSTFT = stft(audio,'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength', params.FFTLength, 'FrequencyRange', 'onesided'); phase = single(angle(audioSTFT(1:end-1,:))); features = single(log(abs(audioSTFT(1:end-1,:)) + 10e-30)); lastFBin = audioSTFT(end,:); end
function [featNorm,minMaxPairs] = featureNormalizeAndReshape(feats) % function to normalize features - range [-1, 1] and reshape to 4 % dimensions % % inputs: % feats - 3-dimensional array of extracted features % % outputs: % featNorm - normalized and reshaped features % minMaxPairs - array of original min and max pairs used for normalization nSamples = length(feats); minMaxPairs = zeros(nSamples,2,'single'); featNorm = zeros([size(feats{1}),nSamples],'single'); parfor i = 1:nSamples feat = feats{i}; maxFeat = max(feat,[],'all'); minFeat = min(feat,[],'all'); featNorm(:,:,i) = 2.*(feat - minFeat)./(maxFeat - minFeat) - 1; minMaxPairs(i,:) = [minFeat,maxFeat]; end featNorm = reshape(featNorm,size(featNorm,1),size(featNorm,2),1,size(featNorm,3)); end
function dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,minMaxPairs,reverbPhase,reverbAudios,params) % This function will reconstruct the 2.072s long audios predicted by the % model using the predicted log-magnitude spectrogram and the phase of the % reverberant audio file % % inputs: % predictedSTFT4D - Predicted 4-dimensional STFT log-magnitude features % minMaxPairs - Original minimum/maximum value pairs used in normalization % reverbPhase - Array of phases of STFT of reverberant audio files % reverbAudios - 2.072s-segments of corresponding reverberant audios % params - Structure containing feature extraction parameters predictedSTFT = squeeze(predictedSTFT4D); denormalizedFeatures = zeros(size(predictedSTFT),'single'); for i = 1:size(predictedSTFT,3) feat = predictedSTFT(:,:,i); maxFeat = minMaxPairs(i,2); minFeat = minMaxPairs(i,1); denormalizedFeatures(:,:,i) = (feat + 1).*(maxFeat-minFeat)./2 + minFeat; end predictedSTFT = exp(denormalizedFeatures); nCount = size(predictedSTFT,3); dereverbedAudioAll = cell(1,nCount); nSeg = params.NumSegments; win = params.Window; ovrlp = params.OverlapLength; FFTLength = params.FFTLength; parfor ii = 1:nCount % Append zeros to the highest frequency bin stftUnit = predictedSTFT(:,:,ii); stftUnit = cat(1,stftUnit, zeros(1,nSeg)); phase = reverbPhase{ii}; phase = cat(1,phase,zeros(1,nSeg)); oneSidedSTFT = stftUnit.*exp(1j*phase); dereverbedAudio= istft(oneSidedSTFT, ... 'Window', win,'OverlapLength', ovrlp, ... 'FFTLength',FFTLength,'ConjugateSymmetric',true,... 'FrequencyRange','onesided'); dereverbedAudioAll{ii} = dereverbedAudio./max(max(abs(dereverbedAudio)), max(abs(reverbAudios{ii}))); end end
function [summaryMeasures,allMeasures] = calculateObjectiveMeasures(reconstructedAudios,cleanAudios,fs) % This function computes the objective measures on time-domain signals. % % inputs: % reconstructedAudios - An array of audio files to evaluate. % cleanAudios - An array of reference audio files % fs - Sampling rate of audio files % % outputs: % summaryMeasures - Global means of CD, LLR individual mean and median values % allMeasures - Individual mean and median values nAudios = length(reconstructedAudios); cdMean = zeros(nAudios,1); cdMedian = zeros(nAudios,1); llrMean = zeros(nAudios,1); llrMedian = zeros(nAudios,1); parfor k = 1 : nAudios y = reconstructedAudios{k}; x = cleanAudios{k}; y = y./max(abs(y)); x = x./max(abs(x)); [cdMean(k),cdMedian(k)] = cepstralDistance(x,y,fs); [llrMean(k),llrMedian(k)] = lpcLogLikelihoodRatio(y,x,fs); end summaryMeasures.avgCdMean = mean(cdMean); summaryMeasures.avgCdMedian = mean(cdMedian); summaryMeasures.avgLlrMean = mean(llrMean); summaryMeasures.avgLlrMedian = mean(llrMedian); allMeasures.cdMean = cdMean; allMeasures.llrMean = llrMean; end
function [meanVal, medianVal] = cepstralDistance(x,y,fs) x = x / sqrt(sum(x.^2)); y = y / sqrt(sum(y.^2)); width = round(0.025*fs); shift = round(0.01*fs); nSamples = length(x); nFrames = floor((nSamples - width + shift)/shift); win = window(@hanning,width); winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1); xFrames = x(winIndex).*win; yFrames = y(winIndex).*win; xCeps = cepstralReal(xFrames,width); yCeps = cepstralReal(yFrames,width); dist = (xCeps - yCeps).^2; cepsD = 10 / log(10)*sqrt(2*sum(dist(2:end,:),1) + dist(1,:)); cepsD = max(min(cepsD,10),0); meanVal = mean(cepsD); medianVal = median(cepsD); end
function realC = cepstralReal(x, width) width2p = 2 ^ nextpow2(width); powX = abs(fft(x, width2p)); lowCutoff = max(powX(:)) * 10^-5; powX = max(powX, lowCutoff); realC = real(ifft(log(powX))); order = 24; realC = realC(1 : order + 1, :); realC = realC - mean(realC, 2); end
function [meanLlr,medianLlr] = lpcLogLikelihoodRatio(x,y,fs) order = 12; width = round(0.025*fs); shift = round(0.01*fs); nSamples = length(x); nFrames = floor((nSamples - width + shift)/shift); win = window(@hanning,width); winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1); xFrames = x(winIndex) .* win; yFrames = y(winIndex) .* win; lpcX = realLpc(xFrames, width, order); [lpcY,realY] = realLpc(yFrames, width, order); llr = zeros(nFrames, 1); for n = 1 : nFrames R = toeplitz(realY(1:order+1,n)); num = lpcX(:,n)'*R*lpcX(:,n); den = lpcY(:,n)'*R*lpcY(:,n); llr(n) = log(num/den); end llr = sort(llr); llr = llr(1:ceil(nFrames*0.95)); llr = max(min(llr,2),0); meanLlr = mean(llr); medianLlr = median(llr); end
function [lpcCoeffs, realX] = realLpc(xFrames, width, order) width2p = 2 ^ nextpow2(width); X = fft(xFrames, width2p); Rx = ifft(abs(X).^2); Rx = Rx./width; realX = real(Rx); lpcX = levinson(realX, order); lpcCoeffs = real(lpcX'); end