В этом примере показано, как идентифицировать ключевое слово в шумной речи с помощью нейронной сети для глубокого обучения. В частности, пример использует сеть Bidirectional Long Short-Term Memory (BiLSTM) и частоту mel cepstral коэффициенты (MFCC).
Ключевое слово, определяющее (KWS), является важной составляющей речи - помогают технологиям, где пользователь говорит предопределенное ключевое слово с пробуждением система прежде, чем говорить полную команду или запрос к устройству.
Этот пример обучает KWS глубокая сеть с последовательностями функции mel-частоты cepstral коэффициентов (MFCC). Пример также демонстрирует, как сетевая точность в шумной среде может быть улучшена с помощью увеличения данных.
Этот пример использует сети долгой краткосрочной памяти (LSTM), которые являются типом рекуррентной нейронной сети (RNN), подходящей, чтобы изучить данные timeseries и последовательность. Сеть LSTM может изучить долгосрочные зависимости между временными шагами последовательности. Слой LSTM (lstmLayer
) может посмотреть в то время последовательность в прямом направлении, в то время как двунаправленный слой LSTM (bilstmLayer
) может посмотреть в то время последовательность и во вперед и в обратные направления. Этот пример использует двунаправленный слой LSTM.
Пример использует google Speech Commands Dataset, чтобы обучить модель глубокого обучения. Чтобы запустить пример, необходимо сначала загрузить набор данных. Если вы не хотите загружать набор данных или обучать сеть, то можно загрузить предварительно обученную сеть путем открытия этого примера в MATLAB® и ввода load("KWSNet.mat")
в командной строке.
Перед входом в учебный процесс подробно, вы будете использовать предварительно обученную сеть определения ключевого слова, чтобы идентифицировать ключевое слово.
В этом примере ключевым словом, чтобы определить является YES.
Считайте тестовый сигнал, где ключевое слово произнесено.
[audioIn, fs] = audioread('keywordTestSignal.wav');
sound(audioIn,fs)
Загрузите предварительно обученную сеть, а также вектора средних значений и векторы среднеквадратического отклонения, используемые для нормализации функции.
load('keywordNetNoAugmentation.mat',"keywordNetNoAugmentation","M","S");
Создайте audioFeatureExtractor
Объект (Audio Toolbox) выполнить извлечение признаков.
WindowLength = 512; OverlapLength = 384; afe = audioFeatureExtractor('SampleRate',fs, ... 'Window',hann(WindowLength,'periodic'), ... 'OverlapLength',OverlapLength, ... 'mfcc',true, ... 'mfccDelta', true, ... 'mfccDeltaDelta',true); setExtractorParams(afe,'mfcc','NumCoeffs',13);
Извлеките функции из тестового сигнала и нормируйте их.
features = extract(afe,audioIn); features = (features - M)./S;
Вычислите ключевое слово, определяющее бинарную маску. Значение маски каждый соответствует сегменту, где ключевое слово было определено.
mask = classify(keywordNetNoAugmentation,features.');
Каждая выборка в маске соответствует 128 выборкам от речевого сигнала (WindowLength
- OverlapLength
).
Расширьте маску до длины сигнала.
mask = repmat(mask, WindowLength-OverlapLength, 1); mask = double(mask) - 1; mask = mask(:);
Постройте тестовый сигнал и маску.
figure audioIn = audioIn(1:length(mask)); t = (0:length(audioIn)-1)/fs; plot(t, audioIn) grid on hold on plot(t, mask) legend('Speech', 'YES')
Слушайте пятнистое ключевое слово.
sound(audioIn(mask==1),fs)
Протестируйте свою предварительно обученную сеть обнаружения команды на передаче потокового аудио от вашего микрофона. Попытайтесь говорить случайные слова, включая ключевое слово (YES).
Вызовите generateMATLABFunction
(Audio Toolbox) на audioFeatureExtractor
объект создать функцию извлечения признаков. Вы будете использовать эту функцию в цикле обработки.
generateMATLABFunction(afe,'generateKeywordFeatures','IsStreaming',true);
Задайте читателя аудио устройства, который может считать аудио из вашего микрофона. Установите длину системы координат на длину транзитного участка. Это позволяет вам вычислить новый набор функций каждой новой аудио системы координат от микрофона.
HopLength = WindowLength - OverlapLength; FrameLength = HopLength; adr = audioDeviceReader('SampleRate',fs, ... 'SamplesPerFrame',FrameLength);
Создайте осциллограф для визуализации речевого сигнала и предполагаемой маски.
scope = timescope('SampleRate',fs, ... 'TimeSpanSource','property', ... 'TimeSpan',5, ... 'TimeSpanOverrunAction','Scroll', ... 'BufferLength',fs*5*2, ... 'ShowLegend',true, ... 'ChannelNames',{'Speech','Keyword Mask'}, ... 'YLimits',[-1.2 1.2], ... 'Title','Keyword Spotting');
Задайте уровень, на котором вы оцениваете маску. Вы сгенерируете маску один раз в NumHopsPerUpdate
аудио системы координат.
NumHopsPerUpdate = 16;
Инициализируйте буфер для аудио.
dataBuff = dsp.AsyncBuffer(WindowLength);
Инициализируйте буфер для вычисленных функций.
featureBuff = dsp.AsyncBuffer(NumHopsPerUpdate);
Инициализируйте буфер, чтобы справиться с графическим выводом аудио и маски.
plotBuff = dsp.AsyncBuffer(NumHopsPerUpdate*WindowLength);
Чтобы запустить цикл неопределенно, установите timeLimit на Inf
. Чтобы остановить симуляцию, закройте осциллограф.
timeLimit = 20; tic while toc < timeLimit data = adr(); write(dataBuff, data); write(plotBuff, data); frame = read(dataBuff,WindowLength,OverlapLength); features = generateKeywordFeatures(frame,fs); write(featureBuff,features.'); if featureBuff.NumUnreadSamples == NumHopsPerUpdate featureMatrix = read(featureBuff); featureMatrix(~isfinite(featureMatrix)) = 0; featureMatrix = (featureMatrix - M)./S; [keywordNet, v] = classifyAndUpdateState(keywordNetNoAugmentation,featureMatrix.'); v = double(v) - 1; v = repmat(v, HopLength, 1); v = v(:); v = mode(v); v = repmat(v, NumHopsPerUpdate * HopLength,1); data = read(plotBuff); scope([data, v]); if ~isVisible(scope) break; end end end hide(scope)
В остальной части примера вы изучите, как обучить сеть определения ключевого слова.
Учебный процесс проходит следующие шаги:
Смотрите базовую линию определения ключевого слова "золотого стандарта" на сигнале валидации.
Создайте учебное произнесение из бесшумного набора данных.
Обучите ключевое слово, определяющее сеть LSTM с помощью последовательностей MFCC, извлеченных из того произнесения.
Проверяйте сетевую точность путем сравнения базовой линии валидации с выходом сети, когда применено сигнал валидации.
Проверяйте сетевую точность на сигнал валидации, поврежденный шумом.
Увеличьте обучающий набор данных путем введения шума к речевым данным с помощью audioDataAugmenter
(Audio Toolbox).
Переобучите сеть с увеличенным набором данных.
Проверьте, что переобученная сеть теперь дает к более высокой точности, когда применено шумный сигнал валидации.
Вы используете демонстрационный речевой сигнал проверить сеть KWS. Сигнал валидации состоит 34 секунды речи с ключевым словом YES, появляющийся периодически.
Загрузите сигнал валидации.
[audioIn,fs] = audioread('KeywordSpeech-16-16-mono-34secs.flac');
Слушайте сигнал.
sound(audioIn,fs)
Визуализируйте сигнал.
figure t = (1/fs) * (0:length(audioIn)-1); plot(t,audioIn); grid on xlabel('Time (s)') title('Validation Speech Signal')
Загрузите базовую линию KWS. Эта базовая линия была получена с помощью speech2text
: Создайте ключевое слово, определяющее маску Используя Audio Labeler (Audio Toolbox).
load('KWSBaseline.mat','KWSBaseline')
Базовая линия является логическим вектором из той же длины как звуковой сигнал валидации. Сегменты в audioIn
где ключевое слово произнесено, установлены в одного в KWSBaseline
.
Визуализируйте речевой сигнал наряду с базовой линией KWS.
fig = figure; plot(t,[audioIn,KWSBaseline']) grid on xlabel('Time (s)') legend('Speech','KWS Baseline','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; title("Validation Signal")
Слушайте речевые сегменты, идентифицированные как ключевые слова.
sound(audioIn(KWSBaseline),fs)
Цель сети, которую вы обучаете, состоит в том, чтобы вывести маску KWS нулей и единиц как эта базовая линия.
Загрузите и извлеките Google Speech Commands Dataset [1].
url = 'https://ssd.mathworks.com/supportfiles/audio/google_speech.zip'; downloadFolder = tempdir; datasetFolder = fullfile(downloadFolder,'google_speech'); if ~exist(datasetFolder,'dir') disp('Downloading Google speech commands data set (1.5 GB)...') unzip(url,datasetFolder) end
Создайте audioDatastore
(Audio Toolbox), который указывает на набор данных.
ads = audioDatastore(datasetFolder,'LabelSource','foldername','Includesubfolders',true); ads = shuffle(ads);
Набор данных содержит файлы фонового шума, которые не используются в этом примере. Используйте subset
(Audio Toolbox), чтобы создать новый datastore, который не имеет файлов фонового шума.
isBackNoise = ismember(ads.Labels,"background");
ads = subset(ads,~isBackNoise);
Набор данных имеет приблизительно 65 000 одного второго длинного произнесения 30 коротких слов (включая ключевое слово YES). Получите отказ распределения слова в datastore.
countEachLabel(ads)
ans=30×2 table
Label Count
______ _____
bed 1713
bird 1731
cat 1733
dog 1746
down 2359
eight 2352
five 2357
four 2372
go 2372
happy 1742
house 1750
left 2353
marvin 1746
nine 2364
no 2375
off 2357
⋮
Разделите ads
в два хранилища данных: первый datastore содержит файлы, соответствующие ключевому слову. Второй datastore содержит все другие слова.
keyword = 'yes';
isKeyword = ismember(ads.Labels,keyword);
ads_keyword = subset(ads,isKeyword);
ads_other = subset(ads,~isKeyword);
Чтобы обучить сеть с набором данных в целом и достигнуть максимально возможной точности, установите reduceDataset
к false
. Чтобы запустить этот пример быстро, установите reduceDataset
к true
.
reduceDataset = false; if reduceDataset % Reduce the dataset by a factor of 20 ads_keyword = splitEachLabel (ads_keyword, вокруг (numel (ads_keyword.Files) / 20)); numUniqueLabels = numel (уникальный (ads_other.Labels)); ads_other = splitEachLabel (ads_other, вокруг (numel (ads_other.Files) / numUniqueLabels / 20)); end
Получите отказ распределения слова в каждом datastore. Переставьте ads_other
datastore так, чтобы последовательные чтения возвратили различные слова.
countEachLabel(ads_keyword)
ans=1×2 table
Label Count
_____ _____
yes 2377
countEachLabel(ads_other)
ans=29×2 table
Label Count
______ _____
bed 1713
bird 1731
cat 1733
dog 1746
down 2359
eight 2352
five 2357
four 2372
go 2372
happy 1742
house 1750
left 2353
marvin 1746
nine 2364
no 2375
off 2357
⋮
ads_other = shuffle(ads_other);
Учебные хранилища данных содержат вторые речевые сигналы, где одно слово произнесено. Вы создадите более комплексное учебное речевое произнесение, которое содержит смесь ключевого слова наряду с другими словами.
Вот пример созданного произнесения. Считайте одно ключевое слово из datastore ключевого слова и нормируйте его, чтобы иметь максимальное значение одного.
yes = read(ads_keyword); yes = yes / max(abs(yes));
Сигнал имеет неречевые фрагменты (тишина, фоновый шум, и т.д.), которые не содержат полезную информацию о речи. Этот пример удаляет тишину с помощью detectSpeech
(Audio Toolbox).
Получите индексы начала и конца полезного фрагмента сигнала.
speechIndices = detectSpeech(yes,fs);
Случайным образом выберите количество слов, чтобы использовать в синтезируемом учебном предложении. Используйте максимум 10 слов.
numWords = randi([0 10]);
Случайным образом выберите местоположение, в котором происходит ключевое слово.
keywordLocation = randi([1 numWords+1]);
Считайте желаемое количество произнесения неключевого слова и создайте учебное предложение и маску.
sentence = []; mask = []; for index = 1:numWords+1 if index == keywordLocation sentence = [sentence;yes]; %#ok newMask = zeros(size(yes)); newMask(speechIndices(1,1):speechIndices(1,2)) = 1; mask = [mask;newMask]; %#ok else other = read(ads_other); other = other ./ max(abs(other)); sentence = [sentence;other]; %#ok mask = [mask;zeros(size(other))]; %#ok end end
Постройте учебное предложение наряду с маской.
figure t = (1/fs) * (0:length(sentence)-1); fig = figure; plot(t,[sentence,mask]) grid on xlabel('Time (s)') legend('Training Signal','Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; title("Example Utterance")
Слушайте учебное предложение.
sound(sentence,fs)
Этот пример обучает нейронную сеть для глубокого обучения с помощью 39 коэффициентов MFCC (13 MFCC, 13 дельт и 13 коэффициентов дельты дельты).
Задайте параметры, требуемые для экстракции MFCC.
WindowLength = 512; OverlapLength = 384;
Создайте объект audioFeatureExtractor выполнить извлечение признаков.
afe = audioFeatureExtractor('SampleRate',fs, ... 'Window',hann(WindowLength,'periodic'), ... 'OverlapLength',OverlapLength, ... 'mfcc',true, ... 'mfccDelta',true, ... 'mfccDeltaDelta',true);
Извлеките функции.
featureMatrix = extract(afe,sentence); size(featureMatrix)
ans = 1×2
497 39
Обратите внимание на то, что вы вычисляете MFCC путем скольжения окна через вход, таким образом, матрица функции короче, чем входной речевой сигнал. Каждая строка в featureMatrix
соответствует 128 выборкам от речевого сигнала (WindowLength
- OverlapLength
).
Вычислите маску той же длины как featureMatrix
.
HopLength = WindowLength - OverlapLength; range = HopLength * (1:size(featureMatrix,1)) + HopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength )); end
Синтез предложения и извлечение признаков для целого обучающего набора данных могут быть довольно длительными. Чтобы ускорить обработку, если у вас есть Parallel Computing Toolbox™, делят учебный datastore и процесс каждый раздел на отдельном рабочем.
Выберите много разделов datastore.
numPartitions = 6;
Инициализируйте массивы ячеек для матриц функции и масок.
TrainingFeatures = {}; TrainingMasks= {};
Выполните синтез предложения, извлечение признаков и создание маски с помощью parfor
.
emptyCategories = categorical([1 0]); emptyCategories(:) = []; tic parfor ii = 1:numPartitions subads_keyword = partition(ads_keyword,numPartitions,ii); subads_other = partition(ads_other,numPartitions,ii); count = 1; localFeatures = cell(length(subads_keyword.Files),1); localMasks = cell(length(subads_keyword.Files),1); while hasdata(subads_keyword) % Create a training sentence [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength); % Compute mfcc features featureMatrix = extract(afe, sentence); featureMatrix(~isfinite(featureMatrix)) = 0; % Create mask hopLength = WindowLength - OverlapLength; range = (hopLength) * (1:size(featureMatrix,1)) + hopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength )); end localFeatures{count} = featureMatrix; localMasks{count} = [emptyCategories,categorical(featureMask)]; count = count + 1; end TrainingFeatures = [TrainingFeatures;localFeatures]; TrainingMasks = [TrainingMasks;localMasks]; end
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 83.790340 seconds.
Это - хорошая практика, чтобы нормировать все функции, чтобы иметь нулевое среднее значение и стандартное отклонение единицы. Вычислите среднее и стандартное отклонение для каждого коэффициента и используйте их, чтобы нормировать данные.
sampleFeature = TrainingFeatures{1}; numFeatures = size(sampleFeature,2); featuresMatrix = cat(1,TrainingFeatures{:}); M = mean(featuresMatrix); S = std(featuresMatrix); for index = 1:length(TrainingFeatures) f = TrainingFeatures{index}; f = (f - M) ./ S; TrainingFeatures{index} = f.'; %#ok end
Извлеките функции MFCC из сигнала валидации.
featureMatrix = extract(afe, audioIn); featureMatrix(~isfinite(featureMatrix)) = 0;
Нормируйте функции валидации.
FeaturesValidationClean = (featureMatrix - M)./S; range = HopLength * (1:size(FeaturesValidationClean,1)) + HopLength;
Создайте валидацию маска KWS.
featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(KWSBaseline( (index-1)*HopLength+1:(index-1)*HopLength+WindowLength )); end BaselineV = categorical(featureMask);
Сети LSTM могут изучить долгосрочные зависимости между временными шагами данных о последовательности. Этот пример использует двунаправленный слой LSTM bilstmLayer
смотреть на последовательность и во вперед и в обратные направления.
Задайте входной размер, чтобы быть последовательностями размера numFeatures
. Задайте два скрытых двунаправленных слоя LSTM с выходным размером 150 и выведите последовательность. Эта команда дает двунаправленному слою LSTM команду сопоставлять входные временные ряды в 150 функций, которые передаются следующему слою. Задайте два класса включением полносвязного слоя размера 2, сопровождаемый softmax слоем и слоем классификации.
layers = [ ... sequenceInputLayer(numFeatures) bilstmLayer(150,"OutputMode","sequence") bilstmLayer(150,"OutputMode","sequence") fullyConnectedLayer(2) softmaxLayer classificationLayer ];
Задайте опции обучения для классификатора. Установите MaxEpochs
к 10 так, чтобы сеть сделала 10, проходит через обучающие данные. Установите MiniBatchSize
к 64
так, чтобы сеть посмотрела на 64 учебных сигнала за один раз. Установите Plots
к "training-progress"
сгенерировать графики, которые показывают процесс обучения количеством увеличений итераций. Установите Verbose
к false
отключить печать таблицы выход, который соответствует данным, показанным в графике. Установите Shuffle
к "every-epoch"
переставить обучающую последовательность в начале каждой эпохи. Установите LearnRateSchedule
к "piecewise"
чтобы уменьшить скорость обучения заданным фактором (0.1) каждый раз, определенное число эпох (5) передало. Установите ValidationData
к предикторам валидации и целям.
Этот пример использует адаптивную оценку момента (ADAM) решатель. ADAM выполняет лучше с рекуррентными нейронными сетями (RNNs) как LSTMs, чем стохастический градиентный спуск по умолчанию с импульсом (SGDM) решатель.
maxEpochs = 10; miniBatchSize = 64; options = trainingOptions("adam", ... "InitialLearnRate",1e-4, ... "MaxEpochs",maxEpochs, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Verbose",false, ... "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ... "ValidationData",{FeaturesValidationClean.',BaselineV}, ... "Plots","training-progress", ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",5);
Обучите сеть LSTM с заданными опциями обучения и архитектурой слоя с помощью trainNetwork
. Поскольку набор обучающих данных является большим, учебный процесс может занять несколько минут.
[keywordNetNoAugmentation,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options); if reduceDataset load('keywordNetNoAugmentation.mat','keywordNetNoAugmentation','M','S'); end
Оцените маску KWS для сигнала валидации использование обучившего сеть.
v = classify(keywordNetNoAugmentation,FeaturesValidationClean.');
Вычислите и постройте матрицу беспорядка валидации от векторов из фактических и предполагаемых меток.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";
Преобразуйте сетевой выход от категориального, чтобы удвоиться.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:);
Слушайте области ключевого слова, идентифицированные сетью.
sound(audioIn(logical(v)),fs)
Визуализируйте предполагаемые и ожидаемые маски KWS.
baseline = double(BaselineV) - 1; baseline = repmat(baseline,HopLength,1); baseline = baseline(:); t = (1/fs) * (0:length(v)-1); fig = figure; plot(t,[audioIn(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noise-Free Speech')
Вы будете теперь проверять сетевую точность на шумный речевой сигнал. Сигнал с шумом был получен путем повреждения чистого сигнала валидации аддитивным белым Гауссовым шумом.
Загрузите сигнал с шумом.
[audioInNoisy,fs] = audioread('NoisyKeywordSpeech-16-16-mono-34secs.flac');
sound(audioInNoisy,fs)
Визуализируйте сигнал.
figure t = (1/fs) * (0:length(audioInNoisy)-1); plot(t,audioInNoisy) grid on xlabel('Time (s)') title('Noisy Validation Speech Signal')
Извлеките матрицу функции из сигнала с шумом.
featureMatrixV = extract(afe, audioInNoisy); featureMatrixV(~isfinite(featureMatrixV)) = 0; FeaturesValidationNoisy = (featureMatrixV - M)./S;
Передайте матрицу функции сети.
v = classify(keywordNetNoAugmentation,FeaturesValidationNoisy.');
Сравните сетевой выход с базовой линией. Обратите внимание на то, что точность ниже, чем та, которую вы получили для чистого сигнала.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy - Noisy Speech"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";
Преобразуйте сетевой выход от категориального, чтобы удвоиться.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:);
Слушайте области ключевого слова, идентифицированные сетью.
sound(audioIn(logical(v)),fs)
Визуализируйте предполагаемые и базовые маски.
t = (1/fs)*(0:length(v)-1); fig = figure; plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noisy Speech - No Data Augmentation')
Обучивший сеть не выполнял хорошо на сигнале с шумом, потому что обученный набор данных содержал только бесшумные предложения. Вы исправите это путем увеличения набора данных, чтобы включать шумные предложения.
Используйте audioDataAugmenter
(Audio Toolbox), чтобы увеличить ваш набор данных.
ada = audioDataAugmenter('TimeStretchProbability',0, ... 'PitchShiftProbability',0, ... 'VolumeControlProbability',0, ... 'TimeShiftProbability',0, ... 'SNRRange',[-1, 1], ... 'AddNoiseProbability',0.85);
С этими настройками, audioDataAugmenter
возразите повреждает входной звуковой сигнал с белым Гауссовым шумом с вероятностью 85%. ОСШ случайным образом выбран из области значений [-1 1] (в дБ). Существует 15%-я вероятность, что увеличение не изменяет ваш входной сигнал.
Как пример, передайте звуковой сигнал увеличению.
reset(ads_keyword) x = read(ads_keyword); data = augment(ada,x,fs)
data=1×2 table
Audio AugmentationInfo
________________ ________________
{16000×1 double} [1×1 struct]
Смотрите AugmentationInfo
переменная в data
проверять, как сигнал был изменен.
data.AugmentationInfo
ans = struct with fields:
SNR: 0.1031
Сбросьте хранилища данных.
reset(ads_keyword) reset(ads_other)
Инициализируйте ячейки маски и функция.
TrainingFeatures = {}; TrainingMasks = {};
Выполните извлечение признаков снова. Каждый сигнал повреждается шумом с вероятностью 85%, таким образом, ваш увеличенный набор данных имеет приблизительно 85%-е зашумленные данные и 15%-е бесшумные данные.
tic parfor ii = 1:numPartitions subads_keyword = partition(ads_keyword,numPartitions,ii); subads_other = partition(ads_other,numPartitions,ii); count = 1; localFeatures = cell(length(subads_keyword.Files),1); localMasks = cell(length(subads_keyword.Files),1); while hasdata(subads_keyword) [sentence,mask] = HelperSynthesizeSentence(subads_keyword,subads_other,fs,WindowLength); % Corrupt with noise augmentedData = augment(ada,sentence,fs); sentence = augmentedData.Audio{1}; % Compute mfcc features featureMatrix = extract(afe, sentence); featureMatrix(~isfinite(featureMatrix)) = 0; hopLength = WindowLength - OverlapLength; range = hopLength * (1:size(featureMatrix,1)) + hopLength; featureMask = zeros(size(range)); for index = 1:numel(range) featureMask(index) = mode(mask( (index-1)*hopLength+1:(index-1)*hopLength+WindowLength )); end localFeatures{count} = featureMatrix; localMasks{count} = [emptyCategories,categorical(featureMask)]; count = count + 1; end TrainingFeatures = [TrainingFeatures;localFeatures]; TrainingMasks = [TrainingMasks;localMasks]; end fprintf('Training feature extraction took %f seconds.\n',toc)
Training feature extraction took 36.809452 seconds.
Вычислите среднее и стандартное отклонение для каждого коэффициента; используйте их, чтобы нормировать данные.
sampleFeature = TrainingFeatures{1}; numFeatures = size(sampleFeature,2); featuresMatrix = cat(1,TrainingFeatures{:}); M = mean(featuresMatrix); S = std(featuresMatrix); for index = 1:length(TrainingFeatures) f = TrainingFeatures{index}; f = (f - M) ./ S; TrainingFeatures{index} = f.'; %#ok end
Нормируйте функции валидации с новыми значениями среднего и стандартного отклонения.
FeaturesValidationNoisy = (featureMatrixV - M)./S;
Воссоздайте опции обучения. Используйте шумные базовые функции и маску для валидации.
options = trainingOptions("adam", ... "InitialLearnRate",1e-4, ... "MaxEpochs",maxEpochs, ... "MiniBatchSize",miniBatchSize, ... "Shuffle","every-epoch", ... "Verbose",false, ... "ValidationFrequency",floor(numel(TrainingFeatures)/miniBatchSize), ... "ValidationData",{FeaturesValidationNoisy.',BaselineV}, ... "Plots","training-progress", ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",5);
Обучите сеть.
[KWSNet,netInfo] = trainNetwork(TrainingFeatures,TrainingMasks,layers,options); if reduceDataset load('KWSNet.mat','KWSNet'); end
Проверьте сетевую точность на сигнале валидации.
v = classify(KWSNet,FeaturesValidationNoisy.');
Сравните предполагаемые и ожидаемые маски KWS.
figure cm = confusionchart(BaselineV,v,"title","Validation Accuracy with Data Augmentation"); cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";
Слушайте идентифицированные области ключевого слова.
v = double(v) - 1; v = repmat(v,HopLength,1); v = v(:); sound(audioIn(logical(v)),fs)
Визуализируйте предполагаемые и ожидаемые маски.
fig = figure; plot(t,[audioInNoisy(1:length(v)),v,0.8*baseline]) grid on xlabel('Time (s)') legend('Training Signal','Network Mask','Baseline Mask','Location','southeast') l = findall(fig,'type','line'); l(1).LineWidth = 2; l(2).LineWidth = 2; title('Results for Noisy Speech - With Data Augmentation')
[1] Начальник П. "Речевые Команды: общедоступный набор данных для распознавания речи однословного", 2017. Доступный от https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Авторское право Google 2017. Речевой Набор данных Команд лицензируется при Приписывании Creative Commons 4,0 лицензии.
function [sentence,mask] = HelperSynthesizeSentence(ads_keyword,ads_other,fs,minlength) % Read one keyword keyword = read(ads_keyword); keyword = keyword ./ max(abs(keyword)); % Identify region of interest speechIndices = detectSpeech(keyword,fs); if isempty(speechIndices) || diff(speechIndices(1,:)) <= minlength speechIndices = [1,length(keyword)]; end keyword = keyword(speechIndices(1,1):speechIndices(1,2)); % Pick a random number of other words (between 0 and 10) numWords = randi([0 10]); % Pick where to insert keyword loc = randi([1 numWords+1]); sentence = []; mask = []; for index = 1:numWords+1 if index==loc sentence = [sentence;keyword]; newMask = ones(size(keyword)); mask = [mask ;newMask]; else other = read(ads_other); other = other ./ max(abs(other)); sentence = [sentence;other]; mask = [mask;zeros(size(other))]; end end end
bilstmLayer
| sequenceInputLayer
| trainingOptions
| trainNetwork