В этом примере показано, как ускорить расчет вейвлета, рассеивающего функции с помощью gpuArray
и параллелизированная версия "в глубину" времени вейвлета, рассеиваясь. У вас должен быть CUDA-поддерживающий NVIDIA, графический процессор с вычисляет возможность 3.0 или выше. Смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox) для деталей. Этот пример использует XP Титана NVIDIA, графический процессор с вычисляет возможность 6.1.
Этот пример воспроизводит версию ЦП рассеивающегося преобразования, найденного в Разговорном Распознавании Цифры с Рассеиванием Вейвлета и Глубоким обучением.
Клонируйте или загрузите Свободный разговорный набор данных цифры (FSDD), доступный в https://github.com/Jakobovski/free-spoken-digit-dataset
. FSDD является открытым набором данных, что означает, что это может расти в зависимости от времени. Этот пример использует версию, фиксировавшую 01/29/2019, который состоит из 2 000 записей английских цифр 0 через 9 полученных от четырех докладчиков. Два из докладчиков в этой версии являются носителями американского варианта английского языка, и два докладчика являются ненативными докладчиками английского языка с французским Бельгии и немецким диакритическим знаком соответственно. Данные производятся на уровне 8 000 Гц.
Используйте audioDatastore
управлять доступом к данным и гарантировать случайное деление записей в наборы обучающих данных и наборы тестов. Установите location
свойство к местоположению папки записей FSDD на вашем компьютере.
location = fullfile(tempdir,'free-spoken-digit-dataset','recordings'); ads = audioDatastore(location);
Функция помощника, helpergenLabels
, заданный в конце этого примера, создает категориальный массив меток из файлов FSDD. Перечислите классы и количество примеров в каждом классе.
ads.Labels = helpergenLabels(ads); summary(ads.Labels)
0 200 1 200 2 200 3 200 4 200 5 200 6 200 7 200 8 200 9 200
Набор данных FSDD состоит из 10 сбалансированных классов с 200 записями каждый. Записи в FSDD не имеют равной длительности. Прочитайте файлы FSDD и создайте гистограмму из длин сигнала.
LenSig = zeros(numel(ads.Files),1); nr = 1; while hasdata(ads) digit = read(ads); LenSig(nr) = numel(digit); nr = nr+1; end reset(ads) histogram(LenSig) grid on xlabel('Signal Length (Samples)') ylabel('Frequency')
Гистограмма показывает, что распределение записи длин положительно скашивается. Для классификации этот пример использует общую длину сигнала 8 192 выборок. Значение 8192, консервативный выбор, гарантирует, что усечение более длительных записей не влияет (отключает) речевое содержимое. Если сигнал больше 8 192 выборок, или 1,024 секунды, в длине, запись является усеченной к 8 192 выборкам. Если сигнал меньше 8 192 выборок в длине, сигнал симметрично предварительно ожидается и добавляется с нулями к продолжительности 8 192 выборок.
Создайте время вейвлета, рассеивая среду с помощью инвариантной шкалы 0,22 секунд. Поскольку характеристические векторы будут созданы путем усреднения рассеивающегося преобразования по всем выборкам времени, установите OversamplingFactor
к 2. Устанавливание значения к 2 приведет к четырехкратному увеличению количества рассеивающихся коэффициентов для каждого пути относительно критически прореженного значения.
sf = helpergpuscat1('SignalLength',8192,'InvarianceScale',0.22,... 'SamplingFrequency',8000,'OversamplingFactor',2);
Разделите FSDD в наборы обучающих данных и наборы тестов. Выделите 80% данных к набору обучающих данных и сохраните 20% для набора тестов. Обучающие данные являются для обучения классификатором на основе рассеивающегося преобразования. Тестовые данные для проверки модели.
rng default;
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
Label Count
_____ _____
0 160
1 160
2 160
3 160
4 160
5 160
6 160
7 160
8 160
9 160
countEachLabel(adsTest)
ans=10×2 table
Label Count
_____ _____
0 40
1 40
2 40
3 40
4 40
5 40
6 40
7 40
8 40
9 40
Сформируйтесь 8192 1600 матрица, где каждый столбец является разговорной разрядной записью. Функция помощника, helperReadSPData,
обрезает или заполняет данные к длине 8192 и нормирует каждую запись ее максимальным значением.
Xtrain = []; scatds_Train = transform(adsTrain,@(x)helperReadSPData(x)); while hasdata(scatds_Train) smat = read(scatds_Train); Xtrain = cat(2,Xtrain,smat); end
Повторите процесс для протянутого набора тестов. Получившаяся матрица 8192 400.
Xtest = []; scatds_Test = transform(adsTest,@(x)helperReadSPData(x)); while hasdata(scatds_Test) smat = read(scatds_Test); Xtest = cat(2,Xtest,smat); end
Применяйте рассеивающееся преобразование к наборам обучающих данных и наборам тестов. Использование gpuArray
с CUDA-поддерживающим NVIDIA графический процессор обеспечивает значительное ускорение. С этой средой рассеивания, пакетным размером и графическим процессором, реализация графического процессора вычисляет рассеивающиеся функции приблизительно в 8 раз быстрее, чем версия ЦП.
Strain = sf.featureMatrix(Xtrain); Stest = sf.featureMatrix(Xtest);
Получите рассеивающиеся функции наборов обучающих данных и наборов тестов.
TrainFeatures = Strain(2:end,:,:); TrainFeatures = squeeze(mean(TrainFeatures,2))'; TestFeatures = Stest(2:end,:,:); TestFeatures = squeeze(mean(TestFeatures,2))';
Этот пример использует классификатор машины опорных векторов (SVM) с квадратичным полиномиальным ядром. Подбирайте модель SVM к рассеивающимся функциям.
template = templateSVM(... 'KernelFunction', 'polynomial', ... 'PolynomialOrder', 2, ... 'KernelScale', 'auto', ... 'BoxConstraint', 1, ... 'Standardize', true); classificationSVM = fitcecoc(... TrainFeatures, ... adsTrain.Labels, ... 'Learners', template, ... 'Coding', 'onevsone', ... 'ClassNames', categorical({'0'; '1'; '2'; '3'; '4'; '5'; '6'; '7'; '8'; '9'}));
Используйте перекрестную проверку k-сгиба, чтобы предсказать точность обобщения модели. Разделите набор обучающих данных в пять групп для перекрестной проверки.
partitionedModel = crossval(classificationSVM, 'KFold', 5); [validationPredictions, validationScores] = kfoldPredict(partitionedModel); validationAccuracy = (1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError'))*100
validationAccuracy = 96.8125
Предполагаемая точность обобщения составляет приблизительно 97%. Теперь используйте модель SVM, чтобы предсказать протянутый набор тестов.
predLabels = predict(classificationSVM,TestFeatures); testAccuracy = sum(predLabels==adsTest.Labels)/numel(predLabels)*100
testAccuracy = 97.7500
Тестовая точность составляет приблизительно 98% на протянутом наборе тестов.
Обобщите производительность модели на наборе тестов с графиком беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца. Таблица в нижней части графика беспорядка показывает значения точности для каждого класса. Таблица справа от графика беспорядка показывает значения отзыва.
figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]); ccscat = confusionchart(adsTest.Labels,predLabels); ccscat.Title = 'Wavelet Scattering Classification'; ccscat.ColumnSummary = 'column-normalized'; ccscat.RowSummary = 'row-normalized';
Как итоговый пример, считайте первые две записи из набора данных, вычислите рассеивающиеся функции и предскажите разговорную цифру с помощью SVM, обученного с рассеиванием функций.
reset(ads); sig1 = helperReadSPData(read(ads)); scat1 = sf.featureMatrix(sig1); scat1 = mean(scat1(2:end,:),2)'; plab1 = predict(classificationSVM,scat1);
Считайте следующую запись и предскажите цифру.
sig2 = helperReadSPData(read(ads)); scat2 = sf.featureMatrix(sig2); scat2 = mean(scat2(2:end,:),2)'; plab2 = predict(classificationSVM,scat2);
t = 0:1/8000:(8192*1/8000)-1/8000; plot(t,[sig1 sig2]) grid on axis tight legend(char(plab1),char(plab2)) title('Spoken Digit Prediction - GPU')
Следующие функции помощника используются в этом примере.
helpergenLabels — генерирует метки на основе имен файлов в FSDD.
function Labels = helpergenLabels(ads) % This function is only for use in Wavelet Toolbox examples. It may be % changed or removed in a future release. tmp = cell(numel(ads.Files),1); expression = "[0-9]+_"; for nf = 1:numel(ads.Files) idx = regexp(ads.Files{nf},expression); tmp{nf} = ads.Files{nf}(idx); end Labels = categorical(tmp); end
helperReadSPData — Гарантирует, что каждая разговорная разрядная запись является 8 192 выборками долго.
function x = helperReadSPData(x) % This function is only for use Wavelet Toolbox examples. It may change or % be removed in a future release. N = numel(x); if N > 8192 x = x(1:8192); elseif N < 8192 pad = 8192-N; prepad = floor(pad/2); postpad = ceil(pad/2); x = [zeros(prepad,1) ; x ; zeros(postpad,1)]; end x = x./max(abs(x)); end