Обучите разговорную сеть распознавания цифры Используя аудиоданные из памяти

Этот пример обучает разговорную сеть распознавания цифры на аудиоданных из памяти с помощью преобразованного datastore. В этом примере вы применяетесь, случайный сдвиг тангажа на аудиоданные раньше обучал сверточную нейронную сеть (CNN). Для каждой учебной итерации аудиоданные увеличиваются с помощью объекта audioDataAugmenter и затем показывают, извлечены с помощью объекта audioFeatureExtractor. Рабочий процесс в этом примере применяется к любому случайному увеличению данных, используемому в учебном цикле. Рабочий процесс также применяется, когда базовый набор аудиоданных или учебные функции не умещаются в памяти.

Данные

Загрузите Свободный Разговорный Набор данных Цифры (FSDD). FSDD состоит из 2 000 записей четырех динамиков, говорящих числа 0 до 9 на английском языке.

url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip';

downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,'FSDD');

if ~exist(datasetFolder,'dir')
    disp('Downloading FSDD...')
    unzip(url,downloadFolder)
end

Создайте audioDatastore это указывает на набор данных.

pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);

Функция помощника, helperGenerateLabels, создает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels перечислен в приложении. Отобразите классы и количество примеров в каждом классе.

ads.Labels = helpergenLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 8).
summary(ads.Labels)
     0      200 
     1      200 
     2      200 
     3      200 
     4      200 
     5      200 
     6      200 
     7      200 
     8      200 
     9      200 

Разделите FSDD в наборы обучающих данных и наборы тестов. Выделите 80% данных к набору обучающих данных и сохраните 20% для набора тестов. Вы используете набор обучающих данных, чтобы обучить модель и набор тестов подтверждать обученную модель.

rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
    Label    Count
    _____    _____

      0       160 
      1       160 
      2       160 
      3       160 
      4       160 
      5       160 
      6       160 
      7       160 
      8       160 
      9       160 

countEachLabel(adsTest)
ans=10×2 table
    Label    Count
    _____    _____

      0       40  
      1       40  
      2       40  
      3       40  
      4       40  
      5       40  
      6       40  
      7       40  
      8       40  
      9       40  

Уменьшайте обучающий набор данных

Чтобы обучить сеть с набором данных в целом и достигнуть максимально возможной точности, установите reduceDataset ко лжи. Чтобы запустить этот пример быстро, установите reduceDataset к истине.

reduceDataset = "false";
if reduceDataset == "true"
    adsTrain = splitEachLabel (adsTrain, 2);
    adsTest = splitEachLabel (adsTest, 2);
end

Преобразованный учебный Datastore

Увеличение данных

Увеличьте обучающие данные путем применения перемены тангажа с audioDataAugmenter объект.

Создайте audioDataAugmenter. Увеличение применяет перемену тангажа на входной звуковой сигнал с 0,5 вероятностями. Увеличение выбирает случайное значение сдвига тангажа в области значений [-12 12] полутоны.

augmenter = audioDataAugmenter('PitchShiftProbability',.5, ...
                               'SemitoneShiftRange',[-12 12], ...
                               'TimeShiftProbability',0, ...
                               'VolumeControlProbability',0, ...
                               'AddNoiseProbability',0, ...
                               'TimeShiftProbability',0);

Установите пользовательские переключающие тангаж параметры. Используйте единичную блокировку фазы и сохраните форманты с помощью спектральной оценки конверта с 30-м кепстральным анализом порядка.

setAugmenterParams(augmenter,'shiftPitch','LockPhase',true,'PreserveFormants',true,'CepstralOrder',30);

Создайте преобразованный datastore, который применяет увеличение данных к обучающим данным.

fs = 8000;
adsAugTrain = transform(adsTrain,@(y)deal(augment(augmenter,y,fs).Audio{1}));

Извлечение признаков спектрограммы Мэла

CNN принимает спектрограммы mel-частоты.

Задайте параметры, используемые, чтобы извлечь спектрограммы mel-частоты. Используйте 220 MS Windows с транзитными участками на 10 мс между окнами. Используйте ДПФ с 2048 точками и 40 диапазонов частот.

frameDuration = 0.22;
hopDuration = 0.01;
params.segmentLength = 8192;
segmentDuration = params.segmentLength*(1/fs);
params.numHops = ceil((segmentDuration-frameDuration)/hopDuration);
params.numBands = 40;
frameLength = round(frameDuration*fs);
hopLength = round(hopDuration*fs);
fftLength = 2048;

Создайте audioFeatureExtractor объект вычислить спектрограммы mel-частоты из входных звуковых сигналов.

afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs);
afe.Window = hamming(frameLength,'periodic');
afe.OverlapLength = frameLength-hopLength;
afe.FFTLength = fftLength;

Установите параметры для спектрограммы mel-частоты.

setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);

Создайте преобразованный datastore, который вычисляет спектрограммы mel-частоты из переключенных тангажом аудиоданных. Функция помощника, getSpeechSpectrogram (см. приложение), стандартизирует продолжительность записи и нормирует амплитуду аудиовхода. getSpeechSpectrogram использует audioFeatureExtractor объект (afe), чтобы получить основанные на журнале спектрограммы mel-частоты.

adsSpecTrain = transform(adsAugTrain, @(x)getSpeechSpectrogram(x,afe,params));

Учебные метки

Используйте arrayDatastore содержать учебные метки.

labelsTrain = arrayDatastore(adsTrain.Labels);

Объединенный учебный Datastore

Создайте объединенный datastore, который указывает на данные о спектрограмме mel-частоты и соответствующие метки.

tdsTrain = combine(adsSpecTrain,labelsTrain);

Данные о валидации

Набор данных валидации помещается в память, и вы предварительно вычисляете функции валидации с помощью функции помощника getValidationSpeechSpectrograms (см. приложение).

XTest = getValidationSpeechSpectrograms(adsTest,afe,params);

Получите метки валидации.

YTest = adsTest.Labels;

Задайте архитектуру CNN

Создайте маленький CNN как массив слоев. Используйте сверточный и слои нормализации партии. и проредите карты функции с помощью макс. слоев объединения. Чтобы уменьшать возможность сети, запоминая определенные функции обучающих данных, добавьте небольшое количество уволенного к входу к последнему полносвязному слою.

sz = size(XTest);
specSize = sz(1:2);
imageSize = [specSize 1];

numClasses = numel(categories(YTest));

dropoutProb = 0.2;
numF = 12;
layers = [
    imageInputLayer(imageSize,'Normalization','none')

    convolution2dLayer(5,numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(2)

    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer('Classes',categories(YTest));
    ];

Установите гиперпараметры использовать в обучении сети. Используйте мини-пакетный размер 50 и скорость обучения 1e-4. Задайте 'adam' оптимизацию. Чтобы использовать параллельный пул, чтобы считать преобразованный datastore, устанавливает DispatchInBackground к true. Для получения дополнительной информации смотрите trainingOptions (Deep Learning Toolbox).

miniBatchSize = 50;
options = trainingOptions('adam', ...
    'InitialLearnRate',1e-4, ...
    'MaxEpochs',30, ...
    'LearnRateSchedule',"piecewise",...
    'LearnRateDropFactor',.1,...
    'LearnRateDropPeriod',15,...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XTest, YTest},...
    'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),...
    'ExecutionEnvironment','gpu',...
    'DispatchInBackground', true);

Обучите сеть путем передачи преобразованного учебного datastore trainNetwork.

trainedNet = trainNetwork(tdsTrain,layers,options);

Используйте обучивший сеть, чтобы предсказать метки цифры для набора тестов.

[Ypredicted,probs] = classify(trainedNet,XTest);
cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 96.2500

Обобщите эффективность обучившего сеть на наборе тестов с графиком беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца. Таблица в нижней части графика беспорядка показывает значения точности. Таблица справа от графика беспорядка показывает значения отзыва.

figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]);
ccDCNN = confusionchart(YTest,Ypredicted);
ccDCNN.Title = 'Confusion Chart for DCNN';
ccDCNN.ColumnSummary = 'column-normalized';
ccDCNN.RowSummary = 'row-normalized';

Приложение: Функции помощника

function Labels = helpergenLabels(ads)
% This function is only for use in this example. It may be changed or
% removed in a future release.
files = ads.Files;
tmp = cell(numel(files),1);
expression = "[0-9]+_";
parfor nf = 1:numel(ads.Files)
    idx = regexp(files{nf},expression);
    tmp{nf} = files{nf}(idx);
end
Labels = categorical(tmp);
end

%------------------------------------------------------------
function X = getValidationSpeechSpectrograms(ads,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for
% the files in the datastore ads using the audioFeatureExtractor afe.

numFiles = length(ads.Files);
X = zeros([params.numBands,params.numHops,1,numFiles],'single');

for i = 1:numFiles
    x = read(ads);    
    spec = getSpeechSpectrogram(x,afe,params);    
    X(:,:,1,i) = spec;
    
end

end

%--------------------------------------------------------------------------
function X = getSpeechSpectrogram(x,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal
% x using the audioFeatureExtractor afe.

X = zeros([params.numBands,params.numHops],'single');

x = normalizeAndResize(single(x),params);

spec = extract(afe,x).';

% If the spectrogram is less wide than numHops, then put spectrogram in
% the middle of X.
w = size(spec,2);
left = floor((params.numHops-w)/2)+1;
ind = left:left+w-1;
X(:,ind) = log10(spec + 1e-6);

end
%--------------------------------------------------------------------------
function x = normalizeAndResize(x,params)
% This function is only for use in this example. It may change or be
% removed in a future release.

L = params.segmentLength;
N = numel(x);
if N > L
    x = x(1:L);
elseif N < L
    pad = L-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = x./max(abs(x));
end