exponenta event banner

Классификация пола с использованием сетей ГРУ

В этом примере показано, как классифицировать пол говорящего с помощью глубокого обучения. В этом примере используется сеть Gated Receivative Unit (GRU) и кэпстральные коэффициенты гамматона (gtcc), тон, отношение гармоник и несколько дескрипторов спектральной формы.

Введение

Гендерная классификация, основанная на речевых сигналах, является важным компонентом многих аудиосистем, таких как автоматическое распознавание речи, распознавание говорящих и индексирование мультимедийных данных на основе содержания.

В этом примере используются сети GRU, тип рекуррентной нейронной сети (RNN), хорошо подходящий для изучения последовательности и данных временных рядов. Сеть GRU может изучать долгосрочные зависимости между временными шагами последовательности.

Этот пример обучает сеть GRU последовательностями коэффициентов кепстра гамматона (gtcc), оценки тангажа (pitch), гармоническое отношение (harmonicRatio) и несколько дескрипторов спектральной формы (спектральные дескрипторы).

Чтобы ускорить процесс обучения, выполните этот пример на машине с графическим процессором. Если ваш компьютер имеет графический процессор и Toolbox™ параллельных вычислений, MATLAB © автоматически использует графический процессор для обучения; в противном случае он использует ЦП.

Классифицировать пол с помощью предварительно обученной сети

Перед подробным началом процесса обучения необходимо использовать предварительно обученную сеть для классификации пола говорящего по двум тестовым сигналам.

Загрузите предварительно обученную сеть.

url = 'http://ssd.mathworks.com/supportfiles/audio/GenderClassification.zip';

downloadNetFolder = tempdir;
netFolder = fullfile(downloadNetFolder,'GenderClassification');

if ~exist(netFolder,'dir')
    unzip(url,downloadNetFolder)
end

Загрузите предварительно обученную сеть вместе с предварительно вычисленными векторами, используемыми для нормализации функций.

matFileName = fullfile(netFolder, 'genderIDNet.mat');
load(matFileName,'genderIDNet','M','S');

Загрузите тестовый сигнал с охватываемым громкоговорителем.

[audioIn,Fs] = audioread('maleSpeech.flac');
sound(audioIn,Fs)

Изолируйте область речи в сигнале.

boundaries = detectSpeech(audioIn,Fs);
audioIn = audioIn(boundaries(1):boundaries(2));

Создание audioFeatureExtractor для извлечения функций из аудиоданных. Этот же объект используется для извлечения элементов для обучения.

extractor = audioFeatureExtractor( ...
    "SampleRate",Fs, ...
    "Window",hamming(round(0.03*Fs),"periodic"), ...
    "OverlapLength",round(0.02*Fs), ...
    ...
    "gtcc",true, ...
    "gtccDelta",true, ...
    "gtccDeltaDelta",true, ...
    ...
    "SpectralDescriptorInput","melSpectrum", ...
    "spectralCentroid",true, ...
    "spectralEntropy",true, ...
    "spectralFlux",true, ...
    "spectralSlope",true, ...
    ...
    "pitch",true, ...
    "harmonicRatio",true);

Извлеките элементы из сигнала и нормализуйте их.

features = extract(extractor,audioIn);
features = (features.' - M)./S;

Классифицируйте сигнал.

gender = classify(genderIDNet,features)
gender = categorical
     male 

Классифицируйте другой сигнал с говорящей женщиной.

[audioIn,Fs] = audioread('femaleSpeech.flac');
sound(audioIn,Fs)
boundaries = detectSpeech(audioIn,Fs);
audioIn = audioIn(boundaries(1):boundaries(2));

features = extract(extractor,audioIn);
features = (features.' - M)./S;

classify(genderIDNet,features)
ans = categorical
     female 

Предварительная обработка обучающих аудиоданных

Сеть GRU, используемая в этом примере, лучше всего работает при использовании последовательностей векторов признаков. Чтобы проиллюстрировать конвейер предварительной обработки, в этом примере рассматриваются шаги для одного аудиофайла.

Чтение содержимого аудиофайла, содержащего речь. Пол оратора - мужчина.

[audioIn,Fs] = audioread('Counting-16-44p1-mono-15secs.wav');
labels = {'male'};

Постройте график звукового сигнала, а затем прослушайте его с помощью sound команда.

timeVector = (1/Fs) * (0:size(audioIn,1)-1);
figure
plot(timeVector,audioIn)
ylabel("Amplitude")
xlabel("Time (s)")
title("Sample Audio")
grid on

sound(audioIn,Fs)

Речевой сигнал имеет сегменты молчания, которые не содержат полезной информации, относящейся к полу говорящего. Использовать detectSpeech для определения местоположения сегментов речи в звуковом сигнале.

speechIndices = detectSpeech(audioIn,Fs);

Создание audioFeatureExtractor для извлечения функций из аудиоданных. Речевой сигнал является динамическим по своей природе и изменяется с течением времени. Предполагается, что речевые сигналы неподвижны на коротких временных шкалах и их обработка часто выполняется в окнах 20-40 мс. Укажите окна 30 мс с перекрытием 20 мс.

extractor = audioFeatureExtractor( ...
    "SampleRate",Fs, ...
    "Window",hamming(round(0.03*Fs),"periodic"), ...
    "OverlapLength",round(0.02*Fs), ...
    ...
    "gtcc",true, ...
    "gtccDelta",true, ...
    "gtccDeltaDelta",true, ...
    ...
    "SpectralDescriptorInput","melSpectrum", ...
    "spectralCentroid",true, ...
    "spectralEntropy",true, ...
    "spectralFlux",true, ...
    "spectralSlope",true, ...
    ...
    "pitch",true, ...
    "harmonicRatio",true);

Извлеките функции из каждого аудиосегмента. Выходные данные audioFeatureExtractor является numFeatureVectorsоколо-numFeatures массив. sequenceInputLayer Для использования в этом примере требуется время вдоль второго размера. Переставьте выходной массив так, чтобы время находилось вдоль второго измерения.

featureVectorsSegment = {};
for ii = 1:size(speechIndices,1)
    featureVectorsSegment{end+1} = ( extract(extractor,audioIn(speechIndices(ii,1):speechIndices(ii,2))) )';
end
numSegments = size(featureVectorsSegment)
numSegments = 1×2

     1    11

[numFeatures,numFeatureVectorsSegment1] = size(featureVectorsSegment{1})
numFeatures = 45
numFeatureVectorsSegment1 = 124

Реплицируйте метки таким образом, чтобы они соответствовали сегментам один к одному.

labels = repelem(labels,size(speechIndices,1))
labels = 1×11 cell
    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}

При использовании sequenceInputLayer, часто выгодно использовать последовательности согласованной длины. Преобразуйте массивы векторов элементов в последовательности векторов элементов. Используйте 20 векторов элементов на последовательность с 5 перекрытиями векторов элементов.

featureVectorsPerSequence = 20;
featureVectorOverlap = 5;
hopLength = featureVectorsPerSequence - featureVectorOverlap;

idx1 = 1;
featuresTrain = {};
sequencePerSegment = zeros(numel(featureVectorsSegment),1);
for ii = 1:numel(featureVectorsSegment)
    sequencePerSegment(ii) = max(floor((size(featureVectorsSegment{ii},2) - featureVectorsPerSequence)/hopLength) + 1,0);
    idx2 = 1;
    for j = 1:sequencePerSegment(ii)
        featuresTrain{idx1,1} = featureVectorsSegment{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1);
        idx1 = idx1 + 1;
        idx2 = idx2 + hopLength;
    end
end

Для краткости вспомогательная функция HelperFeatureVector2Sequence инкапсулирует вышеуказанную обработку и используется в остальном примере.

Реплицируйте метки таким образом, чтобы они соответствовали друг другу с обучающим набором.

labels = repelem(labels,sequencePerSegment);

Результатом предварительной обработки трубопровода является NumSequence-by-1 массив ячеек NumFeaturesоколо-FeatureVectorsPerSequence матрицы. Метки являются NumSequenceмассив -by-1.

NumSequence = numel(featuresTrain)
NumSequence = 27
[NumFeatures,FeatureVectorsPerSequence] = size(featuresTrain{1})
NumFeatures = 45
FeatureVectorsPerSequence = 20
NumSequence = numel(labels)
NumSequence = 27

На рисунке представлен обзор извлечения признаков, используемых для обнаруженной речевой области.

Создание хранилищ данных для обучения и тестирования

В этом примере используется подмножество набора данных Mozilla Common Voice [1]. Набор данных содержит записи 48 кГц субъектов, говорящих короткие предложения. Загрузите набор данных и отмените обработку загруженного файла. Набор PathToDatabase в расположение данных.

url = 'http://ssd.mathworks.com/supportfiles/audio/commonvoice.zip';
downloadDatasetFolder = tempdir;
dataFolder = fullfile(downloadDatasetFolder,'commonvoice');

if ~exist(dataFolder,'dir')
    disp('Downloading data set (956 MB) ...')
    unzip(url,downloadDatasetFolder)
end
Downloading data set (956 MB) ...

Использовать audioDatastore для создания хранилищ данных для наборов обучения и проверки. Использовать readtable для чтения метаданных, связанных с аудиофайлами.

loc = fullfile(dataFolder);
adsTrain = audioDatastore(fullfile(loc,'train'),'IncludeSubfolders',true);
metadataTrain = readtable(fullfile(fullfile(loc,'train'),"train.tsv"),"FileType","text");
adsTrain.Labels = metadataTrain.gender;

adsValidation = audioDatastore(fullfile(loc,'validation'),'IncludeSubfolders',true);
metadataValidation = readtable(fullfile(fullfile(loc,'validation'),"validation.tsv"),"FileType","text");
adsValidation.Labels = metadataValidation.gender;

Использовать countEachLabel проверка гендерной разбивки учебных и валидационных наборов.

countEachLabel(adsTrain)
ans=2×2 table
    Label     Count
    ______    _____

    female    1000 
    male      1000 

countEachLabel(adsValidation)
ans=2×2 table
    Label     Count
    ______    _____

    female     200 
    male       200 

Чтобы обучить сеть всему набору данных и достичь максимально возможной точности, установите reduceDataset кому false. Для быстрого выполнения этого примера установите reduceDataset кому true.

reduceDataset = false;
if reduceDataset
    % Reduce the training dataset by a factor of 20
    adsTrain = splitEachLabel(adsTrain,round(numel(adsTrain.Files) / 2 / 20));
    adsValidation = splitEachLabel(adsValidation,20);
end

Создание наборов обучения и проверки

Определите частоту дискретизации аудиофайлов в наборе данных, а затем обновите частоту дискретизации, окно и длину перекрытия устройства извлечения аудиофайлов.

[~,adsInfo] = read(adsTrain);
Fs = adsInfo.SampleRate;
extractor.SampleRate = Fs;
extractor.Window = hamming(round(0.03*Fs),"periodic");
extractor.OverlapLength = round(0.02*Fs);

Для ускорения обработки распределите вычисления между несколькими работниками. При наличии Toolbox™ Parallel Computing в примере выполняется разделение хранилища данных таким образом, что извлечение элементов происходит параллельно между доступными работниками. Определите оптимальное количество разделов для системы. Если у вас нет Toolbox™ параллельных вычислений, в примере используется один работник.

if ~isempty(ver('parallel')) && ~reduceDataset
    pool = gcp;
    numPar = numpartitions(adsTrain,pool);
else
    numPar = 1;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

В цикле:

  1. Считывание из хранилища аудиоданных.

  2. Обнаружение областей речи.

  3. Извлеките векторы признаков из областей речи.

Реплицируйте метки таким образом, чтобы они соответствовали векторам признаков один к одному.

labelsTrain = [];
featureVectors = {};

% Loop over optimal number of partitions
parfor ii = 1:numPar
    
    % Partition datastore
    subds = partition(adsTrain,numPar,ii);
    
    % Preallocation
    featureVectorsInSubDS = {};
    segmentsPerFile = zeros(numel(subds.Files),1);
    
    % Loop over files in partitioned datastore
    for jj = 1:numel(subds.Files)
        
        % 1. Read in a single audio file
        audioIn = read(subds);
        
        % 2. Determine the regions of the audio that correspond to speech
        speechIndices = detectSpeech(audioIn,Fs);
        
        % 3. Extract features from each speech segment
        segmentsPerFile(jj) = size(speechIndices,1);
        features = cell(segmentsPerFile(jj),1);
        for kk = 1:size(speechIndices,1)
            features{kk} = ( extract(extractor,audioIn(speechIndices(kk,1):speechIndices(kk,2))) )';
        end
        featureVectorsInSubDS = [featureVectorsInSubDS;features(:)];
        
    end
    featureVectors = [featureVectors;featureVectorsInSubDS];
    
    % Replicate the labels so that they are in one-to-one correspondance
    % with the feature vectors.
    repedLabels = repelem(subds.Labels,segmentsPerFile);
    labelsTrain = [labelsTrain;repedLabels(:)];
end

В классификационных приложениях рекомендуется нормализовать все элементы так, чтобы они имели нулевое среднее и единичное стандартное отклонение.

Вычислите среднее значение и стандартное отклонение для каждого коэффициента и используйте их для нормализации данных.

allFeatures = cat(2,featureVectors{:});
allFeatures(isinf(allFeatures)) = nan;
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');
featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

Буферизация векторов признаков в последовательности из 20 векторов признаков с 10 перекрытиями. Если последовательность имеет менее 20 векторов признаков, удалите ее.

[featuresTrain,trainSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);

Реплицируйте метки таким образом, чтобы они соответствовали последовательностям один к одному.

labelsTrain = repelem(labelsTrain,[trainSequencePerSegment{:}]);
labelsTrain = categorical(labelsTrain);

Создайте набор проверки, используя те же шаги, что и при создании набора обучения.

labelsValidation = [];
featureVectors = {};
valSegmentsPerFile = [];
parfor ii = 1:numPar
    subds = partition(adsValidation,numPar,ii);
    featureVectorsInSubDS = {};
    valSegmentsPerFileInSubDS = zeros(numel(subds.Files),1);
    for jj = 1:numel(subds.Files)
        audioIn = read(subds);
        speechIndices = detectSpeech(audioIn,Fs);
        numSegments = size(speechIndices,1);
        features = cell(valSegmentsPerFileInSubDS(jj),1);
        for kk = 1:numSegments
            features{kk} = ( extract(extractor,audioIn(speechIndices(kk,1):speechIndices(kk,2))) )';
        end
        featureVectorsInSubDS = [featureVectorsInSubDS;features(:)];
        valSegmentsPerFileInSubDS(jj) = numSegments;
    end
    repedLabels = repelem(subds.Labels,valSegmentsPerFileInSubDS);
    labelsValidation = [labelsValidation;repedLabels(:)];
    featureVectors = [featureVectors;featureVectorsInSubDS];
    valSegmentsPerFile = [valSegmentsPerFile;valSegmentsPerFileInSubDS];
end

featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

[featuresValidation,valSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);
labelsValidation = repelem(labelsValidation,[valSequencePerSegment{:}]);
labelsValidation = categorical(labelsValidation);

Определение сетевой архитектуры GRU

Сети ГРУ могут изучать долгосрочные зависимости между временными шагами данных последовательности. В этом примере используется gruLayer чтобы посмотреть на последовательность как в прямом, так и в обратном направлениях.

Укажите входной размер, который будет последовательностью размера NumFeatures. Укажите уровень GRU с размером вывода 75 и выведите последовательность. Затем укажите уровень GRU с размером вывода 75 и выведите последний элемент последовательности. Эта команда предписывает уровню GRU отобразить свой вход в 75 элементов, а затем подготавливает выход для полностью подключенного уровня. Наконец, укажите два класса, включив полностью связанный слой размера 2, за которым следуют слой softmax и слой классификации.

layers = [ ...
    sequenceInputLayer(size(featuresTrain{1},1))
    gruLayer(75,"OutputMode","sequence")
    gruLayer(75,"OutputMode","last")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer];

Затем укажите параметры обучения для классификатора. Набор MaxEpochs кому 4 так, что сеть 4 проходит через обучающие данные. Набор MiniBatchSize 256, так что сеть просматривает 128 обучающих сигналов одновременно. Определить Plots как "training-progress" для создания графиков, показывающих ход обучения по мере увеличения числа итераций. Набор Verbose кому false для отключения печати выходных данных таблицы, соответствующих данным, отображаемым на графике. Определить Shuffle как "every-epoch" перетасовать тренировочную последовательность в начале каждой эпохи. Определить LearnRateSchedule кому "piecewise" снижать скорость обучения на заданный коэффициент (0,1) каждый раз, когда проходит определенное количество эпох (1).

В этом примере используется решатель ADAM. ADAM работает лучше с рецидивирующими нейронными сетями (RNN), такими как GRU, чем стохастический градиентный спуск по умолчанию с решателем импульса (SGDM).

miniBatchSize = 256;
validationFrequency = floor(numel(labelsTrain)/miniBatchSize);
options = trainingOptions("adam", ...
    "MaxEpochs",4, ...
    "MiniBatchSize",miniBatchSize, ...
    "Plots","training-progress", ...
    "Verbose",false, ...
    "Shuffle","every-epoch", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",1, ...
    'ValidationData',{featuresValidation,labelsValidation}, ...
    'ValidationFrequency',validationFrequency);

Обучение сети ГРУ

Обучение сети ГРУ указанным вариантам обучения и архитектуре уровней с помощью trainNetwork. Поскольку тренировочный набор большой, тренировочный процесс может занять несколько минут.

net = trainNetwork(featuresTrain,labelsTrain,layers,options);

Верхний подграф графика training-progress представляет точность обучения, которая является точностью классификации для каждой мини-партии. Когда обучение проходит успешно, это значение обычно увеличивается до 100%. Нижняя часть графика отображает потери при обучении, которые представляют собой потери при перекрестной энтропии для каждой мини-партии. Когда обучение проходит успешно, это значение обычно уменьшается до нуля.

Если тренировка не сходится, графики могут колебаться между значениями без тренда в определенном восходящем или нисходящем направлении. Это колебание означает, что точность тренировки не улучшается и потери тренировки не уменьшаются. Такая ситуация может возникнуть в начале обучения или после некоторого предварительного улучшения точности обучения. Во многих случаях изменение вариантов обучения может помочь сети достичь сходимости. Уменьшение MiniBatchSize или уменьшение InitialLearnRate может привести к увеличению времени обучения, но это может помочь сети лучше учиться.

Визуализация точности обучения

Вычислите точность обучения, которая представляет точность классификатора по сигналам, по которым он обучался. Во-первых, классифицируйте данные обучения.

prediction = classify(net,featuresTrain);

Постройте график матрицы путаницы. Отображение точности и отзыва для двух классов с помощью сводки столбцов и строк.

figure
cm = confusionchart(categorical(labelsTrain),prediction,'title','Training Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Визуализация точности проверки

Рассчитайте точность проверки. Во-первых, классифицируйте данные обучения.

[prediction,probabilities] = classify(net,featuresValidation);

Постройте график матрицы путаницы. Отображение точности и отзыва для двух классов с помощью сводки столбцов и строк.

figure
cm = confusionchart(categorical(labelsValidation),prediction,'title','Validation Set Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Пример создает несколько последовательностей из каждого файла обучающей речи. Более высокая точность может быть достигнута путем рассмотрения выходного класса всех последовательностей, соответствующих одному и тому же файлу, и применения решения «max-rule», где выбирается класс с сегментом с наивысшим показателем достоверности.

Определите количество последовательностей, созданных для каждого файла в наборе проверки.

sequencePerFile = zeros(size(valSegmentsPerFile));
valSequencePerSegmentMat = cell2mat(valSequencePerSegment);
idx = 1;
for ii = 1:numel(valSegmentsPerFile)
    sequencePerFile(ii) = sum(valSequencePerSegmentMat(idx:idx+valSegmentsPerFile(ii)-1));
    idx = idx + valSegmentsPerFile(ii);
end

Прогнозирование пола из каждого обучающего файла с учетом выходных классов всех последовательностей, сгенерированных из одного и того же файла.

numFiles = numel(adsValidation.Files);
actualGender = categorical(adsValidation.Labels);
predictedGender = actualGender;      
scores = cell(1,numFiles);
counter = 1;
cats = unique(actualGender);
for index = 1:numFiles
    scores{index} = probabilities(counter: counter + sequencePerFile(index) - 1,:);
    m = max(mean(scores{index},1),[],1);
    if m(1) >= m(2)
        predictedGender(index) = cats(1);
    else
        predictedGender(index) = cats(2); 
    end
    counter = counter + sequencePerFile(index);
end

Визуализация матрицы путаницы в прогнозах правила большинства.

figure
cm = confusionchart(actualGender,predictedGender,'title','Validation Set Accuracy - Max Rule');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Ссылки

[1] Набор общих голосовых данных Mozilla

Вспомогательные функции

function [sequences,sequencePerSegment] = HelperFeatureVector2Sequence(features,featureVectorsPerSequence,featureVectorOverlap)
if featureVectorsPerSequence <= featureVectorOverlap
    error('The number of overlapping feature vectors must be less than the number of feature vectors per sequence.')
end

hopLength = featureVectorsPerSequence - featureVectorOverlap;
idx1 = 1;
sequences = {};
sequencePerSegment = cell(numel(features),1);
for ii = 1:numel(features)
    sequencePerSegment{ii} = max(floor((size(features{ii},2) - featureVectorsPerSequence)/hopLength) + 1,0);
    idx2 = 1;
    for j = 1:sequencePerSegment{ii}
        sequences{idx1,1} = features{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1); %#ok<AGROW>
        idx1 = idx1 + 1;
        idx2 = idx2 + hopLength;
    end
end
end