Речь Dereverberate Используя нейронные сети для глубокого обучения

В этом примере показано, как обучить полностью сверточную сеть (FCN) U-Net [1] к dereverberate, речь сигнализирует.

Введение

Реверберация происходит, когда речевой сигнал отражается от объектов на пробеле, заставляя несколько отражений расти и в конечном счете приводит к ухудшению речевого качества. Dereverberation является процессом сокращения эффектов реверберации в речевом сигнале.

Речевой сигнал Dereverberate Используя предварительно обученную сеть

Перед входом в учебный процесс подробно, используйте предварительно обученную сеть для dereverberate речевой сигнал.

Загрузите предварительно обученную сеть. Эта сеть была обучена на версиях с 56 динамиками обучающих наборов данных. Пример идет посредством обучения на версии с 28 динамиками.

url = 'https://ssd.mathworks.com/supportfiles/audio/dereverbnet.zip';
downloadFolder = tempdir;
networkDataFolder = fullfile(downloadFolder,'derevernet');

if ~exist(networkDataFolder,'dir')
    disp('Downloading pretrained network ...')
    unzip(url,downloadFolder)
end
load(fullfile(networkDataFolder,'dereverbNet.mat'))

Слушайте чистый речевой сигнал, произведенный на уровне 16 кГц.

[cleanAudio,fs] = audioread('clean_speech_signal.wav');

sound(cleanAudio,fs)

Акустический путь может быть смоделирован с помощью импульсной характеристики помещения. Можно смоделировать реверберацию путем свертки к безэховому сигналу с импульсной характеристикой помещения.

Загрузите и постройте импульсную характеристику помещения.

[rirAudio,fsR] = audioread('room_impulse_response.wav');

tAxis = (1/fsR)*(0:numel(rirAudio)-1);

figure
plot(tAxis,rirAudio)
xlabel('Time (s)')
ylabel('Amplitude')
grid on

Примените операцию свертки к чистой речи с импульсной характеристикой помещения, чтобы получить, отразился речь. Выровняйте длины и амплитуды отраженных и чистых речевых сигналов.

revAudio = conv(cleanAudio,rirAudio);

revAudio = revAudio(1:numel(cleanAudio));
revAudio = revAudio.*(max(abs(cleanAudio))/max(abs(revAudio)));

Слушайте отраженный речевой сигнал.

sound(revAudio,fs)

Вход к предварительно обученной сети является кратковременным преобразованием Фурье (STFT) логарифмической величины отражающего аудио. Сеть предсказывает логарифмическую величину STFT входа dereverberated. Чтобы оценить исходный звуковой сигнал временного интервала, вы выполняете обратный STFT и принимаете фазу отражающего аудио.

Используйте следующие параметры, чтобы вычислить STFT.

params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;

Используйте stft вычислить одностороннюю логарифмическую величину STFT. Используйте одинарную точность при вычислении функций, чтобы лучше использовать использование памяти и ускорить обучение. Даже при том, что односторонний STFT дает к 257 интервалам частоты, рассмотрите только 256 интервалов и проигнорируйте самый высокий интервал частоты.

revAudio = single(revAudio);    
audioSTFT = stft(revAudio,'Window',params.Window,'OverlapLength',params.OverlapLength, ...
                'FFTLength',params.FFTLength,'FrequencyRange','onesided'); 
Eps = realmin('single');
reverbFeats = log(abs(audioSTFT(1:end-1,:)) + Eps);

Извлеките фазу STFT.

phaseOriginal = angle(audioSTFT(1:end-1,:));

Каждый вход будет иметь размерности 256 256 (интервалы частоты временными шагами). Разделите логарифмическую величину STFT в сегменты 256 тактов.

params.NumSegments = 256;
params.NumFeatures = 256;
totalFrames = size(reverbFeats,2);
chunks = ceil(totalFrames/params.NumSegments);
reverbSTFTSegments = mat2cell(reverbFeats,params.NumFeatures, ...
    [params.NumSegments*ones(1,chunks - 1),(totalFrames - (chunks-1)*params.NumSegments)]);
reverbSTFTSegments{chunks} = reverbFeats(:,end-params.NumSegments + 1:end);

Масштабируйте сегментированные функции к области значений [-1,1]. Сохраните минимальные и максимальные значения, используемые, чтобы масштабироваться для восстановления сигнала dereverberated.

minVals = num2cell(cellfun(@(x)min(x,[],'all'),reverbSTFTSegments));
maxVals = num2cell(cellfun(@(x)max(x,[],'all'),reverbSTFTSegments));

featNorm = cellfun(@(feat,minFeat,maxFeat)2.*(feat - minFeat)./(maxFeat - minFeat) - 1, ...
    reverbSTFTSegments,minVals,maxVals,'UniformOutput',false);

Измените функции так, чтобы фрагменты приехали четвертая размерность.

featNorm = reshape(cell2mat(featNorm),params.NumFeatures,params.NumSegments,1,chunks);

Предскажите спектры логарифмической величины отраженного речевого сигнала использование предварительно обученной сети.

predictedSTFT4D = predict(dereverbNet,featNorm);

Изменитесь к 3 размерностям и масштабируйте предсказанный STFTs к исходной области значений с помощью сохраненных минимально-максимальных пар.

predictedSTFT = squeeze(mat2cell(predictedSTFT4D,params.NumFeatures,params.NumSegments,1,ones(1,chunks)))';
featDeNorm = cellfun(@(feat,minFeat,maxFeat) (feat + 1).*(maxFeat-minFeat)./2 + minFeat, ...
    predictedSTFT,minVals,maxVals,'UniformOutput',false);

Инвертируйте масштабирование журнала.

predictedSTFT = cellfun(@exp,featDeNorm,'UniformOutput',false);

Конкатенация предсказанного 256 256 сегменты STFT величины, чтобы получить спектрограмму величины исходной длины.

predictedSTFTAll = predictedSTFT(1:chunks - 1);
predictedSTFTAll = cat(2,predictedSTFTAll{:});
predictedSTFTAll(:,totalFrames - params.NumSegments + 1:totalFrames) = predictedSTFT{chunks};

Прежде, чем взять обратный STFT, добавьте нули к предсказанному спектру логарифмической величины и фазе вместо самого высокого интервала частоты, который был исключен, когда вход подготовки показывает.

nCount = size(predictedSTFTAll,3);
predictedSTFTAll = cat(1,predictedSTFTAll,zeros(1,totalFrames,nCount));
phase = cat(1,phaseOriginal,zeros(1,totalFrames,nCount));

Используйте обратную функцию STFT, чтобы восстановить dereverberated речевой сигнал временного интервала использование предсказанной логарифмической величины STFT и фаза отражающего речевого сигнала.

oneSidedSTFT = predictedSTFTAll.*exp(1j*phase);
dereverbedAudio = istft(oneSidedSTFT, ...
    'Window',params.Window,'OverlapLength',params.OverlapLength, ...
    'FFTLength',params.FFTLength,'ConjugateSymmetric',true, ...
    'FrequencyRange','onesided');

dereverbedAudio = dereverbedAudio./max(abs([dereverbedAudio;revAudio]));
dereverbedAudio = [dereverbedAudio;zeros(length(revAudio) - numel(dereverbedAudio), 1)];

Слушайте dereverberated звуковой сигнал.

sound(dereverbedAudio,fs)

Постройте чистые, отражающие, и dereverberated речевые сигналы.

t = (1/fs)*(0:numel(cleanAudio)-1);

figure

subplot(3,1,1)
plot(t,cleanAudio)
xlabel('Time (s)')
grid on
subtitle('Clean Speech Signal')

subplot(3,1,2)
plot(t,revAudio)
xlabel('Time (s)')
grid on
subtitle('Revereberated Speech Signal')

subplot(3,1,3)
plot(t,dereverbedAudio)
xlabel('Time (s)')
grid on
subtitle('Derevereberated Speech Signal')

Визуализируйте спектрограммы чистых, отражающих, и dereverberated речевых сигналов.

figure('Position',[100,100,800,800])

subplot(3,1,1)
spectrogram(cleanAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');
subtitle('Clean')

subplot(3,1,2)
spectrogram(revAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');  
subtitle('Reverberated')
 
subplot(3,1,3)
spectrogram(dereverbedAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');  
subtitle('Predicted (Dereverberated)')

Загрузите набор данных

Этот пример использует Отражающую Речевую Базу данных [2] и соответствующую Чистую Речевую Базу данных [3], чтобы обучить сеть.

Загрузите чистый речевой набор данных.

url1 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_trainset_28spk_wav.zip';
url2 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_testset_wav.zip';
downloadFolder = tempdir;
cleanDataFolder = fullfile(downloadFolder,'DS_10283_2791');

if ~exist(cleanDataFolder,'dir')
    disp('Downloading data set (6 GB) ...')
    unzip(url1,cleanDataFolder)
    unzip(url2,cleanDataFolder)
end

Загрузите отраженный речевой набор данных.

url3 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_trainset_28spk_wav.zip';
url4 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_testset_wav.zip';
downloadFolder = tempdir;
reverbDataFolder = fullfile(downloadFolder,'DS_10283_2031');

if ~exist(reverbDataFolder,'dir')
    disp('Downloading data set (6 GB) ...')
    unzip(url3,reverbDataFolder)
    unzip(url4,reverbDataFolder)
end

Предварительная обработка данных и извлечение признаков

Если данные загружаются, предварительно обработайте загруженные данные и извлеките функции перед обучением модель DNN:

  1. Искусственно сгенерируйте отражающие данные с помощью reverberator объект

  2. Разделите каждый речевой сигнал в маленькие сегменты 2,072 длительности с

  3. Отбросьте сегменты, которые содержат значительные тихие области

  4. Извлеките логарифмическую величину STFTs как предиктор и предназначайтесь для функций

  5. Масштабируйте и измените функции

Во-первых, создайте два audioDatastore объекты, которые указывают на чистые и отражающие речевые наборы данных.

adsCleanTrain = audioDatastore(fullfile(cleanDataFolder,'clean_trainset_28spk_wav'),'IncludeSubfolders',true);
adsReverbTrain = audioDatastore(fullfile(reverbDataFolder,'reverb_trainset_28spk_wav'),'IncludeSubfolders',true);

Синтетическая отражающая речевая генерация данных

Количество реверберации в исходных данных относительно мало. Вы увеличите отражающие речевые данные со значительными эффектами реверберации с помощью reverberator объект.

Создайте audioDatastore это указывает на чистый речевой набор данных, выделенный для синтетической отражающей генерации данных.

adsSyntheticCleanTrain = subset(adsCleanTrain,10e3+1:length(adsCleanTrain.Files));
adsCleanTrain = subset(adsCleanTrain,1:10e3);
adsReverbTrain = subset(adsReverbTrain,1:10e3);

Передискретизируйте от 48 кГц до 16 кГц.

adsSyntheticCleanTrain = transform(adsSyntheticCleanTrain,@(x)resample(x,16e3,48e3));
adsCleanTrain = transform(adsCleanTrain,@(x)resample(x,16e3,48e3));
adsReverbTrain = transform(adsReverbTrain,@(x)resample(x,16e3,48e3));

Объедините два аудио хранилища данных, обеспечив соответствие между чистыми и отражающими речевыми выборками.

adsCombinedTrain = combine(adsCleanTrain,adsReverbTrain);

Функция applyReverb создает reverberator объект, обновляется пред задержка, фактор затухания и влажно-сухие параметры соединения, как задано, и затем применяет реверберацию. Используйте audioDataAugmenter создать искусственно сгенерированные отражающие данные.

augmenter = audioDataAugmenter('AugmentationMode','independent','NumAugmentations', 1,'ApplyAddNoise',0, ...
    'ApplyTimeStretch',0,'ApplyPitchShift',0,'ApplyVolumeControl',0,'ApplyTimeShift',0);
algorithmHandle = @(y,preDelay,decayFactor,wetDryMix,samplingRate) ...
    applyReverb(y,preDelay,decayFactor,wetDryMix,samplingRate);

addAugmentationMethod(augmenter,'Reverb',algorithmHandle, ...
    'AugmentationParameter',{'PreDelay','DecayFactor','WetDryMix','SamplingRate'}, ...
    'ParameterRange',{[0.15,0.25],[0.2,0.5],[0.3,0.45],[16000,16000]})

augmenter.ReverbProbability = 1;
disp(augmenter)
  audioDataAugmenter with properties:

               AugmentationMode: 'independent'
    AugmentationParameterSource: 'random'
               NumAugmentations: 1
               ApplyTimeStretch: 0
                ApplyPitchShift: 0
             ApplyVolumeControl: 0
                  ApplyAddNoise: 0
                 ApplyTimeShift: 0
                    ApplyReverb: 1
                  PreDelayRange: [0.1500 0.2500]
               DecayFactorRange: [0.2000 0.5000]
                 WetDryMixRange: [0.3000 0.4500]
              SamplingRateRange: [16000 16000]

Создайте новый audioDatastore соответствие искусственно сгенерировало отражающие данные путем вызова transform применять увеличение данных.

adsSyntheticReverbTrain = transform(adsSyntheticCleanTrain,@(y)deal(augment(augmenter,y,16e3).Audio{1}));

Объедините два аудио хранилища данных.

adsSyntheticCombinedTrain = combine(adsSyntheticCleanTrain,adsSyntheticReverbTrain);

Затем на основе размерностей входных функций к сети, сегментируйте аудио на фрагменты 2,072 длительности с с перекрытием 50%.

Наличие слишком многих тихих сегментов может оказать негативное влияние на обучение модели DNN. Удалите сегменты, которые в основном тихи (больше чем 50% длительности) и исключают тех из обучения модели. Не полностью удаляйте тишину, потому что модель не будет устойчива в тихие области, и небольшие эффекты реверберации могли быть идентифицированы как тишина. detectSpeech может идентифицировать начальные и конечные точки тихих областей. После этих двух шагов процесс извлечения признаков может быть выполнен, как объяснено в первом разделе. helperFeatureExtract реализует эти шаги.

Задайте параметры извлечения признаков. Установкой reduceDataSet к истине вы выбираете небольшое подмножество наборов данных, чтобы выполнить последующие шаги.

reduceDataSet = true;
params.fs = 16000;
params.WindowdowLength = 512;
params.Window = hamming (params.WindowdowLength,"periodic");
params.OverlapLength = раунд (0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
samplesPerMs = params.fs/1000;
params.samplesPerImage = (24+256*8) *samplesPerMs;
params.shiftImage = params.samplesPerImage/2;
params.NumSegments = 256;
params.NumFeatures = 256
params = struct with fields:
    WindowdowLength: 512
             Window: [512×1 double]
      OverlapLength: 384
          FFTLength: 512
        NumSegments: 256
        NumFeatures: 256
                 fs: 16000
    samplesPerImage: 33152
         shiftImage: 16576

Чтобы ускорить обработку, распределите задачу предварительной обработки и извлечения признаков на нескольких рабочих, использующих parfor.

Определите количество разделов для набора данных. Если у вас нет Parallel Computing Toolbox™, используйте один раздел.

if ~isempty(ver('parallel'))
    pool = gcp;
    numPar = numpartitions(adsCombinedTrain,pool);
else
    numPar = 1;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

Для каждого раздела читайте из datastore, предварительно обработайте звуковой сигнал, и затем извлеките функции.

if reduceDataSet
    adsCombinedTrain = shuffle(adsCombinedTrain); %#ok
    adsCombinedTrain = subset(adsCombinedTrain,1:200);
    
    adsSyntheticCombinedTrain = shuffle(adsSyntheticCombinedTrain);
    adsSyntheticCombinedTrain = subset(adsSyntheticCombinedTrain,1:200);
end

allCleanFeatures = cell(1,numPar);
allReverbFeatures = cell(1,numPar);

parfor iPartition = 1:numPar
    combinedPartition = partition(adsCombinedTrain,numPar,iPartition);
    combinedSyntheticPartition = partition(adsSyntheticCombinedTrain,numPar,iPartition);
        
    cPartitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    cSyntheticPartitionSize = numel(combinedSyntheticPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    partitionSize = cPartitionSize + cSyntheticPartitionSize;
    
    cleanFeaturesPartition = cell(1,partitionSize);    
    reverbFeaturesPartition = cell(1,partitionSize);  
    
    for idx = 1:partitionSize
        if idx <= cPartitionSize
            audios = read(combinedPartition);
        else
            audios = read(combinedSyntheticPartition);
        end
        cleanAudio = single(audios(:,1));
        reverbAudio = single(audios(:,2));
        [featuresClean,featuresReverb] = helperFeatureExtract(cleanAudio,reverbAudio,false,params);
        cleanFeaturesPartition{idx} = featuresClean;
        reverbFeaturesPartition{idx} = featuresReverb;
    end
    allCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:});
    allReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:});
end

allCleanFeatures = cat(2,allCleanFeatures{:});
allReverbFeatures = cat(2,allReverbFeatures{:});

Нормируйте извлеченные функции к области значений [-1,1] и затем изменитесь, как объяснено в первом разделе, с помощью функции featureNormalizeAndReshape.

trainClean = featureNormalizeAndReshape(allCleanFeatures);
trainReverb = featureNormalizeAndReshape(allReverbFeatures);

Теперь, когда вы извлекли логарифмическую величину функции STFT из обучающих наборов данных, выполните ту же процедуру, чтобы извлечь функции из наборов данных валидации. В целях реконструкции сохраните фазу отражающих речевых выборок набора данных валидации. Кроме того, сохраните аудиоданные и для чистых и для отражающих речевых выборок в наборе валидации, чтобы использовать в процессе оценки (следующий раздел).

adsCleanVal = audioDatastore(fullfile(cleanDataFolder,'clean_testset_wav'),'IncludeSubfolders',true);
adsReverbVal = audioDatastore(fullfile(reverbDataFolder,'reverb_testset_wav'),'IncludeSubfolders',true);

Передискретизируйте от 48 кГц до 16 кГц.

adsCleanVal = transform(adsCleanVal,@(x)resample(x,16e3,48e3));
adsReverbVal = transform(adsReverbVal,@(x)resample(x,16e3,48e3));

adsCombinedVal = combine(adsCleanVal,adsReverbVal); 
if reduceDataSet
    adsCombinedVal = shuffle(adsCombinedVal);%#ok
    adsCombinedVal = subset(adsCombinedVal,1:50);
end

allValCleanFeatures = cell(1,numPar);
allValReverbFeatures = cell(1,numPar);
allValReverbPhase = cell(1,numPar);
allValCleanAudios = cell(1,numPar);
allValReverbAudios = cell(1,numPar);

parfor iPartition = 1:numPar
    combinedPartition = partition(adsCombinedVal,numPar,iPartition);
    
    partitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    
    cleanFeaturesPartition = cell(1,partitionSize);    
    reverbFeaturesPartition = cell(1,partitionSize);  
    reverbPhasePartition = cell(1,partitionSize); 
    cleanAudiosPartition = cell(1,partitionSize); 
    reverbAudiosPartition = cell(1,partitionSize);

    for idx = 1:partitionSize
        audios = read(combinedPartition);
        
        cleanAudio = single(audios(:,1));
        reverbAudio = single(audios(:,2));
        
        [a,b,c,d,e] = helperFeatureExtract(cleanAudio,reverbAudio,true,params);
        
        cleanFeaturesPartition{idx} = a;
        reverbFeaturesPartition{idx} = b;  
        reverbPhasePartition{idx} = c;
        cleanAudiosPartition{idx} = d;
        reverbAudiosPartition{idx} = e;
    end
    allValCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:});
    allValReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:});
    allValReverbPhase{iPartition} = cat(2,reverbPhasePartition{:});
    allValCleanAudios{iPartition} = cat(2,cleanAudiosPartition{:});
    allValReverbAudios{iPartition} = cat(2,reverbAudiosPartition{:});
end

allValCleanFeatures = cat(2,allValCleanFeatures{:});
allValReverbFeatures = cat(2,allValReverbFeatures{:});
allValReverbPhase = cat(2,allValReverbPhase{:});
allValCleanAudios = cat(2,allValCleanAudios{:});
allValReverbAudios = cat(2,allValReverbAudios{:});

valClean = featureNormalizeAndReshape(allValCleanFeatures);

Сохраните минимальные и максимальные значения каждой функции отражающего набора валидации. Вы будете использовать эти значения в процессе реконструкции.

[valReverb,valMinMaxPairs] = featureNormalizeAndReshape(allValReverbFeatures);

Задайте архитектуру нейронной сети

Полностью сверточная сетевая архитектура по имени U-Net была адаптирована к этой речи dereverberation задача, как предложено в [1]. "U-Net" является сетью декодера энкодера со связями пропуска. В U-сетевой-модели каждый слой прореживает свой вход (шаг 2), пока слой узкого места не достигнут (путь к кодированию). В последующих слоях вход сверхдискретизирован каждым слоем, пока выходной параметр не возвращен к исходной форме (декодирующий путь). Чтобы минимизировать потерю низкоуровневой информации во время процесса субдискретизации, связи установлены между зеркальными слоями путем прямой конкатенации выходных параметров соответствующих слоев (связи пропуска).

Задайте сетевую архитектуру и возвратите график слоев со связями.

params.WindowdowLength = 512;
params.FFTLength = params.WindowdowLength;
params.NumFeatures = params.FFTLength/2;
params.NumSegments = 256;
    
filterH = 6;
filterW = 6;
numChannels = 1;
nFilters = [64,128,256,512,512,512,512,512];

inputLayer = imageInputLayer([params.NumFeatures,params.NumSegments,numChannels], ...
    'Normalization','none','Name','input');
layers = inputLayer;

% U-Net squeezing path
layers = [layers;
    convolution2dLayer([filterH,filterW],nFilters(1),'Stride',2,'Padding','same','Name',"conv"+string(1));
    leakyReluLayer(0.2,'Name',"leaky-relu"+string(1))];
        
for i = 2:8
    layers =  [layers;
        convolution2dLayer([filterH,filterW],nFilters(i),'Stride',2,'Padding','same','Name',"conv"+string(i));
        batchNormalizationLayer('Name',"batchnorm"+string(i))];%#ok
    if i ~= 8
        layers = [layers;leakyReluLayer(0.2,'Name',"leaky-relu"+string(i))];%#ok
    else
        layers = [layers;reluLayer('Name',"relu"+string(i))];%#ok
    end
end

% U-Net expanding path
for i = 7:-1:0
    nChannels = numChannels;
    if i > 0
        nChannels = nFilters(i);
    end
    layers = [layers;
        transposedConv2dLayer([filterH,filterW],nChannels,'Stride',2,'Cropping','same','Name',"deconv"+string(i))];%#ok
    if i > 0
        layers = [layers; batchNormalizationLayer('Name',"de-batchnorm" +string(i))];%#ok
    end
    if i > 4
        layers = [layers;dropoutLayer(0.5,'Name',"de-dropout"+string(i))];%#ok
    end
    if i > 0
        layers = [layers;
            reluLayer('Name',"de-relu"+string(i));
            concatenationLayer(3,2,'Name',"concat"+string(i))];%#ok
    else
        layers = [layers;tanhLayer('Name',"de-tanh"+string(i))];%#ok
    end
end

layers = [layers;regressionLayer('Name','output')];

unetLayerGraph = layerGraph(layers); 

% Define skip-connections
for i = 1:7
    unetLayerGraph = connectLayers(unetLayerGraph,'leaky-relu'+string(i),'concat'+string(i)+'/in2');
end

Используйте analyzeNetwork просмотреть архитектуру модели. Это - хороший способ визуализировать связи между слоями.

analyzeNetwork(unetLayerGraph); 

Обучите сеть

Вы будете использовать среднеквадратическую ошибку (MSE) между спектрами логарифмической величины dereverberated речевой выборки (выход модели) и соответствующей чистой речевой выборки (цель) как функция потерь. Используйте adam оптимизатор и мини-пакетный размер 128 для обучения. Позвольте модели обучаться максимум для 50 эпох. Если потеря валидации не улучшается в течение 5 эпох подряд, отключает учебный процесс. Уменьшайте скорость обучения на коэффициент 10 каждых 15 эпох.

Задайте опции обучения как ниже. Измените среду выполнения и выполнить ли фоновую диспетчеризацию в зависимости от вашей аппаратной доступности и есть ли у вас доступ к Parallel Computing Toolbox™.

initialLearnRate = 8e-4;
miniBatchSize = 64;

options = trainingOptions("adam", ...
        "MaxEpochs", 50, ...
        "InitialLearnRate",initialLearnRate, ...
        "MiniBatchSize",miniBatchSize, ...
        "Shuffle","every-epoch", ...
        "Plots","training-progress", ...
        "Verbose",false, ...
        "ValidationFrequency",max(1,floor(size(trainReverb,4)/miniBatchSize)), ...
        "ValidationPatience",5, ...
        "LearnRateSchedule","piecewise", ...
        "LearnRateDropFactor",0.1, ... 
        "LearnRateDropPeriod",15, ...
        "ExecutionEnvironment","gpu", ...
        "DispatchInBackground",true, ...
        "ValidationData",{valReverb,valClean});

Обучите сеть.

dereverbNet = trainNetwork(trainReverb,trainClean,unetLayerGraph,options);

Оцените производительность сети

Предсказание и реконструкция

Предскажите спектры логарифмической величины набора валидации.

predictedSTFT4D = predict(dereverbNet,valReverb);

Используйте функцию helperReconstructPredictedAudios, чтобы восстановить предсказанную речь. Эта функция выполняет действия, обрисованные в общих чертах в первом разделе.

params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
params.fs = 16000;

dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,valMinMaxPairs,allValReverbPhase,allValReverbAudios,params);

Визуализируйте логарифмическую величину STFTs чистых, отражающих, и соответствующих dereverberated речевых сигналов.

figure('Position',[100,100,1024,1200])

subplot(3,1,1)
imagesc(squeeze(allValCleanFeatures{1}))    
set(gca,'Ydir','normal')
subtitle('Clean')
xlabel('Time')
ylabel('Frequency')
colorbar

subplot(3,1,2)
imagesc(squeeze(allValReverbFeatures{1}))
set(gca,'Ydir','normal')
subtitle('Reverberated')
xlabel('Time')
ylabel('Frequency')
colorbar

subplot(3,1,3)
imagesc(squeeze(predictedSTFT4D(:,:,:,1)))
set(gca,'Ydir','normal')
subtitle('Predicted (Dereverberated)')
xlabel('Time')
ylabel('Frequency')
caxis([-1,1])
colorbar

Метрики оценки

Вы будете использовать подмножество объективных мер, используемых в [1], чтобы оценить эффективность сети. Эти метрики вычисляются на сигналах временной области.

  • Расстояние кепстра (CD) - Обеспечивает оценку журнала спектральное расстояние между двумя спектрами (предсказанный и чистый). Меньшие значения указывают на лучшее качество.

  • Логарифмическое отношение правдоподобия (LLR) - основанное на линейном предсказательном кодировании (LPC) объективное измерение. Меньшие значения указывают на лучшее качество.

Вычислите эти измерения для отражающей речи и dereverberated речевых сигналов.

[summaryMeasuresReconstructed,allMeasuresReconstructed] = calculateObjectiveMeasures(dereverbedAudioAll,allValCleanAudios,params.fs);
[summaryMeasuresReverb,allMeasuresReverb] = calculateObjectiveMeasures(allValReverbAudios,allValCleanAudios,params.fs);
disp(summaryMeasuresReconstructed)
       avgCdMean: 3.8386
     avgCdMedian: 3.3671
      avgLlrMean: 0.9152
    avgLlrMedian: 0.8096
disp(summaryMeasuresReverb)
       avgCdMean: 4.2591
     avgCdMedian: 3.6336
      avgLlrMean: 0.9726
    avgLlrMedian: 0.8714

Гистограммы иллюстрируют распределение среднего CD, означают SRMR и означают LLR отражающих и dereverberated данных.

figure('position',[50,50,1100,1300])

subplot(2,1,1)
histogram(allMeasuresReverb.cdMean,10)
hold on
histogram(allMeasuresReconstructed.cdMean, 10)
subtitle('Mean Cepstral Distance Distribution')
ylabel('count')
xlabel('mean CD')
legend('Reverberant (Original)','Dereverberated (Predicted)')

subplot(2,1,2)
histogram(allMeasuresReverb.llrMean,10)
hold on
histogram(allMeasuresReconstructed.llrMean,10)
subtitle('Mean Log Likelihood Ratio Distribution')
ylabel('Count')
xlabel('Mean LLR')
legend('Reverberant (Original)','Dereverberated (Predicted)')

Ссылки

[1] Эрнст, O., Chazan, S.E., Gannot, S., & Голдбергер, J. (2018). Речь Dereverberation Используя полностью Сверточные сети. 2 018 26-х европейских конференций по обработке сигналов (EUSIPCO), 390-394.

[2] https://datashare.is.ed.ac.uk/handle/10283/2031

[3] https://datashare.is.ed.ac.uk/handle/10283/2791

[4] https://github.com/MuSAELab/SRMRToolbox

Вспомогательные Функции

Примените реверберацию

function yOut = applyReverb(y,preDelay,decayFactor,wetDryMix,fs)
% This function generates reverberant speech data using the reverberator
% object
%
% inputs: 
% y                                - clean speech sample
% preDelay, decayFactor, wetDryMix - reverberation parameters
% fs                               - sampling rate of y
%
% outputs:
% yOut - corresponding reveberated speech sample

    revObj = reverberator('SampleRate',fs, ...
        'DecayFactor',decayFactor, ...
        'WetDryMix',wetDryMix, ...
        'PreDelay',preDelay);
    yOut = revObj(y);
    yOut = yOut(1:length(y),1);
end

Извлеките пакет функций

function [featuresClean,featuresReverb,phaseReverb,cleanAudios,reverbAudios] ...
    = helperFeatureExtract(cleanAudio,reverbAudio,isVal,params)
% This function performs the preprocessing and features extraction task on 
% the audio files used for dereverberation model training and testing.
%
% inputs:
% cleanAudio  - the clean audio file (reference)
% reverbAudio - corresponding reverberant speech file
% isVal       - Boolean flag indicating if it is the validation set
% params      - a structure containing feature extraction parameters
%
% outputs:
% featuresClean  - log-magnitude STFT features of clean audio
% featuresReverb - log-magnitude STFT features of reverberant audio
% phaseReverb    - phase of STFT of reverberant audio
% cleanAudios    - 2.072s-segments of clean audio file used for feature extraction
% reverbAudios   - 2.072s-segments of corresponding reverberant audio

    assert(length(cleanAudio) == length(reverbAudio));
    nSegments = floor((length(reverbAudio) - (params.samplesPerImage - params.shiftImage))/params.shiftImage);

    featuresClean = {};
    featuresReverb = {};
    phaseReverb = {};
    cleanAudios = {};
    reverbAudios = {};
    nGood = 0;
    nonSilentRegions = detectSpeech(reverbAudio, params.fs);
    nonSilentRegionIdx = 1;
    totalRegions = size(nonSilentRegions, 1);
    for cid = 1:nSegments
        start = (cid - 1)*params.shiftImage + 1;
        en = start + params.samplesPerImage - 1;

        nonSilentSamples = 0;
        while nonSilentRegionIdx < totalRegions && nonSilentRegions(nonSilentRegionIdx, 2) < start
            nonSilentRegionIdx = nonSilentRegionIdx + 1;
        end
        
        nonSilentStart = nonSilentRegionIdx;
        while nonSilentStart <= totalRegions && nonSilentRegions(nonSilentStart, 1) <= en
            nonSilentDuration = min(en, nonSilentRegions(nonSilentStart,2)) - max(start,nonSilentRegions(nonSilentStart,1)) + 1;
            nonSilentSamples = nonSilentSamples + nonSilentDuration; 
            nonSilentStart = nonSilentStart + 1;
        end
        
        nonSilentPerc = nonSilentSamples * 100 / (en - start + 1);
        silent = nonSilentPerc < 50;
        
        reverbAudioSegment = reverbAudio(start:en);
        if ~silent
            nGood = nGood + 1;
            cleanAudioSegment = cleanAudio(start:en);
            assert(length(cleanAudioSegment)==length(reverbAudioSegment), 'Lengths do not match after chunking')
            
            % Clean Audio
            [featsUnit, ~] = featureExtract(cleanAudioSegment, params);
            featuresClean{nGood} = featsUnit; %#ok

            % Reverb Audio
            [featsUnit, phaseUnit] = featureExtract(reverbAudioSegment, params);
            featuresReverb{nGood} = featsUnit; %#ok
            if isVal
                phaseReverb{nGood} = phaseUnit; %#ok
                reverbAudios{nGood} = reverbAudioSegment;%#ok
                cleanAudios{nGood} = cleanAudioSegment;%#ok
            end
        end
    end
end

Выделить признаки

function [features, phase, lastFBin] = featureExtract(audio, params)
% Function to extract features for a speech file
    audio = single(audio);
    
    audioSTFT = stft(audio,'Window',params.Window,'OverlapLength',params.OverlapLength, ...
                    'FFTLength', params.FFTLength, 'FrequencyRange', 'onesided');
    
    phase = single(angle(audioSTFT(1:end-1,:)));     
    features = single(log(abs(audioSTFT(1:end-1,:)) + 10e-30)); 
    lastFBin = audioSTFT(end,:);

end

Нормируйте и измените функции

function [featNorm,minMaxPairs] = featureNormalizeAndReshape(feats)
% function to normalize features - range [-1, 1] and reshape to 4
% dimensions
%
% inputs:
% feats - 3-dimensional array of extracted features
%
% outputs:
% featNorm - normalized and reshaped features
% minMaxPairs - array of original min and max pairs used for normalization

    nSamples = length(feats);
    minMaxPairs = zeros(nSamples,2,'single');
    featNorm = zeros([size(feats{1}),nSamples],'single');
    parfor i = 1:nSamples
        feat = feats{i};
        maxFeat = max(feat,[],'all');
        minFeat = min(feat,[],'all');
        featNorm(:,:,i) = 2.*(feat - minFeat)./(maxFeat - minFeat) - 1;
        minMaxPairs(i,:) = [minFeat,maxFeat];
    end
    featNorm = reshape(featNorm,size(featNorm,1),size(featNorm,2),1,size(featNorm,3));
end

Восстановите предсказанное аудио

function dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,minMaxPairs,reverbPhase,reverbAudios,params)
% This function will reconstruct the 2.072s long audios predicted by the 
% model using the predicted log-magnitude spectrogram and the phase of the 
% reverberant audio file
%
% inputs:
% predictedSTFT4D - Predicted 4-dimensional STFT log-magnitude features
% minMaxPairs     - Original minimum/maximum value pairs used in normalization
% reverbPhase     - Array of phases of STFT of reverberant audio files
% reverbAudios    - 2.072s-segments of corresponding reverberant audios
% params          - Structure containing feature extraction parameters

    predictedSTFT = squeeze(predictedSTFT4D);
    denormalizedFeatures = zeros(size(predictedSTFT),'single');
    for i = 1:size(predictedSTFT,3)
        feat = predictedSTFT(:,:,i);
        maxFeat = minMaxPairs(i,2);
        minFeat = minMaxPairs(i,1);
        denormalizedFeatures(:,:,i) = (feat + 1).*(maxFeat-minFeat)./2 + minFeat;        
    end
    
    predictedSTFT = exp(denormalizedFeatures);
    
    nCount = size(predictedSTFT,3);
    dereverbedAudioAll = cell(1,nCount);

    nSeg = params.NumSegments;
    win = params.Window;
    ovrlp = params.OverlapLength;
    FFTLength = params.FFTLength;
    parfor ii = 1:nCount
        % Append zeros to the highest frequency bin
        stftUnit = predictedSTFT(:,:,ii);
        stftUnit = cat(1,stftUnit, zeros(1,nSeg)); 
        phase = reverbPhase{ii};
        phase = cat(1,phase,zeros(1,nSeg));

        oneSidedSTFT = stftUnit.*exp(1j*phase);
        dereverbedAudio= istft(oneSidedSTFT, ...
            'Window', win,'OverlapLength', ovrlp, ...
                                    'FFTLength',FFTLength,'ConjugateSymmetric',true,...
                                    'FrequencyRange','onesided');

        dereverbedAudioAll{ii} = dereverbedAudio./max(max(abs(dereverbedAudio)), max(abs(reverbAudios{ii})));
    end
end

Вычислите объективные меры

function [summaryMeasures,allMeasures] = calculateObjectiveMeasures(reconstructedAudios,cleanAudios,fs)
% This function computes the objective measures on time-domain signals.
%
% inputs:
% reconstructedAudios - An array of audio files to evaluate.
% cleanAudios - An array of reference audio files
% fs - Sampling rate of audio files
%
% outputs:
% summaryMeasures - Global means of CD, LLR individual mean and median values
% allMeasures - Individual mean and median values

    nAudios = length(reconstructedAudios);
    cdMean = zeros(nAudios,1);
    cdMedian = zeros(nAudios,1);
    llrMean = zeros(nAudios,1);
    llrMedian = zeros(nAudios,1);

    parfor k = 1 : nAudios
      y = reconstructedAudios{k};
      x = cleanAudios{k};

      y = y./max(abs(y));
      x = x./max(abs(x));

      [cdMean(k),cdMedian(k)] = cepstralDistance(x,y,fs);
      [llrMean(k),llrMedian(k)] = lpcLogLikelihoodRatio(y,x,fs);
    end
    
    summaryMeasures.avgCdMean = mean(cdMean);
    summaryMeasures.avgCdMedian = mean(cdMedian);
    summaryMeasures.avgLlrMean = mean(llrMean);
    summaryMeasures.avgLlrMedian = mean(llrMedian);   
    
    allMeasures.cdMean = cdMean;
    allMeasures.llrMean = llrMean;
end

Расстояние Cepstral

function [meanVal, medianVal] = cepstralDistance(x,y,fs)
    x = x / sqrt(sum(x.^2));
    y = y / sqrt(sum(y.^2));

    width = round(0.025*fs);
    shift = round(0.01*fs);

    nSamples = length(x);
    nFrames = floor((nSamples - width + shift)/shift);
    win = window(@hanning,width);

    winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1);

    xFrames = x(winIndex).*win;
    yFrames = y(winIndex).*win;

    xCeps = cepstralReal(xFrames,width);
    yCeps = cepstralReal(yFrames,width);

    dist = (xCeps - yCeps).^2;
    cepsD = 10 / log(10)*sqrt(2*sum(dist(2:end,:),1) + dist(1,:));
    cepsD = max(min(cepsD,10),0);

    meanVal = mean(cepsD);
    medianVal = median(cepsD);
end

Действительный кепстр

function realC = cepstralReal(x, width)
    width2p = 2 ^ nextpow2(width);
    powX = abs(fft(x, width2p));

    lowCutoff = max(powX(:)) * 10^-5;
    powX  = max(powX, lowCutoff);

    realC = real(ifft(log(powX)));
    order = 24;
    realC = realC(1 : order + 1, :);
    realC = realC - mean(realC, 2);
end

Отношение логарифмической правдоподобности LPC

function [meanLlr,medianLlr] = lpcLogLikelihoodRatio(x,y,fs)
    order = 12;
    width = round(0.025*fs);
    shift = round(0.01*fs);

    nSamples = length(x);
    nFrames  = floor((nSamples - width + shift)/shift);
    win = window(@hanning,width);

    winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1);

    xFrames = x(winIndex) .* win;
    yFrames = y(winIndex) .* win;

    lpcX = realLpc(xFrames, width, order);
    [lpcY,realY] = realLpc(yFrames, width, order);

    llr = zeros(nFrames, 1);
    for n = 1 : nFrames
      R = toeplitz(realY(1:order+1,n));
      num = lpcX(:,n)'*R*lpcX(:,n);
      den = lpcY(:,n)'*R*lpcY(:,n);  
      llr(n) = log(num/den);
    end

    llr = sort(llr);
    llr = llr(1:ceil(nFrames*0.95));
    llr = max(min(llr,2),0);

    meanLlr = mean(llr);
    medianLlr = median(llr);
end

Действительные линейные коэффициенты Prection

function [lpcCoeffs, realX] = realLpc(xFrames, width, order)
    width2p = 2 ^ nextpow2(width);
    X = fft(xFrames, width2p);

    Rx = ifft(abs(X).^2);
    Rx = Rx./width; 
    realX = real(Rx);

    lpcX = levinson(realX, order);
    lpcCoeffs = real(lpcX');
end