Этот пример показывает, как идентифицировать ключевое слово в шумной речи с помощью нейронной сети для глубокого обучения. В частности, в примере используются двунаправленная сеть долгой краткосрочной памяти (BiLSTM) и мел-частотные кепстральные коэффициенты (MFCC).
Spotting (KWS) является важным компонентом технологий голосовой помощи, где пользователь говорит предопределенное ключевое слово, чтобы разбудить систему, прежде чем говорить полную команду или запрос к устройству.
Этот пример обучает глубокую сеть KWS с характерными последовательностями мел-частотных кепстральных коэффициентов (MFCC). Пример также демонстрирует, как точность сети в зашумленном окружении может быть улучшена с помощью увеличения данных.
Этот пример использует сети долгой краткосрочной памяти (LSTM), которые являются типом рекуррентной нейронной сети (RNN), хорошо подходящим для изучения данных последовательности и timeseries. Сеть LSTM может изучать долгосрочные зависимости между временными шагами последовательности. Слой LSTM (lstmLayer
(Deep Learning Toolbox)) может смотреть на временную последовательность в прямом направлении, в то время как двунаправленный слой LSTM (bilstmLayer
(Deep Learning Toolbox)) может посмотреть на временную последовательность как в прямом, так и в обратном направлениях. Этот пример использует двунаправленный слой LSTM.
Пример использует Google Speech Commands Dataset, чтобы обучить модель глубокого обучения. Чтобы запустить пример, необходимо сначала загрузить набор данных. Если вы не хотите загружать набор данных или обучать сеть, то можно скачать и использовать предварительно обученную сеть, открывая этот пример в MATLAB ® и выполняя линии 3-10 примера.
Прежде чем подробно войти в процесс обучения, вы будете загружать и использовать предварительно обученную сеть споттинга ключевых слов для идентификации ключевого слова.
В этом примере ключевым словом для определения является YES.
Считайте тестовый сигнал, где произнесено ключевое слово.
[audioIn, fs] = audioread('keywordTestSignal.wav');
sound(audioIn,fs)
Загрузите и загрузите предварительно обученную сеть, средние векторы (M) и стандартное отклонение (S), используемые для нормализации функции, а также 2 аудио файлов, используемые для валидации сети позже в примере.
url = 'http://ssd.mathworks.com/supportfiles/audio/KeywordSpotting.zip'; downloadNetFolder = tempdir; netFolder = fullfile(downloadNetFolder,'KeywordSpotting'); if ~exist(netFolder,'dir') disp('Downloading pretrained network and audio files (4 files - 7 MB) ...') unzip(url,downloadNetFolder) end load(fullfile(netFolder,'KWSNet.mat'));
Создайте audioFeatureExtractor
объект для выполнения редукции данных.
WindowLength = 512; OverlapLength = 384; afe = audioFeatureExtractor('SampleRate',fs, ... 'Window',hann(WindowLength,'periodic'), ... 'OverlapLength',OverlapLength, ... 'mfcc',true, ... 'mfccDelta', true, ... 'mfccDeltaDelta',true);
Извлеките функции из тестового сигнала и нормализуйте их.
features = extract(afe,audioIn); features = (features - M)./S;
Вычислите ключевое слово spotting двоичную маску. Значение маски единицы соответствует сегменту, в котором было замечено ключевое слово.
mask = classify(KWSNet,features.');
Каждая выборка в маске соответствует 128 выборкам из речевого сигнала (WindowLength
- OverlapLength
).
Разверните маску до длины сигнала.
mask = repmat(mask, WindowLength-OverlapLength, 1); mask = double(mask) - 1; mask = mask(:);
Постройте график тестового сигнала и маски.
figure audioIn = audioIn(1:length(mask)); t = (0:length(audioIn)-1)/fs; plot(t, audioIn) grid on hold on plot(t, mask) legend('Speech', 'YES')
Послушайте заметное ключевое слово.
sound(audioIn(mask==1),fs)
Протестируйте предварительно обученную сеть обнаружения команд на передаче потокового аудио с микрофона. Попробуйте произнести случайные слова, включая ключевое слово (YES).
Функции generateMATLABFunction
на audioFeatureExtractor
объект, чтобы создать функцию редукции данных. Вы будете использовать эту функцию в цикле обработки.
generateMATLABFunction(afe,'generateKeywordFeatures','IsStreaming',true);
Задайте устройство чтения аудио устройства, которое может считывать аудио с микрофона. Установите длину системы координат равную длине скачка. Это позволяет вам вычислять новый набор функций для каждого нового аудио системы координат из микрофона.
HopLength = WindowLength - OverlapLength; FrameLength = HopLength; adr = audioDeviceReader('SampleRate',fs, ... 'SamplesPerFrame',FrameLength);
Создайте возможности для визуализации речевого сигнала и предполагаемой маски.
scope = timescope('SampleRate',fs, ... 'TimeSpanSource','property', ... 'TimeSpan',5, ... 'TimeSpanOverrunAction','Scroll', ... 'BufferLength',fs*5*2, ... 'ShowLegend',true, ... 'ChannelNames',{'Speech','Keyword Mask'}, ... 'YLimits',[-1.2 1.2], ... 'Title','Keyword Spotting');
Определите скорость, с которой вы оцениваете маску. Вы будете генерировать маску один раз в NumHopsPerUpdate
аудио систем координат.
NumHopsPerUpdate = 16;
Инициализируйте буфер для аудио.
dataBuff = dsp.AsyncBuffer(WindowLength);
Инициализируйте буфер для вычисляемых функций.
featureBuff = dsp.AsyncBuffer(NumHopsPerUpdate);
Инициализируйте буфер, чтобы управлять графическим изображением аудио и маски.
plotBuff = dsp.AsyncBuffer(NumHopsPerUpdate*WindowLength);
Чтобы запустить цикл бессрочно, установите timeLimit на Inf
. Чтобы остановить симуляцию, закройте возможности.
timeLimit = 20; tic while toc < timeLimit data = adr(); write(dataBuff, data); write(plotBuff, data); frame = read(dataBuff,WindowLength,OverlapLength); features = generateKeywordFeatures(frame,fs); write(featureBuff,features.'); if featureBuff.NumUnreadSamples == NumHopsPerUpdate featureMatrix = read(featureBuff); featureMatrix(~isfinite(featureMatrix)) = 0; featureMatrix = (featureMatrix - M)./S; [keywordNet, v] = classifyAndUpdateState(KWSNet,featureMatrix.'); v = double(v) - 1; v = repmat(v, HopLength, 1); v = v(:); v = mode(v); v = repmat(v, NumHopsPerUpdate * HopLength,1); data = read(plotBuff); scope([data, v]); if ~isVisible(scope) break; end end end hide(scope)
В остальном примере вы научитесь обучать сеть споттинга по ключевым словам.
Процесс обучения проходит следующие шаги:
Смотрите ключевое слово «золотой стандарт», определяющее базовый уровень по сигналу валидации.
Создайте обучающие высказывания из набора данных без шума.
Обучите ключевое слово spotting LSTM network с помощью последовательностей MFCC, извлеченных из этих высказываний.
Проверьте точность сети путем сравнения базового уровня валидации с выходом сети при применении к сигналу валидации.
Проверьте точность сети на наличие сигнала валидации, поврежденного шумом.
Увеличение обучающего набора данных путем введения шума в речевые данные с помощью audioDataAugmenter
.
Переобучите сеть с помощью дополненного набора данных.
Проверьте, что переобученная сеть теперь приводит к более высокой точности при применении к шумному сигналу валидации.
Вы используете пример речевого сигнала, чтобы подтвердить сеть KWS. Сигнал валидации состоит из 34 секунд речи с ключевым словом YES, появляющимся периодически.
Загрузите сигнал валидации.
[audioIn,fs] = audioread(fullfile(netFolder,'KeywordSpeech-16-16-mono-34secs.flac'));
Слушайте сигнал.
sound(audioIn,fs)
Визуализируйте сигнал.
figure t = (1/fs) * (0:length(audioIn)-1); plot(t,audioIn); grid on xlabel('Time (s)') title('Validation Speech Signal')
Загрузите базовый уровень KWS. Эта базовая линия была получена с использованием speech2text
: Создайте маску споттинга ключевых слов с помощью Audio Labeler.
load('KWSBaseline.mat','KWSBaseline')
Базовая линия является логическим вектором той же длины, что и аудиосигнал валидации. Сегменты в audioIn
где произнесено ключевое слово, задается единица KWSBaseline
.
Визуализируйте речевой сигнал вместе с базовой линией KWS.
fig = figure; plot(t,[audioIn,KWSBaseline']) grid on xlabel('Time (s)') legend('Speech','KWS Baseline','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; title("Validation Signal")
Прослушайте сегменты речи, идентифицированные как ключевые слова.
sound(audioIn(KWSBaseline),fs)
Цель сети, которую вы обучаете, состоит в том, чтобы вывести KWS-маску нулей и таковых, подобных этой базовой линии.
Загрузите и извлеките набор данных речевых команд Google [1].
url = 'https://ssd.mathworks.com/supportfiles/audio/google_speech.zip'; downloadFolder = tempdir; datasetFolder = fullfile(downloadFolder,'google_speech'); if ~exist(datasetFolder,'dir') disp('Downloading Google speech commands data set (1.5 GB)...') unzip(url,datasetFolder) end
Создайте audioDatastore
это указывает на набор данных.
ads = audioDatastore(datasetFolder,'LabelSource','foldername','Includesubfolders',true); ads = shuffle(ads);
Набор данных содержит файлы фонового шума, которые не используются в этом примере. Использование subset
чтобы создать новый datastore, который не имеет файлов фонового шума.
isBackNoise = ismember(ads.Labels,"background");
ads = subset(ads,~isBackNoise);
Набор данных имеет приблизительно 65 000 односекундных длинных высказываний из 30 коротких слов (включая ключевое слово YES). Получите разбивку распределения слов в datastore.
countEachLabel(ads)
ans=30×2 table
Label Count
______ _____
bed 1713
bird 1731
cat 1733
dog 1746
down 2359
eight 2352
five 2357
four 2372
go 2372
happy 1742
house 1750
left 2353
marvin 1746
nine 2364
no 2375
off 2357
⋮
Разделение ads
в два хранилища данных: Первый datastore содержит файлы, соответствующие ключевому слову. Второй datastore содержит все другие слова.
keyword = 'yes';
isKeyword = ismember(ads.Labels,keyword);
ads_keyword = subset(ads,isKeyword);
ads_other = subset(ads,~isKeyword);
Чтобы обучить сеть со набором данных в целом и достичь максимально возможной точности, установите reduceDataset
на false
. Чтобы запустить этот пример быстро, установите reduceDataset
на true
.
reduceDataset = false; if reduceDataset % Reduce the dataset by a factor of 20 ads_keyword = splitEachLabel (ads_keyword,round (numel (ads_keyword.Files )/20 )); numUniqueLabels = numel (уникальный (ads_other.Labels)); ads_other = splitEachLabel (ads_other,round (numel (ads_other.Files )/numUniqueLabels/20)); end
Получите разбивку распределения слов в каждом datastore. Тасуйте ads_other
datastore, чтобы последовательные чтения возвращали различные слова.
countEachLabel(ads_keyword)
ans=1×2 table
Label Count
_____ _____
yes 2377
countEachLabel(ads_other)
ans=29×2 table
Label Count
______ _____
bed 1713
bird 1731
cat 1733
dog 1746
down 2359
eight 2352
five 2357
four 2372
go 2372
happy 1742
house 1750
left 2353
marvin 1746
nine 2364
no 2375
off 2357
⋮
ads_other = shuffle(ads_other);
Обучающие хранилища данных содержат речевые сигналы на одну секунду, где произносится одно слово. Вы создадите более сложные обучающие речевые высказывания, которые содержат смесь ключевого слова наряду с другими словами.
Вот пример построенного высказывания. Прочитайте одно ключевое слово из ключевого слова datastore и нормализуйте его, чтобы иметь максимальное значение единицы.
yes = read(ads_keyword); yes = yes / max(abs(yes));
Сигнал имеет неречевые фрагменты (тишина, фоновый шум и т.д.), которые не содержат полезной речевой информации. Этот пример удаляет молчание с помощью detectSpeech
.
Получите начало и конец индексы полезного фрагмента сигнала.
speechIndices = detectSpeech(yes,fs);
Случайным образом выберите количество слов для использования в синтезированном предложении обучения. Используйте не более 10 слов.
numWords = randi([0 10]);
Случайным образом выберите место, в котором находится ключевое слово.
keywordLocation = randi([1 numWords+1]);
Прочитайте желаемое количество высказываний, не являющихся ключевыми словами, и создайте обучающее предложение и маску.
sentence = []; mask = []; for index = 1:numWords+1 if index == keywordLocation sentence = [sentence;yes]; %#ok newMask = zeros(size(yes)); newMask(speechIndices(1,1):speechIndices(1,2)) = 1; mask = [mask;newMask]; %#ok else other = read(ads_other); other = other ./ max(abs(other)); sentence = [sentence;other]; %#ok mask = [mask;zeros(size(other))]; %#ok end end
Постройте график обучающего предложения вместе с маской.
figure t = (1/fs) * (0:length(sentence)-1); fig = figure; plot(t,[sentence,mask]) grid on xlabel('Time (s)') legend('Training Signal','Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; title("Example Utterance")
Послушайте обучающее предложение.
sound(sentence,fs)
Этот пример обучает нейронную сеть для глубокого обучения, используя 39 коэффициентов MFCC (13 коэффициентов MFCC, 13 дельта и 13 коэффициентов дельта-дельта).
Определите параметры, необходимые для извлечения MFCC.
WindowLength = 512; OverlapLength = 384;
Создайте объект audioFeatureExtractor, чтобы выполнить редукцию данных.
afe = audioFeatureExtractor('SampleRate',fs, ... 'Window',hann(WindowLength,'periodic'), ... 'OverlapLength',OverlapLength, ... 'mfcc',true, ... 'mfccDelta',true, ... 'mfccDeltaDelta',true);
Извлеките функции.
featureMatrix = extract(afe,sentence); size(featureMatrix)
ans = 1×2
478 39
Обратите внимание, что вы вычисляете MFCC путем скольжения окна через вход, поэтому матрица функций короче, чем входной речевой сигнал. Каждая строка в featureMatrix
соответствует 128 выборкам из речевого сигнала (WindowLength
- OverlapLength
).
Вычислите маску той же длины, что и featureMatrix
.
HopLength = WindowLength - OverlapLength; range = HopLength * (1:size(featureMatrix,1)) + HopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength )); end
Синтез предложения и редукция данных для всего обучающего набора данных могут быть довольно длительными. Чтобы ускорить обработку, если у вас есть Parallel Computing Toolbox™, разделите обучающий datastore и обработайте каждый раздел на отдельном рабочем месте.
Выберите несколько разделов datastore.
numPartitions = 6;
Инициализируйте массивы ячеек для матриц функций и масок.
TrainingFeatures = {}; TrainingMasks= {};
Выполните синтез предложения, редукцию данных и создание маски с помощью parfor
.
emptyCategories = categorical([1 0]); emptyCategories(:) = []; tic parfor ii = 1:numPartitions subads_keyword = partition(ads_keyword,numPartitions,ii); subads_other = partition(ads_other,numPartitions,ii); count = 1; localFeatures = cell(length(subads_keyword.Files),1); localMasks = cell(length(subads_keyword.Files),1); while hasdata(subads_keyword) % Create a training sentence [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength); % Compute mfcc features featureMatrix = extract(afe, sentence); featureMatrix(~isfinite(featureMatrix)) = 0; % Create mask hopLength = WindowLength - OverlapLength; range = (hopLength) * (1:size(featureMatrix,1)) + hopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength )); end localFeatures{count} = featureMatrix; localMasks{count} = [emptyCategories,categorical(featureMask)]; count = count + 1; end TrainingFeatures = [TrainingFeatures;localFeatures]; TrainingMasks = [TrainingMasks;localMasks]; end fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 33.656404 seconds.
Это хорошая практика, чтобы нормализовать все функции, чтобы иметь нулевое среднее и единичное стандартное отклонение. Вычислите среднее и стандартное отклонение для каждого коэффициента и используйте их, чтобы нормализовать данные.
sampleFeature = TrainingFeatures{1}; numFeatures = size(sampleFeature,2); featuresMatrix = cat(1,TrainingFeatures{:}); if reduceDataset load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S'); else M = mean(featuresMatrix); S = std(featuresMatrix); end for index = 1:length(TrainingFeatures) f = TrainingFeatures{index}; f = (f - M) ./ S; TrainingFeatures{index} = f.'; %#ok end
Извлеките функции MFCC из сигнала валидации.
featureMatrix = extract(afe, audioIn); featureMatrix(~isfinite(featureMatrix)) = 0;
Нормализуйте функции валидации.
FeaturesValidationClean = (featureMatrix - M)./S; range = HopLength * (1:size(FeaturesValidationClean,1)) + HopLength;
Создайте маску KWS валидации.
featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(KWSBaseline( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength )); end BaselineV = categorical(featureMask);
Сети LSTM могут изучать долгосрочные зависимости между временными шагами данных последовательности. Этот пример использует двунаправленный слой LSTM bilstmLayer
(Deep Learning Toolbox), чтобы посмотреть на последовательность как в прямом, так и в обратном направлениях.
Задайте размер входа, чтобы быть последовательностями размера numFeatures
. Задайте два скрытых двунаправленных слоя LSTM с размером выходом 150 и выведите последовательность. Эта команда предписывает двунаправленному слою LSTM преобразовать входные временные ряды в 150 функции, которые передаются следующему слою. Задайте два класса путем включения полносвязного слоя размера 2, а затем слоя softmax и слоя классификации.
layers = [ ... sequenceInputLayer(numFeatures) bilstmLayer(150,"OutputMode","sequence") bilstmLayer(150,"OutputMode","sequence") fullyConnectedLayer(2) softmaxLayer classificationLayer ];
Задайте опции обучения для классификатора. Задайте MaxEpochs
10 так, чтобы сеть 10 прошла через обучающие данные. Задайте MiniBatchSize
на 64
чтобы сеть рассматривала 64 обучающих сигнала за раз. Задайте Plots
на "training-progress"
чтобы сгенерировать графики, которые показывают процесс обучения с увеличениями количества итераций. Задайте Verbose
на false
чтобы отключить печать выхода таблицы, который соответствует данным, показанным на графике. Задайте Shuffle
на "every-epoch"
тасовать обучающую последовательность в начале каждой эпохи. Задайте LearnRateSchedule
на "piecewise"
уменьшить скорость обучения на заданный коэффициент (0,1) каждый раз, когда прошло определенное количество эпох (5). Задайте ValidationData
к предикторам и целям валидации.
Этот пример использует решатель адаптивной оценки момента (ADAM). ADAM работает лучше с рекуррентными нейронными сетями (RNNs), такими как LSTMs, чем стохастический градиентный спуск по умолчанию с импульсом (SGDM) решателя.
maxEpochs = 10; miniBatchSize = 64; options = trainingOptions("adam", ... "InitialLearnRate",1e-4, ... "MaxEpochs",maxEpochs, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Verbose",false, ... "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ... "ValidationData",{FeaturesValidationClean.',BaselineV}, ... "Plots","training-progress", ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",5);
Обучите сеть LSTM с заданными опциями обучения и архитектурой слоя с помощью trainNetwork
(Deep Learning Toolbox). Поскольку набор обучающих данных большая, процесс обучения может занять несколько минут.
[keywordNetNoAugmentation,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);
if reduceDataset load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S'); end
Оцените маску KWS для сигнала валидации с помощью обученной сети.
v = classify(keywordNetNoAugmentation,FeaturesValidationClean.');
Вычислите и постройте матрицу неточностей валидации из векторов фактических и оцененных меток.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";
Преобразуйте выход сети из категориального в двойной.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:);
Прослушайте области ключевых слов, определенные сетью.
sound(audioIn(logical(v)),fs)
Визуализация предполагаемых и ожидаемых масок KWS.
baseline = double(BaselineV) - 1; baseline = repmat(baseline,HopLength,1); baseline = baseline(:); t = (1/fs) * (0:length(v)-1); fig = figure; plot(t,[audioIn(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noise-Free Speech')
Теперь вы проверите точность сети на наличие шумного речевого сигнала. Сигналы с шумом получали путем повреждения сигнала чистой валидации аддитивным белым Гауссовым шумом.
Загрузите сигнал с шумом.
[audioInNoisy,fs] = audioread(fullfile(netFolder,'NoisyKeywordSpeech-16-16-mono-34secs.flac'));
sound(audioInNoisy,fs)
Визуализируйте сигнал.
figure t = (1/fs) * (0:length(audioInNoisy)-1); plot(t,audioInNoisy) grid on xlabel('Time (s)') title('Noisy Validation Speech Signal')
Извлеките матрицу функции из сигнала с шумом.
featureMatrixV = extract(afe, audioInNoisy); featureMatrixV(~isfinite(featureMatrixV)) = 0; FeaturesValidationNoisy = (featureMatrixV - M)./S;
Передайте матрицу функций в сеть.
v = classify(keywordNetNoAugmentation,FeaturesValidationNoisy.');
Сравните выходы сети с базовым уровнем. Обратите внимание, что точность ниже, чем та, что вы получили для чистого сигнала.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy - Noisy Speech"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";
Преобразуйте выход сети из категориального в двойной.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:);
Прослушайте области ключевых слов, определенные сетью.
sound(audioIn(logical(v)),fs)
Визуализируйте предполагаемые и базовые маски.
t = (1/fs)*(0:length(v)-1); fig = figure; plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noisy Speech - No Data Augmentation')
Обученная сеть не показала хороших результатов на сигнал с шумом, потому что обученный набор данных содержал только предложения без шума. Вы исправите это, увеличив свой набор данных, чтобы включить шумные предложения.
Использование audioDataAugmenter
чтобы увеличить ваш набор данных.
ada = audioDataAugmenter('TimeStretchProbability',0, ... 'PitchShiftProbability',0, ... 'VolumeControlProbability',0, ... 'TimeShiftProbability',0, ... 'SNRRange',[-1, 1], ... 'AddNoiseProbability',0.85);
С этими настройками audioDataAugmenter
объект повреждает вход аудиосигнал с белым Гауссовым шумом с вероятностью 85%. ОСШ выбирается случайным образом из области значений [-1 1] (в дБ). Существует 15% вероятность того, что augmenter не изменит ваш входной сигнал.
В качестве примера передайте аудиосигнал в augmenter.
reset(ads_keyword) x = read(ads_keyword); data = augment(ada,x,fs)
data=1×2 table
Audio AugmentationInfo
________________ ________________
{16000×1 double} [1×1 struct]
Осмотрите AugmentationInfo
переменная в data
чтобы проверить, как был изменен сигнал.
data.AugmentationInfo
ans = struct with fields:
SNR: 0.3410
Сбросьте хранилища данных.
reset(ads_keyword) reset(ads_other)
Инициализируйте функцию и маскируйте ячейки.
TrainingFeatures = {}; TrainingMasks = {};
Еще раз выполните редукцию данных. Каждый сигнал повреждается шумом с вероятностью 85%, поэтому ваш дополненный набор данных имеет приблизительно 85% зашумленных данных и 15% бесшумных данных.
tic parfor ii = 1:numPartitions subads_keyword = partition(ads_keyword,numPartitions,ii); subads_other = partition(ads_other,numPartitions,ii); count = 1; localFeatures = cell(length(subads_keyword.Files),1); localMasks = cell(length(subads_keyword.Files),1); while hasdata(subads_keyword) [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength); % Corrupt with noise augmentedData = augment(ada,sentence,fs); sentence = augmentedData.Audio{1}; % Compute mfcc features featureMatrix = extract(afe, sentence); featureMatrix(~isfinite(featureMatrix)) = 0; hopLength = WindowLength - OverlapLength; range = hopLength * (1:size(featureMatrix,1)) + hopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength )); end localFeatures{count} = featureMatrix; localMasks{count} = [emptyCategories,categorical(featureMask)]; count = count + 1; end TrainingFeatures = [TrainingFeatures;localFeatures]; TrainingMasks = [TrainingMasks;localMasks]; end fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 36.090923 seconds.
Вычислите среднее и стандартное отклонение для каждого коэффициента; используйте их для нормализации данных.
sampleFeature = TrainingFeatures{1}; numFeatures = size(sampleFeature,2); featuresMatrix = cat(1,TrainingFeatures{:}); if reduceDataset load(fullfile(netFolder,'KWSNet.mat'),'KWSNet','M','S'); else M = mean(featuresMatrix); S = std(featuresMatrix); end for index = 1:length(TrainingFeatures) f = TrainingFeatures{index}; f = (f - M) ./ S; TrainingFeatures{index} = f.'; %#ok end
Нормализуйте функции валидации с помощью новых средних и стандартных значений отклонения.
FeaturesValidationNoisy = (featureMatrixV - M)./S;
Создайте параметры обучения заново. Используйте шумные базовые функции и маску для валидации.
options = trainingOptions("adam", ... "InitialLearnRate",1e-4, ... "MaxEpochs",maxEpochs, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Verbose",false, ... "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ... "ValidationData",{FeaturesValidationNoisy.',BaselineV}, ... "Plots","training-progress", ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",5);
Обучите сеть.
[KWSNet,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);
if reduceDataset load(fullfile(netFolder,'KWSNet.mat')); end
Проверьте точность сети по сигналу валидации.
v = classify(KWSNet,FeaturesValidationNoisy.');
Сравните предполагаемые и ожидаемые маски KWS.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy with Data Augmentation"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";
Прослушайте идентифицированные области ключевых слов.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:); sound(audioIn(logical(v)),fs)
Визуализация предполагаемых и ожидаемых масок.
fig = figure; plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noisy Speech - With Data Augmentation')
[1] Warden P. «Speech Commands: A public dataset for single-word speech recognition», 2017. Доступно из https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Копирайт Google 2017. Набор данных Speech Commands лицензируется лицензией Creative Commons Attribution 4.0.
function [sentence,mask] = HelperSynthesizeSentence(ads_keyword,ads_other,fs,minlength) % Read one keyword keyword = read(ads_keyword); keyword = keyword ./ max(abs(keyword)); % Identify region of interest speechIndices = detectSpeech(keyword,fs); if isempty(speechIndices) || diff(speechIndices(1,:)) <= minlength speechIndices = [1,length(keyword)]; end keyword = keyword(speechIndices(1,1):speechIndices(1,2)); % Pick a random number of other words (between 0 and 10) numWords = randi([0 10]); % Pick where to insert keyword loc = randi([1 numWords+1]); sentence = []; mask = []; for index = 1:numWords+1 if index==loc sentence = [sentence;keyword]; newMask = ones(size(keyword)); mask = [mask ;newMask]; else other = read(ads_other); other = other ./ max(abs(other)); sentence = [sentence;other]; mask = [mask;zeros(size(other))]; end end end