В этом примере показано, как использовать сверточную глубокую сеть, чтобы узнать, что предварительный акцент фильтрует для распознавания речи. Пример использует learnable слой кратковременного преобразования Фурье (STFT), чтобы получить представление частоты времени, подходящее для использования с 2D сверточными слоями. Использование learnable STFT включает основанную на градиенте оптимизацию весов фильтра перед акцентом.
Клонируйте или загрузите Свободный разговорный набор данных цифры (FSDD), доступный в https://github.com/Jakobovski/free-spoken-digit-dataset. FSDD является открытым набором данных, что означает, что это может расти в зависимости от времени. Этот пример использует версию, фиксировавшую 08/20/2020, который состоит из 3 000 записей английских цифр 0 через 9 полученных от шести докладчиков. Данные производятся на уровне 8 000 Гц.
Этот пример принимает, что вы загрузили данные в папку, соответствующую значению tempdir
в MATLAB. Если вы используете другую папку, заменяете тем именем папки tempdir
в следующем коде. Используйте audioDatastore, чтобы управлять доступом к данным и гарантировать случайное деление данных в наборы обучающих данных и наборы тестов.
tempdir = '/mathworks/devel/sandbox/wking'; pathToRecordingsFolder = fullfile(tempdir,'free-spoken-digit-dataset','recordings'); ads = audioDatastore(pathToRecordingsFolder);
Функция помощника helpergenLabels
создает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels
перечислен в приложении. Перечислите классы и количество примеров в каждом классе. Может потребоваться несколько минут, чтобы сгенерировать все метки для этого набора данных.
ads.Labels = helpergenLabels(ads); summary(ads.Labels)
0 300 1 300 2 300 3 300 4 300 5 300 6 300 7 300 8 300 9 300
Разделите FSDD в наборы обучающих данных и наборы тестов, обеспечивающие равные пропорции класса в каждом подмножестве. Для восстанавливаемых результатов, набор генератор случайных чисел к его значению по умолчанию. Восемьдесят процентов, или 2 400 записей, используются для обучения. Остающиеся 600 записей, 20% общего количества, протянуты для тестирования. Переставьте файлы в datastore, однажды создающем наборы обучающих данных и наборы тестов.
rng default;
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8,0.2);
Записи в FSDD не равны в длине. Используйте преобразование так, чтобы каждое чтение от datastore было дополнено или усеченное к 8 192 выборкам. Данные дополнительно брошены к с одинарной точностью, и нормализация z-счета применяется.
transTrain = transform(adsTrain,@(x,info)helperReadData(x,info),'IncludeInfo',true); transTest = transform(adsTest,@(x,info)helperReadData(x,info),'IncludeInfo',true);
Этот пример использует пользовательский учебный цикл со следующей глубокой сверточной сетью.
numF = 12; dropoutProb = 0.2; layers = [ sequenceInputLayer(1,'Name','input','MinLength',8192,... 'Normalization',"none") convolution1dLayer(5,1,"name","pre-emphasis-filter",... "WeightsInitializer",@(sz)kronDelta(sz),"BiasLearnRateFactor",0) stftLayer('Window',hamming(1280),'OverlapLength',900,... 'OutputMode','spatial','Name','STFT') convolution2dLayer(5,numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,2*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2) dropoutLayer(dropoutProb) fullyConnectedLayer(numel(categories(ads.Labels))) softmaxLayer ]; dlnet = dlnetwork(layers);
Входной слой последовательности сопровождается 1D слоем свертки, состоящим из одного фильтра с 5 коэффициентами. Это - конечный фильтр импульсной характеристики. Сверточные слои в нейронных сетях для глубокого обучения значением по умолчанию реализуют аффинную операцию на входных функциях. Чтобы получить строго линейное (фильтрация) операция, используйте 'BiasInitializer'
по умолчанию который является
'zeros'
и набор смещение изучает фактор уровня слоя к 0. Это означает, что смещение инициализируется к 0 и никогда не изменяется во время обучения. Сеть использует пользовательскую инициализацию весов фильтра, чтобы быть масштабированной Кронекеровой последовательностью дельты. Это - фильтр allpass, который не выполняет фильтрации входа.
stftLayer
берет отфильтрованный пакет входных сигналов и получает их величину STFTs. STFT величины является 2D представлением сигнала, который подсуден, чтобы использовать в 2D сверточных сетях.
В то время как веса STFT не изменяются здесь во время обучения, слой поддерживает обратную связь, которая позволяет коэффициентам фильтра в слое "пред фильтр акцента" быть изученными.
Установите опции обучения для пользовательского учебного цикла. Используйте 25 эпох с minbatch размером 128. Установите начальную букву, изучают уровень 0,001.
NumEpochs = 25; miniBatchSize = 128; learnRate = 0.001;
В пользовательском учебном цикле используйте minibatchqueue
объект. processSpeechMB
функционируйте чтения в мини-пакете, и применяет схему прямого кодирования к меткам.
mbqTrain = minibatchqueue(transTrain, 2,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFormat', {'CBT', 'CB'}, ... 'MiniBatchFcn', @processSpeechMB);
Обучите сеть и постройте потерю для каждой итерации. Используйте оптимизатор Адама, чтобы обновить сетевые настраиваемые параметры. Чтобы построить потерю как процесс обучения, установите значение progress
в следующем коде к "процессу обучения".
progress = "final-loss"; if progress == "training-progress" figure lineLossTrain = animatedline; ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end % Initialize some training loop variables. trailingAvg = []; trailingAvgSq = []; iteration = 0; lossByIteration = 0; % Loop over epochs. Time the epochs. start = tic; for epoch = 1:NumEpochs reset(mbqTrain) shuffle(mbqTrain) % Loop over mini-batches. while hasdata(mbqTrain) iteration = iteration + 1; % Get the next minibatch and one-hot coded targets [dlX,Y] = next(mbqTrain); % Evaluate the model gradients and loss [gradients, loss, state] = dlfeval(@modelGradSTFT,dlnet,dlX,Y); if progress == "final-loss" lossByIteration(iteration) = loss; end % Update the network state dlnet.State = state; % Update the network parameters using an Adam optimizer. [dlnet,trailingAvg,trailingAvgSq] = adamupdate(... dlnet, gradients, trailingAvg, trailingAvgSq, iteration, learnRate); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); if progress == "training-progress" addpoints(lineLossTrain,iteration,loss) title("Epoch: " + epoch + ", Elapsed: " + string(D)) end end disp("Training loss after epoch " + epoch + ": " + loss); end
Training loss after epoch 1: 0.78569 Training loss after epoch 2: 0.33833 Training loss after epoch 3: 0.27921 Training loss after epoch 4: 0.11701 Training loss after epoch 5: 0.15688 Training loss after epoch 6: 0.032381 Training loss after epoch 7: 0.021219 Training loss after epoch 8: 0.048071 Training loss after epoch 9: 0.019537 Training loss after epoch 10: 0.055428 Training loss after epoch 11: 0.029689 Training loss after epoch 12: 0.021452 Training loss after epoch 13: 0.023566 Training loss after epoch 14: 0.010125 Training loss after epoch 15: 0.0027084 Training loss after epoch 16: 0.0074854 Training loss after epoch 17: 0.0053942 Training loss after epoch 18: 0.029233 Training loss after epoch 19: 0.016945 Training loss after epoch 20: 0.0096544 Training loss after epoch 21: 0.0023757 Training loss after epoch 22: 0.0028348 Training loss after epoch 23: 0.0041876 Training loss after epoch 24: 0.0017663 Training loss after epoch 25: 0.000395
if progress == "final-loss" plot(1:iteration,lossByIteration) grid on title('Training Loss by Iteration') xlabel("Iteration") ylabel("Loss") end
Протестируйте обучивший сеть на протянутом наборе тестов. Используйте minibatchqueue
объект с мини-пакетным размером 32.
miniBatchSize = 32; mbqTest = minibatchqueue(transTest, 2,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFormat', {'CBT', 'CB'}, ... 'MiniBatchFcn', @processSpeechMB);
Цикл по набору тестов и предсказывает метки класса для каждого мини-пакета.
numObservations = numel(adsTest.Files); classes = string(unique(adsTest.Labels)); predictions = []; % Loop over mini-batches. while hasdata(mbqTest) % Read mini-batch of data. dlX = next(mbqTest); % Make predictions on the minibatch dlYPred = predict(dlnet,dlX,'Acceleration','none'); % Determine corresponding classes. predBatch = onehotdecode(dlYPred,classes,1); predictions = [predictions predBatch]; end
Оцените точность классификации на этих 600 примерах в протянутом наборе тестов.
accuracy = mean(predictions' == categorical(adsTest.Labels))
accuracy = 0.9833
Проведение испытаний составляет приблизительно 98%. Можно закомментировать 1D слой свертки и переобучить сеть без фильтра перед акцентом. Проведение испытаний без фильтра перед акцентом также превосходно приблизительно в 96%, но использование фильтра перед акцентом делает маленькое улучшение. Это примечательно, что, в то время как использование изученного фильтра перед акцентом только улучшило тестовую точность немного, это было достигнуто путем добавления только 5 настраиваемых параметров в сеть.
Чтобы исследовать изученный фильтр перед акцентом, извлеките веса 1D сверточного слоя. Постройте частотную характеристику. Вспомните, что частота дискретизации данных составляет 8 кГц. Поскольку мы инициализировали фильтр к масштабированной Кронекеровой последовательности дельты (allpass фильтр), мы можем легко сравнить частотную характеристику инициализированного фильтра с изученным ответом.
FIRFilter = dlnet.Layers(2).Weights; [H,W] = freqz(FIRFilter,1,[],8000); delta = kronDelta([5 1 1]); Hinit = freqz(delta,1,[],4000); plot(W,20*log10(abs([H Hinit])),'linewidth',2) grid on xlabel('Hz') ylabel('dB') legend('Learned Filter','Initial Filter','Location','SouthEast') title('Learned Pre-emphasis Filter')
Этот пример показал, как изучить фильтр перед акцентом как шаг предварительной обработки в 2D сверточной сети на основе кратковременных преобразований Фурье сигналов. Способность stftLayer
поддерживать обратную связь включило основанную на градиенте оптимизацию весов фильтра в глубокой сети. В то время как это привело только к маленькому улучшению эффективности сети на наборе тестов, это достигло этого улучшения с тривиальным увеличением количества настраиваемых параметров.
function Labels = helpergenLabels(ads) % This function is only for use in the "Learn Pre-Emphasis Filter using % Deep Learning" example. It may change or be removed in a % future release. tmp = cell(numel(ads.Files),1); expression = "[0-9]+_"; for nf = 1:numel(ads.Files) idx = regexp(ads.Files{nf},expression); tmp{nf} = ads.Files{nf}(idx); end Labels = categorical(tmp); end function [out,info] = helperReadData(x,info) % This function is only for use in the "Learn Pre-Emphasis Filter using % Deep Learning" example. It may change or be removed in a % future release. N = numel(x); x = single(x); if N > 8192 x = x(1:8192); elseif N < 8192 pad = 8192-N; prepad = floor(pad/2); postpad = ceil(pad/2); x = [zeros(prepad,1) ; x ; zeros(postpad,1)]; end x = (x-mean(x))./std(x); x = x(:)'; out = {x,info.Label}; end function [dlX,dlY] = processSpeechMB(Xcell,Ycell) % This function is only for use in the "Learn Pre-Emphasis Filter using % Deep Learning" example. It may change or be removed in a % future release. Xcell = cellfun(@(x)reshape(x,1,1,[]),Xcell,'uni',false); dlX = cat(2,Xcell{:}); dlY = cat(2,Ycell{:}); dlY = onehotencode(dlY,1); end function [grads, loss, state] = modelGradSTFT(net, X, T) % This function is only for use in the "Learn Pre-Emphasis Filter using % Deep Learning" example. It may change or be removed in a % future release. [y, state] = net.forward(X); loss = crossentropy(y, T); grads = dlgradient(loss,net.Learnables); loss = double(gather(extractdata(loss))); end function delta = kronDelta(sz) % This function is only for use in the "Learn Pre-Emphasis Filter using % Deep Learning" example. It may change or be removed in a % future release. L = sz(1); delta = zeros(L,sz(2),sz(3),'single'); delta(1) = 1/sqrt(L); end
dlstft
(Signal Processing Toolbox) | stft
(Signal Processing Toolbox) | istft
(Signal Processing Toolbox) | stftmag2sig
(Signal Processing Toolbox)