Этот пример показывает, что типичный рабочий процесс для выбора признаков применился к задаче речевого распознавания эмоции. Вы начинаете путем создания базовой точности, использующей общие функции аудио (MFCC). Вы затем увеличиваете свой набор данных, чтобы уменьшить сверхподбор кривой. Наконец, вы выполняете последовательный выбор признаков, чтобы выбрать лучший набор функций.
В последовательном выборе признаков вы обучаете сеть на данном наборе функций и затем инкрементно добавляете или удаляете функции, пока самая высокая точность не достигнута [1]. В этом примере вы применяете последовательный прямой выбор к задаче речевого распознавания эмоции с помощью Берлинской Базы данных Эмоциональной Речи [2].
Берлинская База данных Эмоциональной Речи содержит 535 произнесения, на котором говорят 10 агентов, предназначенных, чтобы передать одну из следующих эмоций: гнев, скука, отвращение, беспокойство/страх, счастье, печаль, или нейтральный. Эмоции являются независимым текстом. Загрузите базу данных с http://emodb.bilderbar.info/index-1280.html и затем установите PathToDatabase
к местоположению звуковых файлов. Создайте audioDatastore
это указывает на звуковые файлы.
datafolder = PathToDatabase;
ads = audioDatastore(fullfile(datafolder,"wav"));
Имена файлов являются кодами, указывающими на ID динамика, текст, на котором говорят, эмоция и версия. Веб-сайт содержит ключ для интерпретации кода и дополнительной информации о динамиках, таких как пол и возраст. Составьте таблицу с переменными Speaker
и Emotion
. Декодируйте имена файлов в таблицу.
filepaths = ads.Files; emotionCodes = cellfun(@(x)x(end-5),filepaths,'UniformOutput',false); emotions = replace(emotionCodes,{'W','L','E','A','F','T','N'}, ... {'Anger','Boredom','Disgust','Anxiety/Fear','Happiness','Sadness','Neutral'}); speakerCodes = cellfun(@(x)x(end-10:end-9),filepaths,'UniformOutput',false); labelTable = cell2table([speakerCodes,emotions],'VariableNames',{'Speaker','Emotion'}); labelTable.Emotion = categorical(labelTable.Emotion); labelTable.Speaker = categorical(labelTable.Speaker); head(labelTable)
ans = 8×2 table Speaker Emotion _______ _________ 03 Happiness 03 Neutral 03 Anger 03 Happiness 03 Neutral 03 Sadness 03 Anger 03 Anger
labelTable
находится в том же порядке как файлы в audioDatastore
. Установите Labels
свойство audioDatastore
к labelTable
.
ads.Labels = labelTable;
Можно теперь разделить меткой и подмножеством, чтобы изолировать фрагменты данных. Подмножество datastore, который содержит динамик 12
передача скуки. Слушайте файл и просмотрите форму волны временного интервала. Отобразите полную метку, соответствующую произнесению.
speaker = categorical("12"); emotion = categorical("Boredom"); adsSubset = subset(ads,ads.Labels.Speaker==speaker & ads.Labels.Emotion == emotion); [audio,adsInfo] = read(adsSubset); fs = adsInfo.SampleRate; sound(audio,fs) t = (0:size(audio,1)-1)/fs; figure plot(t,audio) grid on xlabel('Time (s)') ylabel('Amplitude')
Чтобы обеспечить точную оценку модели, вы создаете в этом примере, обучаете и подтверждаете перекрестную проверку k-сгиба отпуска один динамика (LOSO) использования. В этом методе вы обучаете использование k-1 динамики и затем подтверждаете на не учтенном динамике. Вы повторяете эту процедуру k времена. Итоговая точность валидации является средним значением сгибов k.
Создайте переменную, которая содержит идентификаторы динамика. Определите количество сгибов: 1 для каждого динамика. База данных содержит произнесение от 10 уникальных докладчиков. Используйте summary
чтобы отобразить идентификаторы динамика (левый столбец) и количество произнесения, они способствуют базе данных (правый столбец).
speaker = ads.Labels.Speaker; numFolds = numel(speaker); summary(speaker)
03 49 08 58 09 43 10 38 11 55 12 35 13 61 14 69 15 56 16 71
Как первый шаг к разработке модели машинного обучения, определите базовую точность.
Примите, что 10-кратная точность перекрестной проверки первой попытки обучения составляет приблизительно 60% из-за недостаточных обучающих данных, и что модель, обученная на недостаточных данных, сверхсоответствует некоторым сгибам и underfits другим. Чтобы улучшить полную подгонку, увеличьте размер набора данных к 50 разам с помощью audioDataAugmenter
.
Создайте audioDataAugmenter
объект. Установите вероятность применения перехода подачи к 0.5
и используйте область значений по умолчанию. Установите вероятность применения смещения во времени к 1
и используйте область значений [-0.3,0.3]
секунды. Установите вероятность добавления шума к 1
и укажите диапазон ОСШ как [-20,40]
дБ.
augmenter = audioDataAugmenter('NumAugmentations',50, ... 'TimeStretchProbability',0, ... 'VolumeControlProbability',0, ... ... 'PitchShiftProbability',0.5, ... ... 'TimeShiftProbability',1, ... 'TimeShiftRange',[-0.3,0.3], ... ... 'AddNoiseProbability',1, ... 'SNRRange', [-20,40]);
Создайте новую папку в своей текущей папке, чтобы содержать увеличенный набор данных.
currentDir = pwd;
writeDirectory = [currentDir,'\augmentedData'];
mkdir(writeDirectory)
Для каждого файла в аудио datastore:
Создайте 50 увеличений.
Нормируйте аудио, чтобы иметь макс. абсолютное значение 1
.
Запишите увеличенные аудиоданные как файл WAV. Добавьте _augK
к каждым из имен файлов, где K является номером увеличения. Чтобы ускорить обработку, используйте parfor
и раздел datastore.
reset(ads) numPartitions = 6; for ii = 1:numPartitions adsPart = partition(ads,numPartitions,ii); while hasdata(adsPart) [x,adsInfo] = read(adsPart); data = augment(augmenter,x,fs); [~,fn] = fileparts(adsInfo.FileName); for i = 1:size(data,1) augmentedAudio = data.Audio{i}; augmentedAudio = augmentedAudio/max(abs(augmentedAudio),[],'all'); augNum = num2str(i); if numel(augNum)==1 iString = ['0',augNum]; else iString = augNum; end audiowrite([writeDirectory,'\',sprintf('%s_aug%s.wav',fn,iString)],augmentedAudio,fs); end end end
Создайте аудио datastore, который указывает на увеличенный набор данных. Реплицируйте строки таблицы метки исходного datastore NumAugmentations
времена, чтобы определить метки увеличенного datastore.
augads = audioDatastore(writeDirectory); augads.Labels = repelem(ads.Labels,augmenter.NumAugmentations,1);
Mel-частота cepstral коэффициенты (MFCC), дельта-MFCC и дельта дельты MFCC являются популярными функциями аудио. Как базовая линия, используйте MFCC, дельту-MFCC и дельту дельты MFCC с 30 MS Windows и никаким перекрытием.
Создайте audioFeatureExtractor
объект. Установите Window
к периодическому Окну Хэмминга на 30 мс, OverlapLength
к 0
, и SampleRate
к частоте дискретизации базы данных. Установите mfcc
, mfccDelta
, и mfccDeltaDelta
к true
извлекать их.
win = hamming(round(0.03*fs),"periodic"); overlapLength = 0; extractor = audioFeatureExtractor( ... 'Window',win, ... 'OverlapLength',overlapLength, ... 'SampleRate',fs, ... ... 'mfcc',true, ... 'mfccDelta',true, ... 'mfccDeltaDelta',true)
extractor = audioFeatureExtractor with properties: Properties Window: [480×1 double] OverlapLength: 0 SampleRate: 16000 FFTLength: [] SpectralDescriptorInput: 'linearSpectrum' Enabled Features mfcc, mfccDelta, mfccDeltaDelta Disabled Features linearSpectrum, melSpectrum, barkSpectrum, erbSpectrum, gtcc, gtccDelta gtccDeltaDelta, spectralCentroid, spectralCrest, spectralDecrease, spectralEntropy, spectralFlatness spectralFlux, spectralKurtosis, spectralRolloffPoint, spectralSkewness, spectralSlope, spectralSpread pitch, harmonicRatio To extract a feature, set the corresponding property to true. For example, obj.mfcc = true, adds mfcc to the list of enabled features.
Шаги для каждого сгиба следуют:
Разделите аудио datastore на наборы обучения и валидации.
Извлеките характеристические векторы из наборов обучения и валидации.
Нормируйте характеристические векторы.
Буферизуйте характеристические векторы в последовательности 20 с перекрытиями 10.
Реплицируйте метки так, чтобы они были во взаимно-однозначном соответствии с характеристическими векторами.
Задайте опции обучения.
Задайте сеть.
Обучите сеть.
Оцените сеть.
1. Разделите аудио datastore на наборы обучения и валидации. Для набора разработки пропустите первый динамик. Для набора валидации используйте только произнесение от первого докладчика. Преобразуйте данные в длинные массивы.
adsTrain = subset(augads,augads.Labels.Speaker~=speaker(1)); adsTrain.Labels = adsTrain.Labels.Emotion; tallTrain = tall(adsTrain); adsValidation = subset(ads,ads.Labels.Speaker==speaker(1)); adsValidation.Labels = adsValidation.Labels.Emotion; tallValidation = tall(adsValidation);
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
2. Извлеките функции из наборов обучения и валидации. Переориентируйте функции так, чтобы время приехало строки, чтобы быть совместимым с sequenceInputLayer
.
featuresTallTrain = cellfun(@(x)extract(extractor,x),tallTrain,"UniformOutput",false); featuresTallTrain = cellfun(@(x)x',featuresTallTrain,"UniformOutput",false); featuresTrain = gather(featuresTallTrain); featuresTallValidation = cellfun(@(x)extract(extractor,x),tallValidation,"UniformOutput",false); featuresTallValidation = cellfun(@(x)x',featuresTallValidation,"UniformOutput",false); featuresValidation = gather(featuresTallValidation);
Evaluating tall expression using the Parallel Pool 'local': - Pass 1 of 1: Completed in 2 min 18 sec Evaluation completed in 2 min 18 sec Evaluating tall expression using the Parallel Pool 'local': - Pass 1 of 1: Completed in 1.2 sec Evaluation completed in 1.2 sec
3. Используйте набор обучающих данных, чтобы определить среднее и стандартное отклонение каждой функции. Нормируйте наборы обучения и валидации.
allFeatures = cat(2,featuresTrain{:}); M = mean(allFeatures,2,'omitnan'); S = std(allFeatures,0,2,'omitnan'); featuresTrain = cellfun(@(x)(x-M)./S,featuresTrain,'UniformOutput',false); featuresValidation = cellfun(@(x)(x-M)./S,featuresValidation,'UniformOutput',false);
4. Буферизуйте характеристические векторы в последовательности так, чтобы каждая последовательность состояла из 20 характеристических векторов с перекрытиями 10 характеристических векторов.
featureVectorsPerSequence = 20; featureVectorOverlap = 10; [sequencesTrain,sequencePerFileTrain] = HelperFeatureVector2Sequence(featuresTrain,featureVectorsPerSequence,featureVectorOverlap); [sequencesValidation,sequencePerFileValidation] = HelperFeatureVector2Sequence(featuresValidation,featureVectorsPerSequence,featureVectorOverlap);
5. Реплицируйте метки наборов обучения и валидации так, чтобы они были во взаимно-однозначном соответствии с последовательностями. Не у всех динамиков есть произнесение для всех эмоций. Создайте пустой categorical
массив, который содержит все эмоциональные категории и добавляет его к меткам валидации так, чтобы категориальный массив содержал все эмоции.
labelsTrain = repelem(adsTrain.Labels,[sequencePerFileTrain{:}]); emptyEmotions = ads.Labels.Emotion; emptyEmotions(:) = []; labelsValidation = [emptyEmotions;adsValidation.Labels]; labelsValidation = repelem(labelsValidation,[sequencePerFileValidation{:}]);
6. Задайте сеть BiLSTM с помощью bilstmLayer
. Поместите dropoutLayer
до и после bilstmLayer
помочь предотвратить сверхподбор кривой.
dropoutProb1 = 0.3; numUnits = 200; dropoutProb2 = 0.6; layers = [ ... sequenceInputLayer(size(sequencesTrain{1},1)) dropoutLayer(dropoutProb1) bilstmLayer(numUnits,"OutputMode","last") dropoutLayer(dropoutProb2) fullyConnectedLayer(numel(categories(emptyEmotions))) softmaxLayer classificationLayer];
7. Задайте опции обучения с помощью trainingOptions
.
miniBatchSize = 512; initialLearnRate = 0.005; learnRateDropPeriod = 2; maxEpochs = 3; options = trainingOptions("adam", ... "MiniBatchSize",miniBatchSize, ... "InitialLearnRate",initialLearnRate, ... "LearnRateDropPeriod",learnRateDropPeriod, ... "LearnRateSchedule","piecewise", ... "MaxEpochs",maxEpochs, ... "Shuffle","every-epoch", ... "ValidationData",{sequencesValidation,labelsValidation}, ... "Verbose",false, ... "Plots","Training-Progress");
8. Обучите сеть с помощью trainNetwork
.
net = trainNetwork(sequencesTrain,labelsTrain,layers,options);
9. Оцените сеть. Вызовите classify
получить предсказанные метки на последовательность. Заставьте режим предсказанных меток каждой последовательности получать предсказанные метки каждого файла. Постройте график беспорядка истинных меток и предсказанных меток.
predictedLabelsPerSequence = classify(net,sequencesValidation); labelsTrue = adsValidation.Labels; labelsPred = labelsTrue; idx = 1; for ii = 1:numel(labelsTrue) labelsPred(ii,:) = mode(predictedLabelsPerSequence(idx:idx + sequencePerFileValidation{ii} - 1,:),1); idx = idx + sequencePerFileValidation{ii}; end figure cm = confusionchart(labelsTrue,labelsPred); valAccuracy = mean(labelsTrue==labelsPred)*100; cm.Title = sprintf('Confusion Matrix for Fold 1\nAccuracy = %0.1f',valAccuracy); sortClasses(cm,categories(emptyEmotions)) cm.ColumnSummary = 'column-normalized'; cm.RowSummary = 'row-normalized';
Функция помощника HelperTrainAndValidateNetwork
выполняет шаги, обрисованные в общих чертах выше для всех 10 сгибов, и возвращает истинные и предсказанные метки для каждого сгиба. Вызовите HelperTrainAndValidateNetwork
с audioDatastore
, увеличенный audioDatastore
, и audioFeatureExtractor
.
[labelsTrue,labelsPred] = HelperTrainAndValidateNetwork(ads,augads,extractor);
Распечатайте точность на сгиб и постройте 10-кратный график беспорядка.
for ii = 1:numel(labelsTrue) foldAcc = mean(labelsTrue{ii}==labelsPred{ii})*100; fprintf('Fold %1.0f, Accuracy = %0.1f\n',ii,foldAcc); end labelsTrueMat = cat(1,labelsTrue{:}); labelsPredMat = cat(1,labelsPred{:}); figure cm = confusionchart(labelsTrueMat,labelsPredMat); valAccuracy = mean(labelsTrueMat==labelsPredMat)*100; cm.Title = sprintf('Confusion Matrix for 10-Fold Cross-Validation\nAverage Accuracy = %0.1f',valAccuracy); sortClasses(cm,categories(emptyEmotions)) cm.ColumnSummary = 'column-normalized'; cm.RowSummary = 'row-normalized';
Fold 1, Accuracy = 87.8 Fold 2, Accuracy = 86.2 Fold 3, Accuracy = 72.1 Fold 4, Accuracy = 84.2 Fold 5, Accuracy = 76.4 Fold 6, Accuracy = 65.7 Fold 7, Accuracy = 68.9 Fold 8, Accuracy = 85.5 Fold 9, Accuracy = 75.0 Fold 10, Accuracy = 64.8
Затем попытайтесь далее улучшить точность путем выбора лучшего набора функций. Последовательный выбор признаков может быть трудоемким. Чтобы уменьшать время выбора признаков, уменьшайте увеличенный набор аудиоданных так, чтобы было только 10 увеличений для каждого исходного файла. Используйте этот уменьшаемый набор данных, чтобы выбрать функции. Если лучший набор выбран, вы обучаетесь на полном увеличенном наборе данных, который в 50 раз больше для итоговой оценки.
augads10 = subset(augads,1:5:numel(augads.Files));
Создайте новый audioFeatureExtractor
объект. Используйте то же окно и длину перекрытия как ранее. Установите все функции, которые вы хотите протестировать к true
.
extractor = audioFeatureExtractor( ... 'Window', win, ... 'OverlapLength',overlapLength, ... 'SampleRate', fs, ... ... 'linearSpectrum', false, ... 'melSpectrum', false, ... 'barkSpectrum', false, ... 'erbSpectrum', false, ... ... 'mfcc', true, ... 'mfccDelta', true, ... 'mfccDeltaDelta', true, ... 'gtcc', true, ... 'gtccDelta', true, ... 'gtccDeltaDelta', true, ... ... 'SpectralDescriptorInput','melSpectrum', ... 'spectralCentroid', true, ... 'spectralCrest', true, ... 'spectralDecrease', true, ... 'spectralEntropy', true, ... 'spectralFlatness', true, ... 'spectralFlux', true, ... 'spectralKurtosis', true, ... 'spectralRolloffPoint',true, ... 'spectralSkewness', true, ... 'spectralSlope', true, ... 'spectralSpread', true, ... ... 'pitch', true, ... 'harmonicRatio', true);
В прямом выборе вы запускаете путем оценки каждой функции отдельно. Если лучшая одна функция понята, вы содержите ее и оцениваете каждую пару функций. Если лучшие две функции поняты, вы содержите их и оцениваете для лучших трех функций и т.д. Когда точность прекращает улучшаться, вы заканчиваете выбор признаков.
[logbook,bestFeatures] = ... HelperSFS(ads,augads10,extractor,'forward');
Смотрите верхние и нижние настройки функции.
head(logbook) tail(logbook)
ans = 8×2 table Features Accuracy _________________________________________________________________ ________ "mfccDelta, gtcc, gtccDelta, spectralCrest" 75.327 "mfccDelta, gtcc, gtccDelta, gtccDeltaDelta, spectralCrest" 74.393 "mfccDelta, gtcc, gtccDelta, spectralDecrease" 74.019 "mfccDelta, gtcc, gtccDelta, spectralCrest, spectralRolloffPoint" 74.019 "mfccDelta, gtcc, gtccDelta, spectralCrest, harmonicRatio" 74.019 "mfccDelta, gtcc, gtccDelta" 73.458 "mfccDelta, gtcc, gtccDelta, spectralCentroid" 73.458 "mfcc, mfccDelta, gtcc, gtccDelta" 73.271 ans = 8×2 table Features Accuracy ______________________ ________ "pitch" 31.963 "spectralFlux" 31.589 "spectralEntropy" 27.85 "spectralRolloffPoint" 25.794 "spectralCrest" 24.486 "spectralDecrease" 24.486 "harmonicRatio" 23.364 "spectralFlatness" 21.121
Установите лучшую настройку функции, как определено последовательным выбором признаков, на audioFeatureExtractor
объект.
set(extractor,bestFeatures)
Протестируйте 10-кратную точность перекрестной проверки LOSO выбранного набора функций с помощью полного увеличенного набора данных.
[labelsTrue,labelsPred] = HelperTrainAndValidateNetwork(ads,augads,extractor); labelsTrueMat = cat(1,labelsTrue{:}); labelsPredMat = cat(1,labelsPred{:}); figure cm = confusionchart(labelsTrueMat,labelsPredMat); valAccuracy = mean(labelsTrueMat==labelsPredMat)*100; cm.Title = sprintf('Confusion Matrix for 10-Fold Cross-Validation\nAverage Accuracy = %0.1f',valAccuracy); sortClasses(cm,categories(emptyEmotions)) cm.ColumnSummary = 'column-normalized'; cm.RowSummary = 'row-normalized';
[1] Джайн, А., и Д. Зонгкер. "Выбор признаков: Оценка, Приложение и Производительность Небольшой выборки". Транзакции IEEE согласно Анализу Шаблона и Искусственному интеллекту. Издание 19, Выпуск 2, 1997, стр 153-158.
[2] Burkhardt, F., А. Пэешк, М. Рольфес, В.Ф. Сендлмайер и Б. Вайс, "База данных немецкой эмоциональной речи". В межречи 2005 продолжений. Лиссабон, Португалия: международная речевая коммуникационная ассоциация, 2005.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function [sequences,sequencePerFile] = HelperFeatureVector2Sequence(features,featureVectorsPerSequence,featureVectorOverlap) % Copyright 2019 MathWorks, Inc. if featureVectorsPerSequence <= featureVectorOverlap error('The number of overlapping feature vectors must be less than the number of feature vectors per sequence.') end hopLength = featureVectorsPerSequence - featureVectorOverlap; idx1 = 1; sequences = {}; sequencePerFile = cell(numel(features),1); for ii = 1:numel(features) sequencePerFile{ii} = floor((size(features{ii},2) - featureVectorsPerSequence)/hopLength) + 1; idx2 = 1; for j = 1:sequencePerFile{ii} sequences{idx1,1} = features{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1); %#ok<AGROW> idx1 = idx1 + 1; idx2 = idx2 + hopLength; end end end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function [trueLabelsCrossFold,predictedLabelsCrossFold] = HelperTrainAndValidateNetwork(varargin) % Copyright 2019 The MathWorks, Inc. if nargin == 3 ads = varargin{1}; augads = varargin{2}; extractor = varargin{3}; elseif nargin == 2 ads = varargin{1}; augads = varargin{1}; extractor = varargin{2}; end speaker = categories(ads.Labels.Speaker); numFolds = numel(speaker); emptyEmotions = categorical(ads.Labels.Emotion); emptyEmotions(:) = []; % Loop over each fold trueLabelsCrossFold = {}; predictedLabelsCrossFold = {}; for i = 1:numFolds % 1. Divide the audio datastore into training and validation sets. % Convert the data to tall arrays. idxTrain = augads.Labels.Speaker~=speaker(i); augadsTrain = subset(augads,idxTrain); augadsTrain.Labels = augadsTrain.Labels.Emotion; tallTrain = tall(augadsTrain); idxValidation = ads.Labels.Speaker==speaker(i); adsValidation = subset(ads,idxValidation); adsValidation.Labels = adsValidation.Labels.Emotion; tallValidation = tall(adsValidation); % 2. Extract features from the training set. Reorient the features % so that time is along rows to be compatible with % sequenceInputLayer. tallTrain = cellfun(@(x)x/max(abs(x),[],'all'),tallTrain,"UniformOutput",false); tallFeaturesTrain = cellfun(@(x)extract(extractor,x),tallTrain,"UniformOutput",false); tallFeaturesTrain = cellfun(@(x)x',tallFeaturesTrain,"UniformOutput",false); %#ok<NASGU> [~,featuresTrain] = evalc('gather(tallFeaturesTrain)'); % Use evalc to suppress command-line output. tallValidation = cellfun(@(x)x/max(abs(x),[],'all'),tallValidation,"UniformOutput",false); tallFeaturesValidation = cellfun(@(x)extract(extractor,x),tallValidation,"UniformOutput",false); tallFeaturesValidation = cellfun(@(x)x',tallFeaturesValidation,"UniformOutput",false); %#ok<NASGU> [~,featuresValidation] = evalc('gather(tallFeaturesValidation)'); % Use evalc to suppress command-line output. % 3. Use the training set to determine the mean and standard % deviation of each feature. Normalize the training and validation % sets. allFeatures = cat(2,featuresTrain{:}); M = mean(allFeatures,2,'omitnan'); S = std(allFeatures,0,2,'omitnan'); featuresTrain = cellfun(@(x)(x-M)./S,featuresTrain,'UniformOutput',false); for ii = 1:numel(featuresTrain) idx = find(isnan(featuresTrain{ii})); if ~isempty(idx) featuresTrain{ii}(idx) = 0; end end featuresValidation = cellfun(@(x)(x-M)./S,featuresValidation,'UniformOutput',false); for ii = 1:numel(featuresValidation) idx = find(isnan(featuresValidation{ii})); if ~isempty(idx) featuresValidation{ii}(idx) = 0; end end % 4. Buffer the sequences so that each sequence consists of twenty % feature vectors with overlaps of 10 feature vectors. featureVectorsPerSequence = 20; featureVectorOverlap = 10; [sequencesTrain,sequencePerFileTrain] = HelperFeatureVector2Sequence(featuresTrain,featureVectorsPerSequence,featureVectorOverlap); [sequencesValidation,sequencePerFileValidation] = HelperFeatureVector2Sequence(featuresValidation,featureVectorsPerSequence,featureVectorOverlap); % 5. Replicate the labels of the train and validation sets so that % they are in one-to-one correspondence with the sequences. labelsTrain = [emptyEmotions;augadsTrain.Labels]; labelsTrain = labelsTrain(:); labelsTrain = repelem(labelsTrain,[sequencePerFileTrain{:}]); % 6. Define a BiLSTM network. dropoutProb1 = 0.3; numUnits = 200; dropoutProb2 = 0.6; layers = [ ... sequenceInputLayer(size(sequencesTrain{1},1)) dropoutLayer(dropoutProb1) bilstmLayer(numUnits,"OutputMode","last") dropoutLayer(dropoutProb2) fullyConnectedLayer(numel(categories(emptyEmotions))) softmaxLayer classificationLayer]; % 7. Define training options. miniBatchSize = 512; initialLearnRate = 0.005; learnRateDropPeriod = 2; maxEpochs = 3; options = trainingOptions("adam", ... "MiniBatchSize",miniBatchSize, ... "InitialLearnRate",initialLearnRate, ... "LearnRateDropPeriod",learnRateDropPeriod, ... "LearnRateSchedule","piecewise", ... "MaxEpochs",maxEpochs, ... "Shuffle","every-epoch", ... "Verbose",false); % 8. Train the network. net = trainNetwork(sequencesTrain,labelsTrain,layers,options); % 9. Evaluate the network. Call classify to get the predicted labels % for each sequence. Get the mode of the predicted labels of each % sequence to get the predicted labels of each file. predictedLabelsPerSequence = classify(net,sequencesValidation); trueLabels = categorical(adsValidation.Labels); predictedLabels = trueLabels; idx1 = 1; for ii = 1:numel(trueLabels) predictedLabels(ii,:) = mode(predictedLabelsPerSequence(idx1:idx1 + sequencePerFileValidation{ii} - 1,:),1); idx1 = idx1 + sequencePerFileValidation{ii}; end trueLabelsCrossFold{i} = trueLabels; %#ok<AGROW> predictedLabelsCrossFold{i} = predictedLabels; %#ok<AGROW> end end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function [logbook,bestFeatures] = HelperSFS(ads,adsAug,extractor,direction) % logbook = HelperSFS(ads,adsAug,extractor,direction) % returns a table, logbook, that contains the feature configurations tested % and associated validation accuracy. % ads - audioDatastore object that points to the original dataset (used for val). % adsAug - audioDatastore object that points to the augmented dataset (used for dev). % extractor - audioFeatureExtractor object. Set all features to test to true. % direction - specify as 'forward' or 'backward' % %[logbook,bestFeatures] = HelperSFS(ads,adsAug,extractor,direction) % also returns a struct, bestFeatures, containing the best feature % configuration for audioFeatureExtractor. % Copyright 2019 The MathWorks, Inc. featuresToTest = fieldnames(info(extractor)); N = numel(featuresToTest); % --------------------------------------------------------------------- % Set the initial feature configuration: all on for backward selection % or all off for forward selection. featureConfig = info(extractor); for i = 1:N if strcmpi(direction,"backward") featureConfig.(featuresToTest{i}) = true; else featureConfig.(featuresToTest{i}) = false; end end % --------------------------------------------------------------------- % Initialize logbook to track feature configuration and accuracy. logbook = table(featureConfig,0,'VariableNames',["Feature Configuration","Accuracy"]); %% Perform sequential feature evaluation wrapperIdx = 1; bestAccuracy = 0; while wrapperIdx <= N % ----------------------------------------------------------------- % Create a cell array containing all feature configurations to test % in the current loop. featureConfigsToTest = cell(numel(featuresToTest),1); for ii = 1:numel(featuresToTest) if strcmpi(direction,"backward") featureConfig.(featuresToTest{ii}) = false; else featureConfig.(featuresToTest{ii}) = true; end featureConfigsToTest{ii} = featureConfig; if strcmpi(direction,"backward") featureConfig.(featuresToTest{ii}) = true; else featureConfig.(featuresToTest{ii}) = false; end end % ----------------------------------------------------------------- % Loop over every feature set. for ii = 1:numel(featureConfigsToTest) % ------------------------------------------------------------- % Determine the current feature configuration to test. Update % the feature extractor. currentConfig = featureConfigsToTest{ii}; set(extractor,currentConfig) % ------------------------------------------------------------- % ------------------------------------------------------------- % Train and get k-fold cross-validation accuracy for current % feature configuration. [trueLabels,predictedLabels] = HelperTrainAndValidateNetwork(ads,adsAug,extractor); trueLabelsMat = cat(1,trueLabels{:}); predictedLabelsMat = cat(1,predictedLabels{:}); valAccuracy = mean(trueLabelsMat==predictedLabelsMat)*100; % ------------------------------------------------------------- % Update Logbook ---------------------------------------------- result = table(currentConfig,valAccuracy, ... 'VariableNames',["Feature Configuration","Accuracy"]); logbook = [logbook;result]; % ------------------------------------------------------------- end % ----------------------------------------------------------------- % Determine and print the setting with the best accuracy. If % accuracy did not improve, end the run. [a,b] = max(logbook{:,'Accuracy'}); if a <= bestAccuracy wrapperIdx = inf; else wrapperIdx = wrapperIdx + 1; end bestAccuracy = a; % ----------------------------------------------------------------- % ----------------------------------------------------------------- % Update the features-to-test based on the most recent winner. winner = logbook{b,'Feature Configuration'}; fn = fieldnames(winner); tf = structfun(@(x)(x),winner); if strcmpi(direction,"backward") featuresToRemove = fn(~tf); else featuresToRemove = fn(tf); end for ii = 1:numel(featuresToRemove) loc = strcmp(featuresToTest,featuresToRemove{ii}); featuresToTest(loc) = []; if strcmpi(direction,"backward") featureConfig.(featuresToRemove{ii}) = false; else featureConfig.(featuresToRemove{ii}) = true; end end % ----------------------------------------------------------------- end % --------------------------------------------------------------------- % Sort the logbook and make it more readable logbook(1,:) = []; % Delete placeholder first row logbook = sortrows(logbook,{'Accuracy'},{'descend'}); bestFeatures = logbook{1,'Feature Configuration'}; m = logbook{:,'Feature Configuration'}; fn = fieldnames(m); myString = strings(numel(m),1); for wrapperIdx = 1:numel(m) tf = structfun(@(x)(x),logbook{wrapperIdx,'Feature Configuration'}); myString(wrapperIdx) = strjoin(fn(tf),", "); end logbook = table(myString,logbook{:,'Accuracy'},'VariableNames',["Features","Accuracy"]); % --------------------------------------------------------------------- end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%