exponenta event banner

Определение ключевых слов в шуме с использованием сетей MFCC и LSTM

В этом примере показано, как идентифицировать ключевое слово в шумной речи с помощью сети глубокого обучения. В частности, в примере используется сеть двунаправленной долговременной памяти (BiLSTM) и кепстральные коэффициенты частоты (MFCC).

Введение

Ключевое слово spotting (KWS) является важным компонентом технологий voice-assist, где пользователь говорит предопределенное ключевое слово для пробуждения системы, прежде чем передать полную команду или запрос устройству.

В этом примере обучается глубокая сеть KWS с характерными последовательностями кепстральных коэффициентов (MFCC). Пример также демонстрирует, как можно повысить точность сети в шумной среде с помощью увеличения объема данных.

В этом примере используются сети долговременной кратковременной памяти (LSTM), которые являются типом рецидивирующей нейронной сети (RNN), хорошо подходящей для изучения последовательности и данных временных рядов. Сеть LSTM может изучать долгосрочные зависимости между временными шагами последовательности. Уровень LSTM (lstmLayer) может просматривать временную последовательность в прямом направлении, в то время как двунаправленный уровень LSTM (bilstmLayer) может смотреть на временную последовательность как в прямом, так и в обратном направлениях. В этом примере используется двунаправленный уровень LSTM.

В примере используется набор данных речевых команд google для обучения модели глубокого обучения. Для выполнения примера необходимо сначала загрузить набор данных. Если не требуется загружать набор данных или обучать сеть, можно загрузить и использовать предварительно обученную сеть, открыв этот пример в MATLAB ® и выполняя строки 3-10 примера.

Ключевое слово Spot с предварительно обученной сетью

Перед подробным началом процесса обучения необходимо загрузить и использовать предварительно обученную сеть определения ключевых слов для идентификации ключевого слова.

В этом примере ключевое слово spot - YES.

Считывайте тестовый сигнал, где произносится ключевое слово.

[audioIn, fs] = audioread('keywordTestSignal.wav');
sound(audioIn,fs)

Загрузите и загрузите предварительно обученную сеть, векторы среднего (M) и стандартного отклонения (S), используемые для нормализации функций, а также 2 аудиофайла, используемые для проверки правильности сети позже в примере.

url = 'http://ssd.mathworks.com/supportfiles/audio/KeywordSpotting.zip';
downloadNetFolder = tempdir;
netFolder = fullfile(downloadNetFolder,'KeywordSpotting');
if ~exist(netFolder,'dir')
    disp('Downloading pretrained network and audio files (4 files - 7 MB) ...')
    unzip(url,downloadNetFolder)
end
load(fullfile(netFolder,'KWSNet.mat'));

Создание audioFeatureExtractor Объект (Audio Toolbox) для извлечения элементов.

WindowLength = 512;
OverlapLength = 384;
afe = audioFeatureExtractor('SampleRate',fs, ...
                            'Window',hann(WindowLength,'periodic'), ...
                            'OverlapLength',OverlapLength, ...
                            'mfcc',true, ...
                            'mfccDelta', true, ...
                            'mfccDeltaDelta',true);                   

Извлеките элементы из тестового сигнала и нормализуйте их.

features = extract(afe,audioIn);   

features = (features - M)./S;

Вычислите ключевое слово spotting binary mask. Значение маски, равное единице, соответствует сегменту, в котором было замечено ключевое слово.

mask = classify(KWSNet,features.');

Каждая выборка в маске соответствует 128 выборкам из речевого сигнала (WindowLength - OverlapLength).

Разверните маску до длины сигнала.

mask = repmat(mask, WindowLength-OverlapLength, 1);
mask = double(mask) - 1;
mask = mask(:);

Постройте график тестового сигнала и маски.

figure
audioIn = audioIn(1:length(mask));
t = (0:length(audioIn)-1)/fs;
plot(t, audioIn)
grid on
hold on
plot(t, mask)
legend('Speech', 'YES')

Прослушать ключевое слово «» Пятно «».

sound(audioIn(mask==1),fs)

Команды обнаружения с помощью потокового звука с микрофона

Протестируйте предварительно обученную сеть обнаружения команд на потоковом аудио с микрофона. Попробуйте произнести случайные слова, включая ключевое слово (YES).

Звонить generateMATLABFunction (Панель звуковых инструментов) на audioFeatureExtractor объект для создания функции извлечения элемента. Эта функция используется в цикле обработки.

generateMATLABFunction(afe,'generateKeywordFeatures','IsStreaming',true);

Определите устройство чтения аудио, которое может считывать аудио с микрофона. Задайте длину кадра, равную длине транзитного участка. Это позволяет вычислить новый набор функций для каждого нового аудиокадра с микрофона.

HopLength = WindowLength - OverlapLength;
FrameLength = HopLength;
adr = audioDeviceReader('SampleRate',fs, ...
                        'SamplesPerFrame',FrameLength);

Создайте область для визуализации речевого сигнала и предполагаемой маски.

scope = timescope('SampleRate',fs, ...
                  'TimeSpanSource','property', ...
                  'TimeSpan',5, ...
                  'TimeSpanOverrunAction','Scroll', ...
                  'BufferLength',fs*5*2, ...
                  'ShowLegend',true, ...
                  'ChannelNames',{'Speech','Keyword Mask'}, ...
                  'YLimits',[-1.2 1.2], ...
                  'Title','Keyword Spotting');

Определите скорость, с которой оценивается маска. Маска генерируется один раз NumHopsPerUpdate аудиокадры.

NumHopsPerUpdate = 16;

Инициализируйте буфер для звука.

dataBuff = dsp.AsyncBuffer(WindowLength);

Инициализируйте буфер для вычисляемых элементов.

featureBuff = dsp.AsyncBuffer(NumHopsPerUpdate);

Инициализируйте буфер для управления печатью звука и маски.

plotBuff = dsp.AsyncBuffer(NumHopsPerUpdate*WindowLength);

Для бесконечного запуска цикла установите timeLimit в значение Inf. Чтобы остановить моделирование, закройте область.

timeLimit = 20;

tic
while toc < timeLimit
    
    data = adr();
    write(dataBuff, data);
    write(plotBuff, data);

    frame = read(dataBuff,WindowLength,OverlapLength);
    features = generateKeywordFeatures(frame,fs);
    write(featureBuff,features.');

    if featureBuff.NumUnreadSamples == NumHopsPerUpdate
        featureMatrix = read(featureBuff);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        featureMatrix = (featureMatrix - M)./S;
        
        [keywordNet, v] = classifyAndUpdateState(KWSNet,featureMatrix.');
        v = double(v) - 1;
        v = repmat(v, HopLength, 1);
        v = v(:);
        v = mode(v);
        v = repmat(v, NumHopsPerUpdate * HopLength,1);
        
        data = read(plotBuff);
        scope([data, v]);
        
        if ~isVisible(scope)
            break;
        end
    end
end
hide(scope)

В остальном примере вы узнаете, как обучить ключевое слово spotting network.

Краткое описание процесса обучения

Процесс обучения проходит следующие этапы:

  1. Проверьте ключевое слово «gold standard», указывающее базовую линию на сигнале проверки.

  2. Создание обучающих высказываний из набора данных без шума.

  3. Обучение сети LSTM с помощью ключевого слова spotting с использованием последовательностей MFCC, извлеченных из этих высказываний.

  4. Проверьте точность сети путем сравнения базовой линии проверки с выходом сети при применении к сигналу проверки.

  5. Проверьте точность сети на наличие сигнала проверки подлинности, поврежденного шумом.

  6. Увеличение обучающего набора данных путем введения шума в речевые данные с использованием audioDataAugmenter (Панель звуковых инструментов).

  7. Переподготовьте сеть с помощью дополненного набора данных.

  8. Убедитесь, что переученная сеть теперь дает более высокую точность при применении к шумному сигналу проверки подлинности.

Проверка сигнала проверки

Для проверки правильности сети KWS используется образец речевого сигнала. Сигнал проверки подлинности состоит из 34 секунд речи с ключевым словом YES, появляющимся с перерывами.

Загрузите сигнал проверки.

[audioIn,fs] = audioread(fullfile(netFolder,'KeywordSpeech-16-16-mono-34secs.flac'));

Слушай сигнал.

sound(audioIn,fs)

Визуализируйте сигнал.

figure
t = (1/fs) * (0:length(audioIn)-1);
plot(t,audioIn);
grid on
xlabel('Time (s)')
title('Validation Speech Signal')

Проверка базового уровня KWS

Загрузите базовую линию KWS. Этот базовый уровень был получен с использованием speech2text: Создать ключевое слово Spotting Mask с помощью Audio Labeler (Audio Toolbox).

load('KWSBaseline.mat','KWSBaseline')

Базовая линия является логическим вектором той же длины, что и звуковой сигнал проверки подлинности. Сегменты в audioIn где произносится ключевое слово, устанавливаются на единицу в KWSBaseline.

Визуализируйте речевой сигнал вместе с базовой линией KWS.

fig = figure;
plot(t,[audioIn,KWSBaseline'])
grid on
xlabel('Time (s)')
legend('Speech','KWS Baseline','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
title("Validation Signal")

Прослушивание речевых сегментов, определенных как ключевые слова.

sound(audioIn(KWSBaseline),fs)

Целью обучаемой сети является вывод маски KWS из нулей и единиц, подобных этой базовой линии.

Загрузить набор данных речевых команд

Загрузите и извлеките набор данных речевых команд Google [1].

url = 'https://ssd.mathworks.com/supportfiles/audio/google_speech.zip';

downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,'google_speech');

if ~exist(datasetFolder,'dir')
    disp('Downloading Google speech commands data set (1.5 GB)...')
    unzip(url,datasetFolder)
end

Создание audioDatastore (Audio Toolbox), указывающий на набор данных.

ads = audioDatastore(datasetFolder,'LabelSource','foldername','Includesubfolders',true);
ads = shuffle(ads);

Набор данных содержит файлы фонового шума, которые не используются в данном примере. Использовать subset (Audio Toolbox) для создания нового хранилища данных, в котором отсутствуют файлы фонового шума.

isBackNoise = ismember(ads.Labels,"background");
ads = subset(ads,~isBackNoise);

Набор данных содержит приблизительно 65 000 длинных слов длиной в одну секунду из 30 коротких слов (включая ключевое слово YES). Получите разбивку распределения слов в хранилище данных.

countEachLabel(ads)
ans=30×2 table
    Label     Count
    ______    _____

    bed       1713 
    bird      1731 
    cat       1733 
    dog       1746 
    down      2359 
    eight     2352 
    five      2357 
    four      2372 
    go        2372 
    happy     1742 
    house     1750 
    left      2353 
    marvin    1746 
    nine      2364 
    no        2375 
    off       2357 
      ⋮

Разделение ads на два хранилища данных: первое хранилище данных содержит файлы, соответствующие ключевому слову. Второе хранилище данных содержит все остальные слова.

keyword = 'yes';
isKeyword = ismember(ads.Labels,keyword);
ads_keyword = subset(ads,isKeyword);
ads_other = subset(ads,~isKeyword);

Чтобы обучить сеть всему набору данных и достичь максимально возможной точности, установите reduceDataset кому false. Для быстрого выполнения этого примера установите reduceDataset кому true.

reduceDataset = false; 
if reduceDataset 
    % Reduce the dataset by a factor of 20 
    ads_keyword = splitEachLabel(ads_keyword,round(numel(ads_keyword.Files)  / 20)); 
    numUniqueLabels = numel(unique(ads_other.Labels)); 
    ads_other = splitEachLabel(ads_other,round(numel(ads_other.Files) / numUniqueLabels / 20)); 
end

Получите разбивку распределения слов в каждом хранилище данных. Перетасовать ads_other , чтобы последовательные чтения возвращали разные слова.

countEachLabel(ads_keyword)
ans=1×2 table
    Label    Count
    _____    _____

     yes     2377 

countEachLabel(ads_other)
ans=29×2 table
    Label     Count
    ______    _____

    bed       1713 
    bird      1731 
    cat       1733 
    dog       1746 
    down      2359 
    eight     2352 
    five      2357 
    four      2372 
    go        2372 
    happy     1742 
    house     1750 
    left      2353 
    marvin    1746 
    nine      2364 
    no        2375 
    off       2357 
      ⋮

ads_other = shuffle(ads_other);

Создание предложений и меток для обучения

Обучающие хранилища данных содержат односекундные речевые сигналы, в которых произносится одно слово. Вы создадите более сложные обучающие речи, которые содержат смесь ключевого слова и других слов.

Вот пример построенного высказывания. Прочитайте одно ключевое слово из хранилища данных ключевых слов и нормализуйте его, чтобы оно имело максимальное значение, равное единице.

yes = read(ads_keyword);
yes = yes / max(abs(yes));

Сигнал имеет неречевые части (тишина, фоновый шум и т.д.), которые не содержат полезной речевой информации. Этот пример удаляет молчание с помощью detectSpeech (Панель звуковых инструментов).

Получите начальный и конечный индексы полезной части сигнала.

speechIndices = detectSpeech(yes,fs);

Случайным образом выберите количество слов для использования в синтезированном обучающем предложении. Используйте максимум 10 слов.

numWords = randi([0 10]);

Случайный выбор местоположения, в котором встречается ключевое слово.

keywordLocation = randi([1 numWords+1]);

Прочитайте требуемое количество высказываний без ключевых слов и создайте обучающее предложение и маску.

sentence = [];
mask = [];
for index = 1:numWords+1
    if index == keywordLocation
        sentence = [sentence;yes]; %#ok
        newMask = zeros(size(yes));
        newMask(speechIndices(1,1):speechIndices(1,2)) = 1;
        mask = [mask;newMask]; %#ok
    else
        other = read(ads_other);
        other = other ./ max(abs(other));
        sentence = [sentence;other]; %#ok
        mask = [mask;zeros(size(other))]; %#ok
    end
end

Постройте график учебного предложения вместе с маской.

figure
t = (1/fs) * (0:length(sentence)-1);
fig = figure;
plot(t,[sentence,mask])
grid on
xlabel('Time (s)')
legend('Training Signal','Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
title("Example Utterance")

Послушай тренировочное предложение.

sound(sentence,fs)

Извлечь элементы

В этом примере обучается сеть глубокого обучения с использованием 39 коэффициентов MFCC (13 коэффициентов MFCC, 13 дельта и 13 дельта-дельта).

Определите параметры, необходимые для извлечения MFCC.

WindowLength = 512;
OverlapLength = 384;

Создайте объект audioFeatureExtractor для извлечения компонента.

afe = audioFeatureExtractor('SampleRate',fs, ...
                            'Window',hann(WindowLength,'periodic'), ...
                            'OverlapLength',OverlapLength, ...
                            'mfcc',true, ...
                            'mfccDelta',true, ...
                            'mfccDeltaDelta',true);                      

Извлеките элементы.

featureMatrix = extract(afe,sentence);
size(featureMatrix)
ans = 1×2

   478    39

Обратите внимание, что MFCC вычисляется путем перемещения окна через вход, поэтому матрица признаков короче входного речевого сигнала. Каждая строка в featureMatrix соответствует 128 выборкам из речевого сигнала (WindowLength - OverlapLength).

Вычислить маску той же длины, что и featureMatrix.

HopLength = WindowLength - OverlapLength;
range = HopLength * (1:size(featureMatrix,1)) + HopLength;
featureMask = zeros(size(range));
for index = 1:numel(range)
    featureMask(index) = mode(mask( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength ));
end

Извлечение функций из набора учебных данных

Синтез предложений и извлечение признаков для всего набора обучающих данных может занять довольно много времени. Чтобы ускорить обработку, при наличии Toolbox™ Parallel Computing разделите хранилище данных обучения и обработайте каждый раздел на отдельном работнике.

Выберите несколько разделов хранилища данных.

numPartitions = 6;

Инициализируйте массивы ячеек для матриц и масок элементов.

TrainingFeatures = {};
TrainingMasks= {};

Выполнение синтеза предложений, извлечения элементов и создания маски с помощью parfor.

emptyCategories = categorical([1 0]);
emptyCategories(:) = [];

tic
parfor ii = 1:numPartitions
    
    subads_keyword = partition(ads_keyword,numPartitions,ii);
    subads_other = partition(ads_other,numPartitions,ii);
    
    count = 1;
    localFeatures = cell(length(subads_keyword.Files),1);
    localMasks = cell(length(subads_keyword.Files),1);
    
    while hasdata(subads_keyword)
        
        % Create a training sentence
        [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength);
        
        % Compute mfcc features
        featureMatrix = extract(afe, sentence);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        
        % Create mask
        hopLength = WindowLength - OverlapLength;
        range = (hopLength) * (1:size(featureMatrix,1)) + hopLength;
        featureMask = zeros(size(range));
        for index = 1:numel(range)
            featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength ));
        end

        localFeatures{count} = featureMatrix;
        localMasks{count} = [emptyCategories,categorical(featureMask)];
        
        count = count + 1;
    end
    
    TrainingFeatures = [TrainingFeatures;localFeatures];
    TrainingMasks = [TrainingMasks;localMasks];
end
fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 33.656404 seconds.

Рекомендуется нормализовать все элементы так, чтобы они имели нулевое среднее и единичное стандартное отклонение. Вычислите среднее значение и стандартное отклонение для каждого коэффициента и используйте их для нормализации данных.

sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if reduceDataset
    load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S');
else
    M = mean(featuresMatrix);
    S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
    f = TrainingFeatures{index};
    f = (f - M) ./ S;
    TrainingFeatures{index} = f.'; %#ok
end

Извлечь элементы проверки

Извлеките функции MFCC из сигнала проверки.

featureMatrix = extract(afe, audioIn);
featureMatrix(~isfinite(featureMatrix)) = 0;

Нормализуйте функции проверки.

FeaturesValidationClean = (featureMatrix - M)./S;
range = HopLength * (1:size(FeaturesValidationClean,1)) + HopLength;

Создайте маску KWS проверки.

featureMask = zeros(size(range));
for index = 1:numel(range)
    featureMask(index) = mode(KWSBaseline( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength ));
end
BaselineV = categorical(featureMask);

Определение сетевой архитектуры LSTM

Сети LSTM могут изучать долгосрочные зависимости между временными шагами данных последовательности. В этом примере используется двунаправленный уровень LSTM bilstmLayer чтобы посмотреть на последовательность как в прямом, так и в обратном направлениях.

Укажите входной размер, который будет последовательностью размера numFeatures. Укажите два скрытых двунаправленных уровня LSTM с размером вывода 150 и выведите последовательность. Эта команда дает команду двунаправленному уровню LSTM отобразить входной временной ряд в 150 признаков, которые передаются на следующий уровень. Укажите два класса, включив полностью подключенный слой размера 2, а затем слой softmax и слой классификации.

layers = [ ...
    sequenceInputLayer(numFeatures)
    bilstmLayer(150,"OutputMode","sequence")
    bilstmLayer(150,"OutputMode","sequence")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer
    ];

Определение параметров обучения

Укажите параметры обучения для классификатора. Набор MaxEpochs 10, так что сеть 10 проходит через обучающие данные. Набор MiniBatchSize кому 64 чтобы сеть просматривала 64 обучающих сигнала одновременно. Набор Plots кому "training-progress" для создания графиков, показывающих ход обучения по мере увеличения числа итераций. Набор Verbose кому false для отключения печати выходных данных таблицы, соответствующих данным, отображаемым на графике. Набор Shuffle кому "every-epoch" перетасовать тренировочную последовательность в начале каждой эпохи. Набор LearnRateSchedule кому "piecewise" снижать скорость обучения на заданный коэффициент (0,1) каждый раз, когда проходит определенное количество эпох (5). Набор ValidationData к предикторам и целям проверки.

В этом примере используется решатель ADAM. ADAM работает лучше с рецидивирующими нейронными сетями (RNN), такими как LSTM, чем стохастический градиентный спуск по умолчанию с решателем импульса (SGDM).

maxEpochs = 10;
miniBatchSize = 64;
options = trainingOptions("adam", ...
    "InitialLearnRate",1e-4, ...
    "MaxEpochs",maxEpochs, ...
    "MiniBatchSize",miniBatchSize, ...
    "Shuffle","every-epoch", ...
    "Verbose",false, ...
    "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ...
    "ValidationData",{FeaturesValidationClean.',BaselineV}, ...
    "Plots","training-progress", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",5);

Обучение сети LSTM

Обучение сети LSTM с использованием указанных вариантов обучения и архитектуры уровня с помощью trainNetwork. Поскольку тренировочный набор большой, тренировочный процесс может занять несколько минут.

[keywordNetNoAugmentation,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if reduceDataset
    load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S');
end

Проверка точности сети для сигнала проверки без шума

Оцените маску KWS для сигнала проверки подлинности с использованием обученной сети.

v = classify(keywordNetNoAugmentation,FeaturesValidationClean.');

Вычислите и постройте график матрицы путаницы проверки на основе векторов фактических и оценочных меток.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Преобразование сетевых выходных данных из категориальных в двойные.

v = double(v) - 1;
v = repmat(v,HopLength,1);
v = v(:);

Прослушивание областей ключевых слов, определенных сетью.

sound(audioIn(logical(v)),fs)

Визуализируйте предполагаемые и ожидаемые маски KWS.

baseline = double(BaselineV) - 1;
baseline = repmat(baseline,HopLength,1);
baseline = baseline(:);

t = (1/fs) * (0:length(v)-1);
fig = figure;
plot(t,[audioIn(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noise-Free Speech')

Проверка точности сети для сигнала проверки достоверности с шумом

Теперь проверьте точность сети на наличие шумного речевого сигнала. Шумный сигнал был получен путем повреждения чистого сигнала проверки достоверности аддитивным белым гауссовым шумом.

Загрузите шумный сигнал.

[audioInNoisy,fs] = audioread(fullfile(netFolder,'NoisyKeywordSpeech-16-16-mono-34secs.flac'));
sound(audioInNoisy,fs)

Визуализируйте сигнал.

figure
t = (1/fs) * (0:length(audioInNoisy)-1);
plot(t,audioInNoisy)
grid on
xlabel('Time (s)')
title('Noisy Validation Speech Signal')

Извлеките матрицу признаков из шумного сигнала.

featureMatrixV = extract(afe, audioInNoisy);
featureMatrixV(~isfinite(featureMatrixV)) = 0;
FeaturesValidationNoisy = (featureMatrixV - M)./S;

Передайте матрицу функций в сеть.

v = classify(keywordNetNoAugmentation,FeaturesValidationNoisy.');

Сравните выходные данные сети с базовой линией. Обратите внимание, что точность ниже, чем та, которую вы получили для чистого сигнала.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy - Noisy Speech");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Преобразование сетевых выходных данных из категориальных в двойные.

 v = double(v) - 1;
 v = repmat(v,HopLength,1);
 v = v(:);

Прослушивание областей ключевых слов, определенных сетью.

sound(audioIn(logical(v)),fs)

Визуализация оценочных и базовых масок.

t = (1/fs)*(0:length(v)-1);
fig = figure;
plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noisy Speech - No Data Augmentation')

Выполнение увеличения объема данных

Обученная сеть плохо работала на шумном сигнале, потому что обученный набор данных содержал только предложения без шума. Вы исправите это, увеличив свой набор данных, чтобы включить шумные предложения.

Использовать audioDataAugmenter (Audio Toolbox) для расширения набора данных.

ada = audioDataAugmenter('TimeStretchProbability',0, ...
                         'PitchShiftProbability',0, ...
                         'VolumeControlProbability',0, ...
                         'TimeShiftProbability',0, ...
                         'SNRRange',[-1, 1], ...
                         'AddNoiseProbability',0.85);

С этими настройками, audioDataAugmenter объект повреждает входной аудиосигнал белым гауссовым шумом с вероятностью 85%. SNR выбирается случайным образом из диапазона [-1 1] (в дБ). Существует 15% вероятность того, что augmenter не изменит ваш входной сигнал.

В качестве примера передайте звуковой сигнал в аугментатор.

reset(ads_keyword)
x = read(ads_keyword);
data = augment(ada,x,fs)
data=1×2 table
         Audio          AugmentationInfo
    ________________    ________________

    {16000×1 double}      [1×1 struct]  

Осмотрите AugmentationInfo переменная в data для проверки того, как был изменен сигнал.

data.AugmentationInfo
ans = struct with fields:
    SNR: 0.3410

Сбросьте хранилища данных.

reset(ads_keyword)
reset(ads_other)

Инициализируйте ячейки элемента и маски.

TrainingFeatures = {};
TrainingMasks = {};

Снова выполните извлечение элемента. Каждый сигнал испорчен шумом с вероятностью 85%, поэтому расширенный набор данных содержит приблизительно 85% шумных данных и 15% безшумных данных.

tic
parfor ii = 1:numPartitions
    
    subads_keyword = partition(ads_keyword,numPartitions,ii);
    subads_other = partition(ads_other,numPartitions,ii);
    
    count = 1;
    localFeatures = cell(length(subads_keyword.Files),1);
    localMasks = cell(length(subads_keyword.Files),1);
    
    while hasdata(subads_keyword)
        
        [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength);
        
        % Corrupt with noise
        augmentedData = augment(ada,sentence,fs);
        sentence = augmentedData.Audio{1};
        
        % Compute mfcc features
        featureMatrix = extract(afe, sentence);
        featureMatrix(~isfinite(featureMatrix)) = 0;
        
        hopLength = WindowLength - OverlapLength;
        range = hopLength * (1:size(featureMatrix,1)) + hopLength;
        featureMask = zeros(size(range));
        for index = 1:numel(range)
            featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength ));
        end
    
        localFeatures{count} = featureMatrix;
        localMasks{count} = [emptyCategories,categorical(featureMask)];
        
        count = count + 1;
    end
    
    TrainingFeatures = [TrainingFeatures;localFeatures];
    TrainingMasks = [TrainingMasks;localMasks];
end
fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 36.090923 seconds.

Вычислить среднее и стандартное отклонение для каждого коэффициента; используйте их для нормализации данных.

sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if reduceDataset
    load(fullfile(netFolder,'KWSNet.mat'),'KWSNet','M','S');
else
    M = mean(featuresMatrix);
    S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
    f = TrainingFeatures{index};
    f = (f - M) ./ S;
    TrainingFeatures{index} = f.'; %#ok
end

Нормализуйте функции проверки с новыми значениями среднего и стандартного отклонения.

FeaturesValidationNoisy = (featureMatrixV - M)./S;

Переподготовка сети с дополненным набором данных

Повторно создайте параметры обучения. Используйте шумные элементы базовой линии и маску для проверки.

options = trainingOptions("adam", ...
    "InitialLearnRate",1e-4, ...
    "MaxEpochs",maxEpochs, ...
    "MiniBatchSize",miniBatchSize, ...
    "Shuffle","every-epoch", ...
    "Verbose",false, ...
    "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ...
    "ValidationData",{FeaturesValidationNoisy.',BaselineV}, ...
    "Plots","training-progress", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",5);

Обучение сети.

[KWSNet,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if reduceDataset
    load(fullfile(netFolder,'KWSNet.mat'));
end

Проверьте сетевую точность сигнала проверки подлинности.

v = classify(KWSNet,FeaturesValidationNoisy.');

Сравните предполагаемые и ожидаемые маски KWS.

figure
cm = confusionchart(BaselineV,v,"title","Validation Accuracy with Data Augmentation");
cm.ColumnSummary = "column-normalized";
cm.RowSummary = "row-normalized";

Прослушивание определенных областей ключевых слов.

v = double(v) - 1;
v = repmat(v,HopLength,1);
v = v(:);
 
sound(audioIn(logical(v)),fs)

Визуализируйте предполагаемые и ожидаемые маски.

fig = figure;
plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline])
grid on
xlabel('Time (s)')
legend('Training Signal','Network Mask','Baseline Mask','Location','southeast')
l = findall(fig,'type','line');
l(1).LineWidth = 2;
l(2).LineWidth = 2;
title('Results for Noisy Speech - With Data Augmentation')

Ссылки

[1] Уорден П. «Речевые команды: публичный набор данных для однословного распознавания речи», 2017. Доступно в https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Авторское право Google 2017. Набор данных речевых команд лицензируется по лицензии Creative Commons Attribution 4.0.

Приложение - Вспомогательные функции

function [sentence,mask] = HelperSynthesizeSentence(ads_keyword,ads_other,fs,minlength)

% Read one keyword
keyword = read(ads_keyword);
keyword = keyword ./ max(abs(keyword));

% Identify region of interest
speechIndices = detectSpeech(keyword,fs);
if isempty(speechIndices) || diff(speechIndices(1,:)) <= minlength
    speechIndices = [1,length(keyword)];
end
keyword = keyword(speechIndices(1,1):speechIndices(1,2));

% Pick a random number of other words (between 0 and 10)
numWords = randi([0 10]);
% Pick where to insert keyword
loc = randi([1 numWords+1]);
sentence = [];
mask = [];
for index = 1:numWords+1
    if index==loc
        sentence = [sentence;keyword];
        newMask = ones(size(keyword));
        mask = [mask ;newMask];
    else
        
        other = read(ads_other);      
        other = other ./ max(abs(other));
        sentence = [sentence;other];
        mask = [mask;zeros(size(other))];
    end
end
end

См. также

| | |

Связанные темы