Классифицируйте пол Используя сети LSTM

В этом примере показано, как классифицировать пол докладчика, использующего глубокое обучение. Пример использует сеть Bidirectional Long Short-Term Memory (BiLSTM) и Коэффициенты Gammatone Cepstral (gtcc), подачу, гармоническое отношение и несколько спектральных дескрипторов формы.

Введение

Классификация полов на основе речевых сигналов является важной составляющей многих аудиосистем, таких как автоматическое распознавание речи, распознавание динамика и мультимедийная индексация на основе содержимого.

Этот пример использует сети долгой краткосрочной памяти (LSTM), тип рекуррентной нейронной сети (RNN), подходящей, чтобы изучить данные timeseries и последовательность. Сеть LSTM может изучить долгосрочные зависимости между временными шагами последовательности. Слой LSTM (lstmLayer) может посмотреть в то время последовательность в прямом направлении, в то время как двунаправленный слой LSTM (bilstmLayer) может посмотреть в то время последовательность и во вперед и в обратные направления. Этот пример использует двунаправленные слои LSTM.

Этот пример обучает сеть LSTM с последовательностями gammatone коэффициентов кепстра (gtcc), оценки подачи (pitch), гармоническое отношение (harmonicRatio), и несколько спектральных дескрипторов формы (Спектральные Дескрипторы (Audio Toolbox)).

Чтобы ускорить учебный процесс, запустите этот пример на машине с помощью графического процессора. Если ваша машина имеет графический процессор и Parallel Computing Toolbox™, то MATLAB© автоматически использует графический процессор в обучении; в противном случае это использует центральный процессор.

Предварительно обработайте аудиоданные

Сеть BiLSTM, используемая в этом примере, работает лучше всего при использовании последовательностей характеристических векторов. Чтобы проиллюстрировать конвейер предварительной обработки, этот пример идет через шаги для одного звукового файла.

Считайте содержимое звукового файла, содержащего речь. Пол докладчика является штекером.

[audioIn,Fs] = audioread('Counting-16-44p1-mono-15secs.wav');
labels = {'male'};

Постройте звуковой сигнал и затем слушайте его с помощью sound команда.

timeVector = (1/Fs) * (0:size(audioIn,1)-1);
figure
plot(timeVector,audioIn)
ylabel("Amplitude")
xlabel("Time (s)")
title("Sample Audio")
grid on

sound(audioIn,Fs)

Речевой сигнал имеет сегменты тишины, которые не содержат полезную информацию, имеющую отношение к полу докладчика. Используйте detectSpeech определять местоположение сегментов речи в звуковом сигнале.

speechIndices = detectSpeech(audioIn,Fs);

Создайте audioFeatureExtractor извлекать функции из аудиоданных. Речевой сигнал является динамическим по своей природе и изменяется в зависимости от времени. Это принято, что речевые сигналы стационарные по кратковременным шкалам, и их обработка часто делается в окнах 20-40 мс. Задайте 30 MS Windows с перекрытием на 20 мс.

extractor = audioFeatureExtractor( ...
    "SampleRate",Fs, ...
    "Window",hamming(round(0.03*Fs),"periodic"), ...
    "OverlapLength",round(0.02*Fs), ...
    ...
    "gtcc",true, ...
    "gtccDelta",true, ...
    "gtccDeltaDelta",true, ...
    ...
    "SpectralDescriptorInput","melSpectrum", ...
    "spectralCentroid",true, ...
    "spectralEntropy",true, ...
    "spectralFlux",true, ...
    "spectralSlope",true, ...
    ...
    "pitch",true, ...
    "harmonicRatio",true);

Извлеките функции из каждого аудио сегмента. Выход от audioFeatureExtractor numFeatureVectors- numFeatures массив. sequenceInputLayer используемый в этом примере требует, чтобы время приехало второе измерение. Переставьте выходной массив так, чтобы время приехало второе измерение.

featureVectorsSegment = {};
for ii = 1:size(speechIndices,1)
    featureVectorsSegment{end+1} = ( extract(extractor,audioIn(speechIndices(ii,1):speechIndices(ii,2))) )';
end
numSegments = size(featureVectorsSegment)
numSegments = 1×2

     1    11

[numFeatures,numFeatureVectorsSegment1] = size(featureVectorsSegment{1})
numFeatures = 45
numFeatureVectorsSegment1 = 124

Реплицируйте метки так, чтобы они были во взаимно-однозначном соответствии с сегментами.

labels = repelem(labels,size(speechIndices,1))
labels = 1×11 cell
    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}    {'male'}

При использовании sequenceInputLayer, часто выгодно использовать последовательности сопоставимой длины. Преобразуйте массивы характеристических векторов в последовательности характеристических векторов. Используйте 20 характеристических векторов на последовательность с 5 перекрытиями характеристического вектора.

featureVectorsPerSequence = 20;
featureVectorOverlap = 5;
hopLength = featureVectorsPerSequence - featureVectorOverlap;

idx1 = 1;
featuresTrain = {};
sequencePerSegment = zeros(numel(featureVectorsSegment),1);
for ii = 1:numel(featureVectorsSegment)
    sequencePerSegment(ii) = max(floor((size(featureVectorsSegment{ii},2) - featureVectorsPerSequence)/hopLength) + 1,0);
    idx2 = 1;
    for j = 1:sequencePerSegment(ii)
        featuresTrain{idx1,1} = featureVectorsSegment{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1);
        idx1 = idx1 + 1;
        idx2 = idx2 + hopLength;
    end
end

Для краткости функция помощника HelperFeatureVector2Sequence инкапсулирует вышеупомянутую обработку и используется в остальной части примера.

Реплицируйте метки так, чтобы они были во взаимно-однозначном соответствии с набором обучающих данных.

labels = repelem(labels,sequencePerSegment);

Результатом конвейера предварительной обработки является NumSequence- 1 массив ячеек NumFeatures- FeatureVectorsPerSequence матрицы. Метками является NumSequence- 1 массив.

NumSequence = numel(featuresTrain)
NumSequence = 27
[NumFeatures,FeatureVectorsPerSequence] = size(featuresTrain{1})
NumFeatures = 45
FeatureVectorsPerSequence = 20
NumSequence = numel(labels)
NumSequence = 27

Фигура предоставляет обзор извлечения признаков, используемого на обнаруженную речевую область.

Создайте обучение и протестируйте хранилища данных

Этот пример использует набор данных Mozilla Common Voice [1]. Набор данных содержит записи на 48 кГц предметов говорящие короткие предложения. Загрузите набор данных и untar загруженный файл. Установите PathToDatabase к местоположению данных.

datafolder = PathToDatabase;

Используйте audioDatastore создать datastore для всех файлов в наборе данных.

loc = fullfile(datafolder,"clips");
ads = audioDatastore(loc);

Только начиная с части файлов набора данных аннотируются информацией о поле, используют и наборы обучения и валидации, чтобы обучить сеть. Используйте набор тестов, чтобы подтвердить сеть. Используйте readtable считать метаданные, сопоставленные со звуковыми файлами. Метаданные содержатся в train.tsv, dev.tsv, and test.tsv файлы. Смотрите первые несколько строк учебных метаданных.

metadataTrain = readtable(fullfile(datafolder,"train.tsv"),"FileType","text");
metadataDev = readtable(fullfile(datafolder,"dev.tsv"),"FileType","text");
metadataTrain = [metadataTrain;metadataDev];

head(metadataTrain)
ans=8×8 table
                                                                 client_id                                                                                path                                                            sentence                                              up_votes    down_votes        age           gender        accent  
    ____________________________________________________________________________________________________________________________________    ________________________________    ____________________________________________________________________________________________    ________    __________    ____________    __________    __________

    {'4f29be8fe932d773576dd3df5e111929f4e222422322450983695eaa8625a12659cd3e999a061a29ebe71783833bebdc2d0ec6b97e9a648bf6d28979065f85ad'}    {'common_voice_en_19664034.mp3'}    {'These data components in turn serve as the "building blocks" of data exchanges.'         }       2            0         {'thirties'}    {'male'  }    {0×0 char}
    {'4f29be8fe932d773576dd3df5e111929f4e222422322450983695eaa8625a12659cd3e999a061a29ebe71783833bebdc2d0ec6b97e9a648bf6d28979065f85ad'}    {'common_voice_en_19664035.mp3'}    {'The church is unrelated to the Jewish political movement of Zionism.'                    }       3            0         {'thirties'}    {'male'  }    {0×0 char}
    {'4f29be8fe932d773576dd3df5e111929f4e222422322450983695eaa8625a12659cd3e999a061a29ebe71783833bebdc2d0ec6b97e9a648bf6d28979065f85ad'}    {'common_voice_en_19664037.mp3'}    {'The following represents architectures which have been utilized at one point or another.'}       2            0         {'thirties'}    {'male'  }    {0×0 char}
    {'4f29be8fe932d773576dd3df5e111929f4e222422322450983695eaa8625a12659cd3e999a061a29ebe71783833bebdc2d0ec6b97e9a648bf6d28979065f85ad'}    {'common_voice_en_19664038.mp3'}    {'Additionally, the pulse output can be directed through one of three resonator banks.'    }       2            0         {'thirties'}    {'male'  }    {0×0 char}
    {'4f29be8fe932d773576dd3df5e111929f4e222422322450983695eaa8625a12659cd3e999a061a29ebe71783833bebdc2d0ec6b97e9a648bf6d28979065f85ad'}    {'common_voice_en_19664040.mp3'}    {'The two are robbed by a pickpocket who is losing in gambling.'                           }       3            0         {'thirties'}    {'male'  }    {0×0 char}
    {'4f3b69348cb65923dff20efe0eaef4fbc8797f9c2240447ae48764e36fab63867dbf6947bfb8ff623cab4f1d1e185ac79ce3975f98a0f57f90b9ce9bdbbe95fd'}    {'common_voice_en_19742944.mp3'}    {'Its county seat is Phenix City.'                                                         }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'4f3b69348cb65923dff20efe0eaef4fbc8797f9c2240447ae48764e36fab63867dbf6947bfb8ff623cab4f1d1e185ac79ce3975f98a0f57f90b9ce9bdbbe95fd'}    {'common_voice_en_19742945.mp3'}    {'Consequently, the diocese accumulated millions of dollars in debt.'                      }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'4f3b69348cb65923dff20efe0eaef4fbc8797f9c2240447ae48764e36fab63867dbf6947bfb8ff623cab4f1d1e185ac79ce3975f98a0f57f90b9ce9bdbbe95fd'}    {'common_voice_en_19742948.mp3'}    {'The song "Kodachrome" is named after the Kodak film of the same name.'                   }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}

Удалите строки метаданных, которые не содержат информацию о поле. Удалите строки из метаданных, которые не содержат информацию о возрасте, или если информация о возрасте указывает на подростка.

containsGenderInfo = contains(metadataTrain.gender,'male') | contains(metadataTrain.gender,'female');
isAdult = ~contains(metadataTrain.age,'teens') & ~isempty(metadataTrain.age);
highUpVotes = metadataTrain.up_votes >= 3;
metadataTrain(~containsGenderInfo | ~isAdult | ~highUpVotes,:) = [];
trainFiles = fullfile(loc,metadataTrain.path);

Подмножество datastore, чтобы только включать файлы, соответствующие взрослым динамикам с информацией о поле.

[~,idxA,idxB] = intersect(ads.Files,trainFiles);
adsTrain = subset(ads,idxA);
adsTrain.Labels = metadataTrain.gender(idxB);

Используйте countEachLabel смотреть гендерный отказ набора обучающих данных.

labelDistribution = countEachLabel(adsTrain)
labelDistribution=2×2 table
    Label     Count
    ______    _____

    female    1554 
    male      4491 

Используйте splitEachLabel уменьшать учебный datastore так, чтобы было равное количество спикеров и женщин-спикеров.

numFilesPerGender = min(labelDistribution.Count);
adsTrain = splitEachLabel(adsTrain,numFilesPerGender);
countEachLabel(adsTrain)
ans=2×2 table
    Label     Count
    ______    _____

    female    1554 
    male      1554 

Создайте набор валидации с помощью тех же шагов.

metadataValidation = readtable(fullfile(datafolder,"test.tsv"),"FileType","text");
containsGenderInfo = contains(metadataValidation.gender,'male') | contains(metadataValidation.gender,'female');
isAdult = ~contains(metadataValidation.age,'teens') & ~isempty(metadataValidation.age);
metadataValidation(~containsGenderInfo | ~isAdult,:) = [];
validationFiles = fullfile(loc,metadataValidation.path);
[~,idxA,idxB] = intersect(ads.Files,validationFiles);
adsValidation = subset(ads,idxA);
adsValidation.Labels = metadataValidation.gender(idxB);
countEachLabel(adsValidation)
ans=2×2 table
    Label     Count
    ______    _____

    female     312 
    male      1608 

Чтобы обучить сеть с набором данных в целом и достигнуть максимально возможной точности, установите reduceDataset к false. Чтобы запустить этот пример быстро, установите reduceDataset к true.

reduceDataset = false;
if reduceDataset
    % Reduce the training dataset by a factor of 20
    adsTrain = splitEachLabel(adsTrain,round(numel(adsTrain.Files) / 2 / 20));
    adsValidation = splitEachLabel(adsValidation,20);
end

Создайте наборы обучения и валидации

Определите частоту дискретизации звуковых файлов в наборе данных, и затем обновите частоту дискретизации, окно и длину перекрытия экстрактора функции аудио.

[~,adsInfo] = read(adsTrain);
Fs = adsInfo.SampleRate;
extractor.SampleRate = Fs;
extractor.Window = hamming(round(0.03*Fs),"periodic");
extractor.OverlapLength = round(0.02*Fs);

Чтобы ускорить обработку, распределите расчеты по нескольким рабочим. Если у вас есть Parallel Computing Toolbox™, пример делит datastore так, чтобы извлечение признаков произошло параллельно через доступных рабочих. Определите оптимальное количество разделов для вашей системы. Если у вас нет Parallel Computing Toolbox™, пример использует одного рабочего.

if ~isempty(ver('parallel')) && ~reduceDataset
    pool = gcp;
    numPar = numpartitions(adsTrain,pool);
else
    numPar = 1;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

В цикле:

  1. Читайте из аудио datastore.

  2. Обнаружьте области речи.

  3. Извлеките характеристические векторы из областей речи.

Реплицируйте метки так, чтобы они были во взаимно-однозначном соответствии с характеристическими векторами.

labelsTrain = [];
featureVectors = {};

% Loop over optimal number of partitions
parfor ii = 1:numPar
    
    % Partition datastore
    subds = partition(adsTrain,numPar,ii);
    
    % Preallocation
    featureVectorsInSubDS = {};
    segmentsPerFile = zeros(numel(subds.Files),1);
    
    % Loop over files in partitioned datastore
    for jj = 1:numel(subds.Files)
        
        % 1. Read in a single audio file
        audioIn = read(subds);
        
        % 2. Determine the regions of the audio that correspond to speech
        speechIndices = detectSpeech(audioIn,Fs);
        
        % 3. Extract features from each speech segment
        segmentsPerFile(jj) = size(speechIndices,1);
        features = cell(segmentsPerFile(jj),1);
        for kk = 1:size(speechIndices,1)
            features{kk} = ( extract(extractor,audioIn(speechIndices(kk,1):speechIndices(kk,2))) )';
        end
        featureVectorsInSubDS = [featureVectorsInSubDS;features(:)];
        
    end
    featureVectors = [featureVectors;featureVectorsInSubDS];
    
    % Replicate the labels so that they are in one-to-one correspondance
    % with the feature vectors.
    repedLabels = repelem(subds.Labels,segmentsPerFile);
    labelsTrain = [labelsTrain;repedLabels(:)];
end

В приложениях классификации это - хорошая практика, чтобы нормировать все функции, чтобы иметь нулевое среднее значение и стандартное отклонение единицы.

Вычислите среднее и стандартное отклонение для каждого коэффициента и используйте их, чтобы нормировать данные.

allFeatures = cat(2,featureVectors{:});
allFeatures(isinf(allFeatures)) = nan;
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');
featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

Буферизуйте характеристические векторы в последовательности 20 характеристических векторов с 10 перекрытиями. Если последовательность имеет меньше чем 20 характеристических векторов, пропустите ее.

[featuresTrain,trainSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);

Реплицируйте метки так, чтобы они были во взаимно-однозначном соответствии с последовательностями.

labelsTrain = repelem(labelsTrain,[trainSequencePerSegment{:}]);
labelsTrain = categorical(labelsTrain);

Создайте набор валидации с помощью тех же шагов, используемых, чтобы создать набор обучающих данных.

labelsValidation = [];
featureVectors = {};
valSegmentsPerFile = [];
parfor ii = 1:numPar
    subds = partition(adsValidation,numPar,ii);
    featureVectorsInSubDS = {};
    valSegmentsPerFileInSubDS = zeros(numel(subds.Files),1);
    for jj = 1:numel(subds.Files)
        audioIn = read(subds);
        speechIndices = detectSpeech(audioIn,Fs);
        numSegments = size(speechIndices,1);
        features = cell(valSegmentsPerFileInSubDS(jj),1);
        for kk = 1:numSegments
            features{kk} = ( extract(extractor,audioIn(speechIndices(kk,1):speechIndices(kk,2))) )';
        end
        featureVectorsInSubDS = [featureVectorsInSubDS;features(:)];
        valSegmentsPerFileInSubDS(jj) = numSegments;
    end
    repedLabels = repelem(subds.Labels,valSegmentsPerFileInSubDS);
    labelsValidation = [labelsValidation;repedLabels(:)];
    featureVectors = [featureVectors;featureVectorsInSubDS];
    valSegmentsPerFile = [valSegmentsPerFile;valSegmentsPerFileInSubDS];
end

featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

[featuresValidation,valSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);
labelsValidation = repelem(labelsValidation,[valSequencePerSegment{:}]);
labelsValidation = categorical(labelsValidation);

Задайте сетевую архитектуру LSTM

Сети LSTM могут изучить долгосрочные зависимости между временными шагами данных о последовательности. Этот пример использует двунаправленный слой LSTM bilstmLayer смотреть на последовательность и во вперед и в обратные направления.

Задайте входной размер, чтобы быть последовательностями размера NumFeatures. Задайте скрытый двунаправленный слой LSTM с выходным размером 50 и выведите последовательность. Затем задайте двунаправленный слой LSTM с выходным размером 50 и выведите последний элемент последовательности. Эта команда дает двунаправленному слою LSTM команду сопоставлять свой вход в 50 функций и затем готовит выход к полносвязному слою. Наконец, задайте два класса включением полносвязного слоя размера 2, сопровождаемый softmax слоем и слоем классификации.

layers = [ ...
    sequenceInputLayer(size(featuresTrain{1},1))
    bilstmLayer(50,"OutputMode","sequence")
    bilstmLayer(50,"OutputMode","last")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer];

Затем задайте опции обучения для классификатора. Установите MaxEpochs к 4 так, чтобы сеть сделала 4, проходит через обучающие данные. Установите MiniBatchSize из 256 так, чтобы сеть посмотрела на 128 учебных сигналов за один раз. Задайте Plots как "training-progress" сгенерировать графики, которые показывают процесс обучения количеством увеличений итераций. Установите Verbose к false отключить печать таблицы выход, который соответствует данным, показанным в графике. Задайте Shuffle как "every-epoch" переставить обучающую последовательность в начале каждой эпохи. Задайте LearnRateSchedule к "piecewise" чтобы уменьшить скорость обучения заданным фактором (0.1) каждый раз, определенное число эпох (1) передало.

Этот пример использует адаптивную оценку момента (ADAM) решатель. ADAM выполняет лучше с рекуррентными нейронными сетями (RNNs) как LSTMs, чем стохастический градиентный спуск по умолчанию с импульсом (SGDM) решатель.

miniBatchSize = 256;
validationFrequency = floor(numel(labelsTrain)/miniBatchSize);
options = trainingOptions("adam", ...
    "MaxEpochs",4, ...
    "MiniBatchSize",miniBatchSize, ...
    "Plots","training-progress", ...
    "Verbose",false, ...
    "Shuffle","every-epoch", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",1, ...
    'ValidationData',{featuresValidation,labelsValidation}, ...
    'ValidationFrequency',validationFrequency);

Обучите сеть LSTM

Обучите сеть LSTM с заданными опциями обучения и архитектурой слоя с помощью trainNetwork. Поскольку набор обучающих данных является большим, учебный процесс может занять несколько минут.

net = trainNetwork(featuresTrain,labelsTrain,layers,options);

Главный подграфик графика процесса обучения представляет учебную точность, которая является точностью классификации на каждом мини-пакете. Когда обучение прогрессирует успешно, это значение обычно увеличивается к 100%. Нижний подграфик отображает учебную потерю, которая является потерей перекрестной энтропии на каждом мини-пакете. Когда обучение прогрессирует успешно, это значение обычно уменьшается по направлению к нулю.

Если обучение не сходится, графики могут колебаться между значениями, не отклоняясь в определенном восходящем или нисходящем направлении. Это колебание означает, что учебная точность не улучшается, и учебная потеря не уменьшается. Эта ситуация может произойти в начале обучения, или после некоторого предварительного улучшения учебной точности. Во многих случаях изменение опций обучения может помочь сети достигнуть сходимости. Уменьшение MiniBatchSize или уменьшение InitialLearnRate может закончиться в более длительное учебное время, но оно может помочь сети учиться лучше.

Визуализируйте учебную точность

Вычислите учебную точность, которая представляет точность классификатора на сигналах, на которых это было обучено. Во-первых, классифицируйте обучающие данные.

prediction = classify(net,featuresTrain);

Постройте матрицу беспорядка. Отобразите точность и отзыв для этих двух классов при помощи сводных данных строки и столбца.

figure
cm = confusionchart(categorical(labelsTrain),prediction,'title','Training Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Визуализируйте точность валидации

Вычислите точность валидации. Во-первых, классифицируйте обучающие данные.

[prediction,probabilities] = classify(net,featuresValidation);

Постройте матрицу беспорядка. Отобразите точность и отзыв для этих двух классов при помощи сводных данных строки и столбца.

figure
cm = confusionchart(categorical(labelsValidation),prediction,'title','Validation Set Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Пример сгенерировал несколько последовательностей из каждого учебного речевого файла. Более высокая точность может быть достигнута путем рассмотрения выходного класса всех последовательностей, соответствующих тому же файлу и применяющих решение "макс. правила", где класс с сегментом с самой высокой оценкой достоверности выбран.

Определите количество последовательностей, сгенерированных на файл в наборе валидации.

sequencePerFile = zeros(size(valSegmentsPerFile));
valSequencePerSegmentMat = cell2mat(valSequencePerSegment);
idx = 1;
for ii = 1:numel(valSegmentsPerFile)
    sequencePerFile(ii) = sum(valSequencePerSegmentMat(idx:idx+valSegmentsPerFile(ii)-1));
    idx = idx + valSegmentsPerFile(ii);
end

Предскажите пол из каждого учебного файла путем считания выходных классов всех последовательностей сгенерированными из того же файла.

numFiles = numel(adsValidation.Files);
actualGender = categorical(adsValidation.Labels);
predictedGender = actualGender;      
scores = cell(1,numFiles);
counter = 1;
cats = unique(actualGender);
for index = 1:numFiles
    scores{index} = probabilities(counter: counter + sequencePerFile(index) - 1,:);
    m = max(mean(scores{index},1),[],1);
    if m(1) >= m(2)
        predictedGender(index) = cats(1);
    else
        predictedGender(index) = cats(2); 
    end
    counter = counter + sequencePerFile(index);
end

Визуализируйте матрицу беспорядка на предсказаниях принципа большинства.

figure
cm = confusionchart(actualGender,predictedGender,'title','Validation Set Accuracy - Max Rule');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Ссылки

[1] Mozilla общий речевой набор данных

Приложение - поддерживание функций

function [sequences,sequencePerSegment] = HelperFeatureVector2Sequence(features,featureVectorsPerSequence,featureVectorOverlap)
if featureVectorsPerSequence <= featureVectorOverlap
    error('The number of overlapping feature vectors must be less than the number of feature vectors per sequence.')
end

hopLength = featureVectorsPerSequence - featureVectorOverlap;
idx1 = 1;
sequences = {};
sequencePerSegment = cell(numel(features),1);
for ii = 1:numel(features)
    sequencePerSegment{ii} = max(floor((size(features{ii},2) - featureVectorsPerSequence)/hopLength) + 1,0);
    idx2 = 1;
    for j = 1:sequencePerSegment{ii}
        sequences{idx1,1} = features{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1); %#ok<AGROW>
        idx1 = idx1 + 1;
        idx2 = idx2 + hopLength;
    end
end


end

Смотрите также

| |

Похожие темы