В этом примере показано, как ускорить расчет вейвлета, рассеивающего функции с помощью gpuArray
(Parallel Computing Toolbox). У вас должны быть Parallel Computing Toolbox™ и поддерживаемый графический процессор. Смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox) для деталей. Этот пример использует Титана NVIDIA, V графических процессоров с вычисляют возможность 7.0.
Этот пример воспроизводит версию ЦП рассеивающегося преобразования, найденного в Разговорном Распознавании Цифры с Рассеиванием Вейвлета и Глубоким обучением.
Клонируйте или загрузите Свободный разговорный набор данных цифры (FSDD), доступный в https://github.com/Jakobovski/free-spoken-digit-dataset
. FSDD является открытым набором данных, что означает, что это может расти в зависимости от времени. Этот пример использует версию, фиксировавшую 08/20/2020, который состоит из 3 000 записей английских цифр 0 через 9 полученных от шести докладчиков. Данные производятся на уровне 8 000 Гц.
Используйте audioDatastore
управлять доступом к данным и гарантировать случайное деление записей в наборы обучающих данных и наборы тестов. Установите location
свойство к местоположению папки записей FSDD на вашем компьютере. В этом примере данные хранятся в папке под tempdir
.
location = fullfile(tempdir,'free-spoken-digit-dataset','recordings'); ads = audioDatastore(location);
Функция помощника, helpergenLabels
, заданный в конце этого примера, создает категориальный массив меток из файлов FSDD. Перечислите классы и количество примеров в каждом классе.
ads.Labels = helpergenLabels(ads); summary(ads.Labels)
0 300 1 300 2 300 3 300 4 300 5 300 6 300 7 300 8 300 9 300
Набор данных FSDD состоит из 10 сбалансированных классов с 300 записями каждый. Записи в FSDD не имеют равной длительности. Прочитайте файлы FSDD и создайте гистограмму из длин сигнала.
LenSig = zeros(numel(ads.Files),1); nr = 1; while hasdata(ads) digit = read(ads); LenSig(nr) = numel(digit); nr = nr+1; end reset(ads) histogram(LenSig) grid on xlabel('Signal Length (Samples)') ylabel('Frequency')
Гистограмма показывает, что распределение записи длин положительно скашивается. Для классификации этот пример использует общую длину сигнала 8 192 выборок. Значение 8192, консервативный выбор, гарантирует, что усечение более длительных записей не влияет (отключает) речевое содержимое. Если сигнал больше 8 192 выборок, или 1,024 секунды, в длине, запись является усеченной к 8 192 выборкам. Если сигнал меньше 8 192 выборок в длине, сигнал симметрично предварительно ожидается и добавляется с нулями к продолжительности 8 192 выборок.
Создайте время вейвлета, рассеивая сеть с помощью инвариантной шкалы 0,22 секунд. Поскольку характеристические векторы будут созданы путем усреднения рассеивающегося преобразования по всем выборкам времени, установите OversamplingFactor
к 2. Устанавливание значения к 2 приведет к четырехкратному увеличению количества рассеивающихся коэффициентов для каждого пути относительно критически прореженного значения.
sn = waveletScattering('SignalLength',8192,'InvarianceScale',0.22,... 'SamplingFrequency',8000,'OversamplingFactor',2);
Разделите FSDD в наборы обучающих данных и наборы тестов. Выделите 80% данных к набору обучающих данных и сохраните 20% для набора тестов. Обучающие данные являются для обучения классификатором на основе рассеивающегося преобразования. Тестовые данные для проверки модели.
rng default;
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
summary(adsTrain.Labels)
0 240 1 240 2 240 3 240 4 240 5 240 6 240 7 240 8 240 9 240
summary(adsTest.Labels)
0 60 1 60 2 60 3 60 4 60 5 60 6 60 7 60 8 60 9 60
Сформируйтесь 8192 2400 матрица, где каждый столбец является разговорной разрядной записью. Функция помощника helperReadSPData
обрезает или заполняет данные к длине 8192 и нормирует каждую запись на ее максимальное значение.
Xtrain = []; scatds_Train = transform(adsTrain,@(x)helperReadSPData(x)); while hasdata(scatds_Train) smat = read(scatds_Train); Xtrain = cat(2,Xtrain,smat); end
Повторите процесс для протянутого набора тестов. Получившаяся матрица 8192 600.
Xtest = []; scatds_Test = transform(adsTest,@(x)helperReadSPData(x)); while hasdata(scatds_Test) smat = read(scatds_Test); Xtest = cat(2,Xtest,smat); end
Применяйте рассеивающееся преобразование к наборам обучающих данных и наборам тестов. Использование gpuArray
с CUDA-поддерживающим NVIDIA графический процессор обеспечивает значительное ускорение. С этой сетью рассеивания, пакетным размером и графическим процессором, реализация графического процессора вычисляет рассеивающиеся функции приблизительно в 15 раз быстрее, чем версия ЦП.
Strain = sn.featureMatrix(Xtrain); Stest = sn.featureMatrix(Xtest);
Получите рассеивающиеся функции наборов обучающих данных и наборов тестов.
TrainFeatures = Strain(2:end,:,:); TrainFeatures = squeeze(mean(TrainFeatures,2))'; TestFeatures = Stest(2:end,:,:); TestFeatures = squeeze(mean(TestFeatures,2))';
Этот пример использует классификатор машины опорных векторов (SVM) с квадратичным полиномиальным ядром. Подбирайте модель SVM к рассеивающимся функциям.
template = templateSVM(... 'KernelFunction', 'polynomial', ... 'PolynomialOrder', 2, ... 'KernelScale', 'auto', ... 'BoxConstraint', 1, ... 'Standardize', true); classificationSVM = fitcecoc(... TrainFeatures, ... adsTrain.Labels, ... 'Learners', template, ... 'Coding', 'onevsone', ... 'ClassNames', categorical({'0'; '1'; '2'; '3'; '4'; '5'; '6'; '7'; '8'; '9'}));
Используйте перекрестную проверку k-сгиба, чтобы предсказать точность обобщения модели. Разделите набор обучающих данных в пять групп для перекрестной проверки.
partitionedModel = crossval(classificationSVM, 'KFold', 5); [validationPredictions, validationScores] = kfoldPredict(partitionedModel); validationAccuracy = (1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError'))*100
validationAccuracy = 97.4167
Предполагаемая точность обобщения составляет приблизительно 97%. Теперь используйте модель SVM, чтобы предсказать протянутый набор тестов.
predLabels = predict(classificationSVM,TestFeatures); testAccuracy = sum(predLabels==adsTest.Labels)/numel(predLabels)*100
testAccuracy = 97.1667
Точность - также приблизительно 97% на протянутом наборе тестов.
Обобщите эффективность модели на наборе тестов с графиком беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца. Таблица в нижней части графика беспорядка показывает значения точности для каждого класса. Таблица справа от графика беспорядка показывает значения отзыва.
figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]); ccscat = confusionchart(adsTest.Labels,predLabels); ccscat.Title = 'Wavelet Scattering Classification'; ccscat.ColumnSummary = 'column-normalized'; ccscat.RowSummary = 'row-normalized';
Как итоговый пример, считайте первые две записи из набора данных, вычислите рассеивающиеся функции и предскажите разговорную цифру с помощью SVM, обученного с рассеиванием функций.
reset(ads); sig1 = helperReadSPData(read(ads)); scat1 = sn.featureMatrix(sig1); scat1 = mean(scat1(2:end,:),2)'; plab1 = predict(classificationSVM,scat1);
Считайте следующую запись и предскажите цифру.
sig2 = helperReadSPData(read(ads)); scat2 = sn.featureMatrix(sig2); scat2 = mean(scat2(2:end,:),2)'; plab2 = predict(classificationSVM,scat2);
t = 0:1/8000:(8192*1/8000)-1/8000; plot(t,[sig1 sig2]) grid on axis tight legend(char(plab1),char(plab2)) title('Spoken Digit Prediction - GPU')
Следующие функции помощника используются в этом примере.
helpergenLabels — генерирует метки на основе имен файлов в FSDD.
function Labels = helpergenLabels(ads) % This function is only for use in Wavelet Toolbox examples. It may be % changed or removed in a future release. tmp = cell(numel(ads.Files),1); expression = "[0-9]+_"; for nf = 1:numel(ads.Files) idx = regexp(ads.Files{nf},expression); tmp{nf} = ads.Files{nf}(idx); end Labels = categorical(tmp); end
helperReadSPData — Гарантирует, что каждая разговорная разрядная запись является 8 192 выборками долго.
function x = helperReadSPData(x) % This function is only for use Wavelet Toolbox examples. It may change or % be removed in a future release. N = numel(x); if N > 8192 x = x(1:8192); elseif N < 8192 pad = 8192-N; prepad = floor(pad/2); postpad = ceil(pad/2); x = [zeros(prepad,1) ; x ; zeros(postpad,1)]; end x = x./max(abs(x)); end