В этом примере обучается сеть распознавания речевых цифр на слуховых спектрограммах, не имеющих памяти, с использованием преобразованного хранилища данных. В этом примере вы извлекаете слуховые спектрограммы из звука с помощью audioDatastore и audioFeatureExtractorи вы записываете их на диск. Затем вы используете signalDatastore для доступа к функциям во время обучения. Рабочий процесс полезен, когда обучающие функции не помещаются в память. В этом рабочем процессе элементы извлекаются только один раз, что ускоряет рабочий процесс, если выполняется итерация в проекте модели глубокого обучения.
Загрузите бесплатный набор данных речевых цифр (FSDD). FSDD состоит из 2000 записей четырёх говорящих, говорящих цифры от 0 до 9 на английском языке.
url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip'; downloadFolder = tempdir; datasetFolder = fullfile(downloadFolder,'FSDD'); if ~exist(datasetFolder,'dir') disp('Downloading FSDD...') unzip(url,downloadFolder) end
Downloading FSDD...
Создание audioDatastore указывает на набор данных.
pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);Вспомогательная функция, helperGenerateLabelsсоздает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels перечислены в приложении. Просмотрите классы и количество примеров в каждом классе.
ads.Labels = helpergenLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
summary(ads.Labels)
0 200
1 200
2 200
3 200
4 200
5 200
6 200
7 200
8 200
9 200
Разделение FSDD на учебные и тестовые наборы. Выделите 80% данных обучающему набору и сохраните 20% для тестового набора. Обучающий набор используется для обучения модели, а тестовый набор - для проверки обучаемой модели.
rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)ans=10×2 table
Label Count
_____ _____
0 160
1 160
2 160
3 160
4 160
5 160
6 160
7 160
8 160
9 160
countEachLabel(adsTest)
ans=10×2 table
Label Count
_____ _____
0 40
1 40
2 40
3 40
4 40
5 40
6 40
7 40
8 40
9 40
Чтобы обучить сеть всему набору данных и достичь максимально возможной точности, установите reduceDataset в значение false. Для быстрого выполнения этого примера установите reduceDataset к true.
reduceDataset ="false"; if reduceDataset == "true" adsTrain = splitEachLabel(adsTrain,2); adsTest = splitEachLabel(adsTest,2); end
CNN принимает спектрограммы мел-частоты.
Определите параметры, используемые для извлечения мел-частотных спектрограмм. Используйте окна 220 мс с 10 мс прыжками между окнами. Используйте 2048-точечный DFT и 40 диапазонов частот.
fs = 8000; frameDuration = 0.22; hopDuration = 0.01; params.segmentLength = 8192; segmentDuration = params.segmentLength*(1/fs); params.numHops = ceil((segmentDuration-frameDuration)/hopDuration); params.numBands = 40; frameLength = round(frameDuration*fs); hopLength = round(hopDuration*fs); fftLength = 2048;
Создание audioFeatureExtractor Сущность изобретения заключается в том, что на основе входных аудиосигналов вычисляют мель-частотные спектрограммы.
afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs); afe.Window = hamming(frameLength,'periodic'); afe.OverlapLength = frameLength-hopLength; afe.FFTLength = fftLength;
Задайте параметры для спектрограммы mel-frequency.
setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);
Создайте преобразованное хранилище данных, которое вычисляет частотные спектрограммы из аудиоданных. Вспомогательная функция, getSpeechSpectrogram (см. приложение), стандартизирует длину записи и нормализует амплитуду звукового входа. getSpeechSpectrogram использует audioFeatureExtractor объект afe для получения логарифмических спектрограмм.
adsSpecTrain = transform(adsTrain,@(x)getSpeechSpectrogram(x,afe,params));
Использовать writeall для записи слуховых спектрограмм на диск. Набор UseParallel true для параллельного выполнения записи.
writeall(adsSpecTrain,pwd,'WriteFcn',@myCustomWriter,'UseParallel',true);
Создать signalDatastore указывает на отсутствие памяти. Пользовательская функция считывания возвращает пару спектрограмма/метка.
sds = signalDatastore('recordings','ReadFcn',@myCustomRead);
Набор данных проверки помещается в память, и функции проверки предварительно вычисляются с помощью вспомогательной функции. getValidationSpeechSpectrograms (см. приложение).
XTest = getValidationSpeechSpectrograms(adsTest,afe,params);
Получите метки проверки.
YTest = adsTest.Labels;
Создайте небольшой CNN в виде массива слоев. Используйте слои сверточной и пакетной нормализации, а затем выполните понижающую выборку карт элементов с использованием максимального количества слоев пула. Чтобы уменьшить возможность запоминания сетью специфических особенностей обучающих данных, добавьте небольшое количество отсева на вход последнего полностью подключенного уровня.
sz = size(XTest);
specSize = sz(1:2);
imageSize = [specSize 1];
numClasses = numel(categories(YTest));
dropoutProb = 0.2;
numF = 12;
layers = [
imageInputLayer(imageSize,'Normalization','none')
convolution2dLayer(5,numF,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(3,'Stride',2,'Padding','same')
convolution2dLayer(3,2*numF,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(3,'Stride',2,'Padding','same')
convolution2dLayer(3,4*numF,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(3,'Stride',2,'Padding','same')
convolution2dLayer(3,4*numF,'Padding','same')
batchNormalizationLayer
reluLayer
convolution2dLayer(3,4*numF,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2)
dropoutLayer(dropoutProb)
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer('Classes',categories(YTest));
];Установите гиперпараметры для использования при обучении сети. Используйте размер мини-партии 50 и скорость обучения 1e-4. Укажите 'adam'оптимизация. Чтобы использовать параллельный пул для чтения преобразованного хранилища данных, установите DispatchInBackground кому true. Дополнительные сведения см. в разделе trainingOptions (инструментарий глубокого обучения).
miniBatchSize = 50; options = trainingOptions('adam', ... 'InitialLearnRate',1e-4, ... 'MaxEpochs',30, ... 'LearnRateSchedule',"piecewise",... 'LearnRateDropFactor',.1,... 'LearnRateDropPeriod',15,... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'Plots','training-progress', ... 'Verbose',false, ... 'ValidationData',{XTest, YTest},... 'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),... 'ExecutionEnvironment','gpu',... 'DispatchInBackground', true);
Обучение сети путем передачи хранилища данных обучения в trainNetwork.
trainedNet = trainNetwork(sds,layers,options);
Используйте обученную сеть для прогнозирования цифровых меток для тестового набора.
[Ypredicted,probs] = classify(trainedNet,XTest); cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 96.5000
Обобщите производительность обученной сети на тестовом наборе с помощью таблицы путаницы. Отображение точности и отзыва для каждого класса с помощью сводок столбцов и строк. В таблице внизу таблицы путаницы показаны значения точности. В таблице справа от таблицы путаницы показаны значения отзыва.
figure('Units','normalized','Position',[0.2 0.2 1.5 1.5]); ccDCNN = confusionchart(YTest,Ypredicted); ccDCNN.Title = 'Confusion Chart for DCNN'; ccDCNN.ColumnSummary = 'column-normalized'; ccDCNN.RowSummary = 'row-normalized';

function Labels = helpergenLabels(ads) % This function is only for use in this example. It may be changed or % removed in a future release. files = ads.Files; tmp = cell(numel(files),1); expression = "[0-9]+_"; parfor nf = 1:numel(ads.Files) idx = regexp(files{nf},expression); tmp{nf} = files{nf}(idx); end Labels = categorical(tmp); end %------------------------------------------------------------ function X = getValidationSpeechSpectrograms(ads,afe,params) % This function is only for use in this example. It may changed or be % removed in a future release. % % getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for % the files in the datastore ads using the audioFeatureExtractor afe. numFiles = length(ads.Files); X = zeros([params.numBands,params.numHops,1,numFiles],'single'); for i = 1:numFiles x = read(ads); spec = getSpeechSpectrogram(x,afe,params); X(:,:,1,i) = spec; end end %-------------------------------------------------------------------------- function X = getSpeechSpectrogram(x,afe,params) % This function is only for use in this example. It may changed or be % removed in a future release. % % getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal % x using the audioFeatureExtractor afe. X = zeros([params.numBands,params.numHops],'single'); x = normalizeAndResize(single(x),params); spec = extract(afe,x).'; % If the spectrogram is less wide than numHops, then put spectrogram in % the middle of X. w = size(spec,2); left = floor((params.numHops-w)/2)+1; ind = left:left+w-1; X(:,ind) = log10(spec + 1e-6); end %-------------------------------------------------------------------------- function x = normalizeAndResize(x,params) % This function is only for use in this example. It may change or be % removed in a future release. L = params.segmentLength; N = numel(x); if N > L x = x(1:L); elseif N < L pad = L-N; prepad = floor(pad/2); postpad = ceil(pad/2); x = [zeros(prepad,1) ; x ; zeros(postpad,1)]; end x = x./max(abs(x)); end %-------------------------------------------------------------------------- function myCustomWriter(spec,writeInfo,~) % This function is only for use in this example. It may change or be % removed in a future release. % Define custom writing function to write auditory spectrogram/label pair % to MAT files. filename = strrep(writeInfo.SuggestedOutputName, '.wav','.mat'); label = writeInfo.ReadInfo.Label; save(filename,'label','spec'); end %-------------------------------------------------------------------------- function [data,info] = myCustomRead(filename) % This function is only for use in this example. It may change or be % removed in a future release. load(filename,'spec','label'); data = {spec,label}; info.SampleRate = 8000; end