Разговорное распознавание цифры с рассеиванием вейвлета и глубоким обучением

В этом примере показано, как классифицировать разговорные цифры с помощью рассеивания времени вейвлета, соединенного с машиной опорных векторов и глубокой сверточной сетью на основе спектрограмм mel-частоты.

Данные

Клонируйте или загрузите Свободный разговорный набор данных цифры (FSDD), доступный по https://github.com/Jakobovski/free-spoken-digit-dataset. FSDD является открытым набором данных, что означает, что это может расти в зависимости от времени. Этот пример использует версию, фиксировавшую 01/29/2019, который состоит из 2 000 записей английских цифр 0 через 9 полученных от четырех докладчиков. Два из докладчиков в этой версии являются носителями американского варианта английского языка, и два докладчика являются ненативными докладчиками английского языка с французским Бельгии и немецким диакритическим знаком. Данные производятся на уровне 8 000 Гц.

Используйте audioDatastore управлять доступом к данным и гарантировать случайное деление записей в наборы обучающих данных и наборы тестов. Установите location свойство к местоположению папки записей FSDD на вашем компьютере, например:

pathToRecordingsFolder = '/home/user/free-spoken-digit-dataset/recordings';
location = pathToRecordingsFolder;

Укажите audioDatastore к тому местоположению.

ads = audioDatastore(location);

Функция помощника helpergenLabels, заданный в конце этого примера, создает категориальный массив меток из файлов FSDD. Перечислите классы и количество примеров в каждом классе.

ads.Labels = helpergenLabels(ads);
summary(ads.Labels)
     1      200 
     0      200 
     2      200 
     3      200 
     4      200 
     5      200 
     6      200 
     7      200 
     8      200 
     9      200 

Набор данных FSDD состоит из 10 сбалансированных классов с 200 записями каждый. Записи в FSDD не имеют равной длительности. Поскольку FSDD не является предельно большим, прочитывает файлы FSDD и создает гистограмму из длин сигнала.

LenSig = zeros(numel(ads.Files),1);
nr = 1;
while hasdata(ads)
    digit = read(ads);
    LenSig(nr) = numel(digit);
    nr = nr+1;
end
reset(ads)
histogram(LenSig)
grid on
xlabel('Signal Length (Samples)')
ylabel('Frequency')

Гистограмма показывает, что распределение записи длин положительно скашивается. Для классификации этот пример использует общую длину сигнала 8 192 выборок. Значение 8 192 было выбрано, чтобы консервативно гарантировать, что усечение более длительных записей не влияло (отключает) речевое содержимое. Если сигнал больше 8 192 выборок, или 1,024 секунды, в длине, мы обрезаем запись до 8 192 выборок. Если сигнал меньше 8 192 выборок в длине, мы симметрично предварительно заполняем и постзаполняем сигнал нулями к продолжительности 8 192 выборок.

Время вейвлета, рассеиваясь

Используйте waveletScattering создать время вейвлета, рассеивая среду с помощью инвариантной шкалы 0,22 секунд. Поскольку мы создадим характеристические векторы путем усреднения рассеивающегося преобразования по всем выборкам времени, установите OversamplingFactor к 2, приводя к четырехкратному увеличению количества рассеивающихся коэффициентов для каждого пути относительно критически прореженного значения.

sf = waveletScattering('SignalLength',8192,'InvarianceScale',0.22,...
    'SamplingFrequency',8000,'OversamplingFactor',2);

Разделите FSDD в наборы обучающих данных и наборы тестов. Выделите 80% данных к набору обучающих данных и сохраните 20% для набора тестов. Обучающие данные являются для обучения классификатором на основе рассеивающегося преобразования. Тестовые данные для проверки модели.

rng default;
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
    Label    Count
    _____    _____

      0       160 
      1       160 
      2       160 
      3       160 
      4       160 
      5       160 
      6       160 
      7       160 
      8       160 
      9       160 

countEachLabel(adsTest)
ans=10×2 table
    Label    Count
    _____    _____

      0       40  
      1       40  
      2       40  
      3       40  
      4       40  
      5       40  
      6       40  
      7       40  
      8       40  
      9       40  

Сформируйтесь 8192 1600 матрица, где каждый столбец является разговорной разрядной записью. Функция помощника helperReadSPData обрезает или заполняет данные к длине 8192 и нормирует каждую запись ее максимальным значением.

Xtrain = [];
scatds_Train = transform(adsTrain,@(x)helperReadSPData(x));
while hasdata(scatds_Train)
    smat = read(scatds_Train);
    Xtrain = cat(2,Xtrain,smat);
    
end

Повторите процесс для протянутого набора тестов. Получившаяся матрица 8192 400.

Xtest = [];
scatds_Test = transform(adsTest,@(x)helperReadSPData(x));
while hasdata(scatds_Test)
    smat = read(scatds_Test);
    Xtest = cat(2,Xtest,smat);
    
end

Применяйте преобразование рассеивания вейвлета к наборам обучающих данных и наборам тестов.

Strain = sf.featureMatrix(Xtrain);
Stest = sf.featureMatrix(Xtest);

Получите средние функции рассеивания наборов обучающих данных и наборов тестов. Исключите 0-th коэффициенты рассеивания порядка.

TrainFeatures = Strain(2:end,:,:);
TrainFeatures = squeeze(mean(TrainFeatures,2))';
TestFeatures = Stest(2:end,:,:);
TestFeatures = squeeze(mean(TestFeatures,2))';

Этот пример использует классификатор машины опорных векторов (SVM) с рассеивающимися функциями. SVM использует квадратичное полиномиальное ядро.

template = templateSVM(...
    'KernelFunction', 'polynomial', ...
    'PolynomialOrder', 2, ...
    'KernelScale', 'auto', ...
    'BoxConstraint', 1, ...
    'Standardize', true);
classificationSVM = fitcecoc(...
    TrainFeatures, ...
    adsTrain.Labels, ...
    'Learners', template, ...
    'Coding', 'onevsone', ...
    'ClassNames', categorical({'1'; '0'; '2'; '3'; '4'; '5'; '6'; '7'; '8'; '9'}));

Используйте перекрестную проверку k-сгиба, чтобы предсказать точность обобщения основанного на модели на обучающих данных. Разделите набор обучающих данных в пять групп.

partitionedModel = crossval(classificationSVM, 'KFold', 5);
[validationPredictions, validationScores] = kfoldPredict(partitionedModel);
validationAccuracy = (1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError'))*100
validationAccuracy = 96.8750

Предполагаемая точность обобщения составляет приблизительно 97%. Используйте обученный SVM, чтобы предсказать разговорные разрядные классы в протянутом наборе тестов.

predLabels = predict(classificationSVM,TestFeatures);
testAccuracy = sum(predLabels==adsTest.Labels)/numel(predLabels)*100
testAccuracy = 98

Обобщите производительность модели на наборе тестов с графиком беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца. Таблица в нижней части графика беспорядка показывает значения точности для каждого класса. Таблица справа от графика беспорядка показывает значения отзыва.

figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]);
ccscat = confusionchart(adsTest.Labels,predLabels);
ccscat.Title = 'Wavelet Scattering Classification';
ccscat.ColumnSummary = 'column-normalized';
ccscat.RowSummary = 'row-normalized';

Рассеивающееся преобразование вместе с классификатором SVM классифицирует разговорные цифры на набор тестов с процентом точности 98 или коэффициентом ошибок 2%.

Глубоко сверточная сеть Используя спектрограммы Mel-частоты

Как другой подход к задаче разговорного распознавания цифры, используйте глубокую сверточную нейронную сеть (DCNN) на основе спектрограмм mel-частоты, чтобы классифицировать набор данных FSDD. Используйте ту же процедуру усечения/дополнения сигнала в качестве в рассеивающемся преобразовании. Точно так же нормируйте каждую запись путем деления каждой выборки сигнала на максимальное абсолютное значение. Для непротиворечивости используйте те же наборы обучающих данных и наборы тестов что касается рассеивающегося преобразования.

Установите параметры для спектрограмм mel-частоты. Используйте то же окно или систему координат, длительность как в рассеивающемся преобразовании, 0,22 секунды. Установите транзитный участок между окнами к 10 мс. Используйте 40 диапазонов частот.

segmentDuration = 8192*(1/8000);
frameDuration = 0.22;
hopDuration = 0.01;
numBands = 40;

Сбросьте обучение и протестируйте хранилища данных.

reset(adsTrain);
reset(adsTest);

Функция помощника helperspeechSpectrograms, заданный в конце этого примера, melSpectrogram использования получить спектрограмму mel-частоты после стандартизации продолжительности записи и нормализации амплитуды. Используйте логарифм спектрограмм mel-частоты как входные параметры к DCNN. Чтобы постараться не взять логарифм нуля, добавьте маленький эпсилон в каждый элемент.

epsil = 1e-6;
XTrain = helperspeechSpectrograms(adsTrain,segmentDuration,frameDuration,hopDuration,numBands);
Computing speech spectrograms...
Processed 500 files out of 1600
Processed 1000 files out of 1600
Processed 1500 files out of 1600
...done
XTrain = log10(XTrain + epsil);

XTest = helperspeechSpectrograms(adsTest,segmentDuration,frameDuration,hopDuration,numBands);
Computing speech spectrograms...
...done
XTest = log10(XTest + epsil);

YTrain = adsTrain.Labels;
YTest = adsTest.Labels;

Задайте архитектуру DCNN

Создайте маленький DCNN как массив слоев. Используйте сверточные и пакетные слои нормализации и проредите карты функции с помощью макс. слоев объединения. Чтобы уменьшать возможность сети, запоминая определенные функции обучающих данных, добавьте небольшое количество уволенного к входу к последнему полносвязному слою.

sz = size(XTrain);
specSize = sz(1:2);
imageSize = [specSize 1];

numClasses = numel(categories(YTrain));

dropoutProb = 0.2;
numF = 12;
layers = [
    imageInputLayer(imageSize)

    convolution2dLayer(5,numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(2)

    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer('Classes',categories(YTrain));
    ];

Установите гиперпараметры использовать в обучении сети. Используйте мини-пакетный размер 50 и темп обучения 1e-4. Задайте оптимизацию Адама. Поскольку объем данных в этом примере относительно мал, установите среду выполнения на 'cpu' для воспроизводимости. Можно также обучить сеть на доступном графическом процессоре путем установки среды выполнения на любой 'gpu' или 'auto'. Для получения дополнительной информации смотрите trainingOptions.

miniBatchSize = 50;
options = trainingOptions('adam', ...
    'InitialLearnRate',1e-4, ...
    'MaxEpochs',30, ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ExecutionEnvironment','cpu');

Обучите сеть.

trainedNet = trainNetwork(XTrain,YTrain,layers,options);

Используйте обучивший сеть, чтобы предсказать метки цифры для набора тестов.

[Ypredicted,probs] = classify(trainedNet,XTest,'ExecutionEnvironment','CPU');
cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 97.7500

Обобщите производительность обучившего сеть на наборе тестов с графиком беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца. Таблица в нижней части графика беспорядка показывает значения точности. Таблица справа от графика беспорядка показывает значения отзыва.

figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]);
ccDCNN = confusionchart(YTest,Ypredicted);
ccDCNN.Title = 'Confusion Chart for DCNN';
ccDCNN.ColumnSummary = 'column-normalized';
ccDCNN.RowSummary = 'row-normalized';

DCNN использование спектрограмм mel-частоты как входные параметры классифицирует разговорные цифры на набор тестов со степенью точности приблизительно 98% также.

Сводные данные

В этом примере показано, как использовать два разных подхода в классификации разговорных цифр в FSDD. Цель примера состоит в том, чтобы просто продемонстрировать, как использовать инструменты MathWorks™, чтобы приблизиться к проблеме двумя существенно различными, но дополнительными способами. Оба рабочих процесса используют audioDatastore управлять потоком данных из диска и гарантировать соответствующую рандомизацию.

Один подход изучения использует рассеивание времени вейвлета, соединенное с классификатором машины опорных векторов. Другое изучение приближается к спектрограммам mel-частоты использования как к входным параметрам к DCNN. Оба подхода выполняют хорошо на наборе тестов. Этот пример не предназначается как прямое сравнение между двумя подходами. И с методами, можно попробовать различные гиперпараметры и с архитектуру, которая может значительно влиять на результаты. В случае подхода спектрограммы mel-частоты можно экспериментировать с различной параметризацией спектрограммы mel-частоты, а также изменений в слоях DCNN, включая добавляющие слои. Дополнительная стратегия, которая полезна в глубоком обучении для небольших наборов обучающих данных как эта версия FSDD, состоит в том, чтобы использовать увеличение данных. То, как манипуляции влияют на класс, не всегда известно, таким образом, увеличение данных не всегда выполнимо. Однако для речи, установленные стратегии увеличения данных доступны через audioDataAugmenter.

В случае времени вейвлета, рассеиваясь, существует также много модификаций, которые можно попробовать. Например, можно изменить инвариантную шкалу преобразования, варьироваться количество фильтров вейвлета на набор фильтров и попробовать различные классификаторы.

Приложение: Функции помощника

function Labels = helpergenLabels(ads)
% This function is only for use in the
% "Spoken Digit Recognition with Wavelet Scattering and Deep Learning"
% example. It may change or be removed in a future release.

Labels = categorical(numel(ads.Files),1);
expression = "[0-9]+_";
for nf = 1:numel(ads.Files)
    idx = regexp(ads.Files{nf},expression);
    Labels(nf) = categorical(str2double(ads.Files{nf}(idx)));
end

end

function x = helperReadSPData(x)
% This function is only for use Wavelet Toolbox examples. It may change or
% be removed in a future release.

N = numel(x);
if N > 8192
    x = x(1:8192);
elseif N < 8192
    pad = 8192-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = x./max(abs(x));

end

function X = helperspeechSpectrograms(ads,segmentDuration,frameDuration,hopDuration,numBands)
% This function is only for use in the 
% "Spoken Digit Recognition with Wavelet Scattering and Deep Learning"
% example. It may change or be removed in a future release.
%
% helperspeechSpectrograms(ads,segmentDuration,frameDuration,hopDuration,numBands)
% computes speech spectrograms for the files in the datastore ads.
% segmentDuration is the total duration of the speech clips (in seconds),
% frameDuration the duration of each spectrogram frame, hopDuration the
% time shift between each spectrogram frame, and numBands the number of
% frequency bands.
disp("Computing speech spectrograms...");

numHops = ceil((segmentDuration - frameDuration)/hopDuration);
numFiles = length(ads.Files);
X = zeros([numBands,numHops,1,numFiles],'single');

for i = 1:numFiles
    
    [x,info] = read(ads);
    x = normalizeAndResize(x);
    fs = info.SampleRate;
    frameLength = round(frameDuration*fs);
    hopLength = round(hopDuration*fs);
    
    spec = melSpectrogram(x,fs, ...
        'WindowLength',frameLength, ...
        'OverlapLength',frameLength - hopLength, ...
        'FFTLength',2048, ...
        'NumBands',numBands, ...
        'FrequencyRange',[50,4000]);
    
    % If the spectrogram is less wide than numHops, then put spectrogram in
    % the middle of X.
    w = size(spec,2);
    left = floor((numHops-w)/2)+1;
    ind = left:left+w-1;
    X(:,ind,1,i) = spec;
    
    if mod(i,500) == 0
        disp("Processed " + i + " files out of " + numFiles)
    end
    
end

disp("...done");

end

%--------------------------------------------------------------------------
function x = normalizeAndResize(x)
% This function is only for use in the 
% "Spoken Digit Recognition with Wavelet Scattering and Deep Learning"
% example. It may change or be removed in a future release.

N = numel(x);
if N > 8192
    x = x(1:8192);
elseif N < 8192
    pad = 8192-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = x./max(abs(x));
end

Copyright 2018, The MathWorks, Inc.