В этом примере показано, как идентифицировать ключевое слово в шумной речи с помощью сети глубокого обучения. В частности, в примере используется сеть двунаправленной долговременной памяти (BiLSTM) и кепстральные коэффициенты частоты (MFCC).
Ключевое слово spotting (KWS) является важным компонентом технологий voice-assist, где пользователь говорит предопределенное ключевое слово для пробуждения системы, прежде чем передать полную команду или запрос устройству.
В этом примере обучается глубокая сеть KWS с характерными последовательностями кепстральных коэффициентов (MFCC). Пример также демонстрирует, как можно повысить точность сети в шумной среде с помощью увеличения объема данных.
В этом примере используются сети долговременной кратковременной памяти (LSTM), которые являются типом рецидивирующей нейронной сети (RNN), хорошо подходящей для изучения последовательности и данных временных рядов. Сеть LSTM может изучать долгосрочные зависимости между временными шагами последовательности. Уровень LSTM (lstmLayer) может просматривать временную последовательность в прямом направлении, в то время как двунаправленный уровень LSTM (bilstmLayer) может смотреть на временную последовательность как в прямом, так и в обратном направлениях. В этом примере используется двунаправленный уровень LSTM.
В примере используется набор данных речевых команд google для обучения модели глубокого обучения. Для выполнения примера необходимо сначала загрузить набор данных. Если не требуется загружать набор данных или обучать сеть, можно загрузить и использовать предварительно обученную сеть, открыв этот пример в MATLAB ® и выполняя строки 3-10 примера.
Перед подробным началом процесса обучения необходимо загрузить и использовать предварительно обученную сеть определения ключевых слов для идентификации ключевого слова.
В этом примере ключевое слово spot - YES.
Считывайте тестовый сигнал, где произносится ключевое слово.
[audioIn, fs] = audioread('keywordTestSignal.wav');
sound(audioIn,fs)Загрузите и загрузите предварительно обученную сеть, векторы среднего (M) и стандартного отклонения (S), используемые для нормализации функций, а также 2 аудиофайла, используемые для проверки правильности сети позже в примере.
url = 'http://ssd.mathworks.com/supportfiles/audio/KeywordSpotting.zip'; downloadNetFolder = tempdir; netFolder = fullfile(downloadNetFolder,'KeywordSpotting'); if ~exist(netFolder,'dir') disp('Downloading pretrained network and audio files (4 files - 7 MB) ...') unzip(url,downloadNetFolder) end load(fullfile(netFolder,'KWSNet.mat'));
Создание audioFeatureExtractor Объект (Audio Toolbox) для извлечения элементов.
WindowLength = 512; OverlapLength = 384; afe = audioFeatureExtractor('SampleRate',fs, ... 'Window',hann(WindowLength,'periodic'), ... 'OverlapLength',OverlapLength, ... 'mfcc',true, ... 'mfccDelta', true, ... 'mfccDeltaDelta',true);
Извлеките элементы из тестового сигнала и нормализуйте их.
features = extract(afe,audioIn); features = (features - M)./S;
Вычислите ключевое слово spotting binary mask. Значение маски, равное единице, соответствует сегменту, в котором было замечено ключевое слово.
mask = classify(KWSNet,features.');
Каждая выборка в маске соответствует 128 выборкам из речевого сигнала (WindowLength - OverlapLength).
Разверните маску до длины сигнала.
mask = repmat(mask, WindowLength-OverlapLength, 1); mask = double(mask) - 1; mask = mask(:);
Постройте график тестового сигнала и маски.
figure audioIn = audioIn(1:length(mask)); t = (0:length(audioIn)-1)/fs; plot(t, audioIn) grid on hold on plot(t, mask) legend('Speech', 'YES')

Прослушать ключевое слово «» Пятно «».
sound(audioIn(mask==1),fs)
Протестируйте предварительно обученную сеть обнаружения команд на потоковом аудио с микрофона. Попробуйте произнести случайные слова, включая ключевое слово (YES).
Звонить generateMATLABFunction (Панель звуковых инструментов) на audioFeatureExtractor объект для создания функции извлечения элемента. Эта функция используется в цикле обработки.
generateMATLABFunction(afe,'generateKeywordFeatures','IsStreaming',true);
Определите устройство чтения аудио, которое может считывать аудио с микрофона. Задайте длину кадра, равную длине транзитного участка. Это позволяет вычислить новый набор функций для каждого нового аудиокадра с микрофона.
HopLength = WindowLength - OverlapLength; FrameLength = HopLength; adr = audioDeviceReader('SampleRate',fs, ... 'SamplesPerFrame',FrameLength);
Создайте область для визуализации речевого сигнала и предполагаемой маски.
scope = timescope('SampleRate',fs, ... 'TimeSpanSource','property', ... 'TimeSpan',5, ... 'TimeSpanOverrunAction','Scroll', ... 'BufferLength',fs*5*2, ... 'ShowLegend',true, ... 'ChannelNames',{'Speech','Keyword Mask'}, ... 'YLimits',[-1.2 1.2], ... 'Title','Keyword Spotting');
Определите скорость, с которой оценивается маска. Маска генерируется один раз NumHopsPerUpdate аудиокадры.
NumHopsPerUpdate = 16;
Инициализируйте буфер для звука.
dataBuff = dsp.AsyncBuffer(WindowLength);
Инициализируйте буфер для вычисляемых элементов.
featureBuff = dsp.AsyncBuffer(NumHopsPerUpdate);
Инициализируйте буфер для управления печатью звука и маски.
plotBuff = dsp.AsyncBuffer(NumHopsPerUpdate*WindowLength);
Для бесконечного запуска цикла установите timeLimit в значение Inf. Чтобы остановить моделирование, закройте область.
timeLimit = 20; tic while toc < timeLimit data = adr(); write(dataBuff, data); write(plotBuff, data); frame = read(dataBuff,WindowLength,OverlapLength); features = generateKeywordFeatures(frame,fs); write(featureBuff,features.'); if featureBuff.NumUnreadSamples == NumHopsPerUpdate featureMatrix = read(featureBuff); featureMatrix(~isfinite(featureMatrix)) = 0; featureMatrix = (featureMatrix - M)./S; [keywordNet, v] = classifyAndUpdateState(KWSNet,featureMatrix.'); v = double(v) - 1; v = repmat(v, HopLength, 1); v = v(:); v = mode(v); v = repmat(v, NumHopsPerUpdate * HopLength,1); data = read(plotBuff); scope([data, v]); if ~isVisible(scope) break; end end end hide(scope)

В остальном примере вы узнаете, как обучить ключевое слово spotting network.
Процесс обучения проходит следующие этапы:
Проверьте ключевое слово «gold standard», указывающее базовую линию на сигнале проверки.
Создание обучающих высказываний из набора данных без шума.
Обучение сети LSTM с помощью ключевого слова spotting с использованием последовательностей MFCC, извлеченных из этих высказываний.
Проверьте точность сети путем сравнения базовой линии проверки с выходом сети при применении к сигналу проверки.
Проверьте точность сети на наличие сигнала проверки подлинности, поврежденного шумом.
Увеличение обучающего набора данных путем введения шума в речевые данные с использованием audioDataAugmenter (Панель звуковых инструментов).
Переподготовьте сеть с помощью дополненного набора данных.
Убедитесь, что переученная сеть теперь дает более высокую точность при применении к шумному сигналу проверки подлинности.
Для проверки правильности сети KWS используется образец речевого сигнала. Сигнал проверки подлинности состоит из 34 секунд речи с ключевым словом YES, появляющимся с перерывами.
Загрузите сигнал проверки.
[audioIn,fs] = audioread(fullfile(netFolder,'KeywordSpeech-16-16-mono-34secs.flac'));Слушай сигнал.
sound(audioIn,fs)
Визуализируйте сигнал.
figure t = (1/fs) * (0:length(audioIn)-1); plot(t,audioIn); grid on xlabel('Time (s)') title('Validation Speech Signal')

Загрузите базовую линию KWS. Этот базовый уровень был получен с использованием speech2text: Создать ключевое слово Spotting Mask с помощью Audio Labeler (Audio Toolbox).
load('KWSBaseline.mat','KWSBaseline')
Базовая линия является логическим вектором той же длины, что и звуковой сигнал проверки подлинности. Сегменты в audioIn где произносится ключевое слово, устанавливаются на единицу в KWSBaseline.
Визуализируйте речевой сигнал вместе с базовой линией KWS.
fig = figure; plot(t,[audioIn,KWSBaseline']) grid on xlabel('Time (s)') legend('Speech','KWS Baseline','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; title("Validation Signal")

Прослушивание речевых сегментов, определенных как ключевые слова.
sound(audioIn(KWSBaseline),fs)
Целью обучаемой сети является вывод маски KWS из нулей и единиц, подобных этой базовой линии.
Загрузите и извлеките набор данных речевых команд Google [1].
url = 'https://ssd.mathworks.com/supportfiles/audio/google_speech.zip'; downloadFolder = tempdir; datasetFolder = fullfile(downloadFolder,'google_speech'); if ~exist(datasetFolder,'dir') disp('Downloading Google speech commands data set (1.5 GB)...') unzip(url,datasetFolder) end
Создание audioDatastore (Audio Toolbox), указывающий на набор данных.
ads = audioDatastore(datasetFolder,'LabelSource','foldername','Includesubfolders',true); ads = shuffle(ads);
Набор данных содержит файлы фонового шума, которые не используются в данном примере. Использовать subset (Audio Toolbox) для создания нового хранилища данных, в котором отсутствуют файлы фонового шума.
isBackNoise = ismember(ads.Labels,"background");
ads = subset(ads,~isBackNoise);Набор данных содержит приблизительно 65 000 длинных слов длиной в одну секунду из 30 коротких слов (включая ключевое слово YES). Получите разбивку распределения слов в хранилище данных.
countEachLabel(ads)
ans=30×2 table
Label Count
______ _____
bed 1713
bird 1731
cat 1733
dog 1746
down 2359
eight 2352
five 2357
four 2372
go 2372
happy 1742
house 1750
left 2353
marvin 1746
nine 2364
no 2375
off 2357
⋮
Разделение ads на два хранилища данных: первое хранилище данных содержит файлы, соответствующие ключевому слову. Второе хранилище данных содержит все остальные слова.
keyword = 'yes';
isKeyword = ismember(ads.Labels,keyword);
ads_keyword = subset(ads,isKeyword);
ads_other = subset(ads,~isKeyword);Чтобы обучить сеть всему набору данных и достичь максимально возможной точности, установите reduceDataset кому false. Для быстрого выполнения этого примера установите reduceDataset кому true.
reduceDataset =false; if reduceDataset % Reduce the dataset by a factor of 20 ads_keyword = splitEachLabel(ads_keyword,round(numel(ads_keyword.Files) / 20)); numUniqueLabels = numel(unique(ads_other.Labels)); ads_other = splitEachLabel(ads_other,round(numel(ads_other.Files) / numUniqueLabels / 20)); end
Получите разбивку распределения слов в каждом хранилище данных. Перетасовать ads_other , чтобы последовательные чтения возвращали разные слова.
countEachLabel(ads_keyword)
ans=1×2 table
Label Count
_____ _____
yes 2377
countEachLabel(ads_other)
ans=29×2 table
Label Count
______ _____
bed 1713
bird 1731
cat 1733
dog 1746
down 2359
eight 2352
five 2357
four 2372
go 2372
happy 1742
house 1750
left 2353
marvin 1746
nine 2364
no 2375
off 2357
⋮
ads_other = shuffle(ads_other);
Обучающие хранилища данных содержат односекундные речевые сигналы, в которых произносится одно слово. Вы создадите более сложные обучающие речи, которые содержат смесь ключевого слова и других слов.
Вот пример построенного высказывания. Прочитайте одно ключевое слово из хранилища данных ключевых слов и нормализуйте его, чтобы оно имело максимальное значение, равное единице.
yes = read(ads_keyword); yes = yes / max(abs(yes));
Сигнал имеет неречевые части (тишина, фоновый шум и т.д.), которые не содержат полезной речевой информации. Этот пример удаляет молчание с помощью detectSpeech (Панель звуковых инструментов).
Получите начальный и конечный индексы полезной части сигнала.
speechIndices = detectSpeech(yes,fs);
Случайным образом выберите количество слов для использования в синтезированном обучающем предложении. Используйте максимум 10 слов.
numWords = randi([0 10]);
Случайный выбор местоположения, в котором встречается ключевое слово.
keywordLocation = randi([1 numWords+1]);
Прочитайте требуемое количество высказываний без ключевых слов и создайте обучающее предложение и маску.
sentence = []; mask = []; for index = 1:numWords+1 if index == keywordLocation sentence = [sentence;yes]; %#ok newMask = zeros(size(yes)); newMask(speechIndices(1,1):speechIndices(1,2)) = 1; mask = [mask;newMask]; %#ok else other = read(ads_other); other = other ./ max(abs(other)); sentence = [sentence;other]; %#ok mask = [mask;zeros(size(other))]; %#ok end end
Постройте график учебного предложения вместе с маской.
figure t = (1/fs) * (0:length(sentence)-1); fig = figure; plot(t,[sentence,mask]) grid on xlabel('Time (s)') legend('Training Signal','Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; title("Example Utterance")

Послушай тренировочное предложение.
sound(sentence,fs)
В этом примере обучается сеть глубокого обучения с использованием 39 коэффициентов MFCC (13 коэффициентов MFCC, 13 дельта и 13 дельта-дельта).
Определите параметры, необходимые для извлечения MFCC.
WindowLength = 512; OverlapLength = 384;
Создайте объект audioFeatureExtractor для извлечения компонента.
afe = audioFeatureExtractor('SampleRate',fs, ... 'Window',hann(WindowLength,'periodic'), ... 'OverlapLength',OverlapLength, ... 'mfcc',true, ... 'mfccDelta',true, ... 'mfccDeltaDelta',true);
Извлеките элементы.
featureMatrix = extract(afe,sentence); size(featureMatrix)
ans = 1×2
478 39
Обратите внимание, что MFCC вычисляется путем перемещения окна через вход, поэтому матрица признаков короче входного речевого сигнала. Каждая строка в featureMatrix соответствует 128 выборкам из речевого сигнала (WindowLength - OverlapLength).
Вычислить маску той же длины, что и featureMatrix.
HopLength = WindowLength - OverlapLength; range = HopLength * (1:size(featureMatrix,1)) + HopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength )); end
Синтез предложений и извлечение признаков для всего набора обучающих данных может занять довольно много времени. Чтобы ускорить обработку, при наличии Toolbox™ Parallel Computing разделите хранилище данных обучения и обработайте каждый раздел на отдельном работнике.
Выберите несколько разделов хранилища данных.
numPartitions = 6;
Инициализируйте массивы ячеек для матриц и масок элементов.
TrainingFeatures = {};
TrainingMasks= {};Выполнение синтеза предложений, извлечения элементов и создания маски с помощью parfor.
emptyCategories = categorical([1 0]); emptyCategories(:) = []; tic parfor ii = 1:numPartitions subads_keyword = partition(ads_keyword,numPartitions,ii); subads_other = partition(ads_other,numPartitions,ii); count = 1; localFeatures = cell(length(subads_keyword.Files),1); localMasks = cell(length(subads_keyword.Files),1); while hasdata(subads_keyword) % Create a training sentence [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength); % Compute mfcc features featureMatrix = extract(afe, sentence); featureMatrix(~isfinite(featureMatrix)) = 0; % Create mask hopLength = WindowLength - OverlapLength; range = (hopLength) * (1:size(featureMatrix,1)) + hopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength )); end localFeatures{count} = featureMatrix; localMasks{count} = [emptyCategories,categorical(featureMask)]; count = count + 1; end TrainingFeatures = [TrainingFeatures;localFeatures]; TrainingMasks = [TrainingMasks;localMasks]; end fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 33.656404 seconds.
Рекомендуется нормализовать все элементы так, чтобы они имели нулевое среднее и единичное стандартное отклонение. Вычислите среднее значение и стандартное отклонение для каждого коэффициента и используйте их для нормализации данных.
sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if reduceDataset
load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S');
else
M = mean(featuresMatrix);
S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
f = TrainingFeatures{index};
f = (f - M) ./ S;
TrainingFeatures{index} = f.'; %#ok
endИзвлеките функции MFCC из сигнала проверки.
featureMatrix = extract(afe, audioIn); featureMatrix(~isfinite(featureMatrix)) = 0;
Нормализуйте функции проверки.
FeaturesValidationClean = (featureMatrix - M)./S; range = HopLength * (1:size(FeaturesValidationClean,1)) + HopLength;
Создайте маску KWS проверки.
featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(KWSBaseline( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength )); end BaselineV = categorical(featureMask);
Сети LSTM могут изучать долгосрочные зависимости между временными шагами данных последовательности. В этом примере используется двунаправленный уровень LSTM bilstmLayer чтобы посмотреть на последовательность как в прямом, так и в обратном направлениях.
Укажите входной размер, который будет последовательностью размера numFeatures. Укажите два скрытых двунаправленных уровня LSTM с размером вывода 150 и выведите последовательность. Эта команда дает команду двунаправленному уровню LSTM отобразить входной временной ряд в 150 признаков, которые передаются на следующий уровень. Укажите два класса, включив полностью подключенный слой размера 2, а затем слой softmax и слой классификации.
layers = [ ... sequenceInputLayer(numFeatures) bilstmLayer(150,"OutputMode","sequence") bilstmLayer(150,"OutputMode","sequence") fullyConnectedLayer(2) softmaxLayer classificationLayer ];
Укажите параметры обучения для классификатора. Набор MaxEpochs 10, так что сеть 10 проходит через обучающие данные. Набор MiniBatchSize кому 64 чтобы сеть просматривала 64 обучающих сигнала одновременно. Набор Plots кому "training-progress" для создания графиков, показывающих ход обучения по мере увеличения числа итераций. Набор Verbose кому false для отключения печати выходных данных таблицы, соответствующих данным, отображаемым на графике. Набор Shuffle кому "every-epoch" перетасовать тренировочную последовательность в начале каждой эпохи. Набор LearnRateSchedule кому "piecewise" снижать скорость обучения на заданный коэффициент (0,1) каждый раз, когда проходит определенное количество эпох (5). Набор ValidationData к предикторам и целям проверки.
В этом примере используется решатель ADAM. ADAM работает лучше с рецидивирующими нейронными сетями (RNN), такими как LSTM, чем стохастический градиентный спуск по умолчанию с решателем импульса (SGDM).
maxEpochs = 10; miniBatchSize = 64; options = trainingOptions("adam", ... "InitialLearnRate",1e-4, ... "MaxEpochs",maxEpochs, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Verbose",false, ... "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ... "ValidationData",{FeaturesValidationClean.',BaselineV}, ... "Plots","training-progress", ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",5);
Обучение сети LSTM с использованием указанных вариантов обучения и архитектуры уровня с помощью trainNetwork. Поскольку тренировочный набор большой, тренировочный процесс может занять несколько минут.
[keywordNetNoAugmentation,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if reduceDataset load(fullfile(netFolder,'keywordNetNoAugmentation.mat'),'keywordNetNoAugmentation','M','S'); end
Оцените маску KWS для сигнала проверки подлинности с использованием обученной сети.
v = classify(keywordNetNoAugmentation,FeaturesValidationClean.');
Вычислите и постройте график матрицы путаницы проверки на основе векторов фактических и оценочных меток.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";

Преобразование сетевых выходных данных из категориальных в двойные.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:);
Прослушивание областей ключевых слов, определенных сетью.
sound(audioIn(logical(v)),fs)
Визуализируйте предполагаемые и ожидаемые маски KWS.
baseline = double(BaselineV) - 1; baseline = repmat(baseline,HopLength,1); baseline = baseline(:); t = (1/fs) * (0:length(v)-1); fig = figure; plot(t,[audioIn(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noise-Free Speech')

Теперь проверьте точность сети на наличие шумного речевого сигнала. Шумный сигнал был получен путем повреждения чистого сигнала проверки достоверности аддитивным белым гауссовым шумом.
Загрузите шумный сигнал.
[audioInNoisy,fs] = audioread(fullfile(netFolder,'NoisyKeywordSpeech-16-16-mono-34secs.flac'));
sound(audioInNoisy,fs)Визуализируйте сигнал.
figure t = (1/fs) * (0:length(audioInNoisy)-1); plot(t,audioInNoisy) grid on xlabel('Time (s)') title('Noisy Validation Speech Signal')

Извлеките матрицу признаков из шумного сигнала.
featureMatrixV = extract(afe, audioInNoisy); featureMatrixV(~isfinite(featureMatrixV)) = 0; FeaturesValidationNoisy = (featureMatrixV - M)./S;
Передайте матрицу функций в сеть.
v = classify(keywordNetNoAugmentation,FeaturesValidationNoisy.');
Сравните выходные данные сети с базовой линией. Обратите внимание, что точность ниже, чем та, которую вы получили для чистого сигнала.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy - Noisy Speech"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";

Преобразование сетевых выходных данных из категориальных в двойные.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:);
Прослушивание областей ключевых слов, определенных сетью.
sound(audioIn(logical(v)),fs)
Визуализация оценочных и базовых масок.
t = (1/fs)*(0:length(v)-1); fig = figure; plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noisy Speech - No Data Augmentation')

Обученная сеть плохо работала на шумном сигнале, потому что обученный набор данных содержал только предложения без шума. Вы исправите это, увеличив свой набор данных, чтобы включить шумные предложения.
Использовать audioDataAugmenter (Audio Toolbox) для расширения набора данных.
ada = audioDataAugmenter('TimeStretchProbability',0, ... 'PitchShiftProbability',0, ... 'VolumeControlProbability',0, ... 'TimeShiftProbability',0, ... 'SNRRange',[-1, 1], ... 'AddNoiseProbability',0.85);
С этими настройками, audioDataAugmenter объект повреждает входной аудиосигнал белым гауссовым шумом с вероятностью 85%. SNR выбирается случайным образом из диапазона [-1 1] (в дБ). Существует 15% вероятность того, что augmenter не изменит ваш входной сигнал.
В качестве примера передайте звуковой сигнал в аугментатор.
reset(ads_keyword) x = read(ads_keyword); data = augment(ada,x,fs)
data=1×2 table
Audio AugmentationInfo
________________ ________________
{16000×1 double} [1×1 struct]
Осмотрите AugmentationInfo переменная в data для проверки того, как был изменен сигнал.
data.AugmentationInfo
ans = struct with fields:
SNR: 0.3410
Сбросьте хранилища данных.
reset(ads_keyword) reset(ads_other)
Инициализируйте ячейки элемента и маски.
TrainingFeatures = {};
TrainingMasks = {};Снова выполните извлечение элемента. Каждый сигнал испорчен шумом с вероятностью 85%, поэтому расширенный набор данных содержит приблизительно 85% шумных данных и 15% безшумных данных.
tic parfor ii = 1:numPartitions subads_keyword = partition(ads_keyword,numPartitions,ii); subads_other = partition(ads_other,numPartitions,ii); count = 1; localFeatures = cell(length(subads_keyword.Files),1); localMasks = cell(length(subads_keyword.Files),1); while hasdata(subads_keyword) [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength); % Corrupt with noise augmentedData = augment(ada,sentence,fs); sentence = augmentedData.Audio{1}; % Compute mfcc features featureMatrix = extract(afe, sentence); featureMatrix(~isfinite(featureMatrix)) = 0; hopLength = WindowLength - OverlapLength; range = hopLength * (1:size(featureMatrix,1)) + hopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength )); end localFeatures{count} = featureMatrix; localMasks{count} = [emptyCategories,categorical(featureMask)]; count = count + 1; end TrainingFeatures = [TrainingFeatures;localFeatures]; TrainingMasks = [TrainingMasks;localMasks]; end fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 36.090923 seconds.
Вычислить среднее и стандартное отклонение для каждого коэффициента; используйте их для нормализации данных.
sampleFeature = TrainingFeatures{1};
numFeatures = size(sampleFeature,2);
featuresMatrix = cat(1,TrainingFeatures{:});
if reduceDataset
load(fullfile(netFolder,'KWSNet.mat'),'KWSNet','M','S');
else
M = mean(featuresMatrix);
S = std(featuresMatrix);
end
for index = 1:length(TrainingFeatures)
f = TrainingFeatures{index};
f = (f - M) ./ S;
TrainingFeatures{index} = f.'; %#ok
endНормализуйте функции проверки с новыми значениями среднего и стандартного отклонения.
FeaturesValidationNoisy = (featureMatrixV - M)./S;
Повторно создайте параметры обучения. Используйте шумные элементы базовой линии и маску для проверки.
options = trainingOptions("adam", ... "InitialLearnRate",1e-4, ... "MaxEpochs",maxEpochs, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Verbose",false, ... "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ... "ValidationData",{FeaturesValidationNoisy.',BaselineV}, ... "Plots","training-progress", ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",5);
Обучение сети.
[KWSNet,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options);

if reduceDataset load(fullfile(netFolder,'KWSNet.mat')); end
Проверьте сетевую точность сигнала проверки подлинности.
v = classify(KWSNet,FeaturesValidationNoisy.');
Сравните предполагаемые и ожидаемые маски KWS.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy with Data Augmentation"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";

Прослушивание определенных областей ключевых слов.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:); sound(audioIn(logical(v)),fs)
Визуализируйте предполагаемые и ожидаемые маски.
fig = figure; plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noisy Speech - With Data Augmentation')

[1] Уорден П. «Речевые команды: публичный набор данных для однословного распознавания речи», 2017. Доступно в https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Авторское право Google 2017. Набор данных речевых команд лицензируется по лицензии Creative Commons Attribution 4.0.
function [sentence,mask] = HelperSynthesizeSentence(ads_keyword,ads_other,fs,minlength) % Read one keyword keyword = read(ads_keyword); keyword = keyword ./ max(abs(keyword)); % Identify region of interest speechIndices = detectSpeech(keyword,fs); if isempty(speechIndices) || diff(speechIndices(1,:)) <= minlength speechIndices = [1,length(keyword)]; end keyword = keyword(speechIndices(1,1):speechIndices(1,2)); % Pick a random number of other words (between 0 and 10) numWords = randi([0 10]); % Pick where to insert keyword loc = randi([1 numWords+1]); sentence = []; mask = []; for index = 1:numWords+1 if index==loc sentence = [sentence;keyword]; newMask = ones(size(keyword)); mask = [mask ;newMask]; else other = read(ads_other); other = other ./ max(abs(other)); sentence = [sentence;other]; mask = [mask;zeros(size(other))]; end end end
bilstmLayer | sequenceInputLayer | trainingOptions | trainNetwork