Классификация полов с помощью сетей GRU

В этом примере показано, как классифицировать пол диктора с помощью глубокого обучения. В примере используются сеть Gated Recurrent Модуля (GRU) и коэффициенты Гамматона Cepstral (gtcc), тангаж, гармоническое отношение и несколько спектральных дескрипторов формы.

Введение

Гендерная классификация, основанная на речевых сигналах, является важным компонентом многих аудиосистем, таких как автоматическое распознавание речи, распознавание диктора и мультимедийная индексация на основе контента.

Этот пример использует сети GRU, тип рекуррентной нейронной сети (RNN), хорошо подходящий для изучения данных последовательности и timeseries. Сеть GRU может изучать долгосрочные зависимости между временными шагами последовательности.

Этот пример обучает сеть GRU с последовательностями коэффициентов гамматона cepstrum (gtcc), оценки тангажа (pitch), коэффициент гармоники (harmonicRatio) и несколько спектральных дескрипторов формы (спектральные дескрипторы).

Чтобы ускорить процесс обучения, запустите этот пример на машине с графическим процессором. Если ваша машина имеет графический процессор и Parallel Computing Toolbox™, то MATLAB © автоматически использует для обучения графический процессор; в противном случае используется центральный процессор.

Классификация полов с помощью предварительно обученной сети

Прежде чем подробно войти в процесс обучения, вы будете использовать предварительно обученную сеть, чтобы классифицировать пол динамика в двух тестовых сигналах.

Загрузите предварительно обученную сеть.

url = 'http://ssd.mathworks.com/supportfiles/audio/GenderClassification.zip';

downloadNetFolder = tempdir;
netFolder = fullfile(downloadNetFolder,'GenderClassification');

if ~exist(netFolder,'dir')
    unzip(url,downloadNetFolder)
end

Загрузите предварительно обученную сеть вместе с предварительно вычисленными векторами, используемыми для нормализации функции.

matFileName = fullfile(netFolder, 'genderIDNet.mat');
load(matFileName,'genderIDNet','M','S');

Загрузите тестовый сигнал с мужским динамиком.

[audioIn,Fs] = audioread('maleSpeech.flac');
sound(audioIn,Fs)

Изолируйте область речи в сигнале.

boundaries = detectSpeech(audioIn,Fs);
audioIn = audioIn(boundaries(1):boundaries(2));

Создайте audioFeatureExtractor для извлечения функций из аудио данных. Этот же объект будет использоваться для извлечения функций для обучения.

extractor = audioFeatureExtractor( ...
    "SampleRate",Fs, ...
    "Window",hamming(round(0.03*Fs),"periodic"), ...
    "OverlapLength",round(0.02*Fs), ...
    ...
    "gtcc",true, ...
    "gtccDelta",true, ...
    "gtccDeltaDelta",true, ...
    ...
    "SpectralDescriptorInput","melSpectrum", ...
    "spectralCentroid",true, ...
    "spectralEntropy",true, ...
    "spectralFlux",true, ...
    "spectralSlope",true, ...
    ...
    "pitch",true, ...
    "harmonicRatio",true);

Извлеките функции из сигнала и нормализуйте их.

features = extract(extractor,audioIn);
features = (features.' - M)./S;

Классифицируйте сигнал.

gender = classify(genderIDNet,features)
gender = categorical
     male 

Классифицируйте другой сигнал с женским динамиком.

[audioIn,Fs] = audioread('femaleSpeech.flac');
sound(audioIn,Fs)
boundaries = detectSpeech(audioIn,Fs);
audioIn = audioIn(boundaries(1):boundaries(2));

features = extract(extractor,audioIn);
features = (features.' - M)./S;

classify(genderIDNet,features)
ans = categorical
     female 

Предварительная обработка обучающих аудио Данных

Сеть GRU, используемая в этом примере, лучше всего работает при использовании последовательностей векторов функций. Чтобы проиллюстрировать трубопровод предварительной обработки, этот пример проходит через шаги для одного аудио файла.

Чтение содержимого аудио файла, содержащего речь. Носитель пола - мужчина.

[audioIn,Fs] = audioread('Counting-16-44p1-mono-15secs.wav');
labels = {'male'};

Постройте график аудиосигнала и затем прослушайте его с помощью sound команда.

timeVector = (1/Fs) * (0:size(audioIn,1)-1);
figure
plot(timeVector,audioIn)
ylabel("Amplitude")
xlabel("Time (s)")
title("Sample Audio")
grid on

sound(audioIn,Fs)

Речевой сигнал имеет сегменты молчания, которые не содержат полезной информации, относящейся к полу динамика. Использование detectSpeech для определения местоположения сегментов речи в аудиосигнале.

speechIndices = detectSpeech(audioIn,Fs);

Создайте audioFeatureExtractor для извлечения функций из аудио данных. Речевой сигнал имеет динамический характер и изменяется с течением времени. Принято, что речевые сигналы являются стационарными в коротких временных шкалах и их обработка часто выполняется в окнах 20-40 мс. Задайте 30 мс окна с 20 мс перекрытием.

extractor = audioFeatureExtractor( ...
    "SampleRate",Fs, ...
    "Window",hamming(round(0.03*Fs),"periodic"), ...
    "OverlapLength",round(0.02*Fs), ...
    ...
    "gtcc",true, ...
    "gtccDelta",true, ...
    "gtccDeltaDelta",true, ...
    ...
    "SpectralDescriptorInput","melSpectrum", ...
    "spectralCentroid",true, ...
    "spectralEntropy",true, ...
    "spectralFlux",true, ...
    "spectralSlope",true, ...
    ...
    "pitch",true, ...
    "harmonicRatio",true);

Извлечение функций из каждого аудиосегмента. Выходы audioFeatureExtractor является numFeatureVectors-by- numFeatures массив. The sequenceInputLayer используемый в этом примере, требует времени, чтобы быть вдоль второго измерения. Переместите массив выхода так, чтобы время было вдоль второго измерения.

featureVectorsSegment = {};
for ii = 1:size(speechIndices,1)
    featureVectorsSegment{end+1} = ( extract(extractor,audioIn(speechIndices(ii,1):speechIndices(ii,2))) )';
end
numSegments = size(featureVectorsSegment)
numSegments = 1×2

     1    11

[numFeatures,numFeatureVectorsSegment1] = size(featureVectorsSegment{1})
numFeatures = 45
numFeatureVectorsSegment1 = 124

Реплицируйте метки так, чтобы они находились в соответствии с сегментами.

labels = repelem(labels,size(speechIndices,1))
labels = 1×11 cell
    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}

При использовании sequenceInputLayerчасто выгодно использовать последовательности постоянной длины. Преобразуйте массивы векторов функций в последовательности векторов функций. Используйте 20 векторов функций на последовательность с 5 перекрытиями векторов функций.

featureVectorsPerSequence = 20;
featureVectorOverlap = 5;
hopLength = featureVectorsPerSequence - featureVectorOverlap;

idx1 = 1;
featuresTrain = {};
sequencePerSegment = zeros(numel(featureVectorsSegment),1);
for ii = 1:numel(featureVectorsSegment)
    sequencePerSegment(ii) = max(floor((size(featureVectorsSegment{ii},2) - featureVectorsPerSequence)/hopLength) + 1,0);
    idx2 = 1;
    for j = 1:sequencePerSegment(ii)
        featuresTrain{idx1,1} = featureVectorsSegment{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1);
        idx1 = idx1 + 1;
        idx2 = idx2 + hopLength;
    end
end

Для краткости, функция helper HelperFeatureVector2Sequence инкапсулирует вышеописанную обработку и используется во всем остальном примере.

Реплицируйте метки так, чтобы они находились в соответствии «один в один» с набором обучающих данных.

labels = repelem(labels,sequencePerSegment);

Результатом предварительной обработки трубопровода является NumSequence-by-1 массив ячеек NumFeatures-by- FeatureVectorsPerSequence матрицы. Метки являются NumSequenceмассив -by-1.

NumSequence = numel(featuresTrain)
NumSequence = 27
[NumFeatures,FeatureVectorsPerSequence] = size(featuresTrain{1})
NumFeatures = 45
FeatureVectorsPerSequence = 20
NumSequence = numel(labels)
NumSequence = 27

Рисунок предоставляет обзор редукции данных, используемых для каждой обнаруженной речевой области.

Создайте обучающие и тестовые хранилища данных

В этом примере используется подмножество набора данных Mozilla Common Voice [1]. Набор данных содержит записи 48 кГц субъектов, говорящих короткие предложения. Загрузите набор данных и распакуйте загруженный файл. Задайте PathToDatabase к местоположению данных.

url = 'http://ssd.mathworks.com/supportfiles/audio/commonvoice.zip';
downloadDatasetFolder = tempdir;
dataFolder = fullfile(downloadDatasetFolder,'commonvoice');

if ~exist(dataFolder,'dir')
    disp('Downloading data set (956 MB) ...')
    unzip(url,downloadDatasetFolder)
end
Downloading data set (956 MB) ...

Использование audioDatastore создание хранилищ данных для наборов обучения и валидации. Использование readtable для чтения метаданных, связанных с аудио файлов.

loc = fullfile(dataFolder);
adsTrain = audioDatastore(fullfile(loc,'train'),'IncludeSubfolders',true);
metadataTrain = readtable(fullfile(fullfile(loc,'train'),"train.tsv"),"FileType","text");
adsTrain.Labels = metadataTrain.gender;

adsValidation = audioDatastore(fullfile(loc,'validation'),'IncludeSubfolders',true);
metadataValidation = readtable(fullfile(fullfile(loc,'validation'),"validation.tsv"),"FileType","text");
adsValidation.Labels = metadataValidation.gender;

Использование countEachLabel проверить гендерную разбивку наборов обучения и валидации.

countEachLabel(adsTrain)
ans=2×2 table
    Label     Count
    ______    _____

    female    1000 
    male      1000 

countEachLabel(adsValidation)
ans=2×2 table
    Label     Count
    ______    _____

    female     200 
    male       200 

Чтобы обучить сеть со набором данных в целом и достичь максимально возможной точности, установите reduceDataset на false. Чтобы запустить этот пример быстро, установите reduceDataset на true.

reduceDataset = false;
if reduceDataset
    % Reduce the training dataset by a factor of 20
    adsTrain = splitEachLabel(adsTrain,round(numel(adsTrain.Files) / 2 / 20));
    adsValidation = splitEachLabel(adsValidation,20);
end

Создайте наборы обучения и валидации

Определите частоту дискретизации аудио файлов в наборе данных, а затем обновите частоту дискретизации, окно и длину перекрытия извлечения аудио функции.

[~,adsInfo] = read(adsTrain);
Fs = adsInfo.SampleRate;
extractor.SampleRate = Fs;
extractor.Window = hamming(round(0.03*Fs),"periodic");
extractor.OverlapLength = round(0.02*Fs);

Чтобы ускорить обработку, распределите расчеты по нескольким работникам. Если у вас есть Parallel Computing Toolbox™, то пример разделяет datastore так, что редукция данных происходит параллельно между доступными работниками. Определите оптимальное количество разделов для вашей системы. Если у вас нет Parallel Computing Toolbox™, в примере используется один рабочий процесс.

if ~isempty(ver('parallel')) && ~reduceDataset
    pool = gcp;
    numPar = numpartitions(adsTrain,pool);
else
    numPar = 1;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

В цикле:

  1. Чтение из audio datastore.

  2. Обнаружение областей речи.

  3. Извлеките векторы функций из областей речи.

Реплицируйте метки так, чтобы они находились в индивидуальном соответствии с векторами функций.

labelsTrain = [];
featureVectors = {};

% Loop over optimal number of partitions
parfor ii = 1:numPar
    
    % Partition datastore
    subds = partition(adsTrain,numPar,ii);
    
    % Preallocation
    featureVectorsInSubDS = {};
    segmentsPerFile = zeros(numel(subds.Files),1);
    
    % Loop over files in partitioned datastore
    for jj = 1:numel(subds.Files)
        
        % 1. Read in a single audio file
        audioIn = read(subds);
        
        % 2. Determine the regions of the audio that correspond to speech
        speechIndices = detectSpeech(audioIn,Fs);
        
        % 3. Extract features from each speech segment
        segmentsPerFile(jj) = size(speechIndices,1);
        features = cell(segmentsPerFile(jj),1);
        for kk = 1:size(speechIndices,1)
            features{kk} = ( extract(extractor,audioIn(speechIndices(kk,1):speechIndices(kk,2))) )';
        end
        featureVectorsInSubDS = [featureVectorsInSubDS;features(:)];
        
    end
    featureVectors = [featureVectors;featureVectorsInSubDS];
    
    % Replicate the labels so that they are in one-to-one correspondance
    % with the feature vectors.
    repedLabels = repelem(subds.Labels,segmentsPerFile);
    labelsTrain = [labelsTrain;repedLabels(:)];
end

В приложениях классификации рекомендуется нормализовать все функции, чтобы иметь нулевое среднее и стандартное отклонение единицы.

Вычислите среднее и стандартное отклонение для каждого коэффициента и используйте их для нормализации данных.

allFeatures = cat(2,featureVectors{:});
allFeatures(isinf(allFeatures)) = nan;
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');
featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

Буферизуйте векторы функций в последовательности из 20 векторов функций с 10 перекрытиями. Если последовательность имеет менее 20 векторов функций, перетащите ее.

[featuresTrain,trainSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);

Тиражируйте метки так, чтобы они находились в взаимном соответствии с последовательностями.

labelsTrain = repelem(labelsTrain,[trainSequencePerSegment{:}]);
labelsTrain = categorical(labelsTrain);

Создайте набор валидации с помощью тех же шагов, которые используются для создания набора обучающих данных.

labelsValidation = [];
featureVectors = {};
valSegmentsPerFile = [];
parfor ii = 1:numPar
    subds = partition(adsValidation,numPar,ii);
    featureVectorsInSubDS = {};
    valSegmentsPerFileInSubDS = zeros(numel(subds.Files),1);
    for jj = 1:numel(subds.Files)
        audioIn = read(subds);
        speechIndices = detectSpeech(audioIn,Fs);
        numSegments = size(speechIndices,1);
        features = cell(valSegmentsPerFileInSubDS(jj),1);
        for kk = 1:numSegments
            features{kk} = ( extract(extractor,audioIn(speechIndices(kk,1):speechIndices(kk,2))) )';
        end
        featureVectorsInSubDS = [featureVectorsInSubDS;features(:)];
        valSegmentsPerFileInSubDS(jj) = numSegments;
    end
    repedLabels = repelem(subds.Labels,valSegmentsPerFileInSubDS);
    labelsValidation = [labelsValidation;repedLabels(:)];
    featureVectors = [featureVectors;featureVectorsInSubDS];
    valSegmentsPerFile = [valSegmentsPerFile;valSegmentsPerFileInSubDS];
end

featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

[featuresValidation,valSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);
labelsValidation = repelem(labelsValidation,[valSequencePerSegment{:}]);
labelsValidation = categorical(labelsValidation);

Определение сетевой архитектуры GRU

Сети GRU могут изучать долгосрочные зависимости между временными шагами данных последовательности. Этот пример использует gruLayer чтобы просмотреть последовательность как в прямом, так и в обратном направлениях.

Задайте размер входа, чтобы быть последовательностями размера NumFeatures. Задайте слой GRU с размером выходом 75 и выведите последовательность. Затем задайте слой GRU с размером выходом 75 и выведите последний элемент последовательности. Эта команда предписывает слою GRU сопоставить его вход с 75 функциями, а затем подготавливает выход для полностью подключенного слоя. Наконец, задайте два класса путем включения полносвязного слоя размера 2, за которым следуют слой softmax и слой классификации.

layers = [ ...
    sequenceInputLayer(size(featuresTrain{1},1))
    gruLayer(75,"OutputMode","sequence")
    gruLayer(75,"OutputMode","last")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer];

Затем задайте опции обучения для классификатора. Задайте MaxEpochs на 4 так, что сеть делает 4 прохода через обучающие данные. Задайте MiniBatchSize 256 так, чтобы сеть рассматривала 128 обучающих сигналов за раз. Задайте Plots как "training-progress" чтобы сгенерировать графики, которые показывают процесс обучения с увеличениями количества итераций. Задайте Verbose на false чтобы отключить печать выхода таблицы, который соответствует данным, показанным на графике. Задайте Shuffle как "every-epoch" тасовать обучающую последовательность в начале каждой эпохи. Задайте LearnRateSchedule на "piecewise" уменьшить скорость обучения на заданный коэффициент (0,1) каждый раз, когда прошло определенное количество эпох (1).

Этот пример использует решатель адаптивной оценки момента (ADAM). ADAM работает лучше с рекуррентными нейронными сетями (RNNs), такими как GRUs, чем стохастический градиентный спуск по умолчанию с импульсом (SGDM) решателя.

miniBatchSize = 256;
validationFrequency = floor(numel(labelsTrain)/miniBatchSize);
options = trainingOptions("adam", ...
    "MaxEpochs",4, ...
    "MiniBatchSize",miniBatchSize, ...
    "Plots","training-progress", ...
    "Verbose",false, ...
    "Shuffle","every-epoch", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",1, ...
    'ValidationData',{featuresValidation,labelsValidation}, ...
    'ValidationFrequency',validationFrequency);

Обучение сети ГРУ

Обучите сеть GRU с заданными опциями обучения и архитектурой слоя с помощью trainNetwork. Поскольку набор обучающих данных большая, процесс обучения может занять несколько минут.

net = trainNetwork(featuresTrain,labelsTrain,layers,options);

Верхний подграфик графика процесса обучения представляет точность обучения, которая является точностью классификации для каждого мини-пакета. Когда обучение успешно прогрессирует, это значение обычно увеличивается до 100%. В нижней подграфике отображаются потери обучения, которые являются потерями перекрестной энтропии на каждом мини-пакете. Когда обучение успешно прогрессирует, это значение обычно уменьшается к нулю.

Если обучение не сходится, графики могут колебаться между значениями, не стремясь в определенном направлении вверх или вниз. Это колебание означает, что точность обучения не улучшается и потеря обучения не уменьшается. Такая ситуация может возникнуть в начале обучения или после некоторого предварительного повышения точности обучения. Во многих случаях изменение опций обучения может помочь сети достичь сходимости. Уменьшение MiniBatchSize или уменьшение InitialLearnRate это может привести к увеличению времени обучения, но может помочь сети лучше учиться.

Визуализация точности обучения

Вычислите точность обучения, которая представляет точность классификатора на сигналах, на которых он был обучен. Во-первых, классифицируйте обучающие данные.

prediction = classify(net,featuresTrain);

Постройте график матрицы неточностей. Отображение точности и отзыва для двух классов с помощью сводных данных по столбцам и строкам.

figure
cm = confusionchart(categorical(labelsTrain),prediction,'title','Training Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Визуализация точности валидации

Вычислите точность валидации. Во-первых, классифицируйте обучающие данные.

[prediction,probabilities] = classify(net,featuresValidation);

Постройте график матрицы неточностей. Отображение точности и отзыва для двух классов с помощью сводных данных по столбцам и строкам.

figure
cm = confusionchart(categorical(labelsValidation),prediction,'title','Validation Set Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Пример сгенерировал несколько последовательностей из каждого обучающего речевого файла. Более высокая точность может быть достигнута путем принятия выходного класса всех последовательностей, соответствующих одному и тому же файлу, и применения решения «max-rule», где выбран класс с сегментом с наивысшей оценкой достоверности.

Определите количество последовательностей, сгенерированных в каждом файле в наборе валидации.

sequencePerFile = zeros(size(valSegmentsPerFile));
valSequencePerSegmentMat = cell2mat(valSequencePerSegment);
idx = 1;
for ii = 1:numel(valSegmentsPerFile)
    sequencePerFile(ii) = sum(valSequencePerSegmentMat(idx:idx+valSegmentsPerFile(ii)-1));
    idx = idx + valSegmentsPerFile(ii);
end

Спрогнозируйте пол из каждого обучающего файла, принимая во внимание выходные классы всех последовательностей, сгенерированных из одного и того же файла.

numFiles = numel(adsValidation.Files);
actualGender = categorical(adsValidation.Labels);
predictedGender = actualGender;      
scores = cell(1,numFiles);
counter = 1;
cats = unique(actualGender);
for index = 1:numFiles
    scores{index} = probabilities(counter: counter + sequencePerFile(index) - 1,:);
    m = max(mean(scores{index},1),[],1);
    if m(1) >= m(2)
        predictedGender(index) = cats(1);
    else
        predictedGender(index) = cats(2); 
    end
    counter = counter + sequencePerFile(index);
end

Визуализируйте матрицу неточностей на предсказаниях с правилами большинства.

figure
cm = confusionchart(actualGender,predictedGender,'title','Validation Set Accuracy - Max Rule');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Ссылки

[1] Общий набор голосовых наборов данных Mozilla

Вспомогательные функции

function [sequences,sequencePerSegment] = HelperFeatureVector2Sequence(features,featureVectorsPerSequence,featureVectorOverlap)
if featureVectorsPerSequence <= featureVectorOverlap
    error('The number of overlapping feature vectors must be less than the number of feature vectors per sequence.')
end

hopLength = featureVectorsPerSequence - featureVectorOverlap;
idx1 = 1;
sequences = {};
sequencePerSegment = cell(numel(features),1);
for ii = 1:numel(features)
    sequencePerSegment{ii} = max(floor((size(features{ii},2) - featureVectorsPerSequence)/hopLength) + 1,0);
    idx2 = 1;
    for j = 1:sequencePerSegment{ii}
        sequences{idx1,1} = features{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1); %#ok<AGROW>
        idx1 = idx1 + 1;
        idx2 = idx2 + hopLength;
    end
end
end