Определение ключевого слова в шуме с использованием сетей MFCC и LSTM

Этот пример показывает, как идентифицировать ключевое слово в шумной речи с помощью нейронной сети для глубокого обучения. В частности, в примере используются двунаправленная сеть долгой краткосрочной памяти (BiLSTM) и мел-частотные кепстральные коэффициенты (MFCC).

Введение

Spotting (KWS) является важным компонентом технологий голосовой помощи, где пользователь говорит предопределенное ключевое слово, чтобы разбудить систему, прежде чем говорить полную команду или запрос к устройству.

Этот пример обучает глубокую сеть KWS с характерными последовательностями мел-частотных кепстральных коэффициентов (MFCC). Пример также демонстрирует, как точность сети в зашумленном окружении может быть улучшена с помощью увеличения данных.

Этот пример использует сети долгой краткосрочной памяти (LSTM), которые являются типом рекуррентной нейронной сети (RNN), хорошо подходящим для изучения данных последовательности и timeseries. Сеть LSTM может изучать долгосрочные зависимости между временными шагами последовательности. Слой LSTM (lstmLayer (Deep Learning Toolbox)) может смотреть на временную последовательность в прямом направлении, в то время как двунаправленный слой LSTM (bilstmLayer (Deep Learning Toolbox)) может посмотреть на временную последовательность как в прямом, так и в обратном направлениях. Этот пример использует двунаправленный слой LSTM.

Пример использует Google Speech Commands Dataset, чтобы обучить модель глубокого обучения. Чтобы запустить пример, необходимо сначала загрузить набор данных. Если вы не хотите загружать набор данных или обучать сеть, то можно скачать и использовать предварительно обученную сеть, открывая этот пример в MATLAB ® и выполняя линии 3-10 примера.

Ключевое слово Spot с предварительно обученной сетью

Прежде чем подробно войти в процесс обучения, вы будете загружать и использовать предварительно обученную сеть споттинга ключевых слов для идентификации ключевого слова.

В этом примере ключевым словом для определения является YES.

Считайте тестовый сигнал, где произнесено ключевое слово.

[audioIn, fs] = audioread('keywordTestSignal.wav');
sound(audioIn,fs)

Загрузите и загрузите предварительно обученную сеть, средние векторы (M) и стандартное отклонение (S), используемые для нормализации функции, а также 2 аудио файлов, используемые для валидации сети позже в примере.

url = 'http://ssd.mathworks.com/supportfiles/audio/KeywordSpotting.zip';
downloadNetFolder = tempdir;
netFolder = fullfile(downloadNetFolder,'KeywordSpotting');
if ~exist(netFolder,'dir')
    disp('Downloading pretrained network and audio files (4 files - 7 MB) ...')
    unzip(url,downloadNetFolder)
end
load(fullfile(netFolder,'KWSNet.mat'));

Создайте audioFeatureExtractor объект для выполнения редукции данных.

WindowLength = 512;
OverlapLength = 384;
afe = audioFeatureExtractor('SampleRate',fs, ...
                            'Window',hann(WindowLength,'periodic'), ...
                            'OverlapLength',OverlapLength, ...
                            'mfcc',true, ...
                            'mfccDelta', true, ...
                            'mfccDeltaDelta',true);                   

Извлеките функции из тестового сигнала и нормализуйте их.

features = extract(afe,audioIn);   

features = (features - M)./S;

Вычислите ключевое слово spotting двоичную маску. Значение маски единицы соответствует сегменту, в котором было замечено ключевое слово.

mask = classify(KWSNet,features.');

Каждая выборка в маске соответствует 128 выборкам из речевого сигнала (WindowLength - OverlapLength).

Разверните маску до длины сигнала.

mask = repmat(mask, WindowLength-OverlapLength, 1);
mask = double(mask) - 1;
mask = mask(:);

Постройте график тестового сигнала и маски.

figure
audioIn = audioIn(1:length(mask));
t = (0:length(audioIn)-1)/fs;
plot(t, audioIn)
grid on
hold on
plot(t, mask)
legend('Speech', 'YES')

Послушайте заметное ключевое слово.

sound(audioIn(mask==1),fs)

Обнаружение команд, использующих передачу потокового аудио с микрофона

Протестируйте предварительно обученную сеть обнаружения команд на передаче потокового аудио с микрофона. Попробуйте произнести случайные слова, включая ключевое слово (YES).

Функции generateMATLABFunction на audioFeatureExtractor объект, чтобы создать функцию редукции данных. Вы будете использовать эту функцию в цикле обработки.

generateMATLABFunction(afe,'generateKeywordFeatures','IsStreaming',true);

Задайте устройство чтения аудио устройства, которое может считывать аудио с микрофона. Установите длину системы координат равную длине скачка. Это позволяет вам вычислять новый набор функций для каждого нового аудио системы координат из микрофона.

HopLength = WindowLength - OverlapLength;
FrameLength = HopLength;
adr = audioDeviceReader('SampleRate',fs, ...
                        'SamplesPerFrame',FrameLength);

Создайте возможности для визуализации речевого сигнала и предполагаемой маски.

scope = timescope('SampleRate',fs, ...
                  'TimeSpanSource','property', ...
                  'TimeSpan',5, ...
                  'TimeSpanOverrunAction','Scroll', ...
                  'BufferLength',fs*5*2, ...
                  'ShowLegend',true, ...
                  'ChannelNames',{'Speech','Keyword Mask'}, ...
                  'YLimits',[-1.2 1.2], ...
                  'Title','Keyword Spotting');

Определите скорость, с которой вы оцениваете маску. Вы будете генерировать маску один раз в NumHopsPerUpdate аудио систем координат.

NumHopsPerUpdate = 16;

Инициализируйте буфер для аудио.

dataBuff = dsp.AsyncBuffer(WindowLength);

Инициализируйте буфер для вычисляемых функций.

featureBuff = dsp.AsyncBuffer(NumHopsPerUpdate);

Инициализируйте буфер, чтобы управлять графическим изображением аудио и маски.

plotBuff = dsp.AsyncBuffer(NumHopsPerUpdate*WindowLength);

Чтобы запустить цикл бессрочно, установите timeLimit на Inf. Чтобы остановить симуляцию, закройте возможности.

timeLimit = 20;

tic
while toc < timeLimit
    
    data = adr();
    write(dataBuff, data);
    write(plotBuff, data);

    frame = read(dataBuff,WindowLength,OverlapLength);
    features = generateKeywordFeatures(frame,fs);
    write(featureBuff,features.');

    if featureBuff.NumUnreadSamples == NumHopsPerUpdate
        featureMatrix = read(featureBuff);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        featureMatrix = (featureMatrix - M)./S;
        
        [keywordNet, v] = classifyAndUpdateState(KWSNet,featureMatrix.');
        v = double(v) - 1;
        v = repmat(v, HopLength, 1);
        v = v(:);
        v = mode(v);
        v = repmat(v, NumHopsPerUpdate * HopLength,1);
        
        data = read(plotBuff);
        scope([data, v]);
        
        if ~isVisible(scope)
            break;
        end
    end
end
hide(scope)

В остальном примере вы научитесь обучать сеть споттинга по ключевым словам.

Сводные данные процесса обучения

Процесс обучения проходит следующие шаги:

  1. Смотрите ключевое слово «золотой стандарт», определяющее базовый уровень по сигналу валидации.

  2. Создайте обучающие высказывания из набора данных без шума.

  3. Обучите ключевое слово spotting LSTM network с помощью последовательностей MFCC, извлеченных из этих высказываний.

  4. Проверьте точность сети путем сравнения базового уровня валидации с выходом сети при применении к сигналу валидации.

  5. Проверьте точность сети на наличие сигнала валидации, поврежденного шумом.

  6. Увеличение обучающего набора данных путем введения шума в речевые данные с помощью audioDataAugmenter.

  7. Переобучите сеть с помощью дополненного набора данных.

  8. Проверьте, что переобученная сеть теперь приводит к более высокой точности при применении к шумному сигналу валидации.

Смотрите сигнал валидации

Вы используете пример речевого сигнала, чтобы подтвердить сеть KWS. Сигнал валидации состоит из 34 секунд речи с ключевым словом YES, появляющимся периодически.

Загрузите сигнал валидации.

[audioIn,fs] = audioread(fullfile(netFolder,'KeywordSpeech-16-16-mono-34secs.flac'));

Слушайте сигнал.

sound(audioIn,fs)

Визуализируйте сигнал.

figure
t = (1/fs) * (0:length(audioIn)-1);
plot(t,audioIn);
grid on
xlabel('Time (s)')
title('Validation Speech Signal')

Осмотрите базовый уровень KWS

Загрузите базовый уровень KWS. Эта базовая линия была получена с использованием speech2text: Создайте маску споттинга ключевых слов с помощью Audio Labeler.

load('KWSBaseline.mat','KWSBaseline')

Базовая линия является логическим вектором той же длины, что и аудиосигнал валидации. Сегменты в audioIn где произнесено ключевое слово, задается единица KWSBaseline.

Визуализируйте речевой сигнал вместе с базовой линией KWS.

fig = figure;
plot(t,[audioIn,KWSBaseline'])
grid on
xlabel('Time (s)')
legend('Speech','KWS Baseline','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
title("Validation Signal")

Прослушайте сегменты речи, идентифицированные как ключевые слова.

sound(audioIn(KWSBaseline),fs)

Цель сети, которую вы обучаете, состоит в том, чтобы вывести KWS-маску нулей и таковых, подобных этой базовой линии.

Загрузка набора данных речевых команд

Загрузите и извлеките набор данных речевых команд Google [1].

url = 'https://ssd.mathworks.com/supportfiles/audio/google_speech.zip';

downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,'google_speech');

if ~exist(datasetFolder,'dir')
    disp('Downloading Google speech commands data set (1.5 GB)...')
    unzip(url,datasetFolder)
end

Создайте audioDatastore это указывает на набор данных.

ads = audioDatastore(datasetFolder,'LabelSource','foldername','Includesubfolders',true);
ads = shuffle(ads);

Набор данных содержит файлы фонового шума, которые не используются в этом примере. Использование subset чтобы создать новый datastore, который не имеет файлов фонового шума.

isBackNoise = ismember(ads.Labels,"background");
ads = subset(ads,~isBackNoise);

Набор данных имеет приблизительно 65 000 односекундных длинных высказываний из 30 коротких слов (включая ключевое слово YES). Получите разбивку распределения слов в datastore.

countEachLabel(ads)
ans=30×2 table
    Label     Count
    ______    _____

    bed       1713 
    bird      1731 
    cat       1733 
    dog       1746 
    down      2359 
    eight     2352 
    five      2357 
    four      2372 
    go        2372 
    happy     1742 
    house     1750 
    left      2353 
    marvin    1746 
    nine      2364 
    no        2375 
    off       2357 
      ⋮

Разделение ads в два хранилища данных: Первый datastore содержит файлы, соответствующие ключевому слову. Второй datastore содержит все другие слова.

keyword = 'yes';
isKeyword = ismember(ads.Labels,keyword);
ads_keyword = subset(ads,isKeyword);
ads_other = subset(ads,~isKeyword);

Чтобы обучить сеть со набором данных в целом и достичь максимально возможной точности, установите reduceDataset на false. Чтобы запустить этот пример быстро, установите reduceDataset на true.

reduceDataset = false; 
if reduceDataset 
    % Reduce the dataset by a factor of 20 
    ads_keyword = splitEachLabel (ads_keyword,round (numel (ads_keyword.Files )/20 ));
    numUniqueLabels = numel (уникальный (ads_other.Labels)); 
    ads_other = splitEachLabel (ads_other,round (numel (ads_other.Files )/numUniqueLabels/20)); 
end

Получите разбивку распределения слов в каждом datastore. Тасуйте ads_other datastore, чтобы последовательные чтения возвращали различные слова.

countEachLabel(ads_keyword)
ans=1×2 table
    Label    Count
    _____    _____

     yes     2377 

countEachLabel(ads_other)
ans=29×2 table
    Label     Count
    ______    _____

    bed       1713 
    bird      1731 
    cat       1733 
    dog       1746 
    down      2359 
    eight     2352 
    five      2357 
    four      2372 
    go        2372 
    happy     1742 
    house     1750 
    left      2353 
    marvin    1746 
    nine      2364 
    no        2375 
    off       2357 
      ⋮

ads_other = shuffle(ads_other);

Создание обучающих предложений и меток

Обучающие хранилища данных содержат речевые сигналы на одну секунду, где произносится одно слово. Вы создадите более сложные обучающие речевые высказывания, которые содержат смесь ключевого слова наряду с другими словами.

Вот пример построенного высказывания. Прочитайте одно ключевое слово из ключевого слова datastore и нормализуйте его, чтобы иметь максимальное значение единицы.

yes = read(ads_keyword);
yes = yes / max(abs(yes));

Сигнал имеет неречевые фрагменты (тишина, фоновый шум и т.д.), которые не содержат полезной речевой информации. Этот пример удаляет молчание с помощью detectSpeech.

Получите начало и конец индексы полезного фрагмента сигнала.

speechIndices = detectSpeech(yes,fs);

Случайным образом выберите количество слов для использования в синтезированном предложении обучения. Используйте не более 10 слов.

numWords = randi([0 10]);

Случайным образом выберите место, в котором находится ключевое слово.

keywordLocation = randi([1 numWords+1]);

Прочитайте желаемое количество высказываний, не являющихся ключевыми словами, и создайте обучающее предложение и маску.

sentence = [];
mask = [];
for index = 1:numWords+1
    if index == keywordLocation
        sentence = [sentence;yes]; %#ok
        newMask = zeros(size(yes));
        newMask(speechIndices(1,1):speechIndices(1,2)) = 1;
        mask = [mask;newMask]; %#ok
    else
        other = read(ads_other);
        other = other ./ max(abs(other));
        sentence = [sentence;other]; %#ok
        mask = [mask;zeros(size(other))]; %#ok
    end
end

Постройте график обучающего предложения вместе с маской.

figure
t = (1/fs) * (0:length(sentence)-1);
fig = figure;
plot(t,[sentence,mask])
grid on
xlabel('Time (s)')
legend('Training Signal','Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
title("Example Utterance")

Послушайте обучающее предложение.

sound(sentence,fs)

Извлечение функций

Этот пример обучает нейронную сеть для глубокого обучения, используя 39 коэффициентов MFCC (13 коэффициентов MFCC, 13 дельта и 13 коэффициентов дельта-дельта).

Определите параметры, необходимые для извлечения MFCC.

WindowLength = 512;
OverlapLength = 384;

Создайте объект audioFeatureExtractor, чтобы выполнить редукцию данных.

afe = audioFeatureExtractor('SampleRate',fs, ...
                            'Window',hann(WindowLength,'periodic'), ...
                            'OverlapLength',OverlapLength, ...
                            'mfcc',true, ...
                            'mfccDelta',true, ...
                            'mfccDeltaDelta',true);                      

Извлеките функции.

featureMatrix = extract(afe,sentence);
size(featureMatrix)
ans = 1×2

   478    39

Обратите внимание, что вы вычисляете MFCC путем скольжения окна через вход, поэтому матрица функций короче, чем входной речевой сигнал. Каждая строка в featureMatrix соответствует 128 выборкам из речевого сигнала (WindowLength - OverlapLength).

Вычислите маску той же длины, что и featureMatrix.

HopLength = WindowLength - OverlapLength;
range = HopLength * (1:size(featureMatrix,1)) + HopLength;
featureMask = zeros(size(range));
for index = 1:numel(range)
    featureMask(index) = mode(mask( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength ));
end

Извлечение функций из обучающего набора данных

Синтез предложения и редукция данных для всего обучающего набора данных могут быть довольно длительными. Чтобы ускорить обработку, если у вас есть Parallel Computing Toolbox™, разделите обучающий datastore и обработайте каждый раздел на отдельном рабочем месте.

Выберите несколько разделов datastore.

numPartitions = 6;

Инициализируйте массивы ячеек для матриц функций и масок.

TrainingFeatures = {};
TrainingMasks= {};

Выполните синтез предложения, редукцию данных и создание маски с помощью parfor.

emptyCategories = categorical([1 0]);
emptyCategories(:) = [];

tic
parfor ii = 1:numPartitions
    
    subads_keyword = partition(ads_keyword,numPartitions,ii);
    subads_other = partition(ads_other,numPartitions,ii);
    
    count = 1;
    localFeatures = cell(length(subads_keyword.Files),1);
    localMasks = cell(length(subads_keyword.Files),1);
    
    while hasdata(subads_keyword)
        
        % Create a training sentence
        [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength);
        
        % Compute mfcc features
        featureMatrix = extract(afe, sentence);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        
        % Create mask
        hopLength = WindowLength - OverlapLength;
        range = (hopLength) * (1:size(featureMatrix,1)) + hopLength;
        featureMask = zeros(size(range));
        for index = 1:numel(range)
            featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength ));
        end

        localFeatures{count} = featureMatrix;
        localMasks{count} = [emptyCategories,categorical(featureMask)];
        
        count = count + 1;
    end
    
    TrainingFeatures = [TrainingFeatures;localFeatures];
    TrainingMasks = [TrainingMasks;localMasks];
end
fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 33.656404 seconds.

Это хорошая практика, чтобы нормализовать все функции, чтобы иметь нулевое среднее и единичное стандартное отклонение. Вычислите среднее и стандартное отклонение для каждого коэффициента и используйте их, чтобы нормализовать данные.

sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if reduceDataset
    load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S');
else
    M = mean(featuresMatrix);
    S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
    f = TrainingFeatures{index};
    f = (f - M) ./ S;
    TrainingFeatures{index} = f.'; %#ok
end

Извлечение функций валидации

Извлеките функции MFCC из сигнала валидации.

featureMatrix = extract(afe, audioIn);
featureMatrix(~isfinite(featureMatrix)) = 0;

Нормализуйте функции валидации.

FeaturesValidationClean = (featureMatrix - M)./S;
range = HopLength * (1:size(FeaturesValidationClean,1)) + HopLength;

Создайте маску KWS валидации.

featureMask = zeros(size(range));
for index = 1:numel(range)
    featureMask(index) = mode(KWSBaseline( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength ));
end
BaselineV = categorical(featureMask);

Определите сетевую архитектуру LSTM

Сети LSTM могут изучать долгосрочные зависимости между временными шагами данных последовательности. Этот пример использует двунаправленный слой LSTM bilstmLayer (Deep Learning Toolbox), чтобы посмотреть на последовательность как в прямом, так и в обратном направлениях.

Задайте размер входа, чтобы быть последовательностями размера numFeatures. Задайте два скрытых двунаправленных слоя LSTM с размером выходом 150 и выведите последовательность. Эта команда предписывает двунаправленному слою LSTM преобразовать входные временные ряды в 150 функции, которые передаются следующему слою. Задайте два класса путем включения полносвязного слоя размера 2, а затем слоя softmax и слоя классификации.

layers = [ ...
    sequenceInputLayer(numFeatures)
    bilstmLayer(150,"OutputMode","sequence")
    bilstmLayer(150,"OutputMode","sequence")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ];

Задайте опции обучения

Задайте опции обучения для классификатора. Задайте MaxEpochs 10 так, чтобы сеть 10 прошла через обучающие данные. Задайте MiniBatchSize на 64 чтобы сеть рассматривала 64 обучающих сигнала за раз. Задайте Plots на "training-progress" чтобы сгенерировать графики, которые показывают процесс обучения с увеличениями количества итераций. Задайте Verbose на false чтобы отключить печать выхода таблицы, который соответствует данным, показанным на графике. Задайте Shuffle на "every-epoch" тасовать обучающую последовательность в начале каждой эпохи. Задайте LearnRateSchedule на "piecewise" уменьшить скорость обучения на заданный коэффициент (0,1) каждый раз, когда прошло определенное количество эпох (5). Задайте ValidationData к предикторам и целям валидации.

Этот пример использует решатель адаптивной оценки момента (ADAM). ADAM работает лучше с рекуррентными нейронными сетями (RNNs), такими как LSTMs, чем стохастический градиентный спуск по умолчанию с импульсом (SGDM) решателя.

maxEpochs = 10;
miniBatchSize = 64;
options = trainingOptions("adam", ...
    "InitialLearnRate",1e-4, ...
    "MaxEpochs",maxEpochs, ...
    "MiniBatchSize",miniBatchSize, ...
    "Shuffle","every-epoch", ...
    "Verbose",false, ...
    "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ...
    "ValidationData",{FeaturesValidationClean.',BaselineV}, ...
    "Plots","training-progress", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",5);

Обучите сеть LSTM

Обучите сеть LSTM с заданными опциями обучения и архитектурой слоя с помощью trainNetwork (Deep Learning Toolbox). Поскольку набор обучающих данных большая, процесс обучения может занять несколько минут.

[keywordNetNoAugmentation,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if reduceDataset
    load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S');
end

Проверяйте точность сети на наличие сигнала валидации без шума

Оцените маску KWS для сигнала валидации с помощью обученной сети.

v = classify(keywordNetNoAugmentation,FeaturesValidationClean.');

Вычислите и постройте матрицу неточностей валидации из векторов фактических и оцененных меток.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Преобразуйте выход сети из категориального в двойной.

v = double(v) - 1;
v = repmat(v,HopLength,1);
v = v(:);

Прослушайте области ключевых слов, определенные сетью.

sound(audioIn(logical(v)),fs)

Визуализация предполагаемых и ожидаемых масок KWS.

baseline = double(BaselineV) - 1;
baseline = repmat(baseline,HopLength,1);
baseline = baseline(:);

t = (1/fs) * (0:length(v)-1);
fig = figure;
plot(t,[audioIn(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noise-Free Speech')

Проверяйте точность сети на наличие шумного сигнала валидации

Теперь вы проверите точность сети на наличие шумного речевого сигнала. Сигналы с шумом получали путем повреждения сигнала чистой валидации аддитивным белым Гауссовым шумом.

Загрузите сигнал с шумом.

[audioInNoisy,fs] = audioread(fullfile(netFolder,'NoisyKeywordSpeech-16-16-mono-34secs.flac'));
sound(audioInNoisy,fs)

Визуализируйте сигнал.

figure
t = (1/fs) * (0:length(audioInNoisy)-1);
plot(t,audioInNoisy)
grid on
xlabel('Time (s)')
title('Noisy Validation Speech Signal')

Извлеките матрицу функции из сигнала с шумом.

featureMatrixV = extract(afe, audioInNoisy);
featureMatrixV(~isfinite(featureMatrixV)) = 0;
FeaturesValidationNoisy = (featureMatrixV - M)./S;

Передайте матрицу функций в сеть.

v = classify(keywordNetNoAugmentation,FeaturesValidationNoisy.');

Сравните выходы сети с базовым уровнем. Обратите внимание, что точность ниже, чем та, что вы получили для чистого сигнала.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy - Noisy Speech");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Преобразуйте выход сети из категориального в двойной.

 v = double(v) - 1;
 v = repmat(v,HopLength,1);
 v = v(:);

Прослушайте области ключевых слов, определенные сетью.

sound(audioIn(logical(v)),fs)

Визуализируйте предполагаемые и базовые маски.

t = (1/fs)*(0:length(v)-1);
fig = figure;
plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noisy Speech - No Data Augmentation')

Выполните Увеличение Данных

Обученная сеть не показала хороших результатов на сигнал с шумом, потому что обученный набор данных содержал только предложения без шума. Вы исправите это, увеличив свой набор данных, чтобы включить шумные предложения.

Использование audioDataAugmenter чтобы увеличить ваш набор данных.

ada = audioDataAugmenter('TimeStretchProbability',0, ...
                         'PitchShiftProbability',0, ...
                         'VolumeControlProbability',0, ...
                         'TimeShiftProbability',0, ...
                         'SNRRange',[-1, 1], ...
                         'AddNoiseProbability',0.85);

С этими настройками audioDataAugmenter объект повреждает вход аудиосигнал с белым Гауссовым шумом с вероятностью 85%. ОСШ выбирается случайным образом из области значений [-1 1] (в дБ). Существует 15% вероятность того, что augmenter не изменит ваш входной сигнал.

В качестве примера передайте аудиосигнал в augmenter.

reset(ads_keyword)
x = read(ads_keyword);
data = augment(ada,x,fs)
data=1×2 table
         Audio          AugmentationInfo
    ________________    ________________

    {16000×1 double}      [1×1 struct]  

Осмотрите AugmentationInfo переменная в data чтобы проверить, как был изменен сигнал.

data.AugmentationInfo
ans = struct with fields:
    SNR: 0.3410

Сбросьте хранилища данных.

reset(ads_keyword)
reset(ads_other)

Инициализируйте функцию и маскируйте ячейки.

TrainingFeatures = {};
TrainingMasks = {};

Еще раз выполните редукцию данных. Каждый сигнал повреждается шумом с вероятностью 85%, поэтому ваш дополненный набор данных имеет приблизительно 85% зашумленных данных и 15% бесшумных данных.

tic
parfor ii = 1:numPartitions
    
    subads_keyword = partition(ads_keyword,numPartitions,ii);
    subads_other = partition(ads_other,numPartitions,ii);
    
    count = 1;
    localFeatures = cell(length(subads_keyword.Files),1);
    localMasks = cell(length(subads_keyword.Files),1);
    
    while hasdata(subads_keyword)
        
        [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength);
        
        % Corrupt with noise
        augmentedData = augment(ada,sentence,fs);
        sentence = augmentedData.Audio{1};
        
        % Compute mfcc features
        featureMatrix = extract(afe, sentence);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        
        hopLength = WindowLength - OverlapLength;
        range = hopLength * (1:size(featureMatrix,1)) + hopLength;
        featureMask = zeros(size(range));
        for index = 1:numel(range)
            featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength ));
        end
    
        localFeatures{count} = featureMatrix;
        localMasks{count} = [emptyCategories,categorical(featureMask)];
        
        count = count + 1;
    end
    
    TrainingFeatures = [TrainingFeatures;localFeatures];
    TrainingMasks = [TrainingMasks;localMasks];
end
fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 36.090923 seconds.

Вычислите среднее и стандартное отклонение для каждого коэффициента; используйте их для нормализации данных.

sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if reduceDataset
    load(fullfile(netFolder,'KWSNet.mat'),'KWSNet','M','S');
else
    M = mean(featuresMatrix);
    S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
    f = TrainingFeatures{index};
    f = (f - M) ./ S;
    TrainingFeatures{index} = f.'; %#ok
end

Нормализуйте функции валидации с помощью новых средних и стандартных значений отклонения.

FeaturesValidationNoisy = (featureMatrixV - M)./S;

Переобучение сети с помощью дополненного набора данных

Создайте параметры обучения заново. Используйте шумные базовые функции и маску для валидации.

options = trainingOptions("adam", ...
    "InitialLearnRate",1e-4, ...
    "MaxEpochs",maxEpochs, ...
    "MiniBatchSize",miniBatchSize, ...
    "Shuffle","every-epoch", ...
    "Verbose",false, ...
    "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ...
    "ValidationData",{FeaturesValidationNoisy.',BaselineV}, ...
    "Plots","training-progress", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",5);

Обучите сеть.

[KWSNet,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if reduceDataset
    load(fullfile(netFolder,'KWSNet.mat'));
end

Проверьте точность сети по сигналу валидации.

v = classify(KWSNet,FeaturesValidationNoisy.');

Сравните предполагаемые и ожидаемые маски KWS.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy with Data Augmentation");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Прослушайте идентифицированные области ключевых слов.

v = double(v) - 1;
v = repmat(v,HopLength,1);
v = v(:);
 
sound(audioIn(logical(v)),fs)

Визуализация предполагаемых и ожидаемых масок.

fig = figure;
plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noisy Speech - With Data Augmentation')

Ссылки

[1] Warden P. «Speech Commands: A public dataset for single-word speech recognition», 2017. Доступно из https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Копирайт Google 2017. Набор данных Speech Commands лицензируется лицензией Creative Commons Attribution 4.0.

Приложение - Вспомогательные функции

function [sentence,mask] = HelperSynthesizeSentence(ads_keyword,ads_other,fs,minlength)

% Read one keyword
keyword = read(ads_keyword);
keyword = keyword ./ max(abs(keyword));

% Identify region of interest
speechIndices = detectSpeech(keyword,fs);
if isempty(speechIndices) || diff(speechIndices(1,:)) <= minlength
    speechIndices = [1,length(keyword)];
end
keyword = keyword(speechIndices(1,1):speechIndices(1,2));

% Pick a random number of other words (between 0 and 10)
numWords = randi([0 10]);
% Pick where to insert keyword
loc = randi([1 numWords+1]);
sentence = [];
mask = [];
for index = 1:numWords+1
    if index==loc
        sentence = [sentence;keyword];
        newMask = ones(size(keyword));
        mask = [mask ;newMask];
    else
        
        other = read(ads_other);      
        other = other ./ max(abs(other));
        sentence = [sentence;other];
        mask = [mask;zeros(size(other))];
    end
end
end