Этот пример обучает сеть распознавания разговорных цифр на аудио данных вне памяти с помощью преобразованного datastore. В этом примере вы применяете случайный тангаж сдвиг к аудио данных, используемым для обучения сверточной нейронной сети (CNN). Для каждой итерации обучения аудио данных дополняется с помощью объекта audioDataAugmenter, а затем функции извлекаются с помощью объекта audioFeatureExtractor. Рабочий процесс в этом примере применяется к любому случайному увеличению данных, используемому в цикле обучения. Рабочий процесс также применяется, когда базовый набор аудио данных или функции обучения не помещаются в памяти.
Загрузите бесплатный набор данных (FSDD). FSDD состоит из 2000 записей четырех ораторов, говорящих цифры от 0 до 9 на английском языке.
url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip'; downloadFolder = tempdir; datasetFolder = fullfile(downloadFolder,'FSDD'); if ~exist(datasetFolder,'dir') disp('Downloading FSDD...') unzip(url,downloadFolder) end
Создайте audioDatastore
это указывает на набор данных.
pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);
Функция помощника, helperGenerateLabels
создает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels
приведено в приложении. Отображение классов и количества примеров в каждом классе.
ads.Labels = helpergenLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 8).
summary(ads.Labels)
0 200 1 200 2 200 3 200 4 200 5 200 6 200 7 200 8 200 9 200
Разделите FSDD на обучающие и тестовые наборы. Выделите 80% данных наборов обучающих данных и сохраните 20% для тестового набора. Вы используете набор обучающих данных для обучения модели и тестовый набор для проверки обученной модели.
rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
Label Count
_____ _____
0 160
1 160
2 160
3 160
4 160
5 160
6 160
7 160
8 160
9 160
countEachLabel(adsTest)
ans=10×2 table
Label Count
_____ _____
0 40
1 40
2 40
3 40
4 40
5 40
6 40
7 40
8 40
9 40
Чтобы обучить сеть со набором данных в целом и достичь максимально возможной точности, установите reduceDataset
на ложь. Чтобы запустить этот пример быстро, установите reduceDataset
к true.
reduceDataset = "false"; if reduceDataset = ="true" adsTrain = splitEachLabel (adsTrain, 2); adsTest = splitEachLabel (adsTest, 2); end
Увеличение обучающих данных путем применения сдвига тангажа с помощью audioDataAugmenter
объект.
Создайте audioDataAugmenter
.
Усилитель применяет смещение тангажа к входу аудиосигналу с вероятностью 0,5. Augmenter выбирает случайное значение перемены тангажа в области значений [-12 12] полутонов.
augmenter = audioDataAugmenter('PitchShiftProbability',.5, ... 'SemitoneShiftRange',[-12 12], ... 'TimeShiftProbability',0, ... 'VolumeControlProbability',0, ... 'AddNoiseProbability',0, ... 'TimeShiftProbability',0);
Установите пользовательские параметры сдвига тангажа. Используйте тождества фазы блокировки и сохранения формантов, используя оценку спектральной огибающей с кепстральным анализом 30-го порядка.
setAugmenterParams(augmenter,'shiftPitch','LockPhase',true,'PreserveFormants',true,'CepstralOrder',30);
Создайте преобразованный datastore, который применяет увеличение данных к обучающим данным.
fs = 8000; adsAugTrain = transform(adsTrain,@(y)deal(augment(augmenter,y,fs).Audio{1}));
CNN принимает мел-частотные спектрограммы.
Задайте параметры, используемые для извлечения мел-частотных спектрограмм. Используйте 220 мс окна с 10 мс хмеля между окнами. Используйте ДПФ с 2048 точками и 40 полосы частот.
frameDuration = 0.22; hopDuration = 0.01; params.segmentLength = 8192; segmentDuration = params.segmentLength*(1/fs); params.numHops = ceil((segmentDuration-frameDuration)/hopDuration); params.numBands = 40; frameLength = round(frameDuration*fs); hopLength = round(hopDuration*fs); fftLength = 2048;
Создайте audioFeatureExtractor
объект для вычисления спектрограмм mel-частоты из входных аудиосигналов.
afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs); afe.Window = hamming(frameLength,'periodic'); afe.OverlapLength = frameLength-hopLength; afe.FFTLength = fftLength;
Установите параметры для спектрограммы мел-частоты.
setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);
Создайте преобразованный datastore, который вычисляет спектрограммы mel-частоты из аудио данных со сдвигом основного тона. Функция помощника, getSpeechSpectrogram
(см. приложение), стандартизирует длину записи и нормализует амплитуду аудиовхода. getSpeechSpectrogram
использует audioFeatureExtractor
объект (afe) для получения логарифмических мел-частотных спектрограмм.
adsSpecTrain = transform(adsAugTrain, @(x)getSpeechSpectrogram(x,afe,params));
Использование arrayDatastore
для хранения обучающих меток.
labelsTrain = arrayDatastore(adsTrain.Labels);
Создайте комбинированный datastore, который указывает на данные спектрограммы мел-частоты и соответствующие метки.
tdsTrain = combine(adsSpecTrain,labelsTrain);
Набор данных валидации помещается в память, и вы предварительно вычисляете функции валидации с помощью функции helper getValidationSpeechSpectrograms
(см. приложение).
XTest = getValidationSpeechSpectrograms(adsTest,afe,params);
Получите метки валидации.
YTest = adsTest.Labels;
Создайте небольшой CNN как массив слоев. Используйте сверточные и пакетные слои нормализации и понижающее отображение карт признаков с помощью максимальных слоев объединения. Чтобы уменьшить возможность запоминания сетью специфических функций обучающих данных, добавьте небольшое количество отсева на вход к последнему полностью подключенному слою.
sz = size(XTest); specSize = sz(1:2); imageSize = [specSize 1]; numClasses = numel(categories(YTest)); dropoutProb = 0.2; numF = 12; layers = [ imageInputLayer(imageSize,'Normalization','none') convolution2dLayer(5,numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,2*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(3,4*numF,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2) dropoutLayer(dropoutProb) fullyConnectedLayer(numClasses) softmaxLayer classificationLayer('Classes',categories(YTest)); ];
Установите гиперпараметры, которые будут использоваться при обучении сети. Используйте мини-пакет размером 50 и скоростью обучения 1e-4. Задайте оптимизацию 'adam'. Чтобы использовать параллельный пул для чтения преобразованного datastore, установите DispatchInBackground
на true
. Для получения дополнительной информации смотрите trainingOptions
(Deep Learning Toolbox).
miniBatchSize = 50; options = trainingOptions('adam', ... 'InitialLearnRate',1e-4, ... 'MaxEpochs',30, ... 'LearnRateSchedule',"piecewise",... 'LearnRateDropFactor',.1,... 'LearnRateDropPeriod',15,... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'Plots','training-progress', ... 'Verbose',false, ... 'ValidationData',{XTest, YTest},... 'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),... 'ExecutionEnvironment','gpu',... 'DispatchInBackground', true);
Обучите сеть, передав преобразованный обучающий datastore в trainNetwork
.
trainedNet = trainNetwork(tdsTrain,layers,options);
Используйте обученную сеть, чтобы предсказать метки цифр для тестового набора.
[Ypredicted,probs] = classify(trainedNet,XTest); cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 96.2500
Суммируйте эффективность обученной сети на тестовом наборе с графиком неточностей. Отображение точности и отзыва для каждого класса с помощью сводных данных по столбцам и строкам. Таблица в нижней части графика неточностей показывает значения точности. Таблица справа от графика неточностей показывает значения отзыва.
figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]); ccDCNN = confusionchart(YTest,Ypredicted); ccDCNN.Title = 'Confusion Chart for DCNN'; ccDCNN.ColumnSummary = 'column-normalized'; ccDCNN.RowSummary = 'row-normalized';
function Labels = helpergenLabels(ads) % This function is only for use in this example. It may be changed or % removed in a future release. files = ads.Files; tmp = cell(numel(files),1); expression = "[0-9]+_"; parfor nf = 1:numel(ads.Files) idx = regexp(files{nf},expression); tmp{nf} = files{nf}(idx); end Labels = categorical(tmp); end %------------------------------------------------------------ function X = getValidationSpeechSpectrograms(ads,afe,params) % This function is only for use in this example. It may changed or be % removed in a future release. % % getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for % the files in the datastore ads using the audioFeatureExtractor afe. numFiles = length(ads.Files); X = zeros([params.numBands,params.numHops,1,numFiles],'single'); for i = 1:numFiles x = read(ads); spec = getSpeechSpectrogram(x,afe,params); X(:,:,1,i) = spec; end end %-------------------------------------------------------------------------- function X = getSpeechSpectrogram(x,afe,params) % This function is only for use in this example. It may changed or be % removed in a future release. % % getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal % x using the audioFeatureExtractor afe. X = zeros([params.numBands,params.numHops],'single'); x = normalizeAndResize(single(x),params); spec = extract(afe,x).'; % If the spectrogram is less wide than numHops, then put spectrogram in % the middle of X. w = size(spec,2); left = floor((params.numHops-w)/2)+1; ind = left:left+w-1; X(:,ind) = log10(spec + 1e-6); end %-------------------------------------------------------------------------- function x = normalizeAndResize(x,params) % This function is only for use in this example. It may change or be % removed in a future release. L = params.segmentLength; N = numel(x); if N > L x = x(1:L); elseif N < L pad = L-N; prepad = floor(pad/2); postpad = ceil(pad/2); x = [zeros(prepad,1) ; x ; zeros(postpad,1)]; end x = x./max(abs(x)); end