Этот пример обучает разговорную сеть распознавания цифры на слуховых спектрограммах из памяти с помощью преобразованного datastore. В этом примере вы извлекаете слуховые спектрограммы из аудио с помощью audioDatastore
и audioFeatureExtractor
, и вы пишете им в диск. Вы затем используете signalDatastore
получать доступ к функциям во время обучения. Рабочий процесс полезен, когда учебные функции не умещаются в памяти. В этом рабочем процессе вы только извлекаете функции однажды, который ускоряет ваш рабочий процесс, если вы выполняете итерации на проекте модели глубокого обучения.
Загрузите Свободный Разговорный Набор данных Цифры (FSDD). FSDD состоит из 2 000 записей четырех динамиков, говорящих числа 0 до 9 на английском языке.
url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip'; downloadFolder = tempdir; datasetFolder = fullfile(downloadFolder,'FSDD'); if ~exist(datasetFolder,'dir') disp('Downloading FSDD...') unzip(url,downloadFolder) end
Downloading FSDD...
Создайте audioDatastore
это указывает на набор данных.
pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);
Функция помощника, helperGenerateLabels
, создает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels
перечислен в приложении. Отобразите классы и количество примеров в каждом классе.
ads.Labels = helpergenLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
summary(ads.Labels)
0 200 1 200 2 200 3 200 4 200 5 200 6 200 7 200 8 200 9 200
Разделите FSDD в наборы обучающих данных и наборы тестов. Выделите 80% данных к набору обучающих данных и сохраните 20% для набора тестов. Вы используете набор обучающих данных, чтобы обучить модель и набор тестов подтверждать обученную модель.
rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
Label Count
_____ _____
0 160
1 160
2 160
3 160
4 160
5 160
6 160
7 160
8 160
9 160
countEachLabel(adsTest)
ans=10×2 table
Label Count
_____ _____
0 40
1 40
2 40
3 40
4 40
5 40
6 40
7 40
8 40
9 40
Чтобы обучить сеть с набором данных в целом и достигнуть максимально возможной точности, установите reduceDataset
ко лжи. Чтобы запустить этот пример быстро, установите reduceDataset
к истине.
reduceDataset = "false"; if reduceDataset == "true" adsTrain = splitEachLabel (adsTrain, 2); adsTest = splitEachLabel (adsTest, 2); end
CNN принимает спектрограммы mel-частоты.
Задайте параметры, используемые, чтобы извлечь спектрограммы mel-частоты. Используйте 220 MS Windows с транзитными участками на 10 мс между окнами. Используйте ДПФ с 2048 точками и 40 диапазонов частот.
fs = 8000; frameDuration = 0.22; hopDuration = 0.01; params.segmentLength = 8192; segmentDuration = params.segmentLength*(1/fs); params.numHops = ceil((segmentDuration-frameDuration)/hopDuration); params.numBands = 40; frameLength = round(frameDuration*fs); hopLength = round(hopDuration*fs); fftLength = 2048;
Создайте audioFeatureExtractor
объект вычислить спектрограммы mel-частоты из входных звуковых сигналов.
afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs); afe.Window = hamming(frameLength,'periodic'); afe.OverlapLength = frameLength-hopLength; afe.FFTLength = fftLength;
Установите параметры для спектрограммы mel-частоты.
setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);
Создайте преобразованный datastore, который вычисляет спектрограммы mel-частоты из аудиоданных. Функция помощника, getSpeechSpectrogram
(см. приложение), стандартизирует продолжительность записи и нормирует амплитуду аудиовхода. getSpeechSpectrogram
использует audioFeatureExtractor
объект afe
получить основанные на журнале спектрограммы mel-частоты.
adsSpecTrain = transform(adsTrain,@(x)getSpeechSpectrogram(x,afe,params));
Используйте writeall
записать слуховые спектрограммы в диск. Установите UseParallel
к истине, чтобы выполнить запись параллельно.
writeall(adsSpecTrain,pwd,'WriteFcn',@myCustomWriter,'UseParallel',true);
Создайте signalDatastore
это указывает на функции из памяти. Пользовательская функция чтения возвращает пару спектрограммы/метки.
sds = signalDatastore('recordings','ReadFcn',@myCustomRead);
Набор данных валидации помещается в память, и вы предварительно вычисляете функции валидации с помощью функции помощника getValidationSpeechSpectrograms
(см. приложение).
XTest = getValidationSpeechSpectrograms(adsTest,afe,params);
Получите метки валидации.
YTest = adsTest.Labels;
Создайте маленький CNN как массив слоев. Используйте сверточный и слои нормализации партии. и проредите карты функции с помощью макс. слоев объединения. Чтобы уменьшать возможность сети, запоминая определенные функции обучающих данных, добавьте небольшое количество уволенного к входу к последнему полносвязному слою.
sz = size(XTest); specSize = sz(1:2); imageSize = [specSize 1]; numClasses = numel(categories(YTest)); dropoutProb = 0.2; numF = 12; layers = [ imageInputLayer(imageSize,'Normalization','none') convolution2dLayer(5,numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,2*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2) dropoutLayer(dropoutProb) fullyConnectedLayer(numClasses) softmaxLayer classificationLayer('Classes',categories(YTest)); ];
Установите гиперпараметры использовать в обучении сети. Используйте мини-пакетный размер 50 и скорость обучения 1e-4. Задайте 'adam
Оптимизация. Чтобы использовать параллельный пул, чтобы считать преобразованный datastore, устанавливает DispatchInBackground
к true
. Для получения дополнительной информации смотрите trainingOptions
(Deep Learning Toolbox).
miniBatchSize = 50; options = trainingOptions('adam', ... 'InitialLearnRate',1e-4, ... 'MaxEpochs',30, ... 'LearnRateSchedule',"piecewise",... 'LearnRateDropFactor',.1,... 'LearnRateDropPeriod',15,... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'Plots','training-progress', ... 'Verbose',false, ... 'ValidationData',{XTest, YTest},... 'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),... 'ExecutionEnvironment','gpu',... 'DispatchInBackground', true);
Обучите сеть путем передачи учебного datastore trainNetwork
.
trainedNet = trainNetwork(sds,layers,options);
Используйте обучивший сеть, чтобы предсказать метки цифры для набора тестов.
[Ypredicted,probs] = classify(trainedNet,XTest); cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 96.5000
Обобщите эффективность обучившего сеть на наборе тестов с графиком беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца. Таблица в нижней части графика беспорядка показывает значения точности. Таблица справа от графика беспорядка показывает значения отзыва.
figure('Units','normalized','Position',[0.2 0.2 1.5 1.5]); ccDCNN = confusionchart(YTest,Ypredicted); ccDCNN.Title = 'Confusion Chart for DCNN'; ccDCNN.ColumnSummary = 'column-normalized'; ccDCNN.RowSummary = 'row-normalized';
function Labels = helpergenLabels(ads) % This function is only for use in this example. It may be changed or % removed in a future release. files = ads.Files; tmp = cell(numel(files),1); expression = "[0-9]+_"; parfor nf = 1:numel(ads.Files) idx = regexp(files{nf},expression); tmp{nf} = files{nf}(idx); end Labels = categorical(tmp); end %------------------------------------------------------------ function X = getValidationSpeechSpectrograms(ads,afe,params) % This function is only for use in this example. It may changed or be % removed in a future release. % % getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for % the files in the datastore ads using the audioFeatureExtractor afe. numFiles = length(ads.Files); X = zeros([params.numBands,params.numHops,1,numFiles],'single'); for i = 1:numFiles x = read(ads); spec = getSpeechSpectrogram(x,afe,params); X(:,:,1,i) = spec; end end %-------------------------------------------------------------------------- function X = getSpeechSpectrogram(x,afe,params) % This function is only for use in this example. It may changed or be % removed in a future release. % % getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal % x using the audioFeatureExtractor afe. X = zeros([params.numBands,params.numHops],'single'); x = normalizeAndResize(single(x),params); spec = extract(afe,x).'; % If the spectrogram is less wide than numHops, then put spectrogram in % the middle of X. w = size(spec,2); left = floor((params.numHops-w)/2)+1; ind = left:left+w-1; X(:,ind) = log10(spec + 1e-6); end %-------------------------------------------------------------------------- function x = normalizeAndResize(x,params) % This function is only for use in this example. It may change or be % removed in a future release. L = params.segmentLength; N = numel(x); if N > L x = x(1:L); elseif N < L pad = L-N; prepad = floor(pad/2); postpad = ceil(pad/2); x = [zeros(prepad,1) ; x ; zeros(postpad,1)]; end x = x./max(abs(x)); end %-------------------------------------------------------------------------- function myCustomWriter(spec,writeInfo,~) % This function is only for use in this example. It may change or be % removed in a future release. % Define custom writing function to write auditory spectrogram/label pair % to MAT files. filename = strrep(writeInfo.SuggestedOutputName, '.wav','.mat'); label = writeInfo.ReadInfo.Label; save(filename,'label','spec'); end %-------------------------------------------------------------------------- function [data,info] = myCustomRead(filename) % This function is only for use in this example. It may change or be % removed in a future release. load(filename,'spec','label'); data = {spec,label}; info.SampleRate = 8000; end