Последовательный выбор признаков для речевого распознавания эмоции

Этот пример показывает, что типичный рабочий процесс для выбора признаков применился к задаче речевого распознавания эмоции. Вы начинаете путем создания базовой точности, использующей общие функции аудио (MFCC). Вы затем увеличиваете свой набор данных, чтобы уменьшить сверхподбор кривой. Наконец, вы выполняете последовательный выбор признаков, чтобы выбрать лучший набор функций.

В последовательном выборе признаков вы обучаете сеть на данном наборе функций и затем инкрементно добавляете или удаляете функции, пока самая высокая точность не достигнута [1]. В этом примере вы применяете последовательный прямой выбор к задаче речевого распознавания эмоции с помощью Берлинской Базы данных Эмоциональной Речи [2].

Загрузите набор данных

Берлинская База данных Эмоциональной Речи содержит 535 произнесения, на котором говорят 10 агентов, предназначенных, чтобы передать одну из следующих эмоций: гнев, скука, отвращение, беспокойство/страх, счастье, печаль, или нейтральный. Эмоции являются независимым текстом. Загрузите базу данных с http://emodb.bilderbar.info/index-1280.html и затем установите PathToDatabase к местоположению звуковых файлов. Создайте audioDatastore это указывает на звуковые файлы.

datafolder = PathToDatabase;
ads = audioDatastore(fullfile(datafolder,"wav"));

Имена файлов являются кодами, указывающими на ID динамика, текст, на котором говорят, эмоция и версия. Веб-сайт содержит ключ для интерпретации кода и дополнительной информации о динамиках, таких как пол и возраст. Составьте таблицу с переменными Speaker и Emotion. Декодируйте имена файлов в таблицу.

filepaths = ads.Files;
emotionCodes = cellfun(@(x)x(end-5),filepaths,'UniformOutput',false);
emotions = replace(emotionCodes,{'W','L','E','A','F','T','N'}, ...
    {'Anger','Boredom','Disgust','Anxiety/Fear','Happiness','Sadness','Neutral'});

speakerCodes = cellfun(@(x)x(end-10:end-9),filepaths,'UniformOutput',false);
labelTable = cell2table([speakerCodes,emotions],'VariableNames',{'Speaker','Emotion'});
labelTable.Emotion = categorical(labelTable.Emotion);
labelTable.Speaker = categorical(labelTable.Speaker);
head(labelTable)
ans =

  8×2 table

    Speaker     Emotion 
    _______    _________

      03       Happiness
      03       Neutral  
      03       Anger    
      03       Happiness
      03       Neutral  
      03       Sadness  
      03       Anger    
      03       Anger    

labelTable находится в том же порядке как файлы в audioDatastore. Установите Labels свойство audioDatastore к labelTable.

ads.Labels = labelTable;

Можно теперь разделить меткой и подмножеством, чтобы изолировать фрагменты данных. Подмножество datastore, который содержит динамик 12 передача скуки. Слушайте файл и просмотрите форму волны временного интервала. Отобразите полную метку, соответствующую произнесению.

speaker = categorical("12");
emotion = categorical("Boredom");
adsSubset = subset(ads,ads.Labels.Speaker==speaker & ads.Labels.Emotion == emotion);

[audio,adsInfo] = read(adsSubset);
fs = adsInfo.SampleRate;
sound(audio,fs)

t = (0:size(audio,1)-1)/fs;
figure
plot(t,audio)
grid on
xlabel('Time (s)')
ylabel('Amplitude')

Чтобы обеспечить точную оценку модели, вы создаете в этом примере, обучаете и подтверждаете перекрестную проверку k-сгиба отпуска один динамика (LOSO) использования. В этом методе вы обучаете использование k-1 динамики и затем подтверждаете на не учтенном динамике. Вы повторяете эту процедуру k времена. Итоговая точность валидации является средним значением сгибов k.

Создайте переменную, которая содержит идентификаторы динамика. Определите количество сгибов: 1 для каждого динамика. База данных содержит произнесение от 10 уникальных докладчиков. Используйте summary чтобы отобразить идентификаторы динамика (левый столбец) и количество произнесения, они способствуют базе данных (правый столбец).

speaker = ads.Labels.Speaker;
numFolds = numel(speaker);
summary(speaker)
     03      49 
     08      58 
     09      43 
     10      38 
     11      55 
     12      35 
     13      61 
     14      69 
     15      56 
     16      71 

Сгенерируйте базовую точность валидации

Как первый шаг к разработке модели машинного обучения, определите базовую точность.

Примите, что 10-кратная точность перекрестной проверки первой попытки обучения составляет приблизительно 60% из-за недостаточных обучающих данных, и что модель, обученная на недостаточных данных, сверхсоответствует некоторым сгибам и underfits другим. Чтобы улучшить полную подгонку, увеличьте размер набора данных к 50 разам с помощью audioDataAugmenter.

Создайте audioDataAugmenter объект. Установите вероятность применения перехода подачи к 0.5 и используйте область значений по умолчанию. Установите вероятность применения смещения во времени к 1 и используйте область значений [-0.3,0.3] секунды. Установите вероятность добавления шума к 1 и укажите диапазон ОСШ как [-20,40] дБ.

augmenter = audioDataAugmenter('NumAugmentations',50, ...
    'TimeStretchProbability',0, ...
    'VolumeControlProbability',0, ...
    ...
    'PitchShiftProbability',0.5, ...
    ...
    'TimeShiftProbability',1, ...
    'TimeShiftRange',[-0.3,0.3], ...
    ...
    'AddNoiseProbability',1, ...
    'SNRRange', [-20,40]);

Создайте новую папку в своей текущей папке, чтобы содержать увеличенный набор данных.

currentDir = pwd;
writeDirectory = [currentDir,'\augmentedData'];
mkdir(writeDirectory)

Для каждого файла в аудио datastore:

  1. Создайте 50 увеличений.

  2. Нормируйте аудио, чтобы иметь макс. абсолютное значение 1.

  3. Запишите увеличенные аудиоданные как файл WAV. Добавьте _augK к каждым из имен файлов, где K является номером увеличения. Чтобы ускорить обработку, используйте parfor и раздел datastore.

reset(ads)
numPartitions = 6;
for ii = 1:numPartitions
    adsPart = partition(ads,numPartitions,ii);
    while hasdata(adsPart)
        [x,adsInfo] = read(adsPart);
        data = augment(augmenter,x,fs);

        [~,fn] = fileparts(adsInfo.FileName);
        for i = 1:size(data,1)
            augmentedAudio = data.Audio{i};
            augmentedAudio = augmentedAudio/max(abs(augmentedAudio),[],'all');
            augNum = num2str(i);
            if numel(augNum)==1
                iString = ['0',augNum];
            else
                iString = augNum;
            end
            audiowrite([writeDirectory,'\',sprintf('%s_aug%s.wav',fn,iString)],augmentedAudio,fs);
        end
    end
end

Создайте аудио datastore, который указывает на увеличенный набор данных. Реплицируйте строки таблицы метки исходного datastore NumAugmentations времена, чтобы определить метки увеличенного datastore.

augads = audioDatastore(writeDirectory);
augads.Labels = repelem(ads.Labels,augmenter.NumAugmentations,1);

Mel-частота cepstral коэффициенты (MFCC), дельта-MFCC и дельта дельты MFCC являются популярными функциями аудио. Как базовая линия, используйте MFCC, дельту-MFCC и дельту дельты MFCC с 30 MS Windows и никаким перекрытием.

Создайте audioFeatureExtractor объект. Установите Window к периодическому Окну Хэмминга на 30 мс, OverlapLength к 0, и SampleRate к частоте дискретизации базы данных. Установите mfcc, mfccDelta, и mfccDeltaDelta к true извлекать их.

win = hamming(round(0.03*fs),"periodic");
overlapLength = 0;

extractor = audioFeatureExtractor( ...
    'Window',win, ...
    'OverlapLength',overlapLength, ...
    'SampleRate',fs, ...
    ...
    'mfcc',true, ...
    'mfccDelta',true, ...
    'mfccDeltaDelta',true)
extractor = 

  audioFeatureExtractor with properties:

   Properties
                     Window: [480×1 double]
              OverlapLength: 0
                 SampleRate: 16000
                  FFTLength: []
    SpectralDescriptorInput: 'linearSpectrum'

   Enabled Features
     mfcc, mfccDelta, mfccDeltaDelta

   Disabled Features
     linearSpectrum, melSpectrum, barkSpectrum, erbSpectrum, gtcc, gtccDelta
     gtccDeltaDelta, spectralCentroid, spectralCrest, spectralDecrease, spectralEntropy, spectralFlatness
     spectralFlux, spectralKurtosis, spectralRolloffPoint, spectralSkewness, spectralSlope, spectralSpread
     pitch, harmonicRatio


   To extract a feature, set the corresponding property to true.
   For example, obj.mfcc = true, adds mfcc to the list of enabled features.

Шаги для каждого сгиба следуют:

  1. Разделите аудио datastore на наборы обучения и валидации.

  2. Извлеките характеристические векторы из наборов обучения и валидации.

  3. Нормируйте характеристические векторы.

  4. Буферизуйте характеристические векторы в последовательности 20 с перекрытиями 10.

  5. Реплицируйте метки так, чтобы они были во взаимно-однозначном соответствии с характеристическими векторами.

  6. Задайте опции обучения.

  7. Задайте сеть.

  8. Обучите сеть.

  9. Оцените сеть.

1. Разделите аудио datastore на наборы обучения и валидации. Для набора разработки пропустите первый динамик. Для набора валидации используйте только произнесение от первого докладчика. Преобразуйте данные в длинные массивы.

adsTrain = subset(augads,augads.Labels.Speaker~=speaker(1));
adsTrain.Labels = adsTrain.Labels.Emotion;
tallTrain = tall(adsTrain);

adsValidation = subset(ads,ads.Labels.Speaker==speaker(1));
adsValidation.Labels = adsValidation.Labels.Emotion;
tallValidation = tall(adsValidation);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

2. Извлеките функции из наборов обучения и валидации. Переориентируйте функции так, чтобы время приехало строки, чтобы быть совместимым с sequenceInputLayer.

featuresTallTrain = cellfun(@(x)extract(extractor,x),tallTrain,"UniformOutput",false);
featuresTallTrain = cellfun(@(x)x',featuresTallTrain,"UniformOutput",false);
featuresTrain     = gather(featuresTallTrain);

featuresTallValidation = cellfun(@(x)extract(extractor,x),tallValidation,"UniformOutput",false);
featuresTallValidation = cellfun(@(x)x',featuresTallValidation,"UniformOutput",false);
featuresValidation = gather(featuresTallValidation);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 2 min 18 sec
Evaluation completed in 2 min 18 sec
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 1.2 sec
Evaluation completed in 1.2 sec

3. Используйте набор обучающих данных, чтобы определить среднее и стандартное отклонение каждой функции. Нормируйте наборы обучения и валидации.

allFeatures = cat(2,featuresTrain{:});
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');

featuresTrain = cellfun(@(x)(x-M)./S,featuresTrain,'UniformOutput',false);
featuresValidation = cellfun(@(x)(x-M)./S,featuresValidation,'UniformOutput',false);

4. Буферизуйте характеристические векторы в последовательности так, чтобы каждая последовательность состояла из 20 характеристических векторов с перекрытиями 10 характеристических векторов.

featureVectorsPerSequence = 20;
featureVectorOverlap = 10;
[sequencesTrain,sequencePerFileTrain] = HelperFeatureVector2Sequence(featuresTrain,featureVectorsPerSequence,featureVectorOverlap);
[sequencesValidation,sequencePerFileValidation] = HelperFeatureVector2Sequence(featuresValidation,featureVectorsPerSequence,featureVectorOverlap);

5. Реплицируйте метки наборов обучения и валидации так, чтобы они были во взаимно-однозначном соответствии с последовательностями. Не у всех динамиков есть произнесение для всех эмоций. Создайте пустой categorical массив, который содержит все эмоциональные категории и добавляет его к меткам валидации так, чтобы категориальный массив содержал все эмоции.

labelsTrain = repelem(adsTrain.Labels,[sequencePerFileTrain{:}]);

emptyEmotions = ads.Labels.Emotion;
emptyEmotions(:) = [];
labelsValidation = [emptyEmotions;adsValidation.Labels];
labelsValidation = repelem(labelsValidation,[sequencePerFileValidation{:}]);

6. Задайте сеть BiLSTM с помощью bilstmLayer. Поместите dropoutLayer до и после bilstmLayer помочь предотвратить сверхподбор кривой.

dropoutProb1 = 0.3;
numUnits = 200;
dropoutProb2 = 0.6;
layers = [ ...
    sequenceInputLayer(size(sequencesTrain{1},1))
    dropoutLayer(dropoutProb1)
    bilstmLayer(numUnits,"OutputMode","last")
    dropoutLayer(dropoutProb2)
    fullyConnectedLayer(numel(categories(emptyEmotions)))
    softmaxLayer
    classificationLayer];

7. Задайте опции обучения с помощью trainingOptions.

miniBatchSize = 512;
initialLearnRate = 0.005;
learnRateDropPeriod = 2;
maxEpochs = 3;
options = trainingOptions("adam", ...
    "MiniBatchSize",miniBatchSize, ...
    "InitialLearnRate",initialLearnRate, ...
    "LearnRateDropPeriod",learnRateDropPeriod, ...
    "LearnRateSchedule","piecewise", ...
    "MaxEpochs",maxEpochs, ...
    "Shuffle","every-epoch", ...
    "ValidationData",{sequencesValidation,labelsValidation}, ...
    "Verbose",false, ...
    "Plots","Training-Progress");

8. Обучите сеть с помощью trainNetwork.

net = trainNetwork(sequencesTrain,labelsTrain,layers,options);

9. Оцените сеть. Вызовите classify получить предсказанные метки на последовательность. Заставьте режим предсказанных меток каждой последовательности получать предсказанные метки каждого файла. Постройте график беспорядка истинных меток и предсказанных меток.

predictedLabelsPerSequence = classify(net,sequencesValidation);

labelsTrue = adsValidation.Labels;
labelsPred = labelsTrue;
idx = 1;
for ii = 1:numel(labelsTrue)
    labelsPred(ii,:) = mode(predictedLabelsPerSequence(idx:idx + sequencePerFileValidation{ii} - 1,:),1);
    idx = idx + sequencePerFileValidation{ii};
end

figure
cm = confusionchart(labelsTrue,labelsPred);
valAccuracy = mean(labelsTrue==labelsPred)*100;
cm.Title = sprintf('Confusion Matrix for Fold 1\nAccuracy = %0.1f',valAccuracy);
sortClasses(cm,categories(emptyEmotions))
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Функция помощника HelperTrainAndValidateNetwork выполняет шаги, обрисованные в общих чертах выше для всех 10 сгибов, и возвращает истинные и предсказанные метки для каждого сгиба. Вызовите HelperTrainAndValidateNetwork с audioDatastore, увеличенный audioDatastore, и audioFeatureExtractor.

[labelsTrue,labelsPred] = HelperTrainAndValidateNetwork(ads,augads,extractor);

Распечатайте точность на сгиб и постройте 10-кратный график беспорядка.

for ii = 1:numel(labelsTrue)
    foldAcc = mean(labelsTrue{ii}==labelsPred{ii})*100;
    fprintf('Fold %1.0f, Accuracy = %0.1f\n',ii,foldAcc);
end

labelsTrueMat = cat(1,labelsTrue{:});
labelsPredMat = cat(1,labelsPred{:});
figure
cm = confusionchart(labelsTrueMat,labelsPredMat);
valAccuracy = mean(labelsTrueMat==labelsPredMat)*100;
cm.Title = sprintf('Confusion Matrix for 10-Fold Cross-Validation\nAverage Accuracy = %0.1f',valAccuracy);
sortClasses(cm,categories(emptyEmotions))
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
Fold 1, Accuracy = 87.8
Fold 2, Accuracy = 86.2
Fold 3, Accuracy = 72.1
Fold 4, Accuracy = 84.2
Fold 5, Accuracy = 76.4
Fold 6, Accuracy = 65.7
Fold 7, Accuracy = 68.9
Fold 8, Accuracy = 85.5
Fold 9, Accuracy = 75.0
Fold 10, Accuracy = 64.8

Последовательный выбор признаков

Затем попытайтесь далее улучшить точность путем выбора лучшего набора функций. Последовательный выбор признаков может быть трудоемким. Чтобы уменьшать время выбора признаков, уменьшайте увеличенный набор аудиоданных так, чтобы было только 10 увеличений для каждого исходного файла. Используйте этот уменьшаемый набор данных, чтобы выбрать функции. Если лучший набор выбран, вы обучаетесь на полном увеличенном наборе данных, который в 50 раз больше для итоговой оценки.

augads10 = subset(augads,1:5:numel(augads.Files));

Создайте новый audioFeatureExtractor объект. Используйте то же окно и длину перекрытия как ранее. Установите все функции, которые вы хотите протестировать к true.

extractor = audioFeatureExtractor( ...
    'Window',       win, ...
    'OverlapLength',overlapLength, ...
    'SampleRate',   fs, ...
    ...
    'linearSpectrum',      false, ...
    'melSpectrum',         false, ...
    'barkSpectrum',        false, ...
    'erbSpectrum',         false, ...
    ...
    'mfcc',                true, ...
    'mfccDelta',           true, ...
    'mfccDeltaDelta',      true, ...
    'gtcc',                true, ...
    'gtccDelta',           true, ...
    'gtccDeltaDelta',      true, ...
    ...
    'SpectralDescriptorInput','melSpectrum', ...
    'spectralCentroid',    true, ...
    'spectralCrest',       true, ...
    'spectralDecrease',    true, ...
    'spectralEntropy',     true, ...
    'spectralFlatness',    true, ...
    'spectralFlux',        true, ...
    'spectralKurtosis',    true, ...
    'spectralRolloffPoint',true, ...
    'spectralSkewness',    true, ...
    'spectralSlope',       true, ...
    'spectralSpread',      true, ...
    ...
    'pitch',               true, ...
    'harmonicRatio',       true);

Передайте выбор

В прямом выборе вы запускаете путем оценки каждой функции отдельно. Если лучшая одна функция понята, вы содержите ее и оцениваете каждую пару функций. Если лучшие две функции поняты, вы содержите их и оцениваете для лучших трех функций и т.д. Когда точность прекращает улучшаться, вы заканчиваете выбор признаков.

[logbook,bestFeatures] = ...
    HelperSFS(ads,augads10,extractor,'forward');

Смотрите верхние и нижние настройки функции.

head(logbook)
tail(logbook)
ans =

  8×2 table

                                Features                                 Accuracy
    _________________________________________________________________    ________

    "mfccDelta, gtcc, gtccDelta, spectralCrest"                           75.327 
    "mfccDelta, gtcc, gtccDelta, gtccDeltaDelta, spectralCrest"           74.393 
    "mfccDelta, gtcc, gtccDelta, spectralDecrease"                        74.019 
    "mfccDelta, gtcc, gtccDelta, spectralCrest, spectralRolloffPoint"     74.019 
    "mfccDelta, gtcc, gtccDelta, spectralCrest, harmonicRatio"            74.019 
    "mfccDelta, gtcc, gtccDelta"                                          73.458 
    "mfccDelta, gtcc, gtccDelta, spectralCentroid"                        73.458 
    "mfcc, mfccDelta, gtcc, gtccDelta"                                    73.271 


ans =

  8×2 table

           Features           Accuracy
    ______________________    ________

    "pitch"                    31.963 
    "spectralFlux"             31.589 
    "spectralEntropy"           27.85 
    "spectralRolloffPoint"     25.794 
    "spectralCrest"            24.486 
    "spectralDecrease"         24.486 
    "harmonicRatio"            23.364 
    "spectralFlatness"         21.121 

Протестируйте выбранные функции на увеличенном наборе данных

Установите лучшую настройку функции, как определено последовательным выбором признаков, на audioFeatureExtractor объект.

set(extractor,bestFeatures)

Протестируйте 10-кратную точность перекрестной проверки LOSO выбранного набора функций с помощью полного увеличенного набора данных.

[labelsTrue,labelsPred] = HelperTrainAndValidateNetwork(ads,augads,extractor);

labelsTrueMat = cat(1,labelsTrue{:});
labelsPredMat = cat(1,labelsPred{:});
figure
cm = confusionchart(labelsTrueMat,labelsPredMat);
valAccuracy = mean(labelsTrueMat==labelsPredMat)*100;
cm.Title = sprintf('Confusion Matrix for 10-Fold Cross-Validation\nAverage Accuracy = %0.1f',valAccuracy);
sortClasses(cm,categories(emptyEmotions))
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Ссылки

[1] Джайн, А., и Д. Зонгкер. "Выбор признаков: Оценка, Приложение и Производительность Небольшой выборки". Транзакции IEEE согласно Анализу Шаблона и Искусственному интеллекту. Издание 19, Выпуск 2, 1997, стр 153-158.

[2] Burkhardt, F., А. Пэешк, М. Рольфес, В.Ф. Сендлмайер и Б. Вайс, "База данных немецкой эмоциональной речи". В межречи 2005 продолжений. Лиссабон, Португалия: международная речевая коммуникационная ассоциация, 2005.

Приложение - поддерживание функций

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [sequences,sequencePerFile] = HelperFeatureVector2Sequence(features,featureVectorsPerSequence,featureVectorOverlap)
    % Copyright 2019 MathWorks, Inc.
    if featureVectorsPerSequence <= featureVectorOverlap
        error('The number of overlapping feature vectors must be less than the number of feature vectors per sequence.')
    end

    hopLength = featureVectorsPerSequence - featureVectorOverlap;
    idx1 = 1;
    sequences = {};
    sequencePerFile = cell(numel(features),1);
    for ii = 1:numel(features)
        sequencePerFile{ii} = floor((size(features{ii},2) - featureVectorsPerSequence)/hopLength) + 1;
        idx2 = 1;
        for j = 1:sequencePerFile{ii}
            sequences{idx1,1} = features{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1); %#ok<AGROW>
            idx1 = idx1 + 1;
            idx2 = idx2 + hopLength;
        end
    end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [trueLabelsCrossFold,predictedLabelsCrossFold] = HelperTrainAndValidateNetwork(varargin)
    % Copyright 2019 The MathWorks, Inc.
    if nargin == 3
        ads = varargin{1};
        augads = varargin{2};
        extractor = varargin{3};
    elseif nargin == 2
        ads = varargin{1};
        augads = varargin{1};
        extractor = varargin{2};
    end
    speaker = categories(ads.Labels.Speaker);
    numFolds = numel(speaker);
    emptyEmotions = categorical(ads.Labels.Emotion);
    emptyEmotions(:) = [];

    % Loop over each fold
    trueLabelsCrossFold = {};
    predictedLabelsCrossFold = {};
    
    for i = 1:numFolds
        
        % 1. Divide the audio datastore into training and validation sets.
        % Convert the data to tall arrays.
        idxTrain           = augads.Labels.Speaker~=speaker(i);
        augadsTrain        = subset(augads,idxTrain);
        augadsTrain.Labels = augadsTrain.Labels.Emotion;
        tallTrain          = tall(augadsTrain);
        idxValidation        = ads.Labels.Speaker==speaker(i);
        adsValidation        = subset(ads,idxValidation);
        adsValidation.Labels = adsValidation.Labels.Emotion;
        tallValidation       = tall(adsValidation);

        % 2. Extract features from the training set. Reorient the features
        % so that time is along rows to be compatible with
        % sequenceInputLayer.
        tallTrain         = cellfun(@(x)x/max(abs(x),[],'all'),tallTrain,"UniformOutput",false);
        tallFeaturesTrain = cellfun(@(x)extract(extractor,x),tallTrain,"UniformOutput",false);
        tallFeaturesTrain = cellfun(@(x)x',tallFeaturesTrain,"UniformOutput",false);  %#ok<NASGU>
        [~,featuresTrain] = evalc('gather(tallFeaturesTrain)'); % Use evalc to suppress command-line output.
        tallValidation         = cellfun(@(x)x/max(abs(x),[],'all'),tallValidation,"UniformOutput",false);
        tallFeaturesValidation = cellfun(@(x)extract(extractor,x),tallValidation,"UniformOutput",false);
        tallFeaturesValidation = cellfun(@(x)x',tallFeaturesValidation,"UniformOutput",false); %#ok<NASGU>
        [~,featuresValidation] = evalc('gather(tallFeaturesValidation)'); % Use evalc to suppress command-line output.

        % 3. Use the training set to determine the mean and standard
        % deviation of each feature. Normalize the training and validation
        % sets.
        allFeatures = cat(2,featuresTrain{:});
        M = mean(allFeatures,2,'omitnan');
        S = std(allFeatures,0,2,'omitnan');
        featuresTrain = cellfun(@(x)(x-M)./S,featuresTrain,'UniformOutput',false);
        for ii = 1:numel(featuresTrain)
            idx = find(isnan(featuresTrain{ii}));
            if ~isempty(idx)
                featuresTrain{ii}(idx) = 0;
            end
        end
        featuresValidation = cellfun(@(x)(x-M)./S,featuresValidation,'UniformOutput',false);
        for ii = 1:numel(featuresValidation)
            idx = find(isnan(featuresValidation{ii}));
            if ~isempty(idx)
                featuresValidation{ii}(idx) = 0;
            end
        end

        % 4. Buffer the sequences so that each sequence consists of twenty
        % feature vectors with overlaps of 10 feature vectors.
        featureVectorsPerSequence = 20;
        featureVectorOverlap = 10;
        [sequencesTrain,sequencePerFileTrain] = HelperFeatureVector2Sequence(featuresTrain,featureVectorsPerSequence,featureVectorOverlap);
        [sequencesValidation,sequencePerFileValidation] = HelperFeatureVector2Sequence(featuresValidation,featureVectorsPerSequence,featureVectorOverlap);

        % 5. Replicate the labels of the train and validation sets so that
        % they are in one-to-one correspondence with the sequences.
        labelsTrain = [emptyEmotions;augadsTrain.Labels];
        labelsTrain = labelsTrain(:);
        labelsTrain = repelem(labelsTrain,[sequencePerFileTrain{:}]);

        % 6. Define a BiLSTM network.
        dropoutProb1 = 0.3;
        numUnits     = 200;
        dropoutProb2 = 0.6;
        layers = [ ...
            sequenceInputLayer(size(sequencesTrain{1},1))
            dropoutLayer(dropoutProb1)
            bilstmLayer(numUnits,"OutputMode","last")
            dropoutLayer(dropoutProb2)
            fullyConnectedLayer(numel(categories(emptyEmotions)))
            softmaxLayer
            classificationLayer];

        % 7. Define training options.
        miniBatchSize       = 512;
        initialLearnRate    = 0.005;
        learnRateDropPeriod = 2;
        maxEpochs           = 3;
        options = trainingOptions("adam", ...
            "MiniBatchSize",miniBatchSize, ...
            "InitialLearnRate",initialLearnRate, ...
            "LearnRateDropPeriod",learnRateDropPeriod, ...
            "LearnRateSchedule","piecewise", ...
            "MaxEpochs",maxEpochs, ...
            "Shuffle","every-epoch", ...
            "Verbose",false);

        % 8. Train the network.
        net = trainNetwork(sequencesTrain,labelsTrain,layers,options);

        % 9. Evaluate the network. Call classify to get the predicted labels
        % for each sequence. Get the mode of the predicted labels of each
        % sequence to get the predicted labels of each file.
        predictedLabelsPerSequence = classify(net,sequencesValidation);
        trueLabels = categorical(adsValidation.Labels);
        predictedLabels = trueLabels;
        idx1 = 1;
        for ii = 1:numel(trueLabels)
            predictedLabels(ii,:) = mode(predictedLabelsPerSequence(idx1:idx1 + sequencePerFileValidation{ii} - 1,:),1);
            idx1 = idx1 + sequencePerFileValidation{ii};
        end
        trueLabelsCrossFold{i} = trueLabels; %#ok<AGROW>
        predictedLabelsCrossFold{i} = predictedLabels; %#ok<AGROW>
    end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [logbook,bestFeatures] = HelperSFS(ads,adsAug,extractor,direction)
    % logbook = HelperSFS(ads,adsAug,extractor,direction)
    % returns a table, logbook, that contains the feature configurations tested
    % and associated validation accuracy.
    %   ads       - audioDatastore object that points to the original dataset (used for val).
    %   adsAug    - audioDatastore object that points to the augmented dataset (used for dev).
    %   extractor - audioFeatureExtractor object. Set all features to test to true.
    %   direction - specify as 'forward' or 'backward'
    %
    %[logbook,bestFeatures] = HelperSFS(ads,adsAug,extractor,direction)
    % also returns a struct, bestFeatures, containing the best feature
    % configuration for audioFeatureExtractor.

    % Copyright 2019 The MathWorks, Inc.

    featuresToTest = fieldnames(info(extractor));
    N = numel(featuresToTest);

    % ---------------------------------------------------------------------
    % Set the initial feature configuration: all on for backward selection
    % or all off for forward selection.
    featureConfig  = info(extractor);
    for i = 1:N
        if strcmpi(direction,"backward")
            featureConfig.(featuresToTest{i}) = true;
        else
            featureConfig.(featuresToTest{i}) = false;
        end
    end
    % ---------------------------------------------------------------------

    % Initialize logbook to track feature configuration and accuracy.
    logbook = table(featureConfig,0,'VariableNames',["Feature Configuration","Accuracy"]);

    %% Perform sequential feature evaluation
    wrapperIdx = 1;
    bestAccuracy = 0;
    while wrapperIdx <= N
        % -----------------------------------------------------------------
        % Create a cell array containing all feature configurations to test
        % in the current loop.
        featureConfigsToTest = cell(numel(featuresToTest),1);
        for ii = 1:numel(featuresToTest)
            if strcmpi(direction,"backward")
                featureConfig.(featuresToTest{ii}) = false;
            else
                featureConfig.(featuresToTest{ii}) = true;
            end
            featureConfigsToTest{ii} = featureConfig;
            if strcmpi(direction,"backward")
                featureConfig.(featuresToTest{ii}) = true;
            else
                featureConfig.(featuresToTest{ii}) = false;
            end
        end
        % -----------------------------------------------------------------

        % Loop over every feature set.
        for ii = 1:numel(featureConfigsToTest)

            % -------------------------------------------------------------
            % Determine the current feature configuration to test. Update
            % the feature extractor.
            currentConfig = featureConfigsToTest{ii};
            set(extractor,currentConfig)
            % -------------------------------------------------------------

            % -------------------------------------------------------------
            % Train and get k-fold cross-validation accuracy for current
            % feature configuration.
            [trueLabels,predictedLabels] = HelperTrainAndValidateNetwork(ads,adsAug,extractor);
            trueLabelsMat = cat(1,trueLabels{:});
            predictedLabelsMat = cat(1,predictedLabels{:});
            valAccuracy = mean(trueLabelsMat==predictedLabelsMat)*100;
            % -------------------------------------------------------------

            % Update Logbook ----------------------------------------------
            result = table(currentConfig,valAccuracy, ...
                'VariableNames',["Feature Configuration","Accuracy"]);
            logbook = [logbook;result];
            % -------------------------------------------------------------

        end

        % -----------------------------------------------------------------
        % Determine and print the setting with the best accuracy. If
        % accuracy did not improve, end the run.
        [a,b] = max(logbook{:,'Accuracy'});
        if a <= bestAccuracy
            wrapperIdx = inf;
        else
            wrapperIdx = wrapperIdx + 1;
        end
        bestAccuracy = a;
        % -----------------------------------------------------------------

        % -----------------------------------------------------------------
        % Update the features-to-test based on the most recent winner.
        winner = logbook{b,'Feature Configuration'};
        fn = fieldnames(winner);
        tf = structfun(@(x)(x),winner);
        if strcmpi(direction,"backward")
            featuresToRemove = fn(~tf);
        else
            featuresToRemove = fn(tf);
        end
        for ii = 1:numel(featuresToRemove)
            loc =  strcmp(featuresToTest,featuresToRemove{ii});
            featuresToTest(loc) = [];
            if strcmpi(direction,"backward")
                featureConfig.(featuresToRemove{ii}) = false;
            else
                featureConfig.(featuresToRemove{ii}) = true;
            end
        end
        % -----------------------------------------------------------------

    end
    
    % ---------------------------------------------------------------------
    % Sort the logbook and make it more readable
    logbook(1,:) = []; % Delete placeholder first row
    logbook = sortrows(logbook,{'Accuracy'},{'descend'});
    bestFeatures = logbook{1,'Feature Configuration'};
    m = logbook{:,'Feature Configuration'};
    fn = fieldnames(m);
    myString = strings(numel(m),1);
    for wrapperIdx = 1:numel(m)
        tf = structfun(@(x)(x),logbook{wrapperIdx,'Feature Configuration'});
        myString(wrapperIdx) = strjoin(fn(tf),", ");
    end
    logbook = table(myString,logbook{:,'Accuracy'},'VariableNames',["Features","Accuracy"]);
    % ---------------------------------------------------------------------
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

Смотрите также

| | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте