exponenta event banner

Отказ от речи с использованием сетей глубокого обучения

В этом примере показано, как обучить полностью сверточную сеть U-Net (FCN) [1] отключать речевые сигналы.

Введение

Реверберация происходит, когда речевой сигнал отражается от объектов в пространстве, вызывая нарастание множества отражений и в конечном итоге приводит к ухудшению качества речи. Дереверберация - это процесс уменьшения эффектов реверберации в речевом сигнале.

Отказ от речевого сигнала с использованием предварительно обученной сети

Перед подробным началом процесса обучения используйте предварительно обученную сеть для удаления речевого сигнала.

Загрузите предварительно обученную сеть. Эта сеть была обучена на 56-спикерских версиях учебных наборов данных. Пример состоит в обучении 28-громкоговорящей версии.

url = 'https://ssd.mathworks.com/supportfiles/audio/dereverbnet.zip';
downloadFolder = tempdir;
networkDataFolder = fullfile(downloadFolder,'derevernet');

if ~exist(networkDataFolder,'dir')
    disp('Downloading pretrained network ...')
    unzip(url,downloadFolder)
end
load(fullfile(networkDataFolder,'dereverbNet.mat'))

Прослушивание чистого речевого сигнала, дискретизированного на частоте 16 кГц.

[cleanAudio,fs] = audioread('clean_speech_signal.wav');

sound(cleanAudio,fs)

Акустический путь может быть смоделирован с использованием импульсной характеристики помещения. Можно моделировать реверберацию, свернув анехойный сигнал с импульсной характеристикой помещения.

Загрузите и постройте график импульсной характеристики помещения.

[rirAudio,fsR] = audioread('room_impulse_response.wav');

tAxis = (1/fsR)*(0:numel(rirAudio)-1);

figure
plot(tAxis,rirAudio)
xlabel('Time (s)')
ylabel('Amplitude')
grid on

Сверните чистую речь с импульсной характеристикой комнаты для получения ревербрированной речи. Выровняйте длины и амплитуды реверберных и чистых речевых сигналов.

revAudio = conv(cleanAudio,rirAudio);

revAudio = revAudio(1:numel(cleanAudio));
revAudio = revAudio.*(max(abs(cleanAudio))/max(abs(revAudio)));

Прослушать реверберный речевой сигнал.

sound(revAudio,fs)

Вход в предварительно обученную сеть представляет собой логарифмическое кратковременное преобразование Фурье (STFT) реверсивного звука. Сеть прогнозирует логарифмическую величину STFT деинвертированного входного сигнала. Для оценки исходного звукового сигнала во временной области выполняется обратный STFT и принимается фаза реверберативного звука.

Используйте следующие параметры для вычисления STFT.

params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;

Использовать stft для вычисления односторонней логарифмической величины STFT. Используйте одну точность при вычислении функций, чтобы лучше использовать память и ускорить обучение. Несмотря на то, что односторонний STFT дает 257 частотные ячейки, следует учитывать только 256 ячейки и игнорировать самую высокую частотную ячейку.

revAudio = single(revAudio);    
audioSTFT = stft(revAudio,'Window',params.Window,'OverlapLength',params.OverlapLength, ...
                'FFTLength',params.FFTLength,'FrequencyRange','onesided'); 
Eps = realmin('single');
reverbFeats = log(abs(audioSTFT(1:end-1,:)) + Eps);

Извлеките фазу STFT.

phaseOriginal = angle(audioSTFT(1:end-1,:));

Каждый вход будет иметь размеры 256 на 256 (частотные ячейки по временным шагам). Разбейте логарифмическую величину STFT на сегменты 256 временных шагов.

params.NumSegments = 256;
params.NumFeatures = 256;
totalFrames = size(reverbFeats,2);
chunks = ceil(totalFrames/params.NumSegments);
reverbSTFTSegments = mat2cell(reverbFeats,params.NumFeatures, ...
    [params.NumSegments*ones(1,chunks - 1),(totalFrames - (chunks-1)*params.NumSegments)]);
reverbSTFTSegments{chunks} = reverbFeats(:,end-params.NumSegments + 1:end);

Масштабируйте сегментированные элементы до диапазона [-1,1]. Сохранение минимального и максимального значений, используемых для масштабирования для восстановления сигнала с пониженным уровнем мощности.

minVals = num2cell(cellfun(@(x)min(x,[],'all'),reverbSTFTSegments));
maxVals = num2cell(cellfun(@(x)max(x,[],'all'),reverbSTFTSegments));

featNorm = cellfun(@(feat,minFeat,maxFeat)2.*(feat - minFeat)./(maxFeat - minFeat) - 1, ...
    reverbSTFTSegments,minVals,maxVals,'UniformOutput',false);

Измените форму элементов так, чтобы порции располагались вдоль четвертого размера.

featNorm = reshape(cell2mat(featNorm),params.NumFeatures,params.NumSegments,1,chunks);

Предсказать спектры логарифмической величины реверберного речевого сигнала с использованием предварительно обученной сети.

predictedSTFT4D = predict(dereverbNet,featNorm);

Измените форму на 3 размера и масштабируйте прогнозируемые STFT до исходного диапазона с помощью сохраненных пар «минимум-максимум».

predictedSTFT = squeeze(mat2cell(predictedSTFT4D,params.NumFeatures,params.NumSegments,1,ones(1,chunks)))';
featDeNorm = cellfun(@(feat,minFeat,maxFeat) (feat + 1).*(maxFeat-minFeat)./2 + minFeat, ...
    predictedSTFT,minVals,maxVals,'UniformOutput',false);

Сторнируйте масштабирование журнала.

predictedSTFT = cellfun(@exp,featDeNorm,'UniformOutput',false);

Соединяют предсказанные STFT-сегменты 256 на 256 величины для получения спектрограммы величины исходной длины.

predictedSTFTAll = predictedSTFT(1:chunks - 1);
predictedSTFTAll = cat(2,predictedSTFTAll{:});
predictedSTFTAll(:,totalFrames - params.NumSegments + 1:totalFrames) = predictedSTFT{chunks};

Перед принятием обратного STFT добавьте нули к прогнозируемому спектру логарифмических величин и фазе вместо бина с самой высокой частотой, который был исключен при подготовке входных признаков.

nCount = size(predictedSTFTAll,3);
predictedSTFTAll = cat(1,predictedSTFTAll,zeros(1,totalFrames,nCount));
phase = cat(1,phaseOriginal,zeros(1,totalFrames,nCount));

Используйте обратную функцию STFT для восстановления уменьшенного речевого сигнала во временной области с использованием предсказанной логарифмической величины STFT и фазы реверберного речевого сигнала.

oneSidedSTFT = predictedSTFTAll.*exp(1j*phase);
dereverbedAudio = istft(oneSidedSTFT, ...
    'Window',params.Window,'OverlapLength',params.OverlapLength, ...
    'FFTLength',params.FFTLength,'ConjugateSymmetric',true, ...
    'FrequencyRange','onesided');

dereverbedAudio = dereverbedAudio./max(abs([dereverbedAudio;revAudio]));
dereverbedAudio = [dereverbedAudio;zeros(length(revAudio) - numel(dereverbedAudio), 1)];

Прослушать аудиосигнал с пониженной скоростью.

sound(dereverbedAudio,fs)

Постройте график чистых, реверберативных и очищенных речевых сигналов.

t = (1/fs)*(0:numel(cleanAudio)-1);

figure

subplot(3,1,1)
plot(t,cleanAudio)
xlabel('Time (s)')
grid on
subtitle('Clean Speech Signal')

subplot(3,1,2)
plot(t,revAudio)
xlabel('Time (s)')
grid on
subtitle('Revereberated Speech Signal')

subplot(3,1,3)
plot(t,dereverbedAudio)
xlabel('Time (s)')
grid on
subtitle('Derevereberated Speech Signal')

Визуализируйте спектрограммы чистых, реверберативных и подавленных речевых сигналов.

figure('Position',[100,100,800,800])

subplot(3,1,1)
spectrogram(cleanAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');
subtitle('Clean')

subplot(3,1,2)
spectrogram(revAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');  
subtitle('Reverberated')
 
subplot(3,1,3)
spectrogram(dereverbedAudio,params.Window,params.OverlapLength,params.FFTLength,fs,'yaxis');  
subtitle('Predicted (Dereverberated)')

Загрузить набор данных

В этом примере для обучения сети используется база данных реверберной речи [2] и соответствующая база данных чистой речи [3].

Загрузите набор чистых речевых данных.

url1 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_trainset_28spk_wav.zip';
url2 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_testset_wav.zip';
downloadFolder = tempdir;
cleanDataFolder = fullfile(downloadFolder,'DS_10283_2791');

if ~exist(cleanDataFolder,'dir')
    disp('Downloading data set (6 GB) ...')
    unzip(url1,cleanDataFolder)
    unzip(url2,cleanDataFolder)
end

Загрузите реверсированный набор речевых данных.

url3 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_trainset_28spk_wav.zip';
url4 = 'https://datashare.is.ed.ac.uk/bitstream/handle/10283/2031/reverb_testset_wav.zip';
downloadFolder = tempdir;
reverbDataFolder = fullfile(downloadFolder,'DS_10283_2031');

if ~exist(reverbDataFolder,'dir')
    disp('Downloading data set (6 GB) ...')
    unzip(url3,reverbDataFolder)
    unzip(url4,reverbDataFolder)
end

Предварительная обработка данных и извлечение признаков

После загрузки данных выполните предварительную обработку загруженных данных и извлеките функции перед обучением модели DNN:

  1. Синтетически генерировать реверберационные данные с помощью reverberator объект

  2. Разбить каждый речевой сигнал на небольшие сегменты длительностью 2,072с

  3. Отбрасывать сегменты, содержащие значительные тихие области

  4. Извлечение STFT логарифмической величины в качестве предиктора и целевых функций

  5. Масштабирование и изменение формы элементов

Сначала создайте два audioDatastore объекты, указывающие на чистые и реверберные наборы речевых данных.

adsCleanTrain = audioDatastore(fullfile(cleanDataFolder,'clean_trainset_28spk_wav'),'IncludeSubfolders',true);
adsReverbTrain = audioDatastore(fullfile(reverbDataFolder,'reverb_trainset_28spk_wav'),'IncludeSubfolders',true);

Генерация синтетических реверберативных речевых данных

Величина реверберации в исходных данных относительно мала. Вы будете увеличивать реверберационные речевые данные со значительными эффектами реверберации с помощью reverberator объект.

Создание audioDatastore это указывает на набор данных чистой речи, выделенный для генерации синтетических ревербераторных данных.

adsSyntheticCleanTrain = subset(adsCleanTrain,10e3+1:length(adsCleanTrain.Files));
adsCleanTrain = subset(adsCleanTrain,1:10e3);
adsReverbTrain = subset(adsReverbTrain,1:10e3);

Повторная выборка от 48 кГц до 16 кГц.

adsSyntheticCleanTrain = transform(adsSyntheticCleanTrain,@(x)resample(x,16e3,48e3));
adsCleanTrain = transform(adsCleanTrain,@(x)resample(x,16e3,48e3));
adsReverbTrain = transform(adsReverbTrain,@(x)resample(x,16e3,48e3));

Объедините два хранилища аудиоданных, сохранив соответствие между чистыми и реверберными образцами речи.

adsCombinedTrain = combine(adsCleanTrain,adsReverbTrain);

Функция applyReverb создает reverberator объект обновляет параметры предварительной задержки, коэффициента затухания и влажно-сухой смеси, как указано, а затем применяет реверберацию. Использовать audioDataAugmenter для создания синтетически сгенерированных реверберных данных.

augmenter = audioDataAugmenter('AugmentationMode','independent','NumAugmentations', 1,'ApplyAddNoise',0, ...
    'ApplyTimeStretch',0,'ApplyPitchShift',0,'ApplyVolumeControl',0,'ApplyTimeShift',0);
algorithmHandle = @(y,preDelay,decayFactor,wetDryMix,samplingRate) ...
    applyReverb(y,preDelay,decayFactor,wetDryMix,samplingRate);

addAugmentationMethod(augmenter,'Reverb',algorithmHandle, ...
    'AugmentationParameter',{'PreDelay','DecayFactor','WetDryMix','SamplingRate'}, ...
    'ParameterRange',{[0.15,0.25],[0.2,0.5],[0.3,0.45],[16000,16000]})

augmenter.ReverbProbability = 1;
disp(augmenter)
  audioDataAugmenter with properties:

               AugmentationMode: 'independent'
    AugmentationParameterSource: 'random'
               NumAugmentations: 1
               ApplyTimeStretch: 0
                ApplyPitchShift: 0
             ApplyVolumeControl: 0
                  ApplyAddNoise: 0
                 ApplyTimeShift: 0
                    ApplyReverb: 1
                  PreDelayRange: [0.1500 0.2500]
               DecayFactorRange: [0.2000 0.5000]
                 WetDryMixRange: [0.3000 0.4500]
              SamplingRateRange: [16000 16000]

Создание нового audioDatastore соответствующие синтетически сгенерированным реверберативным данным путем вызова transform для применения увеличения данных.

adsSyntheticReverbTrain = transform(adsSyntheticCleanTrain,@(y)deal(augment(augmenter,y,16e3).Audio{1}));

Объединение двух хранилищ аудиоданных.

adsSyntheticCombinedTrain = combine(adsSyntheticCleanTrain,adsSyntheticReverbTrain);

Далее, исходя из размеров входных функций в сеть, сегментируйте звук на порции длительностью 2,072 с с перекрытием 50%.

Наличие слишком большого количества бесшумных сегментов может отрицательно повлиять на обучение модели DNN. Удалите сегменты, которые в основном являются бесшумными (более 50% продолжительности), и исключите их из обучения модели. Не следует полностью удалять молчание, поскольку модель не будет устойчива к тихим областям, и небольшие эффекты реверберации могут быть идентифицированы как молчание. detectSpeech может идентифицировать начальную и конечную точки областей молчания. После этих двух стадий процесс извлечения признаков может быть выполнен, как описано в первом разделе. helperFeatureExtract реализует эти шаги.

Определите параметры извлечения элемента. По настройке reduceDataSet true, для выполнения последующих шагов выбирается небольшое подмножество наборов данных.

reduceDataSet = true;
params.fs = 16000;
params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
samplesPerMs = params.fs/1000;
params.samplesPerImage = (24+256*8)*samplesPerMs;
params.shiftImage = params.samplesPerImage/2;
params.NumSegments = 256;
params.NumFeatures = 256
params = struct with fields:
    WindowdowLength: 512
             Window: [512×1 double]
      OverlapLength: 384
          FFTLength: 512
        NumSegments: 256
        NumFeatures: 256
                 fs: 16000
    samplesPerImage: 33152
         shiftImage: 16576

Чтобы ускорить обработку, распределите задачу предварительной обработки и извлечения компонентов между несколькими работниками с помощью parfor.

Определите количество разделов для набора данных. Если у вас нет Toolbox™ Parallel Computing, используйте один раздел.

if ~isempty(ver('parallel'))
    pool = gcp;
    numPar = numpartitions(adsCombinedTrain,pool);
else
    numPar = 1;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

Для каждого раздела считывайте из хранилища данных, предварительно обрабатывайте аудиосигнал и извлекайте функции.

if reduceDataSet
    adsCombinedTrain = shuffle(adsCombinedTrain); %#ok
    adsCombinedTrain = subset(adsCombinedTrain,1:200);
    
    adsSyntheticCombinedTrain = shuffle(adsSyntheticCombinedTrain);
    adsSyntheticCombinedTrain = subset(adsSyntheticCombinedTrain,1:200);
end

allCleanFeatures = cell(1,numPar);
allReverbFeatures = cell(1,numPar);

parfor iPartition = 1:numPar
    combinedPartition = partition(adsCombinedTrain,numPar,iPartition);
    combinedSyntheticPartition = partition(adsSyntheticCombinedTrain,numPar,iPartition);
        
    cPartitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    cSyntheticPartitionSize = numel(combinedSyntheticPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    partitionSize = cPartitionSize + cSyntheticPartitionSize;
    
    cleanFeaturesPartition = cell(1,partitionSize);    
    reverbFeaturesPartition = cell(1,partitionSize);  
    
    for idx = 1:partitionSize
        if idx <= cPartitionSize
            audios = read(combinedPartition);
        else
            audios = read(combinedSyntheticPartition);
        end
        cleanAudio = single(audios(:,1));
        reverbAudio = single(audios(:,2));
        [featuresClean,featuresReverb] = helperFeatureExtract(cleanAudio,reverbAudio,false,params);
        cleanFeaturesPartition{idx} = featuresClean;
        reverbFeaturesPartition{idx} = featuresReverb;
    end
    allCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:});
    allReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:});
end

allCleanFeatures = cat(2,allCleanFeatures{:});
allReverbFeatures = cat(2,allReverbFeatures{:});

Нормализуйте извлеченные элементы в диапазоне [-1,1], а затем измените форму, как описано в первом разделе, с помощью функции FeatureNormalityAndReshape.

trainClean = featureNormalizeAndReshape(allCleanFeatures);
trainReverb = featureNormalizeAndReshape(allReverbFeatures);

Теперь, когда вы извлекли элементы STFT логарифмической величины из учебных наборов данных, выполните ту же процедуру, чтобы извлечь элементы из наборов данных проверки. Для целей реконструкции сохраните фазу реверберативных речевых выборок набора данных проверки. Кроме того, сохранить аудиоданные как для чистых, так и для реверберных речевых выборок в наборе проверки достоверности, который будет использоваться в процессе оценки (следующий раздел).

adsCleanVal = audioDatastore(fullfile(cleanDataFolder,'clean_testset_wav'),'IncludeSubfolders',true);
adsReverbVal = audioDatastore(fullfile(reverbDataFolder,'reverb_testset_wav'),'IncludeSubfolders',true);

Повторная выборка от 48 кГц до 16 кГц.

adsCleanVal = transform(adsCleanVal,@(x)resample(x,16e3,48e3));
adsReverbVal = transform(adsReverbVal,@(x)resample(x,16e3,48e3));

adsCombinedVal = combine(adsCleanVal,adsReverbVal); 
if reduceDataSet
    adsCombinedVal = shuffle(adsCombinedVal);%#ok
    adsCombinedVal = subset(adsCombinedVal,1:50);
end

allValCleanFeatures = cell(1,numPar);
allValReverbFeatures = cell(1,numPar);
allValReverbPhase = cell(1,numPar);
allValCleanAudios = cell(1,numPar);
allValReverbAudios = cell(1,numPar);

parfor iPartition = 1:numPar
    combinedPartition = partition(adsCombinedVal,numPar,iPartition);
    
    partitionSize = numel(combinedPartition.UnderlyingDatastores{1}.UnderlyingDatastores{1}.Files);
    
    cleanFeaturesPartition = cell(1,partitionSize);    
    reverbFeaturesPartition = cell(1,partitionSize);  
    reverbPhasePartition = cell(1,partitionSize); 
    cleanAudiosPartition = cell(1,partitionSize); 
    reverbAudiosPartition = cell(1,partitionSize);

    for idx = 1:partitionSize
        audios = read(combinedPartition);
        
        cleanAudio = single(audios(:,1));
        reverbAudio = single(audios(:,2));
        
        [a,b,c,d,e] = helperFeatureExtract(cleanAudio,reverbAudio,true,params);
        
        cleanFeaturesPartition{idx} = a;
        reverbFeaturesPartition{idx} = b;  
        reverbPhasePartition{idx} = c;
        cleanAudiosPartition{idx} = d;
        reverbAudiosPartition{idx} = e;
    end
    allValCleanFeatures{iPartition} = cat(2,cleanFeaturesPartition{:});
    allValReverbFeatures{iPartition} = cat(2,reverbFeaturesPartition{:});
    allValReverbPhase{iPartition} = cat(2,reverbPhasePartition{:});
    allValCleanAudios{iPartition} = cat(2,cleanAudiosPartition{:});
    allValReverbAudios{iPartition} = cat(2,reverbAudiosPartition{:});
end

allValCleanFeatures = cat(2,allValCleanFeatures{:});
allValReverbFeatures = cat(2,allValReverbFeatures{:});
allValReverbPhase = cat(2,allValReverbPhase{:});
allValCleanAudios = cat(2,allValCleanAudios{:});
allValReverbAudios = cat(2,allValReverbAudios{:});

valClean = featureNormalizeAndReshape(allValCleanFeatures);

Сохраните минимальное и максимальное значения каждого элемента набора проверки ревербератора. Эти значения будут использоваться в процессе реконструкции.

[valReverb,valMinMaxPairs] = featureNormalizeAndReshape(allValReverbFeatures);

Определение архитектуры нейронной сети

Полностью сверточная сетевая архитектура, названная U-Net, была адаптирована для этой задачи дееверберации речи, как предложено в [1]. «U-Net» - это сеть кодера-декодера с пропущенными соединениями. В модели U-Net каждый уровень понижает свой входной сигнал (шаг 2) до тех пор, пока не будет достигнут узкий уровень (путь кодирования). В последующих слоях входной сигнал усиливается каждым уровнем до тех пор, пока выходной сигнал не вернется в исходную форму (путь декодирования). Чтобы минимизировать потерю низкоуровневой информации во время процесса понижающей дискретизации, соединения между зеркально отраженными слоями создаются посредством прямого конкатенации выходов соответствующих уровней (соединения пропуска).

Определите архитектуру сети и верните график уровней с подключениями.

params.WindowdowLength = 512;
params.FFTLength = params.WindowdowLength;
params.NumFeatures = params.FFTLength/2;
params.NumSegments = 256;
    
filterH = 6;
filterW = 6;
numChannels = 1;
nFilters = [64,128,256,512,512,512,512,512];

inputLayer = imageInputLayer([params.NumFeatures,params.NumSegments,numChannels], ...
    'Normalization','none','Name','input');
layers = inputLayer;

% U-Net squeezing path
layers = [layers;
    convolution2dLayer([filterH,filterW],nFilters(1),'Stride',2,'Padding','same','Name',"conv"+string(1));
    leakyReluLayer(0.2,'Name',"leaky-relu"+string(1))];
        
for i = 2:8
    layers =  [layers;
        convolution2dLayer([filterH,filterW],nFilters(i),'Stride',2,'Padding','same','Name',"conv"+string(i));
        batchNormalizationLayer('Name',"batchnorm"+string(i))];%#ok
    if i ~= 8
        layers = [layers;leakyReluLayer(0.2,'Name',"leaky-relu"+string(i))];%#ok
    else
        layers = [layers;reluLayer('Name',"relu"+string(i))];%#ok
    end
end

% U-Net expanding path
for i = 7:-1:0
    nChannels = numChannels;
    if i > 0
        nChannels = nFilters(i);
    end
    layers = [layers;
        transposedConv2dLayer([filterH,filterW],nChannels,'Stride',2,'Cropping','same','Name',"deconv"+string(i))];%#ok
    if i > 0
        layers = [layers; batchNormalizationLayer('Name',"de-batchnorm" +string(i))];%#ok
    end
    if i > 4
        layers = [layers;dropoutLayer(0.5,'Name',"de-dropout"+string(i))];%#ok
    end
    if i > 0
        layers = [layers;
            reluLayer('Name',"de-relu"+string(i));
            concatenationLayer(3,2,'Name',"concat"+string(i))];%#ok
    else
        layers = [layers;tanhLayer('Name',"de-tanh"+string(i))];%#ok
    end
end

layers = [layers;regressionLayer('Name','output')];

unetLayerGraph = layerGraph(layers); 

% Define skip-connections
for i = 1:7
    unetLayerGraph = connectLayers(unetLayerGraph,'leaky-relu'+string(i),'concat'+string(i)+'/in2');
end

Использовать analyzeNetwork для просмотра архитектуры модели. Это хороший способ визуализации соединений между слоями.

analyzeNetwork(unetLayerGraph); 

Обучение сети

В качестве функции потерь будет использоваться среднеквадратичная ошибка (MSE) между спектрами логарифмических величин выборок речи (выходных данных модели) и соответствующей выборок чистой речи (цели). Используйте adam оптимизатор и размер мини-партии 128 для обучения. Разрешите модели тренироваться не более 50 эпох. Если потеря проверки не улучшается в течение 5 последовательных периодов, завершите процесс обучения. Снижайте уровень обучения в 10 раз каждые 15 эпох.

Определите следующие варианты обучения. Измените среду выполнения и укажите, следует ли выполнять фоновую диспетчеризацию в зависимости от доступности оборудования и наличия доступа к Parallel Computing Toolbox™.

initialLearnRate = 8e-4;
miniBatchSize = 64;

options = trainingOptions("adam", ...
        "MaxEpochs", 50, ...
        "InitialLearnRate",initialLearnRate, ...
        "MiniBatchSize",miniBatchSize, ...
        "Shuffle","every-epoch", ...
        "Plots","training-progress", ...
        "Verbose",false, ...
        "ValidationFrequency",max(1,floor(size(trainReverb,4)/miniBatchSize)), ...
        "ValidationPatience",5, ...
        "LearnRateSchedule","piecewise", ...
        "LearnRateDropFactor",0.1, ... 
        "LearnRateDropPeriod",15, ...
        "ExecutionEnvironment","gpu", ...
        "DispatchInBackground",true, ...
        "ValidationData",{valReverb,valClean});

Обучение сети.

dereverbNet = trainNetwork(trainReverb,trainClean,unetLayerGraph,options);

Оценка производительности сети

Прогнозирование и реконструкция

Спрогнозировать спектры логарифмической величины проверочного набора.

predictedSTFT4D = predict(dereverbNet,valReverb);

Используйте функцию helperReconstreadtAudios для восстановления предсказанной речи. Эта функция выполняет действия, описанные в первом разделе.

params.WindowdowLength = 512;
params.Window = hamming(params.WindowdowLength,"periodic");
params.OverlapLength = round(0.75*params.WindowdowLength);
params.FFTLength = params.WindowdowLength;
params.fs = 16000;

dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,valMinMaxPairs,allValReverbPhase,allValReverbAudios,params);

Визуализация логарифмических сигналов STFT чистого, реверберативного и соответствующих речевых сигналов с пониженным уровнем сигнала.

figure('Position',[100,100,1024,1200])

subplot(3,1,1)
imagesc(squeeze(allValCleanFeatures{1}))    
set(gca,'Ydir','normal')
subtitle('Clean')
xlabel('Time')
ylabel('Frequency')
colorbar

subplot(3,1,2)
imagesc(squeeze(allValReverbFeatures{1}))
set(gca,'Ydir','normal')
subtitle('Reverberated')
xlabel('Time')
ylabel('Frequency')
colorbar

subplot(3,1,3)
imagesc(squeeze(predictedSTFT4D(:,:,:,1)))
set(gca,'Ydir','normal')
subtitle('Predicted (Dereverberated)')
xlabel('Time')
ylabel('Frequency')
caxis([-1,1])
colorbar

Оценочные метрики

Для оценки производительности сети используется подмножество объективных показателей, используемых в [1]. Эти метрики вычисляются на сигналах временной области.

  • Cepstrum distance (CD) - предоставляет оценку логарифмического спектрального расстояния между двумя спектрами (прогнозируемым и чистым). Меньшие значения указывают на лучшее качество.

  • Логарифмическое отношение правдоподобия (LLR) - линейное прогнозирующее кодирование (LPC) на основе объективного измерения. Меньшие значения указывают на лучшее качество.

Вычисляют эти измерения для реверберативной речи и исказившихся речевых сигналов.

[summaryMeasuresReconstructed,allMeasuresReconstructed] = calculateObjectiveMeasures(dereverbedAudioAll,allValCleanAudios,params.fs);
[summaryMeasuresReverb,allMeasuresReverb] = calculateObjectiveMeasures(allValReverbAudios,allValCleanAudios,params.fs);
disp(summaryMeasuresReconstructed)
       avgCdMean: 3.8386
     avgCdMedian: 3.3671
      avgLlrMean: 0.9152
    avgLlrMedian: 0.8096
disp(summaryMeasuresReverb)
       avgCdMean: 4.2591
     avgCdMedian: 3.6336
      avgLlrMean: 0.9726
    avgLlrMedian: 0.8714

Гистограммы иллюстрируют распределение среднего CD, среднего SRMR и среднего LLR реверберативных и дееверберизованных данных.

figure('position',[50,50,1100,1300])

subplot(2,1,1)
histogram(allMeasuresReverb.cdMean,10)
hold on
histogram(allMeasuresReconstructed.cdMean, 10)
subtitle('Mean Cepstral Distance Distribution')
ylabel('count')
xlabel('mean CD')
legend('Reverberant (Original)','Dereverberated (Predicted)')

subplot(2,1,2)
histogram(allMeasuresReverb.llrMean,10)
hold on
histogram(allMeasuresReconstructed.llrMean,10)
subtitle('Mean Log Likelihood Ratio Distribution')
ylabel('Count')
xlabel('Mean LLR')
legend('Reverberant (Original)','Dereverberated (Predicted)')

Ссылки

[1] Эрнст, О., Шазан, С. Э., Ганнот, С., и Голдбергер, Дж. (2018). Дееверберация речи с использованием полностью сверточных сетей. 26-я Европейская конференция по обработке сигналов (EUSIPCO), 390-394.

[2] https://datashare.is.ed.ac.uk/handle/10283/2031

[3] https://datashare.is.ed.ac.uk/handle/10283/2791

[4] https://github.com/MuSAELab/SRMRToolbox

Вспомогательные функции

Применить реверберацию

function yOut = applyReverb(y,preDelay,decayFactor,wetDryMix,fs)
% This function generates reverberant speech data using the reverberator
% object
%
% inputs: 
% y                                - clean speech sample
% preDelay, decayFactor, wetDryMix - reverberation parameters
% fs                               - sampling rate of y
%
% outputs:
% yOut - corresponding reveberated speech sample

    revObj = reverberator('SampleRate',fs, ...
        'DecayFactor',decayFactor, ...
        'WetDryMix',wetDryMix, ...
        'PreDelay',preDelay);
    yOut = revObj(y);
    yOut = yOut(1:length(y),1);
end

Извлечение пакета функций

function [featuresClean,featuresReverb,phaseReverb,cleanAudios,reverbAudios] ...
    = helperFeatureExtract(cleanAudio,reverbAudio,isVal,params)
% This function performs the preprocessing and features extraction task on 
% the audio files used for dereverberation model training and testing.
%
% inputs:
% cleanAudio  - the clean audio file (reference)
% reverbAudio - corresponding reverberant speech file
% isVal       - Boolean flag indicating if it is the validation set
% params      - a structure containing feature extraction parameters
%
% outputs:
% featuresClean  - log-magnitude STFT features of clean audio
% featuresReverb - log-magnitude STFT features of reverberant audio
% phaseReverb    - phase of STFT of reverberant audio
% cleanAudios    - 2.072s-segments of clean audio file used for feature extraction
% reverbAudios   - 2.072s-segments of corresponding reverberant audio

    assert(length(cleanAudio) == length(reverbAudio));
    nSegments = floor((length(reverbAudio) - (params.samplesPerImage - params.shiftImage))/params.shiftImage);

    featuresClean = {};
    featuresReverb = {};
    phaseReverb = {};
    cleanAudios = {};
    reverbAudios = {};
    nGood = 0;
    nonSilentRegions = detectSpeech(reverbAudio, params.fs);
    nonSilentRegionIdx = 1;
    totalRegions = size(nonSilentRegions, 1);
    for cid = 1:nSegments
        start = (cid - 1)*params.shiftImage + 1;
        en = start + params.samplesPerImage - 1;

        nonSilentSamples = 0;
        while nonSilentRegionIdx < totalRegions && nonSilentRegions(nonSilentRegionIdx, 2) < start
            nonSilentRegionIdx = nonSilentRegionIdx + 1;
        end
        
        nonSilentStart = nonSilentRegionIdx;
        while nonSilentStart <= totalRegions && nonSilentRegions(nonSilentStart, 1) <= en
            nonSilentDuration = min(en, nonSilentRegions(nonSilentStart,2)) - max(start,nonSilentRegions(nonSilentStart,1)) + 1;
            nonSilentSamples = nonSilentSamples + nonSilentDuration; 
            nonSilentStart = nonSilentStart + 1;
        end
        
        nonSilentPerc = nonSilentSamples * 100 / (en - start + 1);
        silent = nonSilentPerc < 50;
        
        reverbAudioSegment = reverbAudio(start:en);
        if ~silent
            nGood = nGood + 1;
            cleanAudioSegment = cleanAudio(start:en);
            assert(length(cleanAudioSegment)==length(reverbAudioSegment), 'Lengths do not match after chunking')
            
            % Clean Audio
            [featsUnit, ~] = featureExtract(cleanAudioSegment, params);
            featuresClean{nGood} = featsUnit; %#ok

            % Reverb Audio
            [featsUnit, phaseUnit] = featureExtract(reverbAudioSegment, params);
            featuresReverb{nGood} = featsUnit; %#ok
            if isVal
                phaseReverb{nGood} = phaseUnit; %#ok
                reverbAudios{nGood} = reverbAudioSegment;%#ok
                cleanAudios{nGood} = cleanAudioSegment;%#ok
            end
        end
    end
end

Извлечь элементы

function [features, phase, lastFBin] = featureExtract(audio, params)
% Function to extract features for a speech file
    audio = single(audio);
    
    audioSTFT = stft(audio,'Window',params.Window,'OverlapLength',params.OverlapLength, ...
                    'FFTLength', params.FFTLength, 'FrequencyRange', 'onesided');
    
    phase = single(angle(audioSTFT(1:end-1,:)));     
    features = single(log(abs(audioSTFT(1:end-1,:)) + 10e-30)); 
    lastFBin = audioSTFT(end,:);

end

Нормализация и изменение формы элементов

function [featNorm,minMaxPairs] = featureNormalizeAndReshape(feats)
% function to normalize features - range [-1, 1] and reshape to 4
% dimensions
%
% inputs:
% feats - 3-dimensional array of extracted features
%
% outputs:
% featNorm - normalized and reshaped features
% minMaxPairs - array of original min and max pairs used for normalization

    nSamples = length(feats);
    minMaxPairs = zeros(nSamples,2,'single');
    featNorm = zeros([size(feats{1}),nSamples],'single');
    parfor i = 1:nSamples
        feat = feats{i};
        maxFeat = max(feat,[],'all');
        minFeat = min(feat,[],'all');
        featNorm(:,:,i) = 2.*(feat - minFeat)./(maxFeat - minFeat) - 1;
        minMaxPairs(i,:) = [minFeat,maxFeat];
    end
    featNorm = reshape(featNorm,size(featNorm,1),size(featNorm,2),1,size(featNorm,3));
end

Восстановить прогнозируемый звук

function dereverbedAudioAll = helperReconstructPredictedAudios(predictedSTFT4D,minMaxPairs,reverbPhase,reverbAudios,params)
% This function will reconstruct the 2.072s long audios predicted by the 
% model using the predicted log-magnitude spectrogram and the phase of the 
% reverberant audio file
%
% inputs:
% predictedSTFT4D - Predicted 4-dimensional STFT log-magnitude features
% minMaxPairs     - Original minimum/maximum value pairs used in normalization
% reverbPhase     - Array of phases of STFT of reverberant audio files
% reverbAudios    - 2.072s-segments of corresponding reverberant audios
% params          - Structure containing feature extraction parameters

    predictedSTFT = squeeze(predictedSTFT4D);
    denormalizedFeatures = zeros(size(predictedSTFT),'single');
    for i = 1:size(predictedSTFT,3)
        feat = predictedSTFT(:,:,i);
        maxFeat = minMaxPairs(i,2);
        minFeat = minMaxPairs(i,1);
        denormalizedFeatures(:,:,i) = (feat + 1).*(maxFeat-minFeat)./2 + minFeat;        
    end
    
    predictedSTFT = exp(denormalizedFeatures);
    
    nCount = size(predictedSTFT,3);
    dereverbedAudioAll = cell(1,nCount);

    nSeg = params.NumSegments;
    win = params.Window;
    ovrlp = params.OverlapLength;
    FFTLength = params.FFTLength;
    parfor ii = 1:nCount
        % Append zeros to the highest frequency bin
        stftUnit = predictedSTFT(:,:,ii);
        stftUnit = cat(1,stftUnit, zeros(1,nSeg)); 
        phase = reverbPhase{ii};
        phase = cat(1,phase,zeros(1,nSeg));

        oneSidedSTFT = stftUnit.*exp(1j*phase);
        dereverbedAudio= istft(oneSidedSTFT, ...
            'Window', win,'OverlapLength', ovrlp, ...
                                    'FFTLength',FFTLength,'ConjugateSymmetric',true,...
                                    'FrequencyRange','onesided');

        dereverbedAudioAll{ii} = dereverbedAudio./max(max(abs(dereverbedAudio)), max(abs(reverbAudios{ii})));
    end
end

Расчет целевых показателей

function [summaryMeasures,allMeasures] = calculateObjectiveMeasures(reconstructedAudios,cleanAudios,fs)
% This function computes the objective measures on time-domain signals.
%
% inputs:
% reconstructedAudios - An array of audio files to evaluate.
% cleanAudios - An array of reference audio files
% fs - Sampling rate of audio files
%
% outputs:
% summaryMeasures - Global means of CD, LLR individual mean and median values
% allMeasures - Individual mean and median values

    nAudios = length(reconstructedAudios);
    cdMean = zeros(nAudios,1);
    cdMedian = zeros(nAudios,1);
    llrMean = zeros(nAudios,1);
    llrMedian = zeros(nAudios,1);

    parfor k = 1 : nAudios
      y = reconstructedAudios{k};
      x = cleanAudios{k};

      y = y./max(abs(y));
      x = x./max(abs(x));

      [cdMean(k),cdMedian(k)] = cepstralDistance(x,y,fs);
      [llrMean(k),llrMedian(k)] = lpcLogLikelihoodRatio(y,x,fs);
    end
    
    summaryMeasures.avgCdMean = mean(cdMean);
    summaryMeasures.avgCdMedian = mean(cdMedian);
    summaryMeasures.avgLlrMean = mean(llrMean);
    summaryMeasures.avgLlrMedian = mean(llrMedian);   
    
    allMeasures.cdMean = cdMean;
    allMeasures.llrMean = llrMean;
end

Цепстральное расстояние

function [meanVal, medianVal] = cepstralDistance(x,y,fs)
    x = x / sqrt(sum(x.^2));
    y = y / sqrt(sum(y.^2));

    width = round(0.025*fs);
    shift = round(0.01*fs);

    nSamples = length(x);
    nFrames = floor((nSamples - width + shift)/shift);
    win = window(@hanning,width);

    winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1);

    xFrames = x(winIndex).*win;
    yFrames = y(winIndex).*win;

    xCeps = cepstralReal(xFrames,width);
    yCeps = cepstralReal(yFrames,width);

    dist = (xCeps - yCeps).^2;
    cepsD = 10 / log(10)*sqrt(2*sum(dist(2:end,:),1) + dist(1,:));
    cepsD = max(min(cepsD,10),0);

    meanVal = mean(cepsD);
    medianVal = median(cepsD);
end

Настоящий Цепструм

function realC = cepstralReal(x, width)
    width2p = 2 ^ nextpow2(width);
    powX = abs(fft(x, width2p));

    lowCutoff = max(powX(:)) * 10^-5;
    powX  = max(powX, lowCutoff);

    realC = real(ifft(log(powX)));
    order = 24;
    realC = realC(1 : order + 1, :);
    realC = realC - mean(realC, 2);
end

Логарифмическое отношение правдоподобия LPC

function [meanLlr,medianLlr] = lpcLogLikelihoodRatio(x,y,fs)
    order = 12;
    width = round(0.025*fs);
    shift = round(0.01*fs);

    nSamples = length(x);
    nFrames  = floor((nSamples - width + shift)/shift);
    win = window(@hanning,width);

    winIndex = repmat((1:width)',1,nFrames) + repmat((0:nFrames - 1)*shift,width,1);

    xFrames = x(winIndex) .* win;
    yFrames = y(winIndex) .* win;

    lpcX = realLpc(xFrames, width, order);
    [lpcY,realY] = realLpc(yFrames, width, order);

    llr = zeros(nFrames, 1);
    for n = 1 : nFrames
      R = toeplitz(realY(1:order+1,n));
      num = lpcX(:,n)'*R*lpcX(:,n);
      den = lpcY(:,n)'*R*lpcY(:,n);  
      llr(n) = log(num/den);
    end

    llr = sort(llr);
    llr = llr(1:ceil(nFrames*0.95));
    llr = max(min(llr,2),0);

    meanLlr = mean(llr);
    medianLlr = median(llr);
end

Действительные линейные прецизионные коэффициенты

function [lpcCoeffs, realX] = realLpc(xFrames, width, order)
    width2p = 2 ^ nextpow2(width);
    X = fft(xFrames, width2p);

    Rx = ifft(abs(X).^2);
    Rx = Rx./width; 
    realX = real(Rx);

    lpcX = levinson(realX, order);
    lpcCoeffs = real(lpcX');
end