В этом примере показано, как обучить полностью сверточную сеть (FCN) U-Net [1] к dereverberate, речь сигнализирует.
Реверберация происходит, когда речевой сигнал отражается от объектов на пробеле, заставляя несколько отражений расти и в конечном счете приводит к ухудшению речевого качества. Dereverberation является процессом сокращения эффектов реверберации в речевом сигнале.
Перед входом в учебный процесс подробно, используйте предварительно обученную сеть для dereverberate речевой сигнал.
Загрузите предварительно обученную сеть. Эта сеть была обучена на версиях с 56 динамиками обучающих наборов данных. Пример идет посредством обучения на версии с 28 динамиками.
url = 'https://ssd.mathworks.com/supportfiles/audio/dereverbnet.zip'; downloadFolder = tempdir; networkDataFolder = fullfile(downloadFolder,'derevernet'); if ~exist(networkDataFolder,'dir') disp('Downloading pretrained network ...') unzip(url,downloadFolder) end load(fullfile(networkDataFolder,'dereverbNet.mat'))
Слушайте чистый речевой сигнал, произведенный на уровне 16 кГц.
[cleanAudio,fs] = audioread('clean_speech_signal.wav');
sound(cleanAudio,fs)
Акустический путь может быть смоделирован с помощью импульсной характеристики помещения. Можно смоделировать реверберацию путем свертки к безэховому сигналу с импульсной характеристикой помещения.
Загрузите и постройте импульсную характеристику помещения.
[rirAudio,fsR] = audioread('room_impulse_response.wav'); tAxis = (1/fsR)*(0:numel(rirAudio)-1); figure plot(tAxis,rirAudio) xlabel('Time (s)') ylabel('Amplitude') grid on
Примените операцию свертки к чистой речи с импульсной характеристикой помещения, чтобы получить, отразился речь. Выровняйте длины и амплитуды отраженных и чистых речевых сигналов.
revAudio = conv(cleanAudio,rirAudio); revAudio = revAudio(1:numel(cleanAudio)); revAudio = revAudio.*(max(abs(cleanAudio))/max(abs(revAudio)));
Слушайте отраженный речевой сигнал.
sound(revAudio,fs)
Вход к предварительно обученной сети является кратковременным преобразованием Фурье (STFT) логарифмической величины отражающего аудио. Сеть предсказывает логарифмическую величину STFT входа dereverberated. Чтобы оценить исходный звуковой сигнал временного интервала, вы выполняете обратный STFT и принимаете фазу отражающего аудио.
Используйте следующие параметры, чтобы вычислить STFT.
params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
Используйте stft
вычислить одностороннюю логарифмическую величину STFT. Используйте одинарную точность при вычислении функций, чтобы лучше использовать использование памяти и ускорить обучение. Даже при том, что односторонний STFT дает к 257 интервалам частоты, рассмотрите только 256 интервалов и проигнорируйте самый высокий интервал частоты.
revAudio = single(revAudio); audioSTFT = stft(revAudio,'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength',params.FFTLength,'FrequencyRange','onesided'); Eps = realmin('single'); reverbFeats = log(abs(audioSTFT(1:end-1,:)) + Eps);
Извлеките фазу STFT.
phaseOriginal = angle(audioSTFT(1:end-1,:));
Каждый вход будет иметь размерности 256 256 (интервалы частоты временными шагами). Разделите логарифмическую величину STFT в сегменты 256 тактов.
params.NumSegments = 256;
params.NumFeatures = 256;
totalFrames = size(reverbFeats,2);
chunks = ceil(totalFrames/params.NumSegments);
reverbSTFTSegments = mat2cell(reverbFeats,params.NumFeatures, ...
[params.NumSegments*ones(1,chunks - 1),(totalFrames - (chunks-1)*params.NumSegments)]);
reverbSTFTSegments{chunks} = reverbFeats(:,end-params.NumSegments + 1:end);
Масштабируйте сегментированные функции к области значений [-1,1]. Сохраните минимальные и максимальные значения, используемые, чтобы масштабироваться для восстановления сигнала dereverberated.
minVals = num2cell(cellfun(@(x)min(x,[],'all'),reverbSTFTSegments)); maxVals = num2cell(cellfun(@(x)max(x,[],'all'),reverbSTFTSegments)); featNorm = cellfun(@(feat,minFeat,maxFeat)2.*(feat - minFeat)./(maxFeat - minFeat) - 1, ... reverbSTFTSegments,minVals,maxVals,'UniformOutput',false);
Измените функции так, чтобы фрагменты приехали четвертая размерность.
featNorm = reshape(cell2mat(featNorm),params.NumFeatures,params.NumSegments,1,chunks);
Предскажите спектры логарифмической величины отраженного речевого сигнала использование предварительно обученной сети.
predictedSTFT4D = predict(dereverbNet,featNorm);
Изменитесь к 3 размерностям и масштабируйте предсказанный STFTs к исходной области значений с помощью сохраненных минимально-максимальных пар.
predictedSTFT = squeeze(mat2cell(predictedSTFT4D,params.NumFeatures,params.NumSegments,1,ones(1,chunks)))'; featDeNorm = cellfun(@(feat,minFeat,maxFeat) (feat + 1).*(maxFeat-minFeat)./2 + minFeat, ... predictedSTFT,minVals,maxVals,'UniformOutput',false);
Инвертируйте масштабирование журнала.
predictedSTFT = cellfun(@exp,featDeNorm,'UniformOutput',false);
Конкатенация предсказанного 256 256 сегменты STFT величины, чтобы получить спектрограмму величины исходной длины.
predictedSTFTAll = predictedSTFT(1:chunks - 1); predictedSTFTAll = cat(2,predictedSTFTAll{:}); predictedSTFTAll(:,totalFrames - params.NumSegments + 1:totalFrames) = predictedSTFT{chunks};
Прежде, чем взять обратный STFT, добавьте нули к предсказанному спектру логарифмической величины и фазе вместо самого высокого интервала частоты, который был исключен, когда вход подготовки показывает.
nCount = size(predictedSTFTAll,3); predictedSTFTAll = cat(1,predictedSTFTAll,zeros(1,totalFrames,nCount)); phase = cat(1,phaseOriginal,zeros(1,totalFrames,nCount));
Используйте обратную функцию STFT, чтобы восстановить dereverberated речевой сигнал временного интервала использование предсказанной логарифмической величины STFT и фаза отражающего речевого сигнала.
oneSidedSTFT = predictedSTFTAll.*exp(1j*phase); dereverbedAudio = istft(oneSidedSTFT, ... 'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength',params.FFTLength,'ConjugateSymmetric',true, ... 'FrequencyRange','onesided'); dereverbedAudio = dereverbedAudio./max(abs([dereverbedAudio;revAudio])); dereverbedAudio = [dereverbedAudio;zeros(length(revAudio) - numel(dereverbedAudio), 1)];
Слушайте dereverberated звуковой сигнал.
sound(dereverbedAudio,fs)
Постройте чистые, отражающие, и dereverberated речевые сигналы.
t = (1/fs)*(0:numel(cleanAudio)-1); figure subplot(3,1,1) plot(t,cleanAudio) xlabel('Time (s)') grid on subtitle('Clean Speech Signal') subplot(3,1,2) plot(t,revAudio) xlabel('Time (s)') grid on subtitle('Revereberated Speech Signal') subplot(3,1,3) plot(t,dereverbedAudio) xlabel('Time (s)') grid on subtitle('Derevereberated Speech Signal')
Визуализируйте спектрограммы чистых, отражающих, и dereverberated речевых сигналов.
figure('Position',[100,100,800,800]) subplot(3,1,1) spectrogram(cleanAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Clean') subplot(3,1,2) spectrogram(revAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Reverberated') subplot(3,1,3) spectrogram(dereverbedAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis'); subtitle('Predicted (Dereverberated)')
Этот пример использует Отражающую Речевую Базу данных [2] и соответствующую Чистую Речевую Базу данных [3], чтобы обучить сеть.
Загрузите чистый речевой набор данных.
url1 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_trainset_28spk_wav.zip'; url2 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_testset_wav.zip'; downloadFolder = tempdir; cleanDataFolder = fullfile(downloadFolder,'DS_10283_2791'); if ~exist(cleanDataFolder,'dir') disp('Downloading data set (6 GB) ...') unzip(url1,cleanDataFolder) unzip(url2,cleanDataFolder) end
Загрузите отраженный речевой набор данных.
url3 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_trainset_28spk_wav.zip'; url4 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_testset_wav.zip'; downloadFolder = tempdir; reverbDataFolder = fullfile(downloadFolder,'DS_10283_2031'); if ~exist(reverbDataFolder,'dir') disp('Downloading data set (6 GB) ...') unzip(url3,reverbDataFolder) unzip(url4,reverbDataFolder) end
Если данные загружаются, предварительно обработайте загруженные данные и извлеките функции перед обучением модель DNN:
Искусственно сгенерируйте отражающие данные с помощью reverberator
объект
Разделите каждый речевой сигнал в маленькие сегменты 2,072 длительности с
Отбросьте сегменты, которые содержат значительные тихие области
Извлеките логарифмическую величину STFTs как предиктор и предназначайтесь для функций
Масштабируйте и измените функции
Во-первых, создайте два audioDatastore
объекты, которые указывают на чистые и отражающие речевые наборы данных.
adsCleanTrain = audioDatastore(fullfile(cleanDataFolder,'clean_trainset_28spk_wav'),'IncludeSubfolders',true); adsReverbTrain = audioDatastore(fullfile(reverbDataFolder,'reverb_trainset_28spk_wav'),'IncludeSubfolders',true);
Количество реверберации в исходных данных относительно мало. Вы увеличите отражающие речевые данные со значительными эффектами реверберации с помощью reverberator
объект.
Создайте audioDatastore
это указывает на чистый речевой набор данных, выделенный для синтетической отражающей генерации данных.
adsSyntheticCleanTrain = subset(adsCleanTrain,10e3+1:length(adsCleanTrain.Files)); adsCleanTrain = subset(adsCleanTrain,1:10e3); adsReverbTrain = subset(adsReverbTrain,1:10e3);
Передискретизируйте от 48 кГц до 16 кГц.
adsSyntheticCleanTrain = transform(adsSyntheticCleanTrain,@(x)resample(x,16e3,48e3)); adsCleanTrain = transform(adsCleanTrain,@(x)resample(x,16e3,48e3)); adsReverbTrain = transform(adsReverbTrain,@(x)resample(x,16e3,48e3));
Объедините два аудио хранилища данных, обеспечив соответствие между чистыми и отражающими речевыми выборками.
adsCombinedTrain = combine(adsCleanTrain,adsReverbTrain);
Функция applyReverb создает reverberator
объект, обновляется пред задержка, фактор затухания и влажно-сухие параметры соединения, как задано, и затем применяет реверберацию. Используйте audioDataAugmenter
создать искусственно сгенерированные отражающие данные.
augmenter = audioDataAugmenter('AugmentationMode','independent','NumAugmentations', 1,'ApplyAddNoise',0, ... 'ApplyTimeStretch',0,'ApplyPitchShift',0,'ApplyVolumeControl',0,'ApplyTimeShift',0); algorithmHandle = @(y,preDelay,decayFactor,wetDryMix,samplingRate) ... applyReverb(y,preDelay,decayFactor,wetDryMix,samplingRate); addAugmentationMethod(augmenter,'Reverb',algorithmHandle, ... 'AugmentationParameter',{'PreDelay','DecayFactor','WetDryMix','SamplingRate'}, ... 'ParameterRange',{[0.15,0.25],[0.2,0.5],[0.3,0.45],[16000,16000]}) augmenter.ReverbProbability = 1; disp(augmenter)
audioDataAugmenter with properties: AugmentationMode: 'independent' AugmentationParameterSource: 'random' NumAugmentations: 1 ApplyTimeStretch: 0 ApplyPitchShift: 0 ApplyVolumeControl: 0 ApplyAddNoise: 0 ApplyTimeShift: 0 ApplyReverb: 1 PreDelayRange: [0.1500 0.2500] DecayFactorRange: [0.2000 0.5000] WetDryMixRange: [0.3000 0.4500] SamplingRateRange: [16000 16000]
Создайте новый audioDatastore
соответствие искусственно сгенерировало отражающие данные путем вызова transform
применять увеличение данных.
adsSyntheticReverbTrain = transform(adsSyntheticCleanTrain,@(y)deal(augment(augmenter,y,16e3).Audio{1}));
Объедините два аудио хранилища данных.
adsSyntheticCombinedTrain = combine(adsSyntheticCleanTrain,adsSyntheticReverbTrain);
Затем на основе размерностей входных функций к сети, сегментируйте аудио на фрагменты 2,072 длительности с с перекрытием 50%.
Наличие слишком многих тихих сегментов может оказать негативное влияние на обучение модели DNN. Удалите сегменты, которые в основном тихи (больше чем 50% длительности) и исключают тех из обучения модели. Не полностью удаляйте тишину, потому что модель не будет устойчива в тихие области, и небольшие эффекты реверберации могли быть идентифицированы как тишина. detectSpeech
может идентифицировать начальные и конечные точки тихих областей. После этих двух шагов процесс извлечения признаков может быть выполнен, как объяснено в первом разделе. helperFeatureExtract реализует эти шаги.
Задайте параметры извлечения признаков. Установкой reduceDataSet
к истине вы выбираете небольшое подмножество наборов данных, чтобы выполнить последующие шаги.
reduceDataSet = true; params.fs = 16000; params.WindowdowLength = 512; params.Window = hamming (params.WindowdowLength,"periodic"); params.OverlapLength = раунд (0.75*params.WindowdowLength); params.FFTLength = params.WindowdowLength; samplesPerMs = params.fs/1000; params.samplesPerImage = (24+256*8) *samplesPerMs; params.shiftImage = params.samplesPerImage/2; params.NumSegments = 256; params.NumFeatures = 256
params = struct with fields:
WindowdowLength: 512
Window: [512×1 double]
OverlapLength: 384
FFTLength: 512
NumSegments: 256
NumFeatures: 256
fs: 16000
samplesPerImage: 33152
shiftImage: 16576
Чтобы ускорить обработку, распределите задачу предварительной обработки и извлечения признаков на нескольких рабочих, использующих parfor
.
Определите количество разделов для набора данных. Если у вас нет Parallel Computing Toolbox™, используйте один раздел.
if ~isempty(ver('parallel')) pool = gcp; numPar = numpartitions(adsCombinedTrain,pool); else numPar = 1; end
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
Для каждого раздела читайте из datastore, предварительно обработайте звуковой сигнал, и затем извлеките функции.
if reduceDataSet adsCombinedTrain = shuffle(adsCombinedTrain); %#ok adsCombinedTrain = subset(adsCombinedTrain,1:200); adsSyntheticCombinedTrain = shuffle(adsSyntheticCombinedTrain); adsSyntheticCombinedTrain = subset(adsSyntheticCombinedTrain,1:200); end allCleanFeatures = cell(1,numPar); allReverbFeatures = cell(1,numPar); parfor iPartition = 1:numPar combinedPartition = partition(adsCombinedTrain,numPar,iPartition); combinedSyntheticPartition = partition(adsSyntheticCombinedTrain,numPar,iPartition); cPartitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); cSyntheticPartitionSize = numel(combinedSyntheticPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); partitionSize = cPartitionSize + cSyntheticPartitionSize; cleanFeaturesPartition = cell(1,partitionSize); reverbFeaturesPartition = cell(1,partitionSize); for idx = 1:partitionSize if idx <= cPartitionSize audios = read(combinedPartition); else audios = read(combinedSyntheticPartition); end cleanAudio = single(audios(:,1)); reverbAudio = single(audios(:,2)); [featuresClean,featuresReverb] = helperFeatureExtract(cleanAudio,reverbAudio,false,params); cleanFeaturesPartition{idx} = featuresClean; reverbFeaturesPartition{idx} = featuresReverb; end allCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:}); allReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:}); end allCleanFeatures = cat(2,allCleanFeatures{:}); allReverbFeatures = cat(2,allReverbFeatures{:});
Нормируйте извлеченные функции к области значений [-1,1] и затем изменитесь, как объяснено в первом разделе, с помощью функции featureNormalizeAndReshape.
trainClean = featureNormalizeAndReshape(allCleanFeatures); trainReverb = featureNormalizeAndReshape(allReverbFeatures);
Теперь, когда вы извлекли логарифмическую величину функции STFT из обучающих наборов данных, выполните ту же процедуру, чтобы извлечь функции из наборов данных валидации. В целях реконструкции сохраните фазу отражающих речевых выборок набора данных валидации. Кроме того, сохраните аудиоданные и для чистых и для отражающих речевых выборок в наборе валидации, чтобы использовать в процессе оценки (следующий раздел).
adsCleanVal = audioDatastore(fullfile(cleanDataFolder,'clean_testset_wav'),'IncludeSubfolders',true); adsReverbVal = audioDatastore(fullfile(reverbDataFolder,'reverb_testset_wav'),'IncludeSubfolders',true);
Передискретизируйте от 48 кГц до 16 кГц.
adsCleanVal = transform(adsCleanVal,@(x)resample(x,16e3,48e3)); adsReverbVal = transform(adsReverbVal,@(x)resample(x,16e3,48e3)); adsCombinedVal = combine(adsCleanVal,adsReverbVal);
if reduceDataSet adsCombinedVal = shuffle(adsCombinedVal);%#ok adsCombinedVal = subset(adsCombinedVal,1:50); end allValCleanFeatures = cell(1,numPar); allValReverbFeatures = cell(1,numPar); allValReverbPhase = cell(1,numPar); allValCleanAudios = cell(1,numPar); allValReverbAudios = cell(1,numPar); parfor iPartition = 1:numPar combinedPartition = partition(adsCombinedVal,numPar,iPartition); partitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files); cleanFeaturesPartition = cell(1,partitionSize); reverbFeaturesPartition = cell(1,partitionSize); reverbPhasePartition = cell(1,partitionSize); cleanAudiosPartition = cell(1,partitionSize); reverbAudiosPartition = cell(1,partitionSize); for idx = 1:partitionSize audios = read(combinedPartition); cleanAudio = single(audios(:,1)); reverbAudio = single(audios(:,2)); [a,b,c,d,e] = helperFeatureExtract(cleanAudio,reverbAudio,true,params); cleanFeaturesPartition{idx} = a; reverbFeaturesPartition{idx} = b; reverbPhasePartition{idx} = c; cleanAudiosPartition{idx} = d; reverbAudiosPartition{idx} = e; end allValCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:}); allValReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:}); allValReverbPhase{iPartition} = cat(2,reverbPhasePartition{:}); allValCleanAudios{iPartition} = cat(2,cleanAudiosPartition{:}); allValReverbAudios{iPartition} = cat(2,reverbAudiosPartition{:}); end allValCleanFeatures = cat(2,allValCleanFeatures{:}); allValReverbFeatures = cat(2,allValReverbFeatures{:}); allValReverbPhase = cat(2,allValReverbPhase{:}); allValCleanAudios = cat(2,allValCleanAudios{:}); allValReverbAudios = cat(2,allValReverbAudios{:}); valClean = featureNormalizeAndReshape(allValCleanFeatures);
Сохраните минимальные и максимальные значения каждой функции отражающего набора валидации. Вы будете использовать эти значения в процессе реконструкции.
[valReverb,valMinMaxPairs] = featureNormalizeAndReshape(allValReverbFeatures);
Полностью сверточная сетевая архитектура по имени U-Net была адаптирована к этой речи dereverberation задача, как предложено в [1]. "U-Net" является сетью декодера энкодера со связями пропуска. В U-сетевой-модели каждый слой прореживает свой вход (шаг 2), пока слой узкого места не достигнут (путь к кодированию). В последующих слоях вход сверхдискретизирован каждым слоем, пока выходной параметр не возвращен к исходной форме (декодирующий путь). Чтобы минимизировать потерю низкоуровневой информации во время процесса субдискретизации, связи установлены между зеркальными слоями путем прямой конкатенации выходных параметров соответствующих слоев (связи пропуска).
Задайте сетевую архитектуру и возвратите график слоев со связями.
params.WindowdowLength = 512; params.FFTLength = params.WindowdowLength; params.NumFeatures = params.FFTLength/2; params.NumSegments = 256; filterH = 6; filterW = 6; numChannels = 1; nFilters = [64,128,256,512,512,512,512,512]; inputLayer = imageInputLayer([params.NumFeatures,params.NumSegments,numChannels], ... 'Normalization','none','Name','input'); layers = inputLayer; % U-Net squeezing path layers = [layers; convolution2dLayer([filterH,filterW],nFilters(1),'Stride',2,'Padding','same','Name',"conv"+string(1)); leakyReluLayer(0.2,'Name',"leaky-relu"+string(1))]; for i = 2:8 layers = [layers; convolution2dLayer([filterH,filterW],nFilters(i),'Stride',2,'Padding','same','Name',"conv"+string(i)); batchNormalizationLayer('Name',"batchnorm"+string(i))];%#ok if i ~= 8 layers = [layers;leakyReluLayer(0.2,'Name',"leaky-relu"+string(i))];%#ok else layers = [layers;reluLayer('Name',"relu"+string(i))];%#ok end end % U-Net expanding path for i = 7:-1:0 nChannels = numChannels; if i > 0 nChannels = nFilters(i); end layers = [layers; transposedConv2dLayer([filterH,filterW],nChannels,'Stride',2,'Cropping','same','Name',"deconv"+string(i))];%#ok if i > 0 layers = [layers; batchNormalizationLayer('Name',"de-batchnorm" +string(i))];%#ok end if i > 4 layers = [layers;dropoutLayer(0.5,'Name',"de-dropout"+string(i))];%#ok end if i > 0 layers = [layers; reluLayer('Name',"de-relu"+string(i)); concatenationLayer(3,2,'Name',"concat"+string(i))];%#ok else layers = [layers;tanhLayer('Name',"de-tanh"+string(i))];%#ok end end layers = [layers;regressionLayer('Name','output')]; unetLayerGraph = layerGraph(layers); % Define skip-connections for i = 1:7 unetLayerGraph = connectLayers(unetLayerGraph,'leaky-relu'+string(i),'concat'+string(i)+'/in2'); end
Используйте analyzeNetwork
просмотреть архитектуру модели. Это - хороший способ визуализировать связи между слоями.
analyzeNetwork(unetLayerGraph);
Вы будете использовать среднеквадратическую ошибку (MSE) между спектрами логарифмической величины dereverberated речевой выборки (выход модели) и соответствующей чистой речевой выборки (цель) как функция потерь. Используйте adam
оптимизатор и мини-пакетный размер 128 для обучения. Позвольте модели обучаться максимум для 50 эпох. Если потеря валидации не улучшается в течение 5 эпох подряд, отключает учебный процесс. Уменьшайте скорость обучения на коэффициент 10 каждых 15 эпох.
Задайте опции обучения как ниже. Измените среду выполнения и выполнить ли фоновую диспетчеризацию в зависимости от вашей аппаратной доступности и есть ли у вас доступ к Parallel Computing Toolbox™.
initialLearnRate = 8e-4; miniBatchSize = 64; options = trainingOptions("adam", ... "MaxEpochs", 50, ... "InitialLearnRate",initialLearnRate, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Plots","training-progress", ... "Verbose",false, ... "ValidationFrequency",max(1,floor(size(trainReverb,4)/miniBatchSize)), ... "ValidationPatience",5, ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",15, ... "ExecutionEnvironment","gpu", ... "DispatchInBackground",true, ... "ValidationData",{valReverb,valClean});
Обучите сеть.
dereverbNet = trainNetwork(trainReverb,trainClean,unetLayerGraph,options);
Предскажите спектры логарифмической величины набора валидации.
predictedSTFT4D = predict(dereverbNet,valReverb);
Используйте функцию helperReconstructPredictedAudios, чтобы восстановить предсказанную речь. Эта функция выполняет действия, обрисованные в общих чертах в первом разделе.
params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
params.fs = 16000;
dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,valMinMaxPairs,allValReverbPhase,allValReverbAudios,params);
Визуализируйте логарифмическую величину STFTs чистых, отражающих, и соответствующих dereverberated речевых сигналов.
figure('Position',[100,100,1024,1200]) subplot(3,1,1) imagesc(squeeze(allValCleanFeatures{1})) set(gca,'Ydir','normal') subtitle('Clean') xlabel('Time') ylabel('Frequency') colorbar subplot(3,1,2) imagesc(squeeze(allValReverbFeatures{1})) set(gca,'Ydir','normal') subtitle('Reverberated') xlabel('Time') ylabel('Frequency') colorbar subplot(3,1,3) imagesc(squeeze(predictedSTFT4D(:,:,:,1))) set(gca,'Ydir','normal') subtitle('Predicted (Dereverberated)') xlabel('Time') ylabel('Frequency') caxis([-1,1]) colorbar
Вы будете использовать подмножество объективных мер, используемых в [1], чтобы оценить эффективность сети. Эти метрики вычисляются на сигналах временной области.
Расстояние кепстра (CD) - Обеспечивает оценку журнала спектральное расстояние между двумя спектрами (предсказанный и чистый). Меньшие значения указывают на лучшее качество.
Логарифмическое отношение правдоподобия (LLR) - основанное на линейном предсказательном кодировании (LPC) объективное измерение. Меньшие значения указывают на лучшее качество.
Вычислите эти измерения для отражающей речи и dereverberated речевых сигналов.
[summaryMeasuresReconstructed,allMeasuresReconstructed] = calculateObjectiveMeasures(dereverbedAudioAll,allValCleanAudios,params.fs); [summaryMeasuresReverb,allMeasuresReverb] = calculateObjectiveMeasures(allValReverbAudios,allValCleanAudios,params.fs); disp(summaryMeasuresReconstructed)
avgCdMean: 3.8386 avgCdMedian: 3.3671 avgLlrMean: 0.9152 avgLlrMedian: 0.8096
disp(summaryMeasuresReverb)
avgCdMean: 4.2591 avgCdMedian: 3.6336 avgLlrMean: 0.9726 avgLlrMedian: 0.8714
Гистограммы иллюстрируют распределение среднего CD, означают SRMR и означают LLR отражающих и dereverberated данных.
figure('position',[50,50,1100,1300]) subplot(2,1,1) histogram(allMeasuresReverb.cdMean,10) hold on histogram(allMeasuresReconstructed.cdMean, 10) subtitle('Mean Cepstral Distance Distribution') ylabel('count') xlabel('mean CD') legend('Reverberant (Original)','Dereverberated (Predicted)') subplot(2,1,2) histogram(allMeasuresReverb.llrMean,10) hold on histogram(allMeasuresReconstructed.llrMean,10) subtitle('Mean Log Likelihood Ratio Distribution') ylabel('Count') xlabel('Mean LLR') legend('Reverberant (Original)','Dereverberated (Predicted)')
[1] Эрнст, O., Chazan, S.E., Gannot, S., & Голдбергер, J. (2018). Речь Dereverberation Используя полностью Сверточные сети. 2 018 26-х европейских конференций по обработке сигналов (EUSIPCO), 390-394.
[2] https://datashare.is.ed.ac.uk/handle/10283/2031
[3] https://datashare.is.ed.ac.uk/handle/10283/2791
[4] https://github.com/MuSAELab/SRMRToolbox
function yOut = applyReverb(y,preDelay,decayFactor,wetDryMix,fs) % This function generates reverberant speech data using the reverberator % object % % inputs: % y - clean speech sample % preDelay, decayFactor, wetDryMix - reverberation parameters % fs - sampling rate of y % % outputs: % yOut - corresponding reveberated speech sample revObj = reverberator('SampleRate',fs, ... 'DecayFactor',decayFactor, ... 'WetDryMix',wetDryMix, ... 'PreDelay',preDelay); yOut = revObj(y); yOut = yOut(1:length(y),1); end
function [featuresClean,featuresReverb,phaseReverb,cleanAudios,reverbAudios] ... = helperFeatureExtract(cleanAudio,reverbAudio,isVal,params) % This function performs the preprocessing and features extraction task on % the audio files used for dereverberation model training and testing. % % inputs: % cleanAudio - the clean audio file (reference) % reverbAudio - corresponding reverberant speech file % isVal - Boolean flag indicating if it is the validation set % params - a structure containing feature extraction parameters % % outputs: % featuresClean - log-magnitude STFT features of clean audio % featuresReverb - log-magnitude STFT features of reverberant audio % phaseReverb - phase of STFT of reverberant audio % cleanAudios - 2.072s-segments of clean audio file used for feature extraction % reverbAudios - 2.072s-segments of corresponding reverberant audio assert(length(cleanAudio) == length(reverbAudio)); nSegments = floor((length(reverbAudio) - (params.samplesPerImage - params.shiftImage))/params.shiftImage); featuresClean = {}; featuresReverb = {}; phaseReverb = {}; cleanAudios = {}; reverbAudios = {}; nGood = 0; nonSilentRegions = detectSpeech(reverbAudio, params.fs); nonSilentRegionIdx = 1; totalRegions = size(nonSilentRegions, 1); for cid = 1:nSegments start = (cid - 1)*params.shiftImage + 1; en = start + params.samplesPerImage - 1; nonSilentSamples = 0; while nonSilentRegionIdx < totalRegions && nonSilentRegions(nonSilentRegionIdx, 2) < start nonSilentRegionIdx = nonSilentRegionIdx + 1; end nonSilentStart = nonSilentRegionIdx; while nonSilentStart <= totalRegions && nonSilentRegions(nonSilentStart, 1) <= en nonSilentDuration = min(en, nonSilentRegions(nonSilentStart,2)) - max(start,nonSilentRegions(nonSilentStart,1)) + 1; nonSilentSamples = nonSilentSamples + nonSilentDuration; nonSilentStart = nonSilentStart + 1; end nonSilentPerc = nonSilentSamples * 100 / (en - start + 1); silent = nonSilentPerc < 50; reverbAudioSegment = reverbAudio(start:en); if ~silent nGood = nGood + 1; cleanAudioSegment = cleanAudio(start:en); assert(length(cleanAudioSegment)==length(reverbAudioSegment), 'Lengths do not match after chunking') % Clean Audio [featsUnit, ~] = featureExtract(cleanAudioSegment, params); featuresClean{nGood} = featsUnit; %#ok % Reverb Audio [featsUnit, phaseUnit] = featureExtract(reverbAudioSegment, params); featuresReverb{nGood} = featsUnit; %#ok if isVal phaseReverb{nGood} = phaseUnit; %#ok reverbAudios{nGood} = reverbAudioSegment;%#ok cleanAudios{nGood} = cleanAudioSegment;%#ok end end end end
function [features, phase, lastFBin] = featureExtract(audio, params) % Function to extract features for a speech file audio = single(audio); audioSTFT = stft(audio,'Window',params.Window,'OverlapLength',params.OverlapLength, ... 'FFTLength', params.FFTLength, 'FrequencyRange', 'onesided'); phase = single(angle(audioSTFT(1:end-1,:))); features = single(log(abs(audioSTFT(1:end-1,:)) + 10e-30)); lastFBin = audioSTFT(end,:); end
function [featNorm,minMaxPairs] = featureNormalizeAndReshape(feats) % function to normalize features - range [-1, 1] and reshape to 4 % dimensions % % inputs: % feats - 3-dimensional array of extracted features % % outputs: % featNorm - normalized and reshaped features % minMaxPairs - array of original min and max pairs used for normalization nSamples = length(feats); minMaxPairs = zeros(nSamples,2,'single'); featNorm = zeros([size(feats{1}),nSamples],'single'); parfor i = 1:nSamples feat = feats{i}; maxFeat = max(feat,[],'all'); minFeat = min(feat,[],'all'); featNorm(:,:,i) = 2.*(feat - minFeat)./(maxFeat - minFeat) - 1; minMaxPairs(i,:) = [minFeat,maxFeat]; end featNorm = reshape(featNorm,size(featNorm,1),size(featNorm,2),1,size(featNorm,3)); end
function dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,minMaxPairs,reverbPhase,reverbAudios,params) % This function will reconstruct the 2.072s long audios predicted by the % model using the predicted log-magnitude spectrogram and the phase of the % reverberant audio file % % inputs: % predictedSTFT4D - Predicted 4-dimensional STFT log-magnitude features % minMaxPairs - Original minimum/maximum value pairs used in normalization % reverbPhase - Array of phases of STFT of reverberant audio files % reverbAudios - 2.072s-segments of corresponding reverberant audios % params - Structure containing feature extraction parameters predictedSTFT = squeeze(predictedSTFT4D); denormalizedFeatures = zeros(size(predictedSTFT),'single'); for i = 1:size(predictedSTFT,3) feat = predictedSTFT(:,:,i); maxFeat = minMaxPairs(i,2); minFeat = minMaxPairs(i,1); denormalizedFeatures(:,:,i) = (feat + 1).*(maxFeat-minFeat)./2 + minFeat; end predictedSTFT = exp(denormalizedFeatures); nCount = size(predictedSTFT,3); dereverbedAudioAll = cell(1,nCount); nSeg = params.NumSegments; win = params.Window; ovrlp = params.OverlapLength; FFTLength = params.FFTLength; parfor ii = 1:nCount % Append zeros to the highest frequency bin stftUnit = predictedSTFT(:,:,ii); stftUnit = cat(1,stftUnit, zeros(1,nSeg)); phase = reverbPhase{ii}; phase = cat(1,phase,zeros(1,nSeg)); oneSidedSTFT = stftUnit.*exp(1j*phase); dereverbedAudio= istft(oneSidedSTFT, ... 'Window', win,'OverlapLength', ovrlp, ... 'FFTLength',FFTLength,'ConjugateSymmetric',true,... 'FrequencyRange','onesided'); dereverbedAudioAll{ii} = dereverbedAudio./max(max(abs(dereverbedAudio)), max(abs(reverbAudios{ii}))); end end
function [summaryMeasures,allMeasures] = calculateObjectiveMeasures(reconstructedAudios,cleanAudios,fs) % This function computes the objective measures on time-domain signals. % % inputs: % reconstructedAudios - An array of audio files to evaluate. % cleanAudios - An array of reference audio files % fs - Sampling rate of audio files % % outputs: % summaryMeasures - Global means of CD, LLR individual mean and median values % allMeasures - Individual mean and median values nAudios = length(reconstructedAudios); cdMean = zeros(nAudios,1); cdMedian = zeros(nAudios,1); llrMean = zeros(nAudios,1); llrMedian = zeros(nAudios,1); parfor k = 1 : nAudios y = reconstructedAudios{k}; x = cleanAudios{k}; y = y./max(abs(y)); x = x./max(abs(x)); [cdMean(k),cdMedian(k)] = cepstralDistance(x,y,fs); [llrMean(k),llrMedian(k)] = lpcLogLikelihoodRatio(y,x,fs); end summaryMeasures.avgCdMean = mean(cdMean); summaryMeasures.avgCdMedian = mean(cdMedian); summaryMeasures.avgLlrMean = mean(llrMean); summaryMeasures.avgLlrMedian = mean(llrMedian); allMeasures.cdMean = cdMean; allMeasures.llrMean = llrMean; end
function [meanVal, medianVal] = cepstralDistance(x,y,fs) x = x / sqrt(sum(x.^2)); y = y / sqrt(sum(y.^2)); width = round(0.025*fs); shift = round(0.01*fs); nSamples = length(x); nFrames = floor((nSamples - width + shift)/shift); win = window(@hanning,width); winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1); xFrames = x(winIndex).*win; yFrames = y(winIndex).*win; xCeps = cepstralReal(xFrames,width); yCeps = cepstralReal(yFrames,width); dist = (xCeps - yCeps).^2; cepsD = 10 / log(10)*sqrt(2*sum(dist(2:end,:),1) + dist(1,:)); cepsD = max(min(cepsD,10),0); meanVal = mean(cepsD); medianVal = median(cepsD); end
function realC = cepstralReal(x, width) width2p = 2 ^ nextpow2(width); powX = abs(fft(x, width2p)); lowCutoff = max(powX(:)) * 10^-5; powX = max(powX, lowCutoff); realC = real(ifft(log(powX))); order = 24; realC = realC(1 : order + 1, :); realC = realC - mean(realC, 2); end
function [meanLlr,medianLlr] = lpcLogLikelihoodRatio(x,y,fs) order = 12; width = round(0.025*fs); shift = round(0.01*fs); nSamples = length(x); nFrames = floor((nSamples - width + shift)/shift); win = window(@hanning,width); winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1); xFrames = x(winIndex) .* win; yFrames = y(winIndex) .* win; lpcX = realLpc(xFrames, width, order); [lpcY,realY] = realLpc(yFrames, width, order); llr = zeros(nFrames, 1); for n = 1 : nFrames R = toeplitz(realY(1:order+1,n)); num = lpcX(:,n)'*R*lpcX(:,n); den = lpcY(:,n)'*R*lpcY(:,n); llr(n) = log(num/den); end llr = sort(llr); llr = llr(1:ceil(nFrames*0.95)); llr = max(min(llr,2),0); meanLlr = mean(llr); medianLlr = median(llr); end
function [lpcCoeffs, realX] = realLpc(xFrames, width, order) width2p = 2 ^ nextpow2(width); X = fft(xFrames, width2p); Rx = ifft(abs(X).^2); Rx = Rx./width; realX = real(Rx); lpcX = levinson(realX, order); lpcCoeffs = real(lpcX'); end