Классифицируйте пол Используя сети LSTM

В этом примере показано, как классифицировать пол докладчика, использующего глубокое обучение. В частности, пример использует сеть Bidirectional Long Short-Term Memory (BiLSTM) и Коэффициенты Gammatone Cepstral (gtcc), подачу, гармоническое отношение и несколько спектральных дескрипторов формы.

Введение

Классификация полов на основе речевых сигналов является важной составляющей многих аудиосистем, таких как автоматическое распознавание речи, распознавание динамика и мультимедийная индексация на основе содержимого.

Этот пример использует сети долгой краткосрочной памяти (LSTM), тип рекуррентной нейронной сети (RNN), подходящей, чтобы изучить данные timeseries и последовательность. Сеть LSTM может изучить долгосрочные зависимости между временными шагами последовательности. Слой LSTM (lstmLayer) может посмотреть в то время последовательность в прямом направлении, в то время как двунаправленный слой LSTM (bilstmLayer) может посмотреть в то время последовательность и во вперед и в обратные направления. Этот пример использует двунаправленный слой LSTM.

Этот пример обучает сеть LSTM с последовательностями Коэффициентов Кепстра Gammatone (gtcc), оценок подачи, гармонического отношения и нескольких спектральных дескрипторов формы.

Пример проходит следующие шаги:

  1. Создайте audioDatastore это указывает на аудио речевые файлы, используемые, чтобы обучить сеть LSTM.

  2. Удалите тишину и неречевые сегменты из речевых файлов с помощью простого метода пороговой обработки.

  3. Извлеките последовательности функции, состоящие из коэффициентов GTCC, подачи, гармонического отношения и нескольких спектральных дескрипторов формы от речевых сигналов.

  4. Обучите сеть LSTM с помощью последовательностей функции.

  5. Измерьте и визуализируйте точность классификатора на обучающих данных.

  6. Создайте audioDatastore из речевых файлов, используемых, чтобы протестировать обучивший сеть.

  7. Удалите неречевые сегменты из этих файлов, сгенерируйте последовательности функции, передайте их через сеть и протестируйте ее точность путем сравнения предсказанного и фактического пола докладчиков.

Чтобы ускорить учебный процесс, запустите этот пример на машине с помощью графического процессора. Если ваша машина имеет графический процессор и Parallel Computing Toolbox™, то MATLAB® автоматически использует графический процессор в обучении; в противном случае это использует центральный процессор.

Исследуйте набор данных

Этот пример использует набор данных Mozilla Common Voice [1]. Набор данных содержит записи на 48 кГц предметов говорящие короткие предложения. Загрузите набор данных и untar загруженный файл. Установите datafolder к местоположению данных.

datafolder = PathToDatabase;

Используйте audioDatastore создать datastore для всех файлов в наборе данных.

ads0 = audioDatastore(fullfile(datafolder,"clips"));

Только начиная с части файлов набора данных аннотируются информацией о поле, вы будете использовать и наборы обучения и валидации, чтобы обучить сеть. Вы будете использовать набор тестов, чтобы подтвердить сеть. Используйте readtable считать метаданные, сопоставленные со звуковыми файлами от набора dev и обучения. Метаданные содержатся в train.tsv файл. Смотрите первые несколько строк метаданных.

metadataTrain = readtable(fullfile(datafolder,"train.tsv"),"FileType","text");
metadataDev = readtable(fullfile(datafolder,"dev.tsv"),"FileType","text");
metadata = [metadataTrain;metadataDev];
head(metadata)
ans =

  8×8 table

                                                                 client_id                                                                                                                                  path                                                                                                               sentence                                               up_votes    down_votes        age           gender        accent  
    ____________________________________________________________________________________________________________________________________    ____________________________________________________________________________________________________________________________________    ______________________________________________________________________________________________    ________    __________    ____________    __________    __________

    {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'}    {'f480b8a93bf84b7f74c141284a71c39ff47d264a75dc905dc918286fb67f0333595206ff953a27b8049c7ec09ea895aa66d1cd4f7547535167d3d7901d12feab'}    {'Unfortunately, nobody can warrant the sanctions that will have an effect on the community.'}       3            0         {'twenties'}    {'female'}    {'canada'}
    {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'}    {'7647873ce81cd81c90b9e0fe3cb6c85cc03df7c0c4fdf2a04c356d75063af4b9de296a24e3bef0ba7ef0b0105d166abf35597e9c9a4b3857fd09c57b79f65a99'}    {'Came down and picked it out himself.'                                                      }       2            0         {'twenties'}    {'female'}    {'canada'}
    {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'}    {'81a3dd920de6251cc878a940aff258e859ef13efb9a6446610ab907e08832fafdc463eda334ee74a24cc02e3652a09f5573c133e6f46886cb0ba463efc7a6b43'}    {'She crossed the finish line just in time.'                                                 }       2            0         {'twenties'}    {'female'}    {'canada'}
    {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'}    {'5e6fc96a7bc91ec2261a51e7713bb0ed8a9f4fa9e20a38060dc4544fb0c2600c192d6e849915acaf8ea0766a9e1d481557d674363e780dbb064586352e560f2c'}    {'Please find me the Home at Last trailer.'                                                  }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'}    {'3a0929094a9aac80b961d479a3ee54311cc0d60d51fe8f97071edc2999e7747444b261d2c0c2345f86fb8161f3b73a14dc19da911d19ca8d9db39574c6199a34'}    {'Play something by Louisiana Blues'                                                         }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'}    {'82b8edf3f1420295069b5bb4543b8f349faaca28f45a3279b0cd64c39d16afb590a4cc70ed805020161f8c1f94bc63d3b69756fbc5a0462ce12d2e17c4ebaeeb'}    {'When is The Devil with Hitler playing in Bow Tie Cinemas'                                  }       2            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'60013a707ac8cdd2b44427418064915f7810b2d58d52d8f81ad3c6406b8922d61c134259747f3c73d2e64c885fc6141761d29f7e6ada7d6007c48577123e4af0'}    {'8606bac841a08bcbf5ddb83c768103c467ffd1bf38b16052414210dc3ce3267561cb0368d227b6eb420dc147387cc1807032102b6248a13a40f83e5ac06d7122'}    {'Give me the list of animated movies playing at the closest movie house'                    }       3            0         {0×0 char  }    {0×0 char}    {0×0 char}
    {'60013a707ac8cdd2b44427418064915f7810b2d58d52d8f81ad3c6406b8922d61c134259747f3c73d2e64c885fc6141761d29f7e6ada7d6007c48577123e4af0'}    {'0f7e63d320cfbf6ea5d1cda674007131d804a06f5866d38f81da7def33a4ce8ee4f2cb7e47b45eee97903cad3160b3a10f715862227e8ecdc3fb3bafc6b4279d'}    {'The better part of valor is discretion'                                                    }       3            0         {0×0 char  }    {0×0 char}    {0×0 char}

Создайте обучающий набор данных

Найдите файлы в datastore, соответствующем набору обучающих данных.

csvFiles = metadata.path;
adsFiles = ads0.Files;
adsFiles = cellfun(@HelperGetFilePart,adsFiles,'UniformOutput',false);
[~,indA,indB] = intersect(adsFiles,csvFiles);

Создайте набор обучающих данных подмножества из большого набора данных.

adsTrain = subset(ads0,indA);

Вы будете использовать данные, соответствующие взрослым динамикам только. Считайте пол и переменные возраста из метаданных.

gender = metadata.gender;
gender = gender(indB);
age = metadata.age;
age = age(indB);

Присвойте пол Labels свойство datastore.

adsTrain.Labels = gender;

Не все файлы в наборе данных аннотируются информацией о возрасте и полом. Создайте подмножество datastore, который только содержит файлы, где информация о поле доступна, и возраст больше 19.

maleOrfemale = categorical(adsTrain.Labels) == "male" | categorical(adsTrain.Labels) == "female";
isAdult = categorical(age) ~= "" & categorical(age) ~= "teens";
adsTrain = subset(adsTrain,maleOrfemale & isAdult);

Вы обучите нейронную сеть для глубокого обучения на подмножестве файлов. Создайте подмножество datastore, содержащее равное количество спикеров и женщин-спикеров.

ismale = find(categorical(adsTrain.Labels) == "male");
isfemale = find(categorical(adsTrain.Labels) == "female");
numFilesPerGender = numel(isfemale);
adsTrain = subset(adsTrain,[ismale(1:numFilesPerGender) isfemale(1:numFilesPerGender)]);

Используйте shuffle рандомизировать порядок файлов в datastore.

adsTrain = shuffle(adsTrain);

Используйте countEachLabel смотреть гендерный отказ набора обучающих данных.

countEachLabel(adsTrain)
ans =

  2×2 table

    Label     Count
    ______    _____

    female     925 
    male       925 

Изолированные речевые сегменты

Считайте содержимое звукового файла с помощью read.

[audio,adsInfo] = read(adsTrain);
Fs = adsInfo.SampleRate;

Постройте звуковой сигнал и затем слушайте его с помощью sound команда.

timeVector = (1/Fs) * (0:numel(audio)-1);
figure
plot(timeVector,audio)
ylabel("Amplitude")
xlabel("Time (s)")
title("Sample Audio")
grid on

sound(audio,Fs)

Речевой сигнал имеет сегменты тишины, которые не содержат полезную информацию, имеющую отношение к полу докладчика. Этот пример удаляет тишину с помощью упрощенной версии подхода пороговой обработки, описанного в [2]. Шаги алгоритма удаления тишины обрисованы в общих чертах ниже.

Во-первых, вычислите две функции по неперекрывающимся системам координат аудиоданных: энергия сигнала и спектральный центроид. Спектральный центроид является мерой "центра тяжести" спектра сигнала.

Повредите аудио в неперекрывающиеся системы координат с 50 миллисекундами.

audio = audio ./ max(abs(audio)); % Normalize amplitude
windowLength = 50e-3 * Fs;
segments = buffer(audio,windowLength);

Вычислите энергию и спектральный центроид для каждой системы координат.

win = hann(windowLength,'periodic');
signalEnergy = sum(segments.^2,1)/windowLength;
centroid = spectralCentroid(segments,Fs,'Window',win,'OverlapLength',0);

Затем установите пороги для каждой функции. Игнорируются области, где значения функции падают ниже или выше их соответствующих порогов. В этом примере энергетический порог устанавливается к половине средней энергии, и спектральный центроидный порог устанавливается к 5 000 Гц.

T_E = mean(signalEnergy)/2;
T_C = 5000;
isSpeechRegion = (signalEnergy>=T_E) & (centroid<=T_C);

Визуализируйте вычисленную энергию и спектральный центроид в зависимости от времени.

% Hold the signal energy, spectral centroid, and speech decision values for
% plotting purposes.
CC = repmat(centroid,windowLength,1);
CC = CC(:);
EE = repmat(signalEnergy,windowLength,1);
EE = EE(:);
flags2 = repmat(isSpeechRegion,windowLength,1);
flags2 = flags2(:);

figure

subplot(3,1,1)
plot(timeVector, CC(1:numel(audio)), ...
     timeVector, repmat(T_C,1,numel(timeVector)), "LineWidth",2)
xlabel("Time (s)")
ylabel("Normalized Centroid")
legend("Centroid","Threshold")
title("Spectral Centroid")
grid on

subplot(3,1,2)
plot(timeVector, EE(1:numel(audio)), ...
     timeVector, repmat(T_E,1,numel(timeVector)),"LineWidth",2)
ylabel("Normalized Energy")
legend("Energy","Threshold")
title("Window Energy")
grid on

subplot(3,1,3)
plot(timeVector, audio, ...
     timeVector,flags2(1:numel(audio)),"LineWidth",2)
ylabel("Audio")
legend("Audio","Speech Region")
title("Audio")
grid on
ylim([-1 1.1])

Извлеките сегменты речи от аудио. Примите, что речь присутствует для выборок, где энергия выше ее порога, и спектральный центроид ниже ее порога.

% Get indices of frames where a speech-to-silence or silence-to-speech
% transition occurs.
regionStartPos = find(diff([isSpeechRegion(1)-1, isSpeechRegion]));

% Get the length of the all-silence or all-speech regions.
RegionLengths = diff([regionStartPos, numel(isSpeechRegion)+1]);

% Get speech-only regions.
isSpeechRegion = isSpeechRegion(regionStartPos) == 1;
regionStartPos = regionStartPos(isSpeechRegion);
RegionLengths = RegionLengths(isSpeechRegion);

% Get start and end indices for each speech region. Extend the region by 5
% windows on each side.
startIndices = zeros(1,numel(RegionLengths));
endIndices = zeros(1,numel(RegionLengths));
for index = 1:numel(RegionLengths)
   startIndices(index) = max(1,(regionStartPos(index) - 5) * windowLength + 1);
   endIndices(index) = min(numel(audio),(regionStartPos(index) + RegionLengths(index) + 5) * windowLength);
end

Наконец, объедините пересекающиеся речевые сегменты.

activeSegment = 1;
isSegmentsActive = zeros(1,numel(startIndices));
isSegmentsActive(1) = 1;
for index = 2:numel(startIndices)
    if startIndices(index) <= endIndices(activeSegment)
        % Current segment intersects with previous segment
        if endIndices(index) > endIndices(activeSegment)
           endIndices(activeSegment) =  endIndices(index);
        end
    else
        % New speech segment detected
        activeSegment = index;
        isSegmentsActive(index) = 1;
    end
end
numSegments = sum(isSegmentsActive);
segments = cell(1,numSegments);
limits = zeros(2,numSegments);
speechSegmentsIndices  = find(isSegmentsActive);
for index = 1:length(speechSegmentsIndices)
    segments{index} = audio(startIndices(speechSegmentsIndices(index)): ...
                            endIndices(speechSegmentsIndices(index)));
    limits(:,index) = [startIndices(speechSegmentsIndices(index)); ...
                       endIndices(speechSegmentsIndices(index))];
end

Постройте исходное аудио наряду с обнаруженными речевыми сегментами.

figure

plot(timeVector,audio)
hold on
myLegend = cell(1,numel(segments) + 1);
myLegend{1} = "Original Audio";
for index = 1:numel(segments)
    plot(timeVector(limits(1,index):limits(2,index)),segments{index});
    myLegend{index+1} = sprintf("Output Audio Segment %d",index);
end
xlabel("Time (s)")
ylabel("Audio")
grid on
legend(myLegend)

Функции аудио

Речевой сигнал является динамическим по своей природе и изменяется в зависимости от времени. Это принято, что речевые сигналы стационарные по кратковременным шкалам, и их обработка часто делается в окнах 20-40 мс. Для каждого речевого сегмента этот пример извлекает функции аудио для 30 MS Windows с 75%-м перекрытием.

win = hamming(0.03*Fs,"periodic");
overlapLength = round(0.75*numel(win));
featureParams = struct("SampleRate",Fs, ...
                 "Window",win, ...
                 "OverlapLength",overlapLength);
extractor = audioFeatureExtractor('Window',win, ...
    'OverlapLength',overlapLength, ...
    'SampleRate',Fs, ...
    'SpectralDescriptorInput','melSpectrum', ...
    ...
    'gtcc',true, ...
    'gtccDelta',true, ...
    'gtccDeltaDelta',true, ...
    'spectralSlope',true, ...
    'spectralFlux',true, ...
    'spectralCentroid',true, ...
    'spectralEntropy',true, ...
    'pitch',true, ...
    'harmonicRatio',true);

Фигура предоставляет обзор извлечения признаков, используемого в этом примере.

Извлеките функции Используя длинные массивы

Чтобы ускорить обработку, извлеките последовательности функции из речевых сегментов всех звуковых файлов в datastore с помощью tall массивы. В отличие от массивов в оперативной памяти, длинные массивы обычно остаются неоцененными, пока вы не запрашиваете, чтобы вычисления были выполнены с помощью gather функция. Эта отсроченная оценка позволяет вам работать быстро с большими наборами данных. Когда вы в конечном счете запрашиваете выход с помощью gather, MATLAB® комбинирует вычисления в очереди, где возможный и берет минимальное количество проходов через данные. Если у вас есть Parallel Computing Toolbox™, можно использовать длинные массивы на локальном сеансе MATLAB®, или на локальном параллельном пуле. Можно также выполнить вычисления длинного массива на кластере, если вам установили MATLAB® Parallel Server™.

Во-первых, преобразуйте datastore в длинный массив:

T = tall(adsTrain)
T =

  M×1 tall cell array

    {280560×1 double}
    {156144×1 double}
    {167664×1 double}
    {190704×1 double}
    {401520×1 double}
    {120432×1 double}
    {354288×1 double}
    {318576×1 double}
        :        :
        :        :

Отображение указывает, что количество строк (соответствующий количеству файлов в datastore), M, еще не известно. M является заполнителем, пока вычисление не завершается.

Извлеките речевые сегменты из длинной таблицы. Это действие создает новую переменную длинного массива, чтобы использовать в последующих вычислениях. Функциональный HelperSegmentSpeech выполняет шаги, уже подсвеченные в Isolate Speech Segments раздел. cellfun команда применяет HelperSegmentSpeech к содержимому каждого звукового файла в datastore. Также определите количество сегментов на файл.

segmentsTall = cellfun(@(x)HelperSegmentSpeech(x,Fs),T,"UniformOutput",false);
segmentsPerFileTall = cellfun(@numel,segmentsTall);

Извлеките последовательности функции из речевых сегментов с помощью HelperGetFeatureVectors. Функция помощника применяет извлечение признаков на сегменты с помощью audioFeatureExtractor и затем переориентирует функции так, чтобы время приехало строки, чтобы быть совместимым с sequenceInputLayer.

featureVectorsTall = cellfun(@(x)HelperGetFeatureVectors(x,extractor),segmentsTall,"UniformOutput",false);

Используйте gather оценивать featureVectorsTall и segmentsPerFileTall. featureVectors возвращен как массив ячеек NumFiles-1, где каждый элемент массива ячеек является 1 массивом ячеек NumSegmentsPerFile. Не вложите массив ячеек.

[featureVectors,segmentsPerFile] = gather(featureVectorsTall,segmentsPerFileTall);

featureVectors = cat(2,featureVectors{:});
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 2 min 15 sec
Evaluation completed in 2 min 19 sec

Реплицируйте метки, таким образом, существует одна метка на сегмент.

myLabels = adsTrain.Labels;
myLabels = repelem(myLabels,segmentsPerFile);

В приложениях классификации это - хорошая практика, чтобы нормировать все функции, чтобы иметь нулевое среднее значение и стандартное отклонение единицы.

Вычислите среднее и стандартное отклонение для каждого коэффициента и используйте их, чтобы нормировать данные.

allFeatures = cat(2,featureVectors{:});
allFeatures(isinf(allFeatures)) = nan;
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');
featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

Пользователь HelperFeatureVector2Sequence буферизовать характеристические векторы в последовательности 20 характеристических векторов с 10 перекрытиями.

featureVectorsPerSequence = 20;
featureVectorOverlap = 10;
[featuresTrain,sequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);

Создайте массив ячеек, genderTrain, для ожидаемого пола, сопоставленного с каждой обучающей последовательностью.

genderTrain = repelem(myLabels,[sequencePerSegment{:}]);

Создайте набор данных валидации

Вы создадите набор данных валидации с помощью того же подхода, который вы использовали в обучающем наборе данных. Используйте readtable функционируйте, чтобы считать метаданные, сопоставленные с файлами валидации.

metadata = readtable(fullfile(datafolder,"test.tsv"),"FileType","text");

Найдите файлы валидации в datastore.

csvFiles = metadata.path;
adsFiles = ads0.Files;
adsFiles = cellfun(@HelperGetFilePart,adsFiles,'UniformOutput',false);
[~,indA,indB] = intersect(adsFiles,csvFiles);

Создайте datastore валидации из большого datastore.

adsVal = subset(ads0,indA);

Подобно набору обучающих данных вы будете использовать данные, соответствующие взрослым динамикам только. Считайте пол и переменные возраста из метаданных.

gender = metadata.gender;
gender = gender(indB);
age = metadata.age;
age = age(indB);

Присвойте пол Labels свойство datastore.

adsVal.Labels = gender;

Не все файлы в наборе данных аннотируются информацией о возрасте и полом. Создайте подмножество datastore, который только содержит файлы, где информация о поле доступна, и возраст больше 19.

maleOrfemale =  categorical(adsVal.Labels) == "female" | categorical(adsVal.Labels) == "male";
isAdult = categorical(age) ~= "" & categorical(age) ~= "teens";
adsVal = subset(adsVal,maleOrfemale & isAdult);

Используйте countEachLabel смотреть гендерный отказ файлов.

countEachLabel(adsVal)
ans =

  2×2 table

    Label     Count
    ______    _____

    female      83 
    male       532 

Удалите тишину и извлеките функции из данных о валидации.

T = tall(adsVal);
segments = cellfun(@(x)HelperSegmentSpeech(x,Fs),T,"UniformOutput",false);
segmentsPerFileTall = cellfun(@numel,segments);

featureVectorsTall = cellfun(@(x)HelperGetFeatureVectors(x,extractor),segments,"UniformOutput",false);

[featureVectors,valSegmentsPerFile] = gather(featureVectorsTall,segmentsPerFileTall);

featureVectors = cat(2,featureVectors{:});

valSegmentLabels = repelem(adsVal.Labels,valSegmentsPerFile);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 45 sec
Evaluation completed in 46 sec

Нормируйте последовательность функции средними и стандартными отклонениями, вычисленными во время учебного этапа.

featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false);
for ii = 1:numel(featureVectors)
    idx = find(isnan(featureVectors{ii}));
    if ~isempty(idx)
        featureVectors{ii}(idx) = 0;
    end
end

Создайте массив ячеек, содержащий предикторы последовательности.

[featuresValidation,valSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);

Создайте массив ячеек, gender, для ожидаемого пола, сопоставленного с каждой обучающей последовательностью.

genderValidation = repelem(valSegmentLabels,[valSequencePerSegment{:}]);

Задайте сетевую архитектуру LSTM

Сети LSTM могут изучить долгосрочные зависимости между временными шагами данных о последовательности. Этот пример использует двунаправленный слой LSTM bilstmLayer, чтобы посмотреть на последовательность и во вперед и в обратные направления.

Задайте входной размер, чтобы быть последовательностями размера NumFeatures. Задайте скрытый двунаправленный слой LSTM с выходным размером 50 и выведите последовательность. Затем задайте двунаправленный слой LSTM с выходным размером 50 и выведите последний элемент последовательности. Эта команда дает двунаправленному слою LSTM команду сопоставлять свой вход в 50 функций и затем готовит выход к полносвязному слою. Наконец, задайте два класса включением полносвязного слоя размера 2, сопровождаемый softmax слоем и слоем классификации.

layers = [ ...
    sequenceInputLayer(size(featuresTrain{1},1))
    bilstmLayer(50,"OutputMode","sequence")
    dropoutLayer(0.1)
    bilstmLayer(50,"OutputMode","last")
    fullyConnectedLayer(2)
    softmaxLayer
    classificationLayer];

Затем задайте опции обучения для классификатора. Установите MaxEpochs к 4 так, чтобы сеть сделала 4, проходит через обучающие данные. Установите MiniBatchSize из 128 так, чтобы сеть посмотрела на 128 учебных сигналов за один раз. Задайте Plots как "training-progress" сгенерировать графики, которые показывают процесс обучения количеством увеличений итераций. Установите Verbose к false отключить печать таблицы выход, который соответствует данным, показанным в графике. Задайте Shuffle как "every-epoch" переставить обучающую последовательность в начале каждой эпохи. Задайте LearnRateSchedule к "piecewise" чтобы уменьшить темп обучения заданным фактором (0.1) каждый раз, определенное число эпох (2) передало.

Этот пример использует адаптивную оценку момента (ADAM) решатель. ADAM выполняет лучше с рекуррентными нейронными сетями (RNNs) как LSTMs, чем стохастический градиентный спуск по умолчанию с импульсом (SGDM) решатель.

miniBatchSize = 128;
validationFrequency = floor(numel(genderTrain)/miniBatchSize);
options = trainingOptions("adam", ...
    "MaxEpochs",4, ...
    "MiniBatchSize",miniBatchSize, ...
    "Plots","training-progress", ...
    "Verbose",false, ...
    "Shuffle","every-epoch", ...
    "LearnRateSchedule","piecewise", ...
    "LearnRateDropFactor",0.1, ...
    "LearnRateDropPeriod",2,...
    'ValidationData',{featuresValidation,categorical(genderValidation)}, ...
    'ValidationFrequency',validationFrequency);

Обучите сеть LSTM

Обучите сеть LSTM с заданными опциями обучения и архитектурой слоя с помощью trainNetwork. Поскольку набор обучающих данных является большим, учебный процесс может занять несколько минут.

net = trainNetwork(featuresTrain,categorical(genderTrain),layers,options);

Главный подграфик графика процесса обучения представляет учебную точность, которая является точностью классификации на каждом мини-пакете. Когда обучение прогрессирует успешно, это значение обычно увеличивается к 100%. Нижний подграфик отображает учебную потерю, которая является перекрестной энтропийной потерей на каждом мини-пакете. Когда обучение прогрессирует успешно, это значение обычно уменьшается по направлению к нулю.

Если обучение не сходится, графики могут колебаться между значениями, не отклоняясь в определенном восходящем или нисходящем направлении. Это колебание означает, что учебная точность не улучшается, и учебная потеря не уменьшается. Эта ситуация может произойти в начале обучения, или после некоторого предварительного улучшения учебной точности. Во многих случаях изменение опций обучения может помочь сети достигнуть сходимости. Уменьшение MiniBatchSize или уменьшение InitialLearnRate может закончиться в более длительное учебное время, но оно может помочь сети учиться лучше.

Визуализируйте учебную точность

Вычислите учебную точность, которая представляет точность классификатора на сигналах, на которых это было обучено. Во-первых, классифицируйте обучающие данные.

trainPred = classify(net,featuresTrain);

Постройте матрицу беспорядка. Отобразите точность и отзыв для этих двух классов при помощи сводных данных строки и столбца.

figure
cm = confusionchart(categorical(genderTrain),trainPred,'title','Training Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Визуализируйте точность валидации

Вычислите точность валидации. Во-первых, классифицируйте обучающие данные.

[valPred,valScores] = classify(net,featuresValidation);

Постройте матрицу беспорядка. Отобразите точность и отзыв для этих двух классов при помощи сводных данных строки и столбца.

figure
cm = confusionchart(categorical(genderValidation),valPred,'title','Validation Set Accuracy');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Пример сгенерировал несколько последовательностей из каждого учебного речевого файла. Более высокая точность может быть достигнута путем рассмотрения выходного класса всех последовательностей, соответствующих тому же файлу и применяющих решение "макс. правила", где класс с сегментом с самым высоким счетом уверенности выбран.

Определите количество последовательностей, сгенерированных на файл в наборе валидации.

sequencePerFile = zeros(size(valSegmentsPerFile));
valSequencePerSegmentMat = cell2mat(valSequencePerSegment);
idx = 1;
for ii = 1:numel(valSegmentsPerFile)
    sequencePerFile(ii) = sum(valSequencePerSegmentMat(idx:idx+valSegmentsPerFile(ii)-1));
    idx = idx + valSegmentsPerFile(ii);
end

Предскажите пол из каждого учебного файла путем считания выходных классов всех последовательностей сгенерированными из того же файла.

numFiles = numel(adsVal.Files);
actualGender = categorical(adsVal.Labels);
predictedGender = actualGender;
scores = cell(1,numFiles);
counter = 1;
cats = unique(actualGender);
for index = 1:numFiles
    scores{index}      = valScores(counter: counter + sequencePerFile(index) - 1,:);
    m = max(mean(scores{index},1),[],1);
    if m(1) >= m(2)
        predictedGender(index) = cats(1);
    else
        predictedGender(index) = cats(2);
    end
    counter = counter + sequencePerFile(index);
end

Визуализируйте матрицу беспорядка на прогнозах принципа большинства.

figure
cm = confusionchart(actualGender,predictedGender,'title','Validation Set Accuracy - Max Rule');
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Ссылки

[1] https://voice.mozilla.org /

[2] Введение в аудио анализ: подход MATLAB, Джиэннэкопулос и Пикракис, Academic Press.