Этот пример обучает сеть распознавания разговорных цифр на слуховых спектрограммах за пределами памяти, используя преобразованный datastore. В этом примере вы извлекаете слуховые спектрограммы из аудио с помощью audioDatastore
и audioFeatureExtractor
и вы записываете их на диск. Затем вы используете signalDatastore
для доступа к функциям во время обучения. Рабочий процесс полезен, когда функции обучения не помещаются в памяти. В этом рабочем процессе вы извлекаете функции только один раз, что ускоряет рабочий процесс, если вы итератируетесь по проекту модели глубокого обучения.
Загрузите бесплатный набор данных (FSDD). FSDD состоит из 2000 записей четырех ораторов, говорящих цифры от 0 до 9 на английском языке.
url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip'; downloadFolder = tempdir; datasetFolder = fullfile(downloadFolder,'FSDD'); if ~exist(datasetFolder,'dir') disp('Downloading FSDD...') unzip(url,downloadFolder) end
Downloading FSDD...
Создайте audioDatastore
это указывает на набор данных.
pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);
Функция помощника, helperGenerateLabels
создает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels
приведено в приложении. Отображение классов и количества примеров в каждом классе.
ads.Labels = helpergenLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
summary(ads.Labels)
0 200 1 200 2 200 3 200 4 200 5 200 6 200 7 200 8 200 9 200
Разделите FSDD на обучающие и тестовые наборы. Выделите 80% данных наборов обучающих данных и сохраните 20% для тестового набора. Вы используете набор обучающих данных для обучения модели и тестовый набор для проверки обученной модели.
rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
Label Count
_____ _____
0 160
1 160
2 160
3 160
4 160
5 160
6 160
7 160
8 160
9 160
countEachLabel(adsTest)
ans=10×2 table
Label Count
_____ _____
0 40
1 40
2 40
3 40
4 40
5 40
6 40
7 40
8 40
9 40
Чтобы обучить сеть со набором данных в целом и достичь максимально возможной точности, установите reduceDataset
на ложь. Чтобы запустить этот пример быстро, установите reduceDataset
к true.
reduceDataset = "false"; if reduceDataset = ="true" adsTrain = splitEachLabel (adsTrain, 2); adsTest = splitEachLabel (adsTest, 2); end
CNN принимает мел-частотные спектрограммы.
Задайте параметры, используемые для извлечения мел-частотных спектрограмм. Используйте 220 мс окна с 10 мс хмеля между окнами. Используйте ДПФ с 2048 точками и 40 полосы частот.
fs = 8000; frameDuration = 0.22; hopDuration = 0.01; params.segmentLength = 8192; segmentDuration = params.segmentLength*(1/fs); params.numHops = ceil((segmentDuration-frameDuration)/hopDuration); params.numBands = 40; frameLength = round(frameDuration*fs); hopLength = round(hopDuration*fs); fftLength = 2048;
Создайте audioFeatureExtractor
объект для вычисления спектрограмм mel-частоты из входных аудиосигналов.
afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs); afe.Window = hamming(frameLength,'periodic'); afe.OverlapLength = frameLength-hopLength; afe.FFTLength = fftLength;
Установите параметры для спектрограммы мел-частоты.
setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);
Создайте преобразованный datastore, который вычисляет спектрограммы mel-частоты из аудио данных. Функция помощника, getSpeechSpectrogram
(см. приложение), стандартизирует длину записи и нормализует амплитуду аудиовхода. getSpeechSpectrogram
использует audioFeatureExtractor
afe объекта
для получения логарифмически основанных мел-частотных спектрограмм.
adsSpecTrain = transform(adsTrain,@(x)getSpeechSpectrogram(x,afe,params));
Использование writeall
запись слуховых спектрограмм на диск. Задайте UseParallel
true, чтобы выполнять запись параллельно.
writeall(adsSpecTrain,pwd,'WriteFcn',@myCustomWriter,'UseParallel',true);
Создайте signalDatastore
это указывает на функции нехватки памяти. Пользовательская функция чтения возвращает пару спектрограмма/метка.
sds = signalDatastore('recordings','ReadFcn',@myCustomRead);
Набор данных валидации помещается в память, и вы предварительно вычисляете функции валидации с помощью функции helper getValidationSpeechSpectrograms
(см. приложение).
XTest = getValidationSpeechSpectrograms(adsTest,afe,params);
Получите метки валидации.
YTest = adsTest.Labels;
Создайте небольшой CNN как массив слоев. Используйте сверточные и пакетные слои нормализации и понижающее отображение карт признаков с помощью максимальных слоев объединения. Чтобы уменьшить возможность запоминания сетью специфических функций обучающих данных, добавьте небольшое количество отсева на вход к последнему полностью подключенному слою.
sz = size(XTest); specSize = sz(1:2); imageSize = [specSize 1]; numClasses = numel(categories(YTest)); dropoutProb = 0.2; numF = 12; layers = [ imageInputLayer(imageSize,'Normalization','none') convolution2dLayer(5,numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,2*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2) dropoutLayer(dropoutProb) fullyConnectedLayer(numClasses) softmaxLayer classificationLayer('Classes',categories(YTest)); ];
Установите гиперпараметры, которые будут использоваться при обучении сети. Используйте мини-пакет размером 50 и скоростью обучения 1e-4. Задайте 'adam
'оптимизация. Чтобы использовать параллельный пул для чтения преобразованного datastore, установите DispatchInBackground
на true
. Для получения дополнительной информации смотрите trainingOptions
(Deep Learning Toolbox).
miniBatchSize = 50; options = trainingOptions('adam', ... 'InitialLearnRate',1e-4, ... 'MaxEpochs',30, ... 'LearnRateSchedule',"piecewise",... 'LearnRateDropFactor',.1,... 'LearnRateDropPeriod',15,... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'Plots','training-progress', ... 'Verbose',false, ... 'ValidationData',{XTest, YTest},... 'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),... 'ExecutionEnvironment','gpu',... 'DispatchInBackground', true);
Обучите сеть, передав обучающий datastore, чтобы trainNetwork
.
trainedNet = trainNetwork(sds,layers,options);
Используйте обученную сеть, чтобы предсказать метки цифр для тестового набора.
[Ypredicted,probs] = classify(trainedNet,XTest); cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 96.5000
Суммируйте эффективность обученной сети на тестовом наборе с графиком неточностей. Отображение точности и отзыва для каждого класса с помощью сводных данных по столбцам и строкам. Таблица в нижней части графика неточностей показывает значения точности. Таблица справа от графика неточностей показывает значения отзыва.
figure('Units','normalized','Position',[0.2 0.2 1.5 1.5]); ccDCNN = confusionchart(YTest,Ypredicted); ccDCNN.Title = 'Confusion Chart for DCNN'; ccDCNN.ColumnSummary = 'column-normalized'; ccDCNN.RowSummary = 'row-normalized';
function Labels = helpergenLabels(ads) % This function is only for use in this example. It may be changed or % removed in a future release. files = ads.Files; tmp = cell(numel(files),1); expression = "[0-9]+_"; parfor nf = 1:numel(ads.Files) idx = regexp(files{nf},expression); tmp{nf} = files{nf}(idx); end Labels = categorical(tmp); end %------------------------------------------------------------ function X = getValidationSpeechSpectrograms(ads,afe,params) % This function is only for use in this example. It may changed or be % removed in a future release. % % getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for % the files in the datastore ads using the audioFeatureExtractor afe. numFiles = length(ads.Files); X = zeros([params.numBands,params.numHops,1,numFiles],'single'); for i = 1:numFiles x = read(ads); spec = getSpeechSpectrogram(x,afe,params); X(:,:,1,i) = spec; end end %-------------------------------------------------------------------------- function X = getSpeechSpectrogram(x,afe,params) % This function is only for use in this example. It may changed or be % removed in a future release. % % getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal % x using the audioFeatureExtractor afe. X = zeros([params.numBands,params.numHops],'single'); x = normalizeAndResize(single(x),params); spec = extract(afe,x).'; % If the spectrogram is less wide than numHops, then put spectrogram in % the middle of X. w = size(spec,2); left = floor((params.numHops-w)/2)+1; ind = left:left+w-1; X(:,ind) = log10(spec + 1e-6); end %-------------------------------------------------------------------------- function x = normalizeAndResize(x,params) % This function is only for use in this example. It may change or be % removed in a future release. L = params.segmentLength; N = numel(x); if N > L x = x(1:L); elseif N < L pad = L-N; prepad = floor(pad/2); postpad = ceil(pad/2); x = [zeros(prepad,1) ; x ; zeros(postpad,1)]; end x = x./max(abs(x)); end %-------------------------------------------------------------------------- function myCustomWriter(spec,writeInfo,~) % This function is only for use in this example. It may change or be % removed in a future release. % Define custom writing function to write auditory spectrogram/label pair % to MAT files. filename = strrep(writeInfo.SuggestedOutputName, '.wav','.mat'); label = writeInfo.ReadInfo.Label; save(filename,'label','spec'); end %-------------------------------------------------------------------------- function [data,info] = myCustomRead(filename) % This function is only for use in this example. It may change or be % removed in a future release. load(filename,'spec','label'); data = {spec,label}; info.SampleRate = 8000; end