Обучите сеть распознавания разговорных цифр с помощью Audio Данных

Этот пример обучает сеть распознавания разговорных цифр на аудио данных вне памяти с помощью преобразованного datastore. В этом примере вы применяете случайный тангаж сдвиг к аудио данных, используемым для обучения сверточной нейронной сети (CNN). Для каждой итерации обучения аудио данных дополняется с помощью объекта audioDataAugmenter, а затем функции извлекаются с помощью объекта audioFeatureExtractor. Рабочий процесс в этом примере применяется к любому случайному увеличению данных, используемому в цикле обучения. Рабочий процесс также применяется, когда базовый набор аудио данных или функции обучения не помещаются в памяти.

Данные

Загрузите бесплатный набор данных (FSDD). FSDD состоит из 2000 записей четырех ораторов, говорящих цифры от 0 до 9 на английском языке.

url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip';

downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,'FSDD');

if ~exist(datasetFolder,'dir')
    disp('Downloading FSDD...')
    unzip(url,downloadFolder)
end

Создайте audioDatastore это указывает на набор данных.

pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);

Функция помощника, helperGenerateLabelsсоздает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels приведено в приложении. Отображение классов и количества примеров в каждом классе.

ads.Labels = helpergenLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 8).
summary(ads.Labels)
     0      200 
     1      200 
     2      200 
     3      200 
     4      200 
     5      200 
     6      200 
     7      200 
     8      200 
     9      200 

Разделите FSDD на обучающие и тестовые наборы. Выделите 80% данных наборов обучающих данных и сохраните 20% для тестового набора. Вы используете набор обучающих данных для обучения модели и тестовый набор для проверки обученной модели.

rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
    Label    Count
    _____    _____

      0       160 
      1       160 
      2       160 
      3       160 
      4       160 
      5       160 
      6       160 
      7       160 
      8       160 
      9       160 

countEachLabel(adsTest)
ans=10×2 table
    Label    Count
    _____    _____

      0       40  
      1       40  
      2       40  
      3       40  
      4       40  
      5       40  
      6       40  
      7       40  
      8       40  
      9       40  

Сокращение набора обучающих наборов данных

Чтобы обучить сеть со набором данных в целом и достичь максимально возможной точности, установите reduceDataset на ложь. Чтобы запустить этот пример быстро, установите reduceDataset к true.

reduceDataset = "false";
if reduceDataset =  ="true"
    adsTrain = splitEachLabel (adsTrain, 2);
    adsTest = splitEachLabel (adsTest, 2);
end

Преобразованный обучающий Datastore

Увеличение количества данных

Увеличение обучающих данных путем применения сдвига тангажа с помощью audioDataAugmenter объект.

Создайте audioDataAugmenter. Усилитель применяет смещение тангажа к входу аудиосигналу с вероятностью 0,5. Augmenter выбирает случайное значение перемены тангажа в области значений [-12 12] полутонов.

augmenter = audioDataAugmenter('PitchShiftProbability',.5, ...
                               'SemitoneShiftRange',[-12 12], ...
                               'TimeShiftProbability',0, ...
                               'VolumeControlProbability',0, ...
                               'AddNoiseProbability',0, ...
                               'TimeShiftProbability',0);

Установите пользовательские параметры сдвига тангажа. Используйте тождества фазы блокировки и сохранения формантов, используя оценку спектральной огибающей с кепстральным анализом 30-го порядка.

setAugmenterParams(augmenter,'shiftPitch','LockPhase',true,'PreserveFormants',true,'CepstralOrder',30);

Создайте преобразованный datastore, который применяет увеличение данных к обучающим данным.

fs = 8000;
adsAugTrain = transform(adsTrain,@(y)deal(augment(augmenter,y,fs).Audio{1}));

Редукция данных Mel Spectrogram

CNN принимает мел-частотные спектрограммы.

Задайте параметры, используемые для извлечения мел-частотных спектрограмм. Используйте 220 мс окна с 10 мс хмеля между окнами. Используйте ДПФ с 2048 точками и 40 полосы частот.

frameDuration = 0.22;
hopDuration = 0.01;
params.segmentLength = 8192;
segmentDuration = params.segmentLength*(1/fs);
params.numHops = ceil((segmentDuration-frameDuration)/hopDuration);
params.numBands = 40;
frameLength = round(frameDuration*fs);
hopLength = round(hopDuration*fs);
fftLength = 2048;

Создайте audioFeatureExtractor объект для вычисления спектрограмм mel-частоты из входных аудиосигналов.

afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs);
afe.Window = hamming(frameLength,'periodic');
afe.OverlapLength = frameLength-hopLength;
afe.FFTLength = fftLength;

Установите параметры для спектрограммы мел-частоты.

setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);

Создайте преобразованный datastore, который вычисляет спектрограммы mel-частоты из аудио данных со сдвигом основного тона. Функция помощника, getSpeechSpectrogram (см. приложение), стандартизирует длину записи и нормализует амплитуду аудиовхода. getSpeechSpectrogram использует audioFeatureExtractor объект (afe) для получения логарифмических мел-частотных спектрограмм.

adsSpecTrain = transform(adsAugTrain, @(x)getSpeechSpectrogram(x,afe,params));

Обучающие метки

Использование arrayDatastore для хранения обучающих меток.

labelsTrain = arrayDatastore(adsTrain.Labels);

Комбинированное обучение Datastore

Создайте комбинированный datastore, который указывает на данные спектрограммы мел-частоты и соответствующие метки.

tdsTrain = combine(adsSpecTrain,labelsTrain);

Данные валидации

Набор данных валидации помещается в память, и вы предварительно вычисляете функции валидации с помощью функции helper getValidationSpeechSpectrograms (см. приложение).

XTest = getValidationSpeechSpectrograms(adsTest,afe,params);

Получите метки валидации.

YTest = adsTest.Labels;

Определение архитектуры CNN

Создайте небольшой CNN как массив слоев. Используйте сверточные и пакетные слои нормализации и понижающее отображение карт признаков с помощью максимальных слоев объединения. Чтобы уменьшить возможность запоминания сетью специфических функций обучающих данных, добавьте небольшое количество отсева на вход к последнему полностью подключенному слою.

sz = size(XTest);
specSize = sz(1:2);
imageSize = [specSize 1];

numClasses = numel(categories(YTest));

dropoutProb = 0.2;
numF = 12;
layers = [
    imageInputLayer(imageSize,'Normalization','none')

    convolution2dLayer(5,numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(2)

    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer('Classes',categories(YTest));
    ];

Установите гиперпараметры, которые будут использоваться при обучении сети. Используйте мини-пакет размером 50 и скоростью обучения 1e-4. Задайте оптимизацию 'adam'. Чтобы использовать параллельный пул для чтения преобразованного datastore, установите DispatchInBackground на true. Для получения дополнительной информации смотрите trainingOptions (Deep Learning Toolbox).

miniBatchSize = 50;
options = trainingOptions('adam', ...
    'InitialLearnRate',1e-4, ...
    'MaxEpochs',30, ...
    'LearnRateSchedule',"piecewise",...
    'LearnRateDropFactor',.1,...
    'LearnRateDropPeriod',15,...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XTest, YTest},...
    'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),...
    'ExecutionEnvironment','gpu',...
    'DispatchInBackground', true);

Обучите сеть, передав преобразованный обучающий datastore в trainNetwork.

trainedNet = trainNetwork(tdsTrain,layers,options);

Используйте обученную сеть, чтобы предсказать метки цифр для тестового набора.

[Ypredicted,probs] = classify(trainedNet,XTest);
cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 96.2500

Суммируйте эффективность обученной сети на тестовом наборе с графиком неточностей. Отображение точности и отзыва для каждого класса с помощью сводных данных по столбцам и строкам. Таблица в нижней части графика неточностей показывает значения точности. Таблица справа от графика неточностей показывает значения отзыва.

figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]);
ccDCNN = confusionchart(YTest,Ypredicted);
ccDCNN.Title = 'Confusion Chart for DCNN';
ccDCNN.ColumnSummary = 'column-normalized';
ccDCNN.RowSummary = 'row-normalized';

Приложение: Функции помощника

function Labels = helpergenLabels(ads)
% This function is only for use in this example. It may be changed or
% removed in a future release.
files = ads.Files;
tmp = cell(numel(files),1);
expression = "[0-9]+_";
parfor nf = 1:numel(ads.Files)
    idx = regexp(files{nf},expression);
    tmp{nf} = files{nf}(idx);
end
Labels = categorical(tmp);
end

%------------------------------------------------------------
function X = getValidationSpeechSpectrograms(ads,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for
% the files in the datastore ads using the audioFeatureExtractor afe.

numFiles = length(ads.Files);
X = zeros([params.numBands,params.numHops,1,numFiles],'single');

for i = 1:numFiles
    x = read(ads);    
    spec = getSpeechSpectrogram(x,afe,params);    
    X(:,:,1,i) = spec;
    
end

end

%--------------------------------------------------------------------------
function X = getSpeechSpectrogram(x,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal
% x using the audioFeatureExtractor afe.

X = zeros([params.numBands,params.numHops],'single');

x = normalizeAndResize(single(x),params);

spec = extract(afe,x).';

% If the spectrogram is less wide than numHops, then put spectrogram in
% the middle of X.
w = size(spec,2);
left = floor((params.numHops-w)/2)+1;
ind = left:left+w-1;
X(:,ind) = log10(spec + 1e-6);

end
%--------------------------------------------------------------------------
function x = normalizeAndResize(x,params)
% This function is only for use in this example. It may change or be
% removed in a future release.

L = params.segmentLength;
N = numel(x);
if N > L
    x = x(1:L);
elseif N < L
    pad = L-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = x./max(abs(x));
end