Время вейвлета, рассеиваясь с ускорением графического процессора — разговорное распознавание цифры

В этом примере показано, как ускорить расчет вейвлета, рассеивающего функции с помощью gpuArray (Parallel Computing Toolbox). У вас должны быть Parallel Computing Toolbox™ и поддерживаемый графический процессор. Смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox) для деталей. Этот пример использует Титана NVIDIA, V графических процессоров с вычисляют возможность 7.0.

Этот пример воспроизводит версию ЦП рассеивающегося преобразования, найденного в Разговорном Распознавании Цифры с Рассеиванием Вейвлета и Глубоким обучением.

Данные

Клонируйте или загрузите Свободный разговорный набор данных цифры (FSDD), доступный в https://github.com/Jakobovski/free-spoken-digit-dataset. FSDD является открытым набором данных, что означает, что это может расти в зависимости от времени. Этот пример использует версию, фиксировавшую 08/20/2020, который состоит из 3 000 записей английских цифр 0 через 9 полученных от шести докладчиков. Данные производятся на уровне 8 000 Гц.

Используйте audioDatastore управлять доступом к данным и гарантировать случайное деление записей в наборы обучающих данных и наборы тестов. Установите location свойство к местоположению папки записей FSDD на вашем компьютере. В этом примере данные хранятся в папке под tempdir.

location = fullfile(tempdir,'free-spoken-digit-dataset','recordings');
ads = audioDatastore(location);

Функция помощника, helpergenLabels, заданный в конце этого примера, создает категориальный массив меток из файлов FSDD. Перечислите классы и количество примеров в каждом классе.

ads.Labels = helpergenLabels(ads);
summary(ads.Labels)
     0      300 
     1      300 
     2      300 
     3      300 
     4      300 
     5      300 
     6      300 
     7      300 
     8      300 
     9      300 

Набор данных FSDD состоит из 10 сбалансированных классов с 300 записями каждый. Записи в FSDD не имеют равной длительности. Прочитайте файлы FSDD и создайте гистограмму из длин сигнала.

LenSig = zeros(numel(ads.Files),1);
nr = 1;
while hasdata(ads)
    digit = read(ads);
    LenSig(nr) = numel(digit);
    nr = nr+1;
end
reset(ads)
histogram(LenSig)
grid on
xlabel('Signal Length (Samples)')
ylabel('Frequency')

Гистограмма показывает, что распределение записи длин положительно скашивается. Для классификации этот пример использует общую длину сигнала 8 192 выборок. Значение 8192, консервативный выбор, гарантирует, что усечение более длительных записей не влияет (отключает) речевое содержимое. Если сигнал больше 8 192 выборок, или 1,024 секунды, в длине, запись является усеченной к 8 192 выборкам. Если сигнал меньше 8 192 выборок в длине, сигнал симметрично предварительно ожидается и добавляется с нулями к продолжительности 8 192 выборок.

Время вейвлета, рассеиваясь

Создайте время вейвлета, рассеивая сеть с помощью инвариантной шкалы 0,22 секунд. Поскольку характеристические векторы будут созданы путем усреднения рассеивающегося преобразования по всем выборкам времени, установите OversamplingFactor к 2. Устанавливание значения к 2 приведет к четырехкратному увеличению количества рассеивающихся коэффициентов для каждого пути относительно критически прореженного значения.

sn = waveletScattering('SignalLength',8192,'InvarianceScale',0.22,...
    'SamplingFrequency',8000,'OversamplingFactor',2);

Разделите FSDD в наборы обучающих данных и наборы тестов. Выделите 80% данных к набору обучающих данных и сохраните 20% для набора тестов. Обучающие данные являются для обучения классификатором на основе рассеивающегося преобразования. Тестовые данные для проверки модели.

rng default;
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
summary(adsTrain.Labels)
     0      240 
     1      240 
     2      240 
     3      240 
     4      240 
     5      240 
     6      240 
     7      240 
     8      240 
     9      240 
summary(adsTest.Labels)
     0      60 
     1      60 
     2      60 
     3      60 
     4      60 
     5      60 
     6      60 
     7      60 
     8      60 
     9      60 

Сформируйтесь 8192 2400 матрица, где каждый столбец является разговорной разрядной записью. Функция помощника helperReadSPData обрезает или заполняет данные к длине 8192 и нормирует каждую запись на ее максимальное значение.

Xtrain = [];
scatds_Train = transform(adsTrain,@(x)helperReadSPData(x));
while hasdata(scatds_Train)
    smat = read(scatds_Train);
    Xtrain = cat(2,Xtrain,smat);
    
end

Повторите процесс для протянутого набора тестов. Получившаяся матрица 8192 600.

Xtest = [];
scatds_Test = transform(adsTest,@(x)helperReadSPData(x));
while hasdata(scatds_Test)
    smat = read(scatds_Test);
    Xtest = cat(2,Xtest,smat);
    
end

Применяйте рассеивающееся преобразование к наборам обучающих данных и наборам тестов. Использование gpuArray с CUDA-поддерживающим NVIDIA графический процессор обеспечивает значительное ускорение. С этой сетью рассеивания, пакетным размером и графическим процессором, реализация графического процессора вычисляет рассеивающиеся функции приблизительно в 15 раз быстрее, чем версия ЦП.

Strain = sn.featureMatrix(Xtrain);
Stest = sn.featureMatrix(Xtest);

Получите рассеивающиеся функции наборов обучающих данных и наборов тестов.

TrainFeatures = Strain(2:end,:,:);
TrainFeatures = squeeze(mean(TrainFeatures,2))';
TestFeatures = Stest(2:end,:,:);
TestFeatures = squeeze(mean(TestFeatures,2))';

Этот пример использует классификатор машины опорных векторов (SVM) с квадратичным полиномиальным ядром. Подбирайте модель SVM к рассеивающимся функциям.

template = templateSVM(...
    'KernelFunction', 'polynomial', ...
    'PolynomialOrder', 2, ...
    'KernelScale', 'auto', ...
    'BoxConstraint', 1, ...
    'Standardize', true);
classificationSVM = fitcecoc(...
    TrainFeatures, ...
    adsTrain.Labels, ...
    'Learners', template, ...
    'Coding', 'onevsone', ...
    'ClassNames', categorical({'0'; '1'; '2'; '3'; '4'; '5'; '6'; '7'; '8'; '9'}));

Используйте перекрестную проверку k-сгиба, чтобы предсказать точность обобщения модели. Разделите набор обучающих данных в пять групп для перекрестной проверки.

partitionedModel = crossval(classificationSVM, 'KFold', 5);
[validationPredictions, validationScores] = kfoldPredict(partitionedModel);
validationAccuracy = (1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError'))*100
validationAccuracy = 97.4167

Предполагаемая точность обобщения составляет приблизительно 97%. Теперь используйте модель SVM, чтобы предсказать протянутый набор тестов.

predLabels = predict(classificationSVM,TestFeatures);
testAccuracy = sum(predLabels==adsTest.Labels)/numel(predLabels)*100
testAccuracy = 97.1667

Точность - также приблизительно 97% на протянутом наборе тестов.

Обобщите эффективность модели на наборе тестов с графиком беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца. Таблица в нижней части графика беспорядка показывает значения точности для каждого класса. Таблица справа от графика беспорядка показывает значения отзыва.

figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]);
ccscat = confusionchart(adsTest.Labels,predLabels);
ccscat.Title = 'Wavelet Scattering Classification';
ccscat.ColumnSummary = 'column-normalized';
ccscat.RowSummary = 'row-normalized';

Как итоговый пример, считайте первые две записи из набора данных, вычислите рассеивающиеся функции и предскажите разговорную цифру с помощью SVM, обученного с рассеиванием функций.

reset(ads);
sig1 = helperReadSPData(read(ads));
scat1 = sn.featureMatrix(sig1);
scat1 = mean(scat1(2:end,:),2)';
plab1 = predict(classificationSVM,scat1);

Считайте следующую запись и предскажите цифру.

sig2 = helperReadSPData(read(ads));
scat2 = sn.featureMatrix(sig2);
scat2 = mean(scat2(2:end,:),2)';
plab2 = predict(classificationSVM,scat2);
t = 0:1/8000:(8192*1/8000)-1/8000;
plot(t,[sig1 sig2])
grid on
axis tight
legend(char(plab1),char(plab2))
title('Spoken Digit Prediction - GPU')

Приложение

Следующие функции помощника используются в этом примере.

helpergenLabels — генерирует метки на основе имен файлов в FSDD.

function Labels = helpergenLabels(ads)
% This function is only for use in Wavelet Toolbox examples. It may be
% changed or removed in a future release.
tmp = cell(numel(ads.Files),1);
expression = "[0-9]+_";
for nf = 1:numel(ads.Files)
    idx = regexp(ads.Files{nf},expression);
    tmp{nf} = ads.Files{nf}(idx);
end
Labels = categorical(tmp);

end

helperReadSPData — Гарантирует, что каждая разговорная разрядная запись является 8 192 выборками долго.

function x = helperReadSPData(x)
% This function is only for use Wavelet Toolbox examples. It may change or
% be removed in a future release.
N = numel(x);
if N > 8192
    x = x(1:8192);
elseif N < 8192
    pad = 8192-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = x./max(abs(x));

end

Смотрите также

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте