exponenta event banner

Классификация пола с использованием сетей ГРУ

В этом примере показано, как классифицировать пол говорящего с помощью глубокого обучения. В этом примере используется сеть Gated Receivative Unit (GRU) и кэпстральные коэффициенты гамматона (gtcc), тон, отношение гармоник и несколько дескрипторов спектральной формы.

Введение

Гендерная классификация, основанная на речевых сигналах, является важным компонентом многих аудиосистем, таких как автоматическое распознавание речи, распознавание говорящих и индексирование мультимедийных данных на основе содержания.

В этом примере используются сети GRU, тип рекуррентной нейронной сети (RNN), хорошо подходящий для изучения последовательности и данных временных рядов. Сеть GRU может изучать долгосрочные зависимости между временными шагами последовательности.

Этот пример обучает сеть GRU последовательностями коэффициентов кепстра гамматона (gtcc (Audio Toolbox)), оценки шага (pitch (Audio Toolbox)), отношение гармоник (harmonicRatio (Audio Toolbox)) и несколько дескрипторов спектральной формы (Spectral Descriptors (Audio Toolbox)).

Чтобы ускорить процесс обучения, выполните этот пример на машине с графическим процессором. Если ваш компьютер имеет графический процессор и Toolbox™ параллельных вычислений, MATLAB © автоматически использует графический процессор для обучения; в противном случае он использует ЦП.

Классифицировать пол с помощью предварительно обученной сети

Перед подробным началом процесса обучения необходимо использовать предварительно обученную сеть для классификации пола говорящего по двум тестовым сигналам.

Загрузите предварительно обученную сеть.

url = 'http://ssd.mathworks.com/supportfiles/audio/GenderClassification.zip';

downloadNetFolder = tempdir;
netFolder = fullfile(downloadNetFolder,'GenderClassification');

if ~exist(netFolder,'dir')
    unzip(url,downloadNetFolder)
end

Загрузите предварительно обученную сеть вместе с предварительно вычисленными векторами, используемыми для нормализации функций.

matFileName = fullfile(netFolder, 'genderIDNet.mat');
load(matFileName,'genderIDNet','M','S');

Загрузите тестовый сигнал с охватываемым громкоговорителем.

[audioIn,Fs] = audioread('maleSpeech.flac');
sound(audioIn,Fs)

Изолируйте область речи в сигнале.

boundaries = detectSpeech(audioIn,Fs);
audioIn = audioIn(boundaries(1):boundaries(2));

Создание audioFeatureExtractor (Audio Toolbox) для извлечения элементов из аудиоданных. Этот же объект используется для извлечения элементов для обучения.

extractor = audioFeatureExtractor( ...
    "SampleRate",Fs, ...
    "Window",hamming(round(0.03*Fs),"periodic"), ...
    "OverlapLength",round(0.02*Fs), ...
    ...
    "gtcc",true, ...
    "gtccDelta",true, ...
    "gtccDeltaDelta",true, ...
    ...
    "SpectralDescriptorInput","melSpectrum", ...
    "spectralCentroid",true, ...
    "spectralEntropy",true, ...
    "spectralFlux",true, ...
    "spectralSlope",true, ...
    ...
    "pitch",true, ...
    "harmonicRatio",true);

Извлеките элементы из сигнала и нормализуйте их.

features = extract(extractor,audioIn);
features = (features.' - M)./S;

Классифицируйте сигнал.

gender = classify(genderIDNet,features)
gender = categorical
     male 

Классифицируйте другой сигнал с говорящей женщиной.

[audioIn,Fs] = audioread('femaleSpeech.flac');
sound(audioIn,Fs)
boundaries = detectSpeech(audioIn,Fs);
audioIn = audioIn(boundaries(1):boundaries(2));

features = extract(extractor,audioIn);
features = (features.' - M)./S;

classify(genderIDNet,features)
ans = categorical
     female 

Предварительная обработка обучающих аудиоданных

Сеть GRU, используемая в этом примере, лучше всего работает при использовании последовательностей векторов признаков. Чтобы проиллюстрировать конвейер предварительной обработки, в этом примере рассматриваются шаги для одного аудиофайла.

Чтение содержимого аудиофайла, содержащего речь. Пол оратора - мужчина.

[audioIn,Fs] = audioread('Counting-16-44p1-mono-15secs.wav');
labels = {'male'};

Постройте график звукового сигнала, а затем прослушайте его с помощью sound команда.

timeVector = (1/Fs) * (0:size(audioIn,1)-1);
figure
plot(timeVector,audioIn)
ylabel("Amplitude")
xlabel("Time (s)")
title("Sample Audio")
grid on

sound(audioIn,Fs)

Речевой сигнал имеет сегменты молчания, которые не содержат полезной информации, относящейся к полу говорящего. Использовать detectSpeech (Audio Toolbox) для поиска сегментов речи в звуковом сигнале.

speechIndices = detectSpeech(audioIn,Fs);

Создание audioFeatureExtractor (Audio Toolbox) для извлечения элементов из аудиоданных. Речевой сигнал является динамическим по своей природе и изменяется с течением времени. Предполагается, что речевые сигналы неподвижны на коротких временных шкалах и их обработка часто выполняется в окнах 20-40 мс. Укажите окна 30 мс с перекрытием 20 мс.

extractor = audioFeatureExtractor( ...
    "SampleRate",Fs, ...
    "Window",hamming(round(0.03*Fs),"periodic"), ...
    "OverlapLength",round(0.02*Fs), ...
    ...
    "gtcc",true, ...
    "gtccDelta",true, ...
    "gtccDeltaDelta",true, ...
    ...
    "SpectralDescriptorInput","melSpectrum", ...
    "spectralCentroid",true, ...
    "spectralEntropy",true, ...
    "spectralFlux",true, ...
    "spectralSlope",true, ...
    ...
    "pitch",true, ...
    "harmonicRatio",true);

Извлеките функции из каждого аудиосегмента. Выходные данные audioFeatureExtractor является numFeatureVectorsоколо-numFeatures массив. sequenceInputLayer Для использования в этом примере требуется время вдоль второго размера. Переставьте выходной массив так, чтобы время находилось вдоль второго измерения.

featureVectorsSegment = {};
for ii = 1:size(speechIndices,1)
    featureVectorsSegment{end+1} = ( extract(extractor,audioIn(speechIndices(ii,1):speechIndices(ii,2))) )';
end
numSegments = size(featureVectorsSegment)
numSegments = 1×2

     1    11

[numFeatures,numFeatureVectorsSegment1] = size(featureVectorsSegment{1})
numFeatures = 45
numFeatureVectorsSegment1 = 124

Реплицируйте метки таким образом, чтобы они соответствовали сегментам один к одному.

labels = repelem(labels,size(speechIndices,1))
labels = 1×11 cell
    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}

При использовании sequenceInputLayer, часто выгодно использовать последовательности согласованной длины. Преобразуйте массивы векторов элементов в последовательности векторов элементов. Используйте 20 векторов элементов на последовательность с 5 перекрытиями векторов элементов.

featureVectorsPerSequence = 20;
featureVectorOverlap = 5;
hopLength = featureVectorsPerSequence - featureVectorOverlap;

idx1 = 1;
featuresTrain = {};
sequencePerSegment = zeros(numel(featureVectorsSegment),1);
for ii = 1:numel(featureVectorsSegment)
    sequencePerSegment(ii) = max(floor((size(featureVectorsSegment{ii},2) - featureVectorsPerSequence)/hopLength) + 1,0);
    idx2 = 1;
    for j = 1:sequencePerSegment(ii)
        featuresTrain{idx1,1} = featureVectorsSegment{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1);
        idx1 = idx1 + 1;
        idx2 = idx2 + hopLength;
    end
end

Для краткости вспомогательная функция HelperFeatureVector2Sequence инкапсулирует вышеуказанную обработку и используется в остальном примере.

Реплицируйте метки таким образом, чтобы они соответствовали друг другу с обучающим набором.

labels = repelem(labels,sequencePerSegment);

Результатом предварительной обработки трубопровода является NumSequence-by-1 массив ячеек NumFeaturesоколо-FeatureVectorsPerSequence матрицы. Метки являются NumSequenceмассив -by-1.

NumSequence = numel(featuresTrain)
NumSequence = 27
[NumFeatures,FeatureVectorsPerSequence] = size(featuresTrain{1})
NumFeatures = 45
FeatureVectorsPerSequence = 20
NumSequence = numel(labels)
NumSequence = 27

На рисунке представлен обзор извлечения признаков, используемых для обнаруженной речевой области.

Создание хранилищ данных для обучения и тестирования

В этом примере используется подмножество набора данных Mozilla Common Voice [1]. Набор данных содержит записи 48 кГц субъектов, говорящих короткие предложения. Загрузите набор данных и отмените обработку загруженного файла. Набор PathToDatabase в расположение данных.

url = 'http://ssd.mathworks.com/supportfiles/audio/commonvoice.zip';
downloadDatasetFolder = tempdir;
dataFolder = fullfile(downloadDatasetFolder,'commonvoice');

if ~exist(dataFolder,'dir')
    disp('Downloading data set (956 MB) ...')
    unzip(url,downloadDatasetFolder)
end
Downloading data set (956 MB) ...

Использовать audioDatastore для создания хранилищ данных для наборов обучения и проверки. Использовать readtable для чтения метаданных, связанных с аудиофайлами.

loc = fullfile(dataFolder);
adsTrain = audioDatastore(fullfile(loc,'train'),'IncludeSubfolders',true);
metadataTrain = readtable(fullfile(fullfile(loc,'train'),"train.tsv"),"FileType","text");
adsTrain.Labels = metadataTrain.gender;

adsValidation = audioDatastore(fullfile(loc,'validation'),'IncludeSubfolders',true);
metadataValidation = readtable(fullfile(fullfile(loc,'validation'),"validation.tsv"),"FileType","text");
adsValidation.Labels = metadataValidation.gender;

Использовать countEachLabel (Audio Toolbox) для проверки гендерной разбивки учебных и проверочных наборов.

countEachLabel(adsTrain)
ans=2×2 table
    Label     Count
    ______    _____

    female    1000 
    male      1000 

countEachLabel(adsValidation)
ans=2×2 table
    Label     Count
    ______    _____

    female     200 
    male       200 

Чтобы обучить сеть всему набору данных и достичь максимально возможной точности, установите reduceDataset кому false. Для быстрого выполнения этого примера установите reduceDataset кому true.

reduceDataset = false;
if reduceDataset
    % Reduce the training dataset by a factor of 20
    adsTrain = splitEachLabel(adsTrain,round(numel(adsTrain.Files) / 2 / 20));
    adsValidation = splitEachLabel(adsValidation,20);
end

Создание наборов обучения и проверки

Определите частоту дискретизации аудиофайлов в наборе данных, а затем обновите частоту дискретизации, окно и длину перекрытия устройства извлечения аудиофайлов.

[~,adsInfo] = read(adsTrain);
Fs = adsInfo.SampleRate;
extractor.SampleRate = Fs;
extractor.Window = hamming(round(0.03*Fs),"periodic");
extractor.OverlapLength = round(0.02*Fs);

Для ускорения обработки распределите вычисления между несколькими работниками. При наличии Toolbox™ Parallel Computing в примере выполняется разделение хранилища данных таким образом, что извлечение элементов происходит параллельно между доступными работниками. Определите оптимальное количество разделов для системы. Если у вас нет Toolbox™ параллельных вычислений, в примере используется один работник.

if ~isempty(ver('parallel')) && ~reduceDataset
    pool = gcp;
    numPar = numpartitions(adsTrain,pool);
else
    numPar = 1;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

В цикле:

  1. Считывание из хранилища аудиоданных.

  2. Обнаружение областей речи.

  3. Извлеките векторы признаков из областей речи.

Реплицируйте метки таким образом, чтобы они соответствовали векторам признаков один к одному.

labelsTrain = [];
featureVectors = {};

% Loop over optimal number of partitions
parfor ii = 1:numPar
    
    % Partition datastore
    subds = partition(adsTrain,numPar,ii);
    
    % Preallocation
    featureVectorsInSubDS = {};
    segmentsPerFile = zeros(numel(subds.Files),1);
    
    % Loop over files in partitioned datastore
    for jj = 1:numel(subds.Files)
        
        % 1. Read in a single audio file
        audioIn = read(subds);
        
        % 2. Determine the regions of the audio that correspond to speech
        speechIndices = detectSpeech(audioIn,Fs);
        
        % 3. Extract features from each speech segment
        segmentsPerFile(jj) = size(speechIndices,1);
        features = cell(segmentsPerFile(jj),1);
        for kk = 1:size(speechIndices,1)
            features{kk} = ( extract(extractor,audioIn(speechIndices(kk,1):speechIndices(kk,2))) )';
        end
        featureVectorsInSubDS = [featureVectorsInSubDS;features(:)];
        
    end
    featureVectors = [featureVectors;featureVectorsInSubDS];
    
    % Replicate the labels so that they are in one-to-one correspondance
    % with the feature vectors.
    repedLabels = repelem(subds.Labels,segmentsPerFile);
    labelsTrain = [labelsTrain;repedLabels(:)];
end

В классификационных приложениях рекомендуется нормализовать все элементы так, чтобы они имели нулевое среднее и единичное стандартное отклонение.

Вычислите среднее значение и стандартное отклонение для каждого коэффициента и используйте их для нормализации данных.

allFeatures = cat(2,featureVectors{:});
allFeatures(isinf(allFeatures)) = nan;
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');
featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

Буферизация векторов признаков в последовательности из 20 векторов признаков с 10 перекрытиями. Если последовательность имеет менее 20 векторов признаков, удалите ее.

[featuresTrain,trainSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);

Реплицируйте метки таким образом, чтобы они соответствовали последовательностям один к одному.

labelsTrain = repelem(labelsTrain,[trainSequencePerSegment{:}]);
labelsTrain = categorical(labelsTrain);

Создайте набор проверки, используя те же шаги, что и при создании набора обучения.

labelsValidation = [];
featureVectors = {};
valSegmentsPerFile = [];
parfor ii = 1:numPar
    subds = partition(adsValidation,numPar,ii);
    featureVectorsInSubDS = {};
    valSegmentsPerFileInSubDS = zeros(numel(subds.Files),1);
    for jj = 1:numel(subds.Files)
        audioIn = read(subds);
        speechIndices = detectSpeech(audioIn,Fs);
        numSegments = size(speechIndices,1);
        features = cell(valSegmentsPerFileInSubDS(jj),1);
        for kk = 1:numSegments
            features{kk} = ( extract(extractor,audioIn(speechIndices(kk,1):speechIndices(kk,2))) )';
        end
        featureVectorsInSubDS = [featureVectorsInSubDS;features(:)];
        valSegmentsPerFileInSubDS(jj) = numSegments;
    end
    repedLabels = repelem(subds.Labels,valSegmentsPerFileInSubDS);
    labelsValidation = [labelsValidation;repedLabels(:)];
    featureVectors = [featureVectors;featureVectorsInSubDS];
    valSegmentsPerFile = [valSegmentsPerFile;valSegmentsPerFileInSubDS];
end

featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

[featuresValidation,valSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);
labelsValidation = repelem(labelsValidation,[valSequencePerSegment{:}]);
labelsValidation = categorical(labelsValidation);

Определение сетевой архитектуры GRU

Сети ГРУ могут изучать долгосрочные зависимости между временными шагами данных последовательности. В этом примере используется gruLayer чтобы посмотреть на последовательность как в прямом, так и в обратном направлениях.

Укажите входной размер, который будет последовательностью размера NumFeatures. Укажите уровень GRU с размером вывода 75 и выведите последовательность. Затем укажите уровень GRU с размером вывода 75 и выведите последний элемент последовательности. Эта команда предписывает уровню GRU отобразить свой вход в 75 элементов, а затем подготавливает выход для полностью подключенного уровня. Наконец, укажите два класса, включив полностью связанный слой размера 2, за которым следуют слой softmax и слой классификации.

layers = [ ...
    sequenceInputLayer(size(featuresTrain{1},1))
    gruLayer(75,"OutputMode","sequence")
    gruLayer(75,"OutputMode","last")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer];

Затем укажите параметры обучения для классификатора. Набор MaxEpochs кому 4 так, что сеть 4 проходит через обучающие данные. Набор MiniBatchSize 256, так что сеть просматривает 128 обучающих сигналов одновременно. Определить Plots как "training-progress" для создания графиков, показывающих ход обучения по мере увеличения числа итераций. Набор Verbose кому false для отключения печати выходных данных таблицы, соответствующих данным, отображаемым на графике. Определить Shuffle как "every-epoch" перетасовать тренировочную последовательность в начале каждой эпохи. Определить LearnRateSchedule кому "piecewise" снижать скорость обучения на заданный коэффициент (0,1) каждый раз, когда проходит определенное количество эпох (1).

В этом примере используется решатель ADAM. ADAM работает лучше с рецидивирующими нейронными сетями (RNN), такими как GRU, чем стохастический градиентный спуск по умолчанию с решателем импульса (SGDM).

miniBatchSize = 256;
validationFrequency = floor(numel(labelsTrain)/miniBatchSize);
options = trainingOptions("adam", ...
    "MaxEpochs",4, ...
    "MiniBatchSize",miniBatchSize, ...
    "Plots","training-progress", ...
    "Verbose",false, ...
    "Shuffle","every-epoch", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",1, ...
    'ValidationData',{featuresValidation,labelsValidation}, ...
    'ValidationFrequency',validationFrequency);

Обучение сети ГРУ

Обучение сети ГРУ указанным вариантам обучения и архитектуре уровней с помощью trainNetwork. Поскольку тренировочный набор большой, тренировочный процесс может занять несколько минут.

net = trainNetwork(featuresTrain,labelsTrain,layers,options);

Верхний подграф графика training-progress представляет точность обучения, которая является точностью классификации для каждой мини-партии. Когда обучение проходит успешно, это значение обычно увеличивается до 100%. Нижняя часть графика отображает потери при обучении, которые представляют собой потери при перекрестной энтропии для каждой мини-партии. Когда обучение проходит успешно, это значение обычно уменьшается до нуля.

Если тренировка не сходится, графики могут колебаться между значениями без тренда в определенном восходящем или нисходящем направлении. Это колебание означает, что точность тренировки не улучшается и потери тренировки не уменьшаются. Такая ситуация может возникнуть в начале обучения или после некоторого предварительного улучшения точности обучения. Во многих случаях изменение вариантов обучения может помочь сети достичь сходимости. Уменьшение MiniBatchSize или уменьшение InitialLearnRate может привести к увеличению времени обучения, но это может помочь сети лучше учиться.

Визуализация точности обучения

Вычислите точность обучения, которая представляет точность классификатора по сигналам, по которым он обучался. Во-первых, классифицируйте данные обучения.

prediction = classify(net,featuresTrain);

Постройте график матрицы путаницы. Отображение точности и отзыва для двух классов с помощью сводки столбцов и строк.

figure
cm = confusionchart(categorical(labelsTrain),prediction,'title','Training Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Визуализация точности проверки

Рассчитайте точность проверки. Во-первых, классифицируйте данные обучения.

[prediction,probabilities] = classify(net,featuresValidation);

Постройте график матрицы путаницы. Отображение точности и отзыва для двух классов с помощью сводки столбцов и строк.

figure
cm = confusionchart(categorical(labelsValidation),prediction,'title','Validation Set Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Пример создает несколько последовательностей из каждого файла обучающей речи. Более высокая точность может быть достигнута путем рассмотрения выходного класса всех последовательностей, соответствующих одному и тому же файлу, и применения решения «max-rule», где выбирается класс с сегментом с наивысшим показателем достоверности.

Определите количество последовательностей, созданных для каждого файла в наборе проверки.

sequencePerFile = zeros(size(valSegmentsPerFile));
valSequencePerSegmentMat = cell2mat(valSequencePerSegment);
idx = 1;
for ii = 1:numel(valSegmentsPerFile)
    sequencePerFile(ii) = sum(valSequencePerSegmentMat(idx:idx+valSegmentsPerFile(ii)-1));
    idx = idx + valSegmentsPerFile(ii);
end

Прогнозирование пола из каждого обучающего файла с учетом выходных классов всех последовательностей, сгенерированных из одного и того же файла.

numFiles = numel(adsValidation.Files);
actualGender = categorical(adsValidation.Labels);
predictedGender = actualGender;      
scores = cell(1,numFiles);
counter = 1;
cats = unique(actualGender);
for index = 1:numFiles
    scores{index} = probabilities(counter: counter + sequencePerFile(index) - 1,:);
    m = max(mean(scores{index},1),[],1);
    if m(1) >= m(2)
        predictedGender(index) = cats(1);
    else
        predictedGender(index) = cats(2); 
    end
    counter = counter + sequencePerFile(index);
end

Визуализация матрицы путаницы в прогнозах правила большинства.

figure
cm = confusionchart(actualGender,predictedGender,'title','Validation Set Accuracy - Max Rule');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Ссылки

[1] Набор общих голосовых данных Mozilla

Вспомогательные функции

function [sequences,sequencePerSegment] = HelperFeatureVector2Sequence(features,featureVectorsPerSequence,featureVectorOverlap)
if featureVectorsPerSequence <= featureVectorOverlap
    error('The number of overlapping feature vectors must be less than the number of feature vectors per sequence.')
end

hopLength = featureVectorsPerSequence - featureVectorOverlap;
idx1 = 1;
sequences = {};
sequencePerSegment = cell(numel(features),1);
for ii = 1:numel(features)
    sequencePerSegment{ii} = max(floor((size(features{ii},2) - featureVectorsPerSequence)/hopLength) + 1,0);
    idx2 = 1;
    for j = 1:sequencePerSegment{ii}
        sequences{idx1,1} = features{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1); %#ok<AGROW>
        idx1 = idx1 + 1;
        idx2 = idx2 + hopLength;
    end
end
end

См. также

| |

Связанные темы