Определение ключевого слова в шуме Используя MFCC и сети LSTM

В этом примере показано, как идентифицировать ключевое слово в шумной речи с помощью нейронной сети для глубокого обучения. В частности, пример использует сеть Bidirectional Long Short-Term Memory (BiLSTM) и частоту mel cepstral коэффициенты (MFCC).

Введение

Ключевое слово, определяющее (KWS), является важной составляющей речи - помогают технологиям, где пользователь говорит предопределенное ключевое слово с пробуждением система прежде, чем говорить полную команду или запрос к устройству.

Этот пример обучает KWS глубокая сеть с последовательностями функции mel-частоты cepstral коэффициентов (MFCC). Пример также демонстрирует, как сетевая точность в шумной среде может быть улучшена с помощью увеличения данных.

Этот пример использует сети долгой краткосрочной памяти (LSTM), которые являются типом рекуррентной нейронной сети (RNN), подходящей, чтобы изучить данные timeseries и последовательность. Сеть LSTM может изучить долгосрочные зависимости между временными шагами последовательности. Слой LSTM (lstmLayer (Deep Learning Toolbox)), может посмотреть в то время последовательность в прямом направлении, в то время как двунаправленный слой LSTM (bilstmLayer (Deep Learning Toolbox)), может посмотреть в то время последовательность и во вперед и в обратные направления. Этот пример использует двунаправленный слой LSTM.

Пример использует google Speech Commands Dataset, чтобы обучить модель глубокого обучения. Чтобы запустить пример, необходимо сначала загрузить набор данных. Если вы не хотите загружать набор данных или обучать сеть, то можно загрузить и использовать предварительно обученную сеть путем открытия этого примера в MATLAB® и рабочих линиях 3-10 из примера.

Точечное ключевое слово с предварительно обученной сетью

Перед входом в учебный процесс подробно, вы будете загружать и использовать предварительно обученную сеть определения ключевого слова, чтобы идентифицировать ключевое слово.

В этом примере ключевым словом, чтобы определить является YES.

Считайте тестовый сигнал, где ключевое слово произнесено.

[audioIn, fs] = audioread('keywordTestSignal.wav');
sound(audioIn,fs)

Загрузите и загрузите предварительно обученную сеть, среднее значение (M) и стандартное отклонение (S) векторы, используемые для нормализации функции, а также 2 звуковых файлов, используемых для того, чтобы проверить сеть позже в примере.

url = 'http://ssd.mathworks.com/supportfiles/audio/KeywordSpotting.zip';
downloadNetFolder = tempdir;
netFolder = fullfile(downloadNetFolder,'KeywordSpotting');
if ~exist(netFolder,'dir')
    disp('Downloading pretrained network and audio files (4 files - 7 MB) ...')
    unzip(url,downloadNetFolder)
end
load(fullfile(netFolder,'KWSNet.mat'));

Создайте audioFeatureExtractor объект выполнить извлечение признаков.

WindowLength = 512;
OverlapLength = 384;
afe = audioFeatureExtractor('SampleRate',fs, ...
                            'Window',hann(WindowLength,'periodic'), ...
                            'OverlapLength',OverlapLength, ...
                            'mfcc',true, ...
                            'mfccDelta', true, ...
                            'mfccDeltaDelta',true);                   

Извлеките функции из тестового сигнала и нормируйте их.

features = extract(afe,audioIn);   

features = (features - M)./S;

Вычислите ключевое слово, определяющее бинарную маску. Значение маски каждый соответствует сегменту, где ключевое слово было определено.

mask = classify(KWSNet,features.');

Каждая выборка в маске соответствует 128 выборкам от речевого сигнала (WindowLength - OverlapLength).

Расширьте маску до длины сигнала.

mask = repmat(mask, WindowLength-OverlapLength, 1);
mask = double(mask) - 1;
mask = mask(:);

Постройте тестовый сигнал и маску.

figure
audioIn = audioIn(1:length(mask));
t = (0:length(audioIn)-1)/fs;
plot(t, audioIn)
grid on
hold on
plot(t, mask)
legend('Speech', 'YES')

Слушайте пятнистое ключевое слово.

sound(audioIn(mask==1),fs)

Обнаружьте команды Используя передачу потокового аудио от микрофона

Протестируйте свою предварительно обученную сеть обнаружения команды на передаче потокового аудио от вашего микрофона. Попытайтесь говорить случайные слова, включая ключевое слово (YES).

Вызовите generateMATLABFunction на audioFeatureExtractor объект создать функцию извлечения признаков. Вы будете использовать эту функцию в цикле обработки.

generateMATLABFunction(afe,'generateKeywordFeatures','IsStreaming',true);

Задайте читателя аудио устройства, который может считать аудио из вашего микрофона. Установите длину системы координат на длину транзитного участка. Это позволяет вам вычислить новый набор функций каждой новой аудио системы координат от микрофона.

HopLength = WindowLength - OverlapLength;
FrameLength = HopLength;
adr = audioDeviceReader('SampleRate',fs, ...
                        'SamplesPerFrame',FrameLength);

Создайте осциллограф для визуализации речевого сигнала и предполагаемой маски.

scope = timescope('SampleRate',fs, ...
                  'TimeSpanSource','property', ...
                  'TimeSpan',5, ...
                  'TimeSpanOverrunAction','Scroll', ...
                  'BufferLength',fs*5*2, ...
                  'ShowLegend',true, ...
                  'ChannelNames',{'Speech','Keyword Mask'}, ...
                  'YLimits',[-1.2 1.2], ...
                  'Title','Keyword Spotting');

Задайте уровень, на котором вы оцениваете маску. Вы сгенерируете маску один раз в NumHopsPerUpdate аудио системы координат.

NumHopsPerUpdate = 16;

Инициализируйте буфер для аудио.

dataBuff = dsp.AsyncBuffer(WindowLength);

Инициализируйте буфер для вычисленных функций.

featureBuff = dsp.AsyncBuffer(NumHopsPerUpdate);

Инициализируйте буфер, чтобы справиться с графическим выводом аудио и маски.

plotBuff = dsp.AsyncBuffer(NumHopsPerUpdate*WindowLength);

Чтобы запустить цикл неопределенно, установите timeLimit на Inf. Чтобы остановить симуляцию, закройте осциллограф.

timeLimit = 20;

tic
while toc < timeLimit
    
    data = adr();
    write(dataBuff, data);
    write(plotBuff, data);

    frame = read(dataBuff,WindowLength,OverlapLength);
    features = generateKeywordFeatures(frame,fs);
    write(featureBuff,features.');

    if featureBuff.NumUnreadSamples == NumHopsPerUpdate
        featureMatrix = read(featureBuff);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        featureMatrix = (featureMatrix - M)./S;
        
        [keywordNet, v] = classifyAndUpdateState(KWSNet,featureMatrix.');
        v = double(v) - 1;
        v = repmat(v, HopLength, 1);
        v = v(:);
        v = mode(v);
        v = repmat(v, NumHopsPerUpdate * HopLength,1);
        
        data = read(plotBuff);
        scope([data, v]);
        
        if ~isVisible(scope)
            break;
        end
    end
end
hide(scope)

В остальной части примера вы изучите, как обучить сеть определения ключевого слова.

Учебные сводные данные процесса

Учебный процесс проходит следующие шаги:

  1. Смотрите базовую линию определения ключевого слова "золотого стандарта" на сигнале валидации.

  2. Создайте учебное произнесение из бесшумного набора данных.

  3. Обучите ключевое слово, определяющее сеть LSTM с помощью последовательностей MFCC, извлеченных из того произнесения.

  4. Проверяйте сетевую точность путем сравнения базовой линии валидации с выходом сети, когда применено сигнал валидации.

  5. Проверяйте сетевую точность на сигнал валидации, поврежденный шумом.

  6. Увеличьте обучающий набор данных путем введения шума к речевым данным с помощью audioDataAugmenter.

  7. Переобучите сеть с увеличенным набором данных.

  8. Проверьте, что переобученная сеть теперь дает к более высокой точности, когда применено шумный сигнал валидации.

Смотрите сигнал валидации

Вы используете демонстрационный речевой сигнал проверить сеть KWS. Сигнал валидации состоит 34 секунды речи с ключевым словом YES, появляющийся периодически.

Загрузите сигнал валидации.

[audioIn,fs] = audioread(fullfile(netFolder,'KeywordSpeech-16-16-mono-34secs.flac'));

Слушайте сигнал.

sound(audioIn,fs)

Визуализируйте сигнал.

figure
t = (1/fs) * (0:length(audioIn)-1);
plot(t,audioIn);
grid on
xlabel('Time (s)')
title('Validation Speech Signal')

Смотрите базовую линию KWS

Загрузите базовую линию KWS. Эта базовая линия была получена с помощью speech2text: Создайте ключевое слово, определяющее маску Используя Audio Labeler.

load('KWSBaseline.mat','KWSBaseline')

Базовая линия является логическим вектором из той же длины как звуковой сигнал валидации. Сегменты в audioIn где ключевое слово произнесено, установлены в одного в KWSBaseline.

Визуализируйте речевой сигнал наряду с базовой линией KWS.

fig = figure;
plot(t,[audioIn,KWSBaseline'])
grid on
xlabel('Time (s)')
legend('Speech','KWS Baseline','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
title("Validation Signal")

Слушайте речевые сегменты, идентифицированные как ключевые слова.

sound(audioIn(KWSBaseline),fs)

Цель сети, которую вы обучаете, состоит в том, чтобы вывести маску KWS нулей и единиц как эта базовая линия.

Загрузите речевой набор данных команд

Загрузите и извлеките Google Speech Commands Dataset [1].

url = 'https://ssd.mathworks.com/supportfiles/audio/google_speech.zip';

downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,'google_speech');

if ~exist(datasetFolder,'dir')
    disp('Downloading Google speech commands data set (1.5 GB)...')
    unzip(url,datasetFolder)
end

Создайте audioDatastore это указывает на набор данных.

ads = audioDatastore(datasetFolder,'LabelSource','foldername','Includesubfolders',true);
ads = shuffle(ads);

Набор данных содержит файлы фонового шума, которые не используются в этом примере. Используйте subset создать новый datastore, который не имеет файлов фонового шума.

isBackNoise = ismember(ads.Labels,"background");
ads = subset(ads,~isBackNoise);

Набор данных имеет приблизительно 65 000 одного второго длинного произнесения 30 коротких слов (включая ключевое слово YES). Получите отказ распределения слова в datastore.

countEachLabel(ads)
ans=30×2 table
    Label     Count
    ______    _____

    bed       1713 
    bird      1731 
    cat       1733 
    dog       1746 
    down      2359 
    eight     2352 
    five      2357 
    four      2372 
    go        2372 
    happy     1742 
    house     1750 
    left      2353 
    marvin    1746 
    nine      2364 
    no        2375 
    off       2357 
      ⋮

Разделите ads в два хранилища данных: первый datastore содержит файлы, соответствующие ключевому слову. Второй datastore содержит все другие слова.

keyword = 'yes';
isKeyword = ismember(ads.Labels,keyword);
ads_keyword = subset(ads,isKeyword);
ads_other = subset(ads,~isKeyword);

Чтобы обучить сеть с набором данных в целом и достигнуть максимально возможной точности, установите reduceDataset к false. Чтобы запустить этот пример быстро, установите reduceDataset к true.

reduceDataset = false; 
if reduceDataset 
    % Reduce the dataset by a factor of 20 
    ads_keyword = splitEachLabel (ads_keyword, вокруг (numel (ads_keyword.Files) / 20)); 
    numUniqueLabels = numel (уникальный (ads_other.Labels)); 
    ads_other = splitEachLabel (ads_other, вокруг (numel (ads_other.Files) / numUniqueLabels / 20)); 
end

Получите отказ распределения слова в каждом datastore. Переставьте ads_other datastore так, чтобы последовательные чтения возвратили различные слова.

countEachLabel(ads_keyword)
ans=1×2 table
    Label    Count
    _____    _____

     yes     2377 

countEachLabel(ads_other)
ans=29×2 table
    Label     Count
    ______    _____

    bed       1713 
    bird      1731 
    cat       1733 
    dog       1746 
    down      2359 
    eight     2352 
    five      2357 
    four      2372 
    go        2372 
    happy     1742 
    house     1750 
    left      2353 
    marvin    1746 
    nine      2364 
    no        2375 
    off       2357 
      ⋮

ads_other = shuffle(ads_other);

Создайте учебные предложения и метки

Учебные хранилища данных содержат вторые речевые сигналы, где одно слово произнесено. Вы создадите более комплексное учебное речевое произнесение, которое содержит смесь ключевого слова наряду с другими словами.

Вот пример созданного произнесения. Считайте одно ключевое слово из datastore ключевого слова и нормируйте его, чтобы иметь максимальное значение одного.

yes = read(ads_keyword);
yes = yes / max(abs(yes));

Сигнал имеет неречевые фрагменты (тишина, фоновый шум, и т.д.), которые не содержат полезную информацию о речи. Этот пример удаляет тишину с помощью detectSpeech.

Получите индексы начала и конца полезного фрагмента сигнала.

speechIndices = detectSpeech(yes,fs);

Случайным образом выберите количество слов, чтобы использовать в синтезируемом учебном предложении. Используйте максимум 10 слов.

numWords = randi([0 10]);

Случайным образом выберите местоположение, в котором происходит ключевое слово.

keywordLocation = randi([1 numWords+1]);

Считайте желаемое количество произнесения неключевого слова и создайте учебное предложение и маску.

sentence = [];
mask = [];
for index = 1:numWords+1
    if index == keywordLocation
        sentence = [sentence;yes]; %#ok
        newMask = zeros(size(yes));
        newMask(speechIndices(1,1):speechIndices(1,2)) = 1;
        mask = [mask;newMask]; %#ok
    else
        other = read(ads_other);
        other = other ./ max(abs(other));
        sentence = [sentence;other]; %#ok
        mask = [mask;zeros(size(other))]; %#ok
    end
end

Постройте учебное предложение наряду с маской.

figure
t = (1/fs) * (0:length(sentence)-1);
fig = figure;
plot(t,[sentence,mask])
grid on
xlabel('Time (s)')
legend('Training Signal','Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
title("Example Utterance")

Слушайте учебное предложение.

sound(sentence,fs)

Выделить признаки

Этот пример обучает нейронную сеть для глубокого обучения с помощью 39 коэффициентов MFCC (13 MFCC, 13 дельт и 13 коэффициентов дельты дельты).

Задайте параметры, требуемые для экстракции MFCC.

WindowLength = 512;
OverlapLength = 384;

Создайте объект audioFeatureExtractor выполнить извлечение признаков.

afe = audioFeatureExtractor('SampleRate',fs, ...
                            'Window',hann(WindowLength,'periodic'), ...
                            'OverlapLength',OverlapLength, ...
                            'mfcc',true, ...
                            'mfccDelta',true, ...
                            'mfccDeltaDelta',true);                      

Извлеките функции.

featureMatrix = extract(afe,sentence);
size(featureMatrix)
ans = 1×2

   478    39

Обратите внимание на то, что вы вычисляете MFCC путем скольжения окна через вход, таким образом, матрица функции короче, чем входной речевой сигнал. Каждая строка в featureMatrix соответствует 128 выборкам от речевого сигнала (WindowLength - OverlapLength).

Вычислите маску той же длины как featureMatrix.

HopLength = WindowLength - OverlapLength;
range = HopLength * (1:size(featureMatrix,1)) + HopLength;
featureMask = zeros(size(range));
for index = 1:numel(range)
    featureMask(index) = mode(mask( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength ));
end

Извлеките функции из обучающего набора данных

Синтез предложения и извлечение признаков для целого обучающего набора данных могут быть довольно длительными. Чтобы ускорить обработку, если у вас есть Parallel Computing Toolbox™, делят учебный datastore и процесс каждый раздел на отдельном рабочем.

Выберите много разделов datastore.

numPartitions = 6;

Инициализируйте массивы ячеек для матриц функции и масок.

TrainingFeatures = {};
TrainingMasks= {};

Выполните синтез предложения, извлечение признаков и создание маски с помощью parfor.

emptyCategories = categorical([1 0]);
emptyCategories(:) = [];

tic
parfor ii = 1:numPartitions
    
    subads_keyword = partition(ads_keyword,numPartitions,ii);
    subads_other = partition(ads_other,numPartitions,ii);
    
    count = 1;
    localFeatures = cell(length(subads_keyword.Files),1);
    localMasks = cell(length(subads_keyword.Files),1);
    
    while hasdata(subads_keyword)
        
        % Create a training sentence
        [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength);
        
        % Compute mfcc features
        featureMatrix = extract(afe, sentence);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        
        % Create mask
        hopLength = WindowLength - OverlapLength;
        range = (hopLength) * (1:size(featureMatrix,1)) + hopLength;
        featureMask = zeros(size(range));
        for index = 1:numel(range)
            featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength ));
        end

        localFeatures{count} = featureMatrix;
        localMasks{count} = [emptyCategories,categorical(featureMask)];
        
        count = count + 1;
    end
    
    TrainingFeatures = [TrainingFeatures;localFeatures];
    TrainingMasks = [TrainingMasks;localMasks];
end
fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 33.656404 seconds.

Это - хорошая практика, чтобы нормировать все функции, чтобы иметь нулевое среднее значение и стандартное отклонение единицы. Вычислите среднее и стандартное отклонение для каждого коэффициента и используйте их, чтобы нормировать данные.

sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if reduceDataset
    load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S');
else
    M = mean(featuresMatrix);
    S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
    f = TrainingFeatures{index};
    f = (f - M) ./ S;
    TrainingFeatures{index} = f.'; %#ok
end

Извлеките функции валидации

Извлеките функции MFCC из сигнала валидации.

featureMatrix = extract(afe, audioIn);
featureMatrix(~isfinite(featureMatrix)) = 0;

Нормируйте функции валидации.

FeaturesValidationClean = (featureMatrix - M)./S;
range = HopLength * (1:size(FeaturesValidationClean,1)) + HopLength;

Создайте валидацию маска KWS.

featureMask = zeros(size(range));
for index = 1:numel(range)
    featureMask(index) = mode(KWSBaseline( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength ));
end
BaselineV = categorical(featureMask);

Задайте сетевую архитектуру LSTM

Сети LSTM могут изучить долгосрочные зависимости между временными шагами данных о последовательности. Этот пример использует двунаправленный слой LSTM bilstmLayer (Deep Learning Toolbox), чтобы посмотреть на последовательность и во вперед и в обратные направления.

Задайте входной размер, чтобы быть последовательностями размера numFeatures. Задайте два скрытых двунаправленных слоя LSTM с выходным размером 150 и выведите последовательность. Эта команда дает двунаправленному слою LSTM команду сопоставлять входные временные ряды в 150 функций, которые передаются следующему слою. Задайте два класса включением полносвязного слоя размера 2, сопровождаемый softmax слоем и слоем классификации.

layers = [ ...
    sequenceInputLayer(numFeatures)
    bilstmLayer(150,"OutputMode","sequence")
    bilstmLayer(150,"OutputMode","sequence")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ];

Задайте опции обучения

Задайте опции обучения для классификатора. Установите MaxEpochs к 10 так, чтобы сеть сделала 10, проходит через обучающие данные. Установите MiniBatchSize к 64 так, чтобы сеть посмотрела на 64 учебных сигнала за один раз. Установите Plots к "training-progress" сгенерировать графики, которые показывают процесс обучения количеством увеличений итераций. Установите Verbose к false отключить печать таблицы выход, который соответствует данным, показанным в графике. Установите Shuffle к "every-epoch" переставить обучающую последовательность в начале каждой эпохи. Установите LearnRateSchedule к "piecewise" чтобы уменьшить скорость обучения заданным фактором (0.1) каждый раз, определенное число эпох (5) передало. Установите ValidationData к предикторам валидации и целям.

Этот пример использует адаптивную оценку момента (ADAM) решатель. ADAM выполняет лучше с рекуррентными нейронными сетями (RNNs) как LSTMs, чем стохастический градиентный спуск по умолчанию с импульсом (SGDM) решатель.

maxEpochs = 10;
miniBatchSize = 64;
options = trainingOptions("adam", ...
    "InitialLearnRate",1e-4, ...
    "MaxEpochs",maxEpochs, ...
    "MiniBatchSize",miniBatchSize, ...
    "Shuffle","every-epoch", ...
    "Verbose",false, ...
    "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ...
    "ValidationData",{FeaturesValidationClean.',BaselineV}, ...
    "Plots","training-progress", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",5);

Обучите сеть LSTM

Обучите сеть LSTM с заданными опциями обучения и архитектурой слоя с помощью trainNetwork (Deep Learning Toolbox). Поскольку набор обучающих данных является большим, учебный процесс может занять несколько минут.

[keywordNetNoAugmentation,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if reduceDataset
    load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S');
end

Проверяйте сетевую точность на бесшумный сигнал валидации

Оцените маску KWS для сигнала валидации использование обучившего сеть.

v = classify(keywordNetNoAugmentation,FeaturesValidationClean.');

Вычислите и постройте матрицу беспорядка валидации от векторов из фактических и предполагаемых меток.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Преобразуйте сетевой выход от категориального, чтобы удвоиться.

v = double(v) - 1;
v = repmat(v,HopLength,1);
v = v(:);

Слушайте области ключевого слова, идентифицированные сетью.

sound(audioIn(logical(v)),fs)

Визуализируйте предполагаемые и ожидаемые маски KWS.

baseline = double(BaselineV) - 1;
baseline = repmat(baseline,HopLength,1);
baseline = baseline(:);

t = (1/fs) * (0:length(v)-1);
fig = figure;
plot(t,[audioIn(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noise-Free Speech')

Проверяйте сетевую точность на шумный сигнал валидации

Вы будете теперь проверять сетевую точность на шумный речевой сигнал. Сигнал с шумом был получен путем повреждения чистого сигнала валидации аддитивным белым Гауссовым шумом.

Загрузите сигнал с шумом.

[audioInNoisy,fs] = audioread(fullfile(netFolder,'NoisyKeywordSpeech-16-16-mono-34secs.flac'));
sound(audioInNoisy,fs)

Визуализируйте сигнал.

figure
t = (1/fs) * (0:length(audioInNoisy)-1);
plot(t,audioInNoisy)
grid on
xlabel('Time (s)')
title('Noisy Validation Speech Signal')

Извлеките матрицу функции из сигнала с шумом.

featureMatrixV = extract(afe, audioInNoisy);
featureMatrixV(~isfinite(featureMatrixV)) = 0;
FeaturesValidationNoisy = (featureMatrixV - M)./S;

Передайте матрицу функции сети.

v = classify(keywordNetNoAugmentation,FeaturesValidationNoisy.');

Сравните сетевой выход с базовой линией. Обратите внимание на то, что точность ниже, чем та, которую вы получили для чистого сигнала.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy - Noisy Speech");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Преобразуйте сетевой выход от категориального, чтобы удвоиться.

 v = double(v) - 1;
 v = repmat(v,HopLength,1);
 v = v(:);

Слушайте области ключевого слова, идентифицированные сетью.

sound(audioIn(logical(v)),fs)

Визуализируйте предполагаемые и базовые маски.

t = (1/fs)*(0:length(v)-1);
fig = figure;
plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noisy Speech - No Data Augmentation')

Выполните увеличение данных

Обучивший сеть не выполнял хорошо на сигнале с шумом, потому что обученный набор данных содержал только бесшумные предложения. Вы исправите это путем увеличения набора данных, чтобы включать шумные предложения.

Используйте audioDataAugmenter увеличивать ваш набор данных.

ada = audioDataAugmenter('TimeStretchProbability',0, ...
                         'PitchShiftProbability',0, ...
                         'VolumeControlProbability',0, ...
                         'TimeShiftProbability',0, ...
                         'SNRRange',[-1, 1], ...
                         'AddNoiseProbability',0.85);

С этими настройками, audioDataAugmenter возразите повреждает входной звуковой сигнал с белым Гауссовым шумом с вероятностью 85%. ОСШ случайным образом выбран из области значений [-1 1] (в дБ). Существует 15%-я вероятность, что увеличение не изменяет ваш входной сигнал.

Как пример, передайте звуковой сигнал увеличению.

reset(ads_keyword)
x = read(ads_keyword);
data = augment(ada,x,fs)
data=1×2 table
         Audio          AugmentationInfo
    ________________    ________________

    {16000×1 double}      [1×1 struct]  

Смотрите AugmentationInfo переменная в data проверять, как сигнал был изменен.

data.AugmentationInfo
ans = struct with fields:
    SNR: 0.3410

Сбросьте хранилища данных.

reset(ads_keyword)
reset(ads_other)

Инициализируйте ячейки маски и функция.

TrainingFeatures = {};
TrainingMasks = {};

Выполните извлечение признаков снова. Каждый сигнал повреждается шумом с вероятностью 85%, таким образом, ваш увеличенный набор данных имеет приблизительно 85%-е зашумленные данные и 15%-е бесшумные данные.

tic
parfor ii = 1:numPartitions
    
    subads_keyword = partition(ads_keyword,numPartitions,ii);
    subads_other = partition(ads_other,numPartitions,ii);
    
    count = 1;
    localFeatures = cell(length(subads_keyword.Files),1);
    localMasks = cell(length(subads_keyword.Files),1);
    
    while hasdata(subads_keyword)
        
        [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength);
        
        % Corrupt with noise
        augmentedData = augment(ada,sentence,fs);
        sentence = augmentedData.Audio{1};
        
        % Compute mfcc features
        featureMatrix = extract(afe, sentence);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        
        hopLength = WindowLength - OverlapLength;
        range = hopLength * (1:size(featureMatrix,1)) + hopLength;
        featureMask = zeros(size(range));
        for index = 1:numel(range)
            featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength ));
        end
    
        localFeatures{count} = featureMatrix;
        localMasks{count} = [emptyCategories,categorical(featureMask)];
        
        count = count + 1;
    end
    
    TrainingFeatures = [TrainingFeatures;localFeatures];
    TrainingMasks = [TrainingMasks;localMasks];
end
fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 36.090923 seconds.

Вычислите среднее и стандартное отклонение для каждого коэффициента; используйте их, чтобы нормировать данные.

sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if reduceDataset
    load(fullfile(netFolder,'KWSNet.mat'),'KWSNet','M','S');
else
    M = mean(featuresMatrix);
    S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
    f = TrainingFeatures{index};
    f = (f - M) ./ S;
    TrainingFeatures{index} = f.'; %#ok
end

Нормируйте функции валидации с новыми значениями среднего и стандартного отклонения.

FeaturesValidationNoisy = (featureMatrixV - M)./S;

Переобучите сеть с увеличенным набором данных

Воссоздайте опции обучения. Используйте шумные базовые функции и маску для валидации.

options = trainingOptions("adam", ...
    "InitialLearnRate",1e-4, ...
    "MaxEpochs",maxEpochs, ...
    "MiniBatchSize",miniBatchSize, ...
    "Shuffle","every-epoch", ...
    "Verbose",false, ...
    "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ...
    "ValidationData",{FeaturesValidationNoisy.',BaselineV}, ...
    "Plots","training-progress", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",5);

Обучите сеть.

[KWSNet,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if reduceDataset
    load(fullfile(netFolder,'KWSNet.mat'));
end

Проверьте сетевую точность на сигнале валидации.

v = classify(KWSNet,FeaturesValidationNoisy.');

Сравните предполагаемые и ожидаемые маски KWS.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy with Data Augmentation");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Слушайте идентифицированные области ключевого слова.

v = double(v) - 1;
v = repmat(v,HopLength,1);
v = v(:);
 
sound(audioIn(logical(v)),fs)

Визуализируйте предполагаемые и ожидаемые маски.

fig = figure;
plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noisy Speech - With Data Augmentation')

Ссылки

[1] Начальник П. "Речевые Команды: общедоступный набор данных для распознавания речи однословного", 2017. Доступный от https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Авторское право Google 2017. Речевой Набор данных Команд лицензируется при Приписывании Creative Commons 4,0 лицензии.

Приложение - функции помощника

function [sentence,mask] = HelperSynthesizeSentence(ads_keyword,ads_other,fs,minlength)

% Read one keyword
keyword = read(ads_keyword);
keyword = keyword ./ max(abs(keyword));

% Identify region of interest
speechIndices = detectSpeech(keyword,fs);
if isempty(speechIndices) || diff(speechIndices(1,:)) <= minlength
    speechIndices = [1,length(keyword)];
end
keyword = keyword(speechIndices(1,1):speechIndices(1,2));

% Pick a random number of other words (between 0 and 10)
numWords = randi([0 10]);
% Pick where to insert keyword
loc = randi([1 numWords+1]);
sentence = [];
mask = [];
for index = 1:numWords+1
    if index==loc
        sentence = [sentence;keyword];
        newMask = ones(size(keyword));
        mask = [mask ;newMask];
    else
        
        other = read(ads_other);      
        other = other ./ max(abs(other));
        sentence = [sentence;other];
        mask = [mask;zeros(size(other))];
    end
end
end