В этом примере показано, как классифицировать пол докладчика, использующего глубокое обучение. В частности, пример использует сеть Bidirectional Long Short-Term Memory (BiLSTM) и Коэффициенты Gammatone Cepstral (gtcc), подачу, гармоническое отношение и несколько спектральных дескрипторов формы.
Классификация полов на основе речевых сигналов является важной составляющей многих аудиосистем, таких как автоматическое распознавание речи, распознавание динамика и мультимедийная индексация на основе содержимого.
Этот пример использует сети долгой краткосрочной памяти (LSTM), тип рекуррентной нейронной сети (RNN), подходящей, чтобы изучить данные timeseries и последовательность. Сеть LSTM может изучить долгосрочные зависимости между временными шагами последовательности. Слой LSTM (lstmLayer
) может посмотреть в то время последовательность в прямом направлении, в то время как двунаправленный слой LSTM (bilstmLayer
) может посмотреть в то время последовательность и во вперед и в обратные направления. Этот пример использует двунаправленный слой LSTM.
Этот пример обучает сеть LSTM с последовательностями Коэффициентов Кепстра Gammatone (gtcc), оценок подачи, гармонического отношения и нескольких спектральных дескрипторов формы.
Пример проходит следующие шаги:
Создайте audioDatastore
это указывает на аудио речевые файлы, используемые, чтобы обучить сеть LSTM.
Удалите тишину и неречевые сегменты из речевых файлов с помощью простого метода пороговой обработки.
Извлеките последовательности функции, состоящие из коэффициентов GTCC, подачи, гармонического отношения и нескольких спектральных дескрипторов формы от речевых сигналов.
Обучите сеть LSTM с помощью последовательностей функции.
Измерьте и визуализируйте точность классификатора на обучающих данных.
Создайте audioDatastore
из речевых файлов, используемых, чтобы протестировать обучивший сеть.
Удалите неречевые сегменты из этих файлов, сгенерируйте последовательности функции, передайте их через сеть и протестируйте ее точность путем сравнения предсказанного и фактического пола докладчиков.
Чтобы ускорить учебный процесс, запустите этот пример на машине с помощью графического процессора. Если ваша машина имеет графический процессор и Parallel Computing Toolbox™, то MATLAB® автоматически использует графический процессор в обучении; в противном случае это использует центральный процессор.
Этот пример использует набор данных Mozilla Common Voice [1]. Набор данных содержит записи на 48 кГц предметов говорящие короткие предложения. Загрузите набор данных и untar загруженный файл. Установите datafolder
к местоположению данных.
datafolder = PathToDatabase;
Используйте audioDatastore
создать datastore для всех файлов в наборе данных.
ads0 = audioDatastore(fullfile(datafolder,"clips"));
Только начиная с части файлов набора данных аннотируются информацией о поле, вы будете использовать и наборы обучения и валидации, чтобы обучить сеть. Вы будете использовать набор тестов, чтобы подтвердить сеть. Используйте readtable
считать метаданные, сопоставленные со звуковыми файлами от набора dev и обучения. Метаданные содержатся в train.tsv
файл. Смотрите первые несколько строк метаданных.
metadataTrain = readtable(fullfile(datafolder,"train.tsv"),"FileType","text"); metadataDev = readtable(fullfile(datafolder,"dev.tsv"),"FileType","text"); metadata = [metadataTrain;metadataDev]; head(metadata)
ans = 8×8 table client_id path sentence up_votes down_votes age gender accent ____________________________________________________________________________________________________________________________________ ____________________________________________________________________________________________________________________________________ ______________________________________________________________________________________________ ________ __________ ____________ __________ __________ {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'} {'f480b8a93bf84b7f74c141284a71c39ff47d264a75dc905dc918286fb67f0333595206ff953a27b8049c7ec09ea895aa66d1cd4f7547535167d3d7901d12feab'} {'Unfortunately, nobody can warrant the sanctions that will have an effect on the community.'} 3 0 {'twenties'} {'female'} {'canada'} {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'} {'7647873ce81cd81c90b9e0fe3cb6c85cc03df7c0c4fdf2a04c356d75063af4b9de296a24e3bef0ba7ef0b0105d166abf35597e9c9a4b3857fd09c57b79f65a99'} {'Came down and picked it out himself.' } 2 0 {'twenties'} {'female'} {'canada'} {'55451a804635a88160a09b9b8122e3dddba46c2e6df2d6d9ec9d3445c38180fd18516d76acc9035978f27ee1f798f480dcb55dcbd31a142374c3af566c9be3c4'} {'81a3dd920de6251cc878a940aff258e859ef13efb9a6446610ab907e08832fafdc463eda334ee74a24cc02e3652a09f5573c133e6f46886cb0ba463efc7a6b43'} {'She crossed the finish line just in time.' } 2 0 {'twenties'} {'female'} {'canada'} {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'} {'5e6fc96a7bc91ec2261a51e7713bb0ed8a9f4fa9e20a38060dc4544fb0c2600c192d6e849915acaf8ea0766a9e1d481557d674363e780dbb064586352e560f2c'} {'Please find me the Home at Last trailer.' } 2 0 {0×0 char } {0×0 char} {0×0 char} {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'} {'3a0929094a9aac80b961d479a3ee54311cc0d60d51fe8f97071edc2999e7747444b261d2c0c2345f86fb8161f3b73a14dc19da911d19ca8d9db39574c6199a34'} {'Play something by Louisiana Blues' } 2 0 {0×0 char } {0×0 char} {0×0 char} {'5b8c0f566c1201a94e684a334cf8a2cbced8a009a5a346fc24f1d51446c6b8610fc7bd78f69e559b29d138ab92652a45408ef87c3ec0e426d2fc5f1b2b44935b'} {'82b8edf3f1420295069b5bb4543b8f349faaca28f45a3279b0cd64c39d16afb590a4cc70ed805020161f8c1f94bc63d3b69756fbc5a0462ce12d2e17c4ebaeeb'} {'When is The Devil with Hitler playing in Bow Tie Cinemas' } 2 0 {0×0 char } {0×0 char} {0×0 char} {'60013a707ac8cdd2b44427418064915f7810b2d58d52d8f81ad3c6406b8922d61c134259747f3c73d2e64c885fc6141761d29f7e6ada7d6007c48577123e4af0'} {'8606bac841a08bcbf5ddb83c768103c467ffd1bf38b16052414210dc3ce3267561cb0368d227b6eb420dc147387cc1807032102b6248a13a40f83e5ac06d7122'} {'Give me the list of animated movies playing at the closest movie house' } 3 0 {0×0 char } {0×0 char} {0×0 char} {'60013a707ac8cdd2b44427418064915f7810b2d58d52d8f81ad3c6406b8922d61c134259747f3c73d2e64c885fc6141761d29f7e6ada7d6007c48577123e4af0'} {'0f7e63d320cfbf6ea5d1cda674007131d804a06f5866d38f81da7def33a4ce8ee4f2cb7e47b45eee97903cad3160b3a10f715862227e8ecdc3fb3bafc6b4279d'} {'The better part of valor is discretion' } 3 0 {0×0 char } {0×0 char} {0×0 char}
Найдите файлы в datastore, соответствующем набору обучающих данных.
csvFiles = metadata.path;
adsFiles = ads0.Files;
adsFiles = cellfun(@HelperGetFilePart,adsFiles,'UniformOutput',false);
[~,indA,indB] = intersect(adsFiles,csvFiles);
Создайте набор обучающих данных подмножества из большого набора данных.
adsTrain = subset(ads0,indA);
Вы будете использовать данные, соответствующие взрослым динамикам только. Считайте пол и переменные возраста из метаданных.
gender = metadata.gender; gender = gender(indB); age = metadata.age; age = age(indB);
Присвойте пол Labels
свойство datastore.
adsTrain.Labels = gender;
Не все файлы в наборе данных аннотируются информацией о возрасте и полом. Создайте подмножество datastore, который только содержит файлы, где информация о поле доступна, и возраст больше 19.
maleOrfemale = categorical(adsTrain.Labels) == "male" | categorical(adsTrain.Labels) == "female"; isAdult = categorical(age) ~= "" & categorical(age) ~= "teens"; adsTrain = subset(adsTrain,maleOrfemale & isAdult);
Вы обучите нейронную сеть для глубокого обучения на подмножестве файлов. Создайте подмножество datastore, содержащее равное количество спикеров и женщин-спикеров.
ismale = find(categorical(adsTrain.Labels) == "male"); isfemale = find(categorical(adsTrain.Labels) == "female"); numFilesPerGender = numel(isfemale); adsTrain = subset(adsTrain,[ismale(1:numFilesPerGender) isfemale(1:numFilesPerGender)]);
Используйте shuffle
рандомизировать порядок файлов в datastore.
adsTrain = shuffle(adsTrain);
Используйте countEachLabel
смотреть гендерный отказ набора обучающих данных.
countEachLabel(adsTrain)
ans = 2×2 table Label Count ______ _____ female 925 male 925
Считайте содержимое звукового файла с помощью read
.
[audio,adsInfo] = read(adsTrain); Fs = adsInfo.SampleRate;
Постройте звуковой сигнал и затем слушайте его с помощью sound
команда.
timeVector = (1/Fs) * (0:numel(audio)-1); figure plot(timeVector,audio) ylabel("Amplitude") xlabel("Time (s)") title("Sample Audio") grid on sound(audio,Fs)
Речевой сигнал имеет сегменты тишины, которые не содержат полезную информацию, имеющую отношение к полу докладчика. Этот пример удаляет тишину с помощью упрощенной версии подхода пороговой обработки, описанного в [2]. Шаги алгоритма удаления тишины обрисованы в общих чертах ниже.
Во-первых, вычислите две функции по неперекрывающимся системам координат аудиоданных: энергия сигнала и спектральный центроид. Спектральный центроид является мерой "центра тяжести" спектра сигнала.
Повредите аудио в неперекрывающиеся системы координат с 50 миллисекундами.
audio = audio ./ max(abs(audio)); % Normalize amplitude
windowLength = 50e-3 * Fs;
segments = buffer(audio,windowLength);
Вычислите энергию и спектральный центроид для каждой системы координат.
win = hann(windowLength,'periodic'); signalEnergy = sum(segments.^2,1)/windowLength; centroid = spectralCentroid(segments,Fs,'Window',win,'OverlapLength',0);
Затем установите пороги для каждой функции. Игнорируются области, где значения функции падают ниже или выше их соответствующих порогов. В этом примере энергетический порог устанавливается к половине средней энергии, и спектральный центроидный порог устанавливается к 5 000 Гц.
T_E = mean(signalEnergy)/2; T_C = 5000; isSpeechRegion = (signalEnergy>=T_E) & (centroid<=T_C);
Визуализируйте вычисленную энергию и спектральный центроид в зависимости от времени.
% Hold the signal energy, spectral centroid, and speech decision values for % plotting purposes. CC = repmat(centroid,windowLength,1); CC = CC(:); EE = repmat(signalEnergy,windowLength,1); EE = EE(:); flags2 = repmat(isSpeechRegion,windowLength,1); flags2 = flags2(:); figure subplot(3,1,1) plot(timeVector, CC(1:numel(audio)), ... timeVector, repmat(T_C,1,numel(timeVector)), "LineWidth",2) xlabel("Time (s)") ylabel("Normalized Centroid") legend("Centroid","Threshold") title("Spectral Centroid") grid on subplot(3,1,2) plot(timeVector, EE(1:numel(audio)), ... timeVector, repmat(T_E,1,numel(timeVector)),"LineWidth",2) ylabel("Normalized Energy") legend("Energy","Threshold") title("Window Energy") grid on subplot(3,1,3) plot(timeVector, audio, ... timeVector,flags2(1:numel(audio)),"LineWidth",2) ylabel("Audio") legend("Audio","Speech Region") title("Audio") grid on ylim([-1 1.1])
Извлеките сегменты речи от аудио. Примите, что речь присутствует для выборок, где энергия выше ее порога, и спектральный центроид ниже ее порога.
% Get indices of frames where a speech-to-silence or silence-to-speech % transition occurs. regionStartPos = find(diff([isSpeechRegion(1)-1, isSpeechRegion])); % Get the length of the all-silence or all-speech regions. RegionLengths = diff([regionStartPos, numel(isSpeechRegion)+1]); % Get speech-only regions. isSpeechRegion = isSpeechRegion(regionStartPos) == 1; regionStartPos = regionStartPos(isSpeechRegion); RegionLengths = RegionLengths(isSpeechRegion); % Get start and end indices for each speech region. Extend the region by 5 % windows on each side. startIndices = zeros(1,numel(RegionLengths)); endIndices = zeros(1,numel(RegionLengths)); for index = 1:numel(RegionLengths) startIndices(index) = max(1,(regionStartPos(index) - 5) * windowLength + 1); endIndices(index) = min(numel(audio),(regionStartPos(index) + RegionLengths(index) + 5) * windowLength); end
Наконец, объедините пересекающиеся речевые сегменты.
activeSegment = 1; isSegmentsActive = zeros(1,numel(startIndices)); isSegmentsActive(1) = 1; for index = 2:numel(startIndices) if startIndices(index) <= endIndices(activeSegment) % Current segment intersects with previous segment if endIndices(index) > endIndices(activeSegment) endIndices(activeSegment) = endIndices(index); end else % New speech segment detected activeSegment = index; isSegmentsActive(index) = 1; end end numSegments = sum(isSegmentsActive); segments = cell(1,numSegments); limits = zeros(2,numSegments); speechSegmentsIndices = find(isSegmentsActive); for index = 1:length(speechSegmentsIndices) segments{index} = audio(startIndices(speechSegmentsIndices(index)): ... endIndices(speechSegmentsIndices(index))); limits(:,index) = [startIndices(speechSegmentsIndices(index)); ... endIndices(speechSegmentsIndices(index))]; end
Постройте исходное аудио наряду с обнаруженными речевыми сегментами.
figure plot(timeVector,audio) hold on myLegend = cell(1,numel(segments) + 1); myLegend{1} = "Original Audio"; for index = 1:numel(segments) plot(timeVector(limits(1,index):limits(2,index)),segments{index}); myLegend{index+1} = sprintf("Output Audio Segment %d",index); end xlabel("Time (s)") ylabel("Audio") grid on legend(myLegend)
Речевой сигнал является динамическим по своей природе и изменяется в зависимости от времени. Это принято, что речевые сигналы стационарные по кратковременным шкалам, и их обработка часто делается в окнах 20-40 мс. Для каждого речевого сегмента этот пример извлекает функции аудио для 30 MS Windows с 75%-м перекрытием.
win = hamming(0.03*Fs,"periodic"); overlapLength = round(0.75*numel(win)); featureParams = struct("SampleRate",Fs, ... "Window",win, ... "OverlapLength",overlapLength); extractor = audioFeatureExtractor('Window',win, ... 'OverlapLength',overlapLength, ... 'SampleRate',Fs, ... 'SpectralDescriptorInput','melSpectrum', ... ... 'gtcc',true, ... 'gtccDelta',true, ... 'gtccDeltaDelta',true, ... 'spectralSlope',true, ... 'spectralFlux',true, ... 'spectralCentroid',true, ... 'spectralEntropy',true, ... 'pitch',true, ... 'harmonicRatio',true);
Фигура предоставляет обзор извлечения признаков, используемого в этом примере.
Чтобы ускорить обработку, извлеките последовательности функции из речевых сегментов всех звуковых файлов в datastore с помощью tall
массивы. В отличие от массивов в оперативной памяти, длинные массивы обычно остаются неоцененными, пока вы не запрашиваете, чтобы вычисления были выполнены с помощью gather
функция. Эта отсроченная оценка позволяет вам работать быстро с большими наборами данных. Когда вы в конечном счете запрашиваете выход с помощью gather
, MATLAB® комбинирует вычисления в очереди, где возможный и берет минимальное количество проходов через данные. Если у вас есть Parallel Computing Toolbox™, можно использовать длинные массивы на локальном сеансе MATLAB®, или на локальном параллельном пуле. Можно также выполнить вычисления длинного массива на кластере, если вам установили MATLAB® Parallel Server™.
Во-первых, преобразуйте datastore в длинный массив:
T = tall(adsTrain)
T = M×1 tall cell array {280560×1 double} {156144×1 double} {167664×1 double} {190704×1 double} {401520×1 double} {120432×1 double} {354288×1 double} {318576×1 double} : : : :
Отображение указывает, что количество строк (соответствующий количеству файлов в datastore), M, еще не известно. M является заполнителем, пока вычисление не завершается.
Извлеките речевые сегменты из длинной таблицы. Это действие создает новую переменную длинного массива, чтобы использовать в последующих вычислениях. Функциональный HelperSegmentSpeech
выполняет шаги, уже подсвеченные в Isolate Speech Segments
раздел. cellfun
команда применяет HelperSegmentSpeech
к содержимому каждого звукового файла в datastore. Также определите количество сегментов на файл.
segmentsTall = cellfun(@(x)HelperSegmentSpeech(x,Fs),T,"UniformOutput",false);
segmentsPerFileTall = cellfun(@numel,segmentsTall);
Извлеките последовательности функции из речевых сегментов с помощью HelperGetFeatureVectors
. Функция помощника применяет извлечение признаков на сегменты с помощью audioFeatureExtractor
и затем переориентирует функции так, чтобы время приехало строки, чтобы быть совместимым с sequenceInputLayer
.
featureVectorsTall = cellfun(@(x)HelperGetFeatureVectors(x,extractor),segmentsTall,"UniformOutput",false);
Используйте gather
оценивать featureVectorsTall
и segmentsPerFileTall
. featureVectors
возвращен как массив ячеек NumFiles-1, где каждый элемент массива ячеек является 1 массивом ячеек NumSegmentsPerFile. Не вложите массив ячеек.
[featureVectors,segmentsPerFile] = gather(featureVectorsTall,segmentsPerFileTall); featureVectors = cat(2,featureVectors{:});
Evaluating tall expression using the Parallel Pool 'local': - Pass 1 of 1: Completed in 2 min 15 sec Evaluation completed in 2 min 19 sec
Реплицируйте метки, таким образом, существует одна метка на сегмент.
myLabels = adsTrain.Labels; myLabels = repelem(myLabels,segmentsPerFile);
В приложениях классификации это - хорошая практика, чтобы нормировать все функции, чтобы иметь нулевое среднее значение и стандартное отклонение единицы.
Вычислите среднее и стандартное отклонение для каждого коэффициента и используйте их, чтобы нормировать данные.
allFeatures = cat(2,featureVectors{:}); allFeatures(isinf(allFeatures)) = nan; M = mean(allFeatures,2,'omitnan'); S = std(allFeatures,0,2,'omitnan'); featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false); for ii = 1:numel(featureVectors) idx = find(isnan(featureVectors{ii})); if ~isempty(idx) featureVectors{ii}(idx) = 0; end end
Пользователь HelperFeatureVector2Sequence
буферизовать характеристические векторы в последовательности 20 характеристических векторов с 10 перекрытиями.
featureVectorsPerSequence = 20; featureVectorOverlap = 10; [featuresTrain,sequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);
Создайте массив ячеек, genderTrain
, для ожидаемого пола, сопоставленного с каждой обучающей последовательностью.
genderTrain = repelem(myLabels,[sequencePerSegment{:}]);
Вы создадите набор данных валидации с помощью того же подхода, который вы использовали в обучающем наборе данных. Используйте readtable
функционируйте, чтобы считать метаданные, сопоставленные с файлами валидации.
metadata = readtable(fullfile(datafolder,"test.tsv"),"FileType","text");
Найдите файлы валидации в datastore.
csvFiles = metadata.path;
adsFiles = ads0.Files;
adsFiles = cellfun(@HelperGetFilePart,adsFiles,'UniformOutput',false);
[~,indA,indB] = intersect(adsFiles,csvFiles);
Создайте datastore валидации из большого datastore.
adsVal = subset(ads0,indA);
Подобно набору обучающих данных вы будете использовать данные, соответствующие взрослым динамикам только. Считайте пол и переменные возраста из метаданных.
gender = metadata.gender; gender = gender(indB); age = metadata.age; age = age(indB);
Присвойте пол Labels
свойство datastore.
adsVal.Labels = gender;
Не все файлы в наборе данных аннотируются информацией о возрасте и полом. Создайте подмножество datastore, который только содержит файлы, где информация о поле доступна, и возраст больше 19.
maleOrfemale = categorical(adsVal.Labels) == "female" | categorical(adsVal.Labels) == "male"; isAdult = categorical(age) ~= "" & categorical(age) ~= "teens"; adsVal = subset(adsVal,maleOrfemale & isAdult);
Используйте countEachLabel
смотреть гендерный отказ файлов.
countEachLabel(adsVal)
ans = 2×2 table Label Count ______ _____ female 83 male 532
Удалите тишину и извлеките функции из данных о валидации.
T = tall(adsVal); segments = cellfun(@(x)HelperSegmentSpeech(x,Fs),T,"UniformOutput",false); segmentsPerFileTall = cellfun(@numel,segments); featureVectorsTall = cellfun(@(x)HelperGetFeatureVectors(x,extractor),segments,"UniformOutput",false); [featureVectors,valSegmentsPerFile] = gather(featureVectorsTall,segmentsPerFileTall); featureVectors = cat(2,featureVectors{:}); valSegmentLabels = repelem(adsVal.Labels,valSegmentsPerFile);
Evaluating tall expression using the Parallel Pool 'local': - Pass 1 of 1: Completed in 45 sec Evaluation completed in 46 sec
Нормируйте последовательность функции средними и стандартными отклонениями, вычисленными во время учебного этапа.
featureVectors = cellfun(@(x)(x-M)./S,featureVectors,'UniformOutput',false); for ii = 1:numel(featureVectors) idx = find(isnan(featureVectors{ii})); if ~isempty(idx) featureVectors{ii}(idx) = 0; end end
Создайте массив ячеек, содержащий предикторы последовательности.
[featuresValidation,valSequencePerSegment] = HelperFeatureVector2Sequence(featureVectors,featureVectorsPerSequence,featureVectorOverlap);
Создайте массив ячеек, gender
, для ожидаемого пола, сопоставленного с каждой обучающей последовательностью.
genderValidation = repelem(valSegmentLabels,[valSequencePerSegment{:}]);
Сети LSTM могут изучить долгосрочные зависимости между временными шагами данных о последовательности. Этот пример использует двунаправленный слой LSTM bilstmLayer, чтобы посмотреть на последовательность и во вперед и в обратные направления.
Задайте входной размер, чтобы быть последовательностями размера NumFeatures
. Задайте скрытый двунаправленный слой LSTM с выходным размером 50 и выведите последовательность. Затем задайте двунаправленный слой LSTM с выходным размером 50 и выведите последний элемент последовательности. Эта команда дает двунаправленному слою LSTM команду сопоставлять свой вход в 50 функций и затем готовит выход к полносвязному слою. Наконец, задайте два класса включением полносвязного слоя размера 2, сопровождаемый softmax слоем и слоем классификации.
layers = [ ... sequenceInputLayer(size(featuresTrain{1},1)) bilstmLayer(50,"OutputMode","sequence") dropoutLayer(0.1) bilstmLayer(50,"OutputMode","last") fullyConnectedLayer(2) softmaxLayer classificationLayer];
Затем задайте опции обучения для классификатора. Установите MaxEpochs
к 4
так, чтобы сеть сделала 4, проходит через обучающие данные. Установите MiniBatchSize
из 128
так, чтобы сеть посмотрела на 128 учебных сигналов за один раз. Задайте Plots
как "training-progress"
сгенерировать графики, которые показывают процесс обучения количеством увеличений итераций. Установите Verbose
к false
отключить печать таблицы выход, который соответствует данным, показанным в графике. Задайте Shuffle
как "every-epoch"
переставить обучающую последовательность в начале каждой эпохи. Задайте LearnRateSchedule
к "piecewise"
чтобы уменьшить темп обучения заданным фактором (0.1) каждый раз, определенное число эпох (2) передало.
Этот пример использует адаптивную оценку момента (ADAM) решатель. ADAM выполняет лучше с рекуррентными нейронными сетями (RNNs) как LSTMs, чем стохастический градиентный спуск по умолчанию с импульсом (SGDM) решатель.
miniBatchSize = 128; validationFrequency = floor(numel(genderTrain)/miniBatchSize); options = trainingOptions("adam", ... "MaxEpochs",4, ... "MiniBatchSize",miniBatchSize, ... "Plots","training-progress", ... "Verbose",false, ... "Shuffle","every-epoch", ... "LearnRateSchedule","piecewise", ... "LearnRateDropFactor",0.1, ... "LearnRateDropPeriod",2,... 'ValidationData',{featuresValidation,categorical(genderValidation)}, ... 'ValidationFrequency',validationFrequency);
Обучите сеть LSTM с заданными опциями обучения и архитектурой слоя с помощью trainNetwork
. Поскольку набор обучающих данных является большим, учебный процесс может занять несколько минут.
net = trainNetwork(featuresTrain,categorical(genderTrain),layers,options);
Главный подграфик графика процесса обучения представляет учебную точность, которая является точностью классификации на каждом мини-пакете. Когда обучение прогрессирует успешно, это значение обычно увеличивается к 100%. Нижний подграфик отображает учебную потерю, которая является перекрестной энтропийной потерей на каждом мини-пакете. Когда обучение прогрессирует успешно, это значение обычно уменьшается по направлению к нулю.
Если обучение не сходится, графики могут колебаться между значениями, не отклоняясь в определенном восходящем или нисходящем направлении. Это колебание означает, что учебная точность не улучшается, и учебная потеря не уменьшается. Эта ситуация может произойти в начале обучения, или после некоторого предварительного улучшения учебной точности. Во многих случаях изменение опций обучения может помочь сети достигнуть сходимости. Уменьшение MiniBatchSize
или уменьшение InitialLearnRate
может закончиться в более длительное учебное время, но оно может помочь сети учиться лучше.
Вычислите учебную точность, которая представляет точность классификатора на сигналах, на которых это было обучено. Во-первых, классифицируйте обучающие данные.
trainPred = classify(net,featuresTrain);
Постройте матрицу беспорядка. Отобразите точность и отзыв для этих двух классов при помощи сводных данных строки и столбца.
figure cm = confusionchart(categorical(genderTrain),trainPred,'title','Training Accuracy'); cm.ColumnSummary = 'column-normalized'; cm.RowSummary = 'row-normalized';
Вычислите точность валидации. Во-первых, классифицируйте обучающие данные.
[valPred,valScores] = classify(net,featuresValidation);
Постройте матрицу беспорядка. Отобразите точность и отзыв для этих двух классов при помощи сводных данных строки и столбца.
figure cm = confusionchart(categorical(genderValidation),valPred,'title','Validation Set Accuracy'); cm.ColumnSummary = 'column-normalized'; cm.RowSummary = 'row-normalized';
Пример сгенерировал несколько последовательностей из каждого учебного речевого файла. Более высокая точность может быть достигнута путем рассмотрения выходного класса всех последовательностей, соответствующих тому же файлу и применяющих решение "макс. правила", где класс с сегментом с самым высоким счетом уверенности выбран.
Определите количество последовательностей, сгенерированных на файл в наборе валидации.
sequencePerFile = zeros(size(valSegmentsPerFile)); valSequencePerSegmentMat = cell2mat(valSequencePerSegment); idx = 1; for ii = 1:numel(valSegmentsPerFile) sequencePerFile(ii) = sum(valSequencePerSegmentMat(idx:idx+valSegmentsPerFile(ii)-1)); idx = idx + valSegmentsPerFile(ii); end
Предскажите пол из каждого учебного файла путем считания выходных классов всех последовательностей сгенерированными из того же файла.
numFiles = numel(adsVal.Files); actualGender = categorical(adsVal.Labels); predictedGender = actualGender; scores = cell(1,numFiles); counter = 1; cats = unique(actualGender); for index = 1:numFiles scores{index} = valScores(counter: counter + sequencePerFile(index) - 1,:); m = max(mean(scores{index},1),[],1); if m(1) >= m(2) predictedGender(index) = cats(1); else predictedGender(index) = cats(2); end counter = counter + sequencePerFile(index); end
Визуализируйте матрицу беспорядка на прогнозах принципа большинства.
figure cm = confusionchart(actualGender,predictedGender,'title','Validation Set Accuracy - Max Rule'); cm.ColumnSummary = 'column-normalized'; cm.RowSummary = 'row-normalized';
[1] https://voice.mozilla.org /
[2] Введение в аудио анализ: подход MATLAB, Джиэннэкопулос и Пикракис, Academic Press.