Речевое распознавание эмоции

Этот пример иллюстрирует простую систему речевого распознавания эмоции (SER) с помощью сети BiLSTM. Вы начинаете путем загрузки набора данных и затем тестирования обучившего сеть на отдельных файлах. Сеть была обучена на маленькой немецкоязычной базе данных [1].

Пример обходит вас посредством обучения сети, которая включает загрузку, увеличение и обучение набор данных. Наконец, вы выполняете отпуск один динамик (LOSO) 10-кратная перекрестная проверка, чтобы оценить сетевую архитектуру.

Признаки, используемые в этом примере, были выбраны с помощью последовательного выбора признаков, похожего на метод, описанный в Последовательном Выборе признаков для Функций аудио (Audio Toolbox).

Загрузите набор данных

Загрузите Берлинскую Базу данных Эмоциональной Речи [1]. База данных содержит 535 произнесения, на котором говорят 10 агентов, предназначенных, чтобы передать одну из следующих эмоций: гнев, скука, отвращение, беспокойство/страх, счастье, печаль, или нейтральный. Эмоции являются независимым текстом.

url = "http://emodb.bilderbar.info/download/download.zip";
downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,"Emo-DB");

if ~exist(datasetFolder,'dir')
    disp('Downloading Emo-DB (40.5 MB) ...')
    unzip(url,datasetFolder)
end

Создайте audioDatastore (Audio Toolbox), который указывает на звуковые файлы.

ads = audioDatastore(fullfile(datasetFolder,"wav"));

Имена файлов являются кодами, указывающими на ID динамика, текст, на котором говорят, эмоция и версия. Веб-сайт содержит ключ для интерпретации кода и дополнительной информации о динамиках, таких как пол и возраст. Составьте таблицу с переменными Speaker и Emotion. Декодируйте имена файлов в таблицу.

filepaths = ads.Files;
emotionCodes = cellfun(@(x)x(end-5),filepaths,'UniformOutput',false);
emotions = replace(emotionCodes,{'W','L','E','A','F','T','N'}, ...
    {'Anger','Boredom','Disgust','Anxiety/Fear','Happiness','Sadness','Neutral'});

speakerCodes = cellfun(@(x)x(end-10:end-9),filepaths,'UniformOutput',false);
labelTable = cell2table([speakerCodes,emotions],'VariableNames',{'Speaker','Emotion'});
labelTable.Emotion = categorical(labelTable.Emotion);
labelTable.Speaker = categorical(labelTable.Speaker);
summary(labelTable)
Variables:

    Speaker: 535×1 categorical

        Values:

            03       49   
            08       58   
            09       43   
            10       38   
            11       55   
            12       35   
            13       61   
            14       69   
            15       56   
            16       71   

    Emotion: 535×1 categorical

        Values:

            Anger             127   
            Anxiety/Fear       69   
            Boredom            81   
            Disgust            46   
            Happiness          71   
            Neutral            79   
            Sadness            62   

labelTable находится в том же порядке как файлы в audioDatastore. Установите Labels свойство audioDatastore к labelTable.

ads.Labels = labelTable;

Выполните речевое распознавание эмоции

Загрузите и загрузите предварительно обученную сеть, audioFeatureExtractor Объект (Audio Toolbox) раньше обучал сеть и коэффициенты нормализации для функций. Эта сеть была обучена с помощью всех динамиков в наборе данных кроме динамика 03.

url = 'http://ssd.mathworks.com/supportfiles/audio/SpeechEmotionRecognition.zip';
    downloadNetFolder = tempdir;
    netFolder = fullfile(downloadNetFolder,'SpeechEmotionRecognition');
    if ~exist(netFolder,'dir')
        disp('Downloading pretrained network (1 file - 1.5 MB) ...')
        unzip(url,downloadNetFolder)
    end
load(fullfile(netFolder,'network_Audio_SER.mat'));

Частота дискретизации установлена на audioFeatureExtractor соответствует частоте дискретизации набора данных.

fs = afe.SampleRate;

Выберите динамик и эмоцию, затем подмножество datastore, чтобы только включать выбранный динамик и эмоцию. Читайте из datastore и слушайте файл.

speaker = categorical("03");
эмоция = categorical("Disgust");

adsSubset = подмножество (объявления, ads.Labels.Speaker == динамик & ads.Labels.Emotion == эмоция);

аудио = читало (adsSubset);
звук (аудио, фс)

Используйте audioFeatureExtractor возразите, чтобы извлечь функции и затем транспонировать их так, чтобы время приехало строки. Нормируйте функции и затем преобразуйте их в последовательности с 20 элементами с перекрытием с 10 элементами, которое соответствует приблизительно 600 MS Windows с перекрытием на 300 мс. Используйте функцию поддержки, HelperFeatureVector2Sequence, чтобы преобразовать массив характеристических векторов к последовательностям.

features = (extract(afe,audio))';

featuresNormalized = (features - normalizers.Mean)./normalizers.StandardDeviation;

numOverlap = 10;
featureSequences = HelperFeatureVector2Sequence (featuresNormalized, 20, numOverlap);

Подайте последовательности функции в сеть для предсказания. Вычислите среднее предсказание и постройте вероятностное распределение выбранных эмоций как круговая диаграмма. Можно судить различные динамики, эмоции, перекрытие последовательности и среднее значение предсказания, чтобы проверить производительность сети. Чтобы получить реалистическое приближение эффективности сети, используйте динамик 03, на котором не была обучена сеть.

YPred = double(predict(net,featureSequences));

average = "mode";
switch среднее значение
    case 'mean'
        probs = среднее значение (YPred, 1);
    case 'median'
        probs = медиана (YPred, 1);
    case 'mode'
        probs = режим (YPred, 1);
end

круг (probs./sum (probs), строка (net.Layers (конец).Classes))

Остаток от примера иллюстрирует, как сеть была обучена и подтверждена.

Обучение сети

10-кратная точность перекрестной проверки первой попытки обучения составляла приблизительно 60% из-за недостаточных обучающих данных. Модель, обученная на недостаточных данных, сверхсоответствует некоторым сгибам и underfits другим. Чтобы улучшить полную подгонку, увеличьте размер набора данных с помощью audioDataAugmenter (Audio Toolbox). 50 увеличений на файл были выбраны опытным путем в качестве хорошего компромисса между улучшением точности и временем вычислений. Можно сократить число увеличений, чтобы ускорить пример.

Создайте audioDataAugmenter объект. Установите вероятность применения перехода тангажа к 0.5 и используйте область значений по умолчанию. Установите вероятность применения смещения во времени к 1 и используйте область значений [-0.3,0.3] секунды. Установите вероятность добавления шума к 1 и укажите диапазон ОСШ как [-20,40] дБ.

numAugmentations = 50;
увеличение = audioDataAugmenter ('NumAugmentations', numAugmentations, ...
    'TimeStretchProbability',0, ...
    'VolumeControlProbability',0, ...
    ...
    'PitchShiftProbability',0.5, ...
    ...
    'TimeShiftProbability',1, ...
    'TimeShiftRange',[-0.3,0.3], ...
    ...
    'AddNoiseProbability',1, ...
    'SNRRange', [-20,40]);

Создайте новую папку в своей текущей папке, чтобы содержать увеличенный набор данных.

currentDir = pwd;
writeDirectory = fullfile(currentDir,'augmentedData');
mkdir(writeDirectory)

Для каждого файла в аудио datastore:

  1. Создайте 50 увеличений.

  2. Нормируйте аудио, чтобы иметь макс. абсолютное значение 1.

  3. Запишите увеличенные аудиоданные как файл WAV. ДобавлениеaugK к каждым из имен файлов, где K является номером увеличения. Чтобы ускорить обработку, используйте parfor и раздел datastore.

Этот метод увеличения базы данных является трудоемким (приблизительно 1 час) и пробел, использующий (приблизительно 26 Гбайт). Однако при итерации при выборе сетевой архитектуры или трубопровода извлечения признаков, эта оплачиваемая авансом стоимость обычно выгодна.

N = numel(ads.Files)*numAugmentations;
myWaitBar = HelperPoolWaitbar(N,"Augmenting Dataset...");

reset(ads)

numPartitions = 18;

tic
parfor ii = 1:numPartitions
    adsPart = partition(ads,numPartitions,ii);
    while hasdata(adsPart)
        [x,adsInfo] = read(adsPart);
        data = augment(augmenter,x,fs);

        [~,fn] = fileparts(adsInfo.FileName);
        for i = 1:size(data,1)
            augmentedAudio = data.Audio{i};
            augmentedAudio = augmentedAudio/max(abs(augmentedAudio),[],'all');
            augNum = num2str(i);
            if numel(augNum)==1
                iString = ['0',augNum];
            else
                iString = augNum;
            end
            audiowrite(fullfile(writeDirectory,sprintf('%s_aug%s.wav',fn,iString)),augmentedAudio,fs);
            increment(myWaitBar)
        end
    end
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).
delete(myWaitBar)
fprintf('Augmentation complete (%0.2f minutes).\n',toc/60)
Augmentation complete (6.28 minutes).

Создайте аудио datastore, который указывает на увеличенный набор данных. Реплицируйте строки таблицы метки исходного datastore NumAugmentations времена, чтобы определить метки увеличенного datastore.

adsAug = audioDatastore(writeDirectory);
adsAug.Labels = repelem(ads.Labels,augmenter.NumAugmentations,1);

Создайте audioFeatureExtractor Объект (Audio Toolbox). Установите Window к периодическому Окну Хэмминга на 30 мс, OverlapLength к 0, и SampleRate к частоте дискретизации базы данных. Установите gtcc, gtccDelta, mfccDelta, и spectralCrest к true извлекать их. Установите SpectralDescriptorInput к melSpectrum так, чтобы spectralCrest вычисляется для mel спектра.

win = hamming(round(0.03*fs),"periodic");
overlapLength = 0;

afe = audioFeatureExtractor( ...
    'Window',win, ...
    'OverlapLength',overlapLength, ...
    'SampleRate',fs, ...
    ...
    'gtcc',true, ...
    'gtccDelta',true, ...
    'mfccDelta',true, ...
    ...
    'SpectralDescriptorInput','melSpectrum', ...
    'spectralCrest',true);

Обучайтесь для развертывания

Когда вы будете обучаться для развертывания, используйте все доступные динамики в наборе данных. Установите учебный datastore на увеличенный datastore.

adsTrain = adsAug;

Преобразуйте учебный аудио datastore в длинный массив. Если у вас есть Parallel Computing Toolbox™, экстракция автоматически параллелизируется. Если у вас нет Parallel Computing Toolbox™, код продолжает запускаться.

tallTrain = tall(adsTrain);

Извлеките учебные функции и переориентируйте функции так, чтобы время приехало строки, чтобы быть совместимым с sequenceInputLayer.

featuresTallTrain = cellfun(@(x)extract(afe,x),tallTrain,"UniformOutput",false);
featuresTallTrain = cellfun(@(x)x',featuresTallTrain,"UniformOutput",false);
featuresTrain = gather(featuresTallTrain);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 1 min 7 sec
Evaluation completed in 1 min 7 sec

Используйте набор обучающих данных, чтобы определить среднее и стандартное отклонение каждой функции.

allFeatures = cat(2,featuresTrain{:});
M = mean(allFeatures,2,'omitnan');
S = std(allFeatures,0,2,'omitnan');

featuresTrain = cellfun(@(x)(x-M)./S,featuresTrain,'UniformOutput',false);

Буферизуйте характеристические векторы в последовательности так, чтобы каждая последовательность состояла из 20 характеристических векторов с перекрытиями 10 характеристических векторов.

featureVectorsPerSequence = 20;
featureVectorOverlap = 10;
[sequencesTrain,sequencePerFileTrain] = HelperFeatureVector2Sequence(featuresTrain,featureVectorsPerSequence,featureVectorOverlap);

Реплицируйте метки наборов обучения и валидации так, чтобы они были во взаимно-однозначном соответствии с последовательностями. Не у всех динамиков есть произнесение для всех эмоций. Создайте пустой categorical массив, который содержит все эмоциональные категории и добавляет его к меткам валидации так, чтобы категориальный массив содержал все эмоции.

labelsTrain = repelem(adsTrain.Labels.Emotion,[sequencePerFileTrain{:}]);

emptyEmotions = ads.Labels.Emotion;
emptyEmotions(:) = [];

Задайте сеть BiLSTM с помощью bilstmLayer. Поместите dropoutLayer до и после bilstmLayer помочь предотвратить сверхподбор кривой.

dropoutProb1 = 0.3;
numUnits = 200;
dropoutProb2 = 0.6;
layers = [ ...
    sequenceInputLayer(size(sequencesTrain{1},1))
    dropoutLayer(dropoutProb1)
    bilstmLayer(numUnits,"OutputMode","last")
    dropoutLayer(dropoutProb2)
    fullyConnectedLayer(numel(categories(emptyEmotions)))
    softmaxLayer
    classificationLayer];

Задайте опции обучения с помощью trainingOptions.

miniBatchSize = 512;
initialLearnRate = 0.005;
learnRateDropPeriod = 2;
maxEpochs = 3;
options = trainingOptions("adam", ...
    "MiniBatchSize",miniBatchSize, ...
    "InitialLearnRate",initialLearnRate, ...
    "LearnRateDropPeriod",learnRateDropPeriod, ...
    "LearnRateSchedule","piecewise", ...
    "MaxEpochs",maxEpochs, ...
    "Shuffle","every-epoch", ...
    "Verbose",false, ...
    "Plots","Training-Progress");

Обучите сеть с помощью trainNetwork.

net = trainNetwork(sequencesTrain,labelsTrain,layers,options);

Чтобы сохранить сеть, сконфигурировал audioFeatureExtractor, и коэффициенты нормализации, набор saveSERSystem к true.

saveSERSystem = false;
if saveSERSystem
    normalizers.Mean = M;
    normalizers.StandardDeviation = S;
    сохранение'network_Audio_SER.mat','net','afe','normalizers')
end

Обучение системной валидации

Чтобы обеспечить точную оценку модели, вы создали в этом примере, обучите и подтвердите перекрестную проверку k-сгиба отпуска один динамика (LOSO) использования. В этом методе вы обучаете использование k-1 динамики и затем подтверждают на не учтенном динамике. Вы повторяете эту процедуру k времена. Итоговая точность валидации является средним значением сгибов k.

Создайте переменную, которая содержит идентификаторы динамика. Определите количество сгибов: 1 для каждого динамика. База данных содержит произнесение от 10 уникальных докладчиков. Используйте summary чтобы отобразить идентификаторы динамика (левый столбец) и количество произнесения, они способствуют базе данных (правый столбец).

speaker = ads.Labels.Speaker;
numFolds = numel(speaker);
summary(speaker)
     03      49 
     08      58 
     09      43 
     10      38 
     11      55 
     12      35 
     13      61 
     14      69 
     15      56 
     16      71 

Функция помощника HelperTrainAndValidateNetwork выполняет шаги, обрисованные в общих чертах выше для всех 10 сгибов, и возвращает true и предсказанные метки для каждого сгиба. Вызовите HelperTrainAndValidateNetwork с audioDatastore, увеличенный audioDatastore, и audioFeatureExtractor.

[labelsTrue,labelsPred] = HelperTrainAndValidateNetwork(ads,adsAug,afe);

Распечатайте точность на сгиб и постройте 10-кратный график беспорядка.

for ii = 1:numel(labelsTrue)
    foldAcc = mean(labelsTrue{ii}==labelsPred{ii})*100;
    fprintf('Fold %1.0f, Accuracy = %0.1f\n',ii,foldAcc);
end
Fold 1, Accuracy = 73.5
Fold 2, Accuracy = 77.6
Fold 3, Accuracy = 74.4
Fold 4, Accuracy = 68.4
Fold 5, Accuracy = 76.4
Fold 6, Accuracy = 80.0
Fold 7, Accuracy = 73.8
Fold 8, Accuracy = 87.0
Fold 9, Accuracy = 69.6
Fold 10, Accuracy = 70.4
labelsTrueMat = cat(1,labelsTrue{:});
labelsPredMat = cat(1,labelsPred{:});
figure
cm = confusionchart(labelsTrueMat,labelsPredMat);
valAccuracy = mean(labelsTrueMat==labelsPredMat)*100;
cm.Title = sprintf('Confusion Matrix for 10-Fold Cross-Validation\nAverage Accuracy = %0.1f',valAccuracy);
sortClasses(cm,categories(emptyEmotions))
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Вспомогательные Функции

Преобразуйте массив характеристических векторов к последовательностям

function [sequences,sequencePerFile] = HelperFeatureVector2Sequence(features,featureVectorsPerSequence,featureVectorOverlap)
    % Copyright 2019 MathWorks, Inc.
    if featureVectorsPerSequence <= featureVectorOverlap
        error('The number of overlapping feature vectors must be less than the number of feature vectors per sequence.')
    end

    if ~iscell(features)
        features = {features};
    end
    hopLength = featureVectorsPerSequence - featureVectorOverlap;
    idx1 = 1;
    sequences = {};
    sequencePerFile = cell(numel(features),1);
    for ii = 1:numel(features)
        sequencePerFile{ii} = floor((size(features{ii},2) - featureVectorsPerSequence)/hopLength) + 1;
        idx2 = 1;
        for j = 1:sequencePerFile{ii}
            sequences{idx1,1} = features{ii}(:,idx2:idx2 + featureVectorsPerSequence - 1); %#ok<AGROW>
            idx1 = idx1 + 1;
            idx2 = idx2 + hopLength;
        end
    end
end

Обучите и проверьте сеть

function [trueLabelsCrossFold,predictedLabelsCrossFold] = HelperTrainAndValidateNetwork(varargin)
    % Copyright 2019 The MathWorks, Inc.
    if nargin == 3
        ads = varargin{1};
        augads = varargin{2};
        extractor = varargin{3};
    elseif nargin == 2
        ads = varargin{1};
        augads = varargin{1};
        extractor = varargin{2};
    end
    speaker = categories(ads.Labels.Speaker);
    numFolds = numel(speaker);
    emptyEmotions = (ads.Labels.Emotion);
    emptyEmotions(:) = [];

    % Loop over each fold.
    trueLabelsCrossFold = {};
    predictedLabelsCrossFold = {};
    
    for i = 1:numFolds
        
        % 1. Divide the audio datastore into training and validation sets.
        % Convert the data to tall arrays.
        idxTrain           = augads.Labels.Speaker~=speaker(i);
        augadsTrain        = subset(augads,idxTrain);
        augadsTrain.Labels = augadsTrain.Labels.Emotion;
        tallTrain          = tall(augadsTrain);
        idxValidation        = ads.Labels.Speaker==speaker(i);
        adsValidation        = subset(ads,idxValidation);
        adsValidation.Labels = adsValidation.Labels.Emotion;
        tallValidation       = tall(adsValidation);

        % 2. Extract features from the training set. Reorient the features
        % so that time is along rows to be compatible with
        % sequenceInputLayer.
        tallTrain         = cellfun(@(x)x/max(abs(x),[],'all'),tallTrain,"UniformOutput",false);
        tallFeaturesTrain = cellfun(@(x)extract(extractor,x),tallTrain,"UniformOutput",false);
        tallFeaturesTrain = cellfun(@(x)x',tallFeaturesTrain,"UniformOutput",false);  %#ok<NASGU>
        [~,featuresTrain] = evalc('gather(tallFeaturesTrain)'); % Use evalc to suppress command-line output.
        tallValidation         = cellfun(@(x)x/max(abs(x),[],'all'),tallValidation,"UniformOutput",false);
        tallFeaturesValidation = cellfun(@(x)extract(extractor,x),tallValidation,"UniformOutput",false);
        tallFeaturesValidation = cellfun(@(x)x',tallFeaturesValidation,"UniformOutput",false); %#ok<NASGU>
        [~,featuresValidation] = evalc('gather(tallFeaturesValidation)'); % Use evalc to suppress command-line output.

        % 3. Use the training set to determine the mean and standard
        % deviation of each feature. Normalize the training and validation
        % sets.
        allFeatures = cat(2,featuresTrain{:});
        M = mean(allFeatures,2,'omitnan');
        S = std(allFeatures,0,2,'omitnan');
        featuresTrain = cellfun(@(x)(x-M)./S,featuresTrain,'UniformOutput',false);
        for ii = 1:numel(featuresTrain)
            idx = find(isnan(featuresTrain{ii}));
            if ~isempty(idx)
                featuresTrain{ii}(idx) = 0;
            end
        end
        featuresValidation = cellfun(@(x)(x-M)./S,featuresValidation,'UniformOutput',false);
        for ii = 1:numel(featuresValidation)
            idx = find(isnan(featuresValidation{ii}));
            if ~isempty(idx)
                featuresValidation{ii}(idx) = 0;
            end
        end

        % 4. Buffer the sequences so that each sequence consists of twenty
        % feature vectors with overlaps of 10 feature vectors.
        featureVectorsPerSequence = 20;
        featureVectorOverlap = 10;
        [sequencesTrain,sequencePerFileTrain] = HelperFeatureVector2Sequence(featuresTrain,featureVectorsPerSequence,featureVectorOverlap);
        [sequencesValidation,sequencePerFileValidation] = HelperFeatureVector2Sequence(featuresValidation,featureVectorsPerSequence,featureVectorOverlap);

        % 5. Replicate the labels of the train and validation sets so that
        % they are in one-to-one correspondence with the sequences.
        labelsTrain = [emptyEmotions;augadsTrain.Labels];
        labelsTrain = labelsTrain(:);
        labelsTrain = repelem(labelsTrain,[sequencePerFileTrain{:}]);

        % 6. Define a BiLSTM network.
        dropoutProb1 = 0.3;
        numUnits     = 200;
        dropoutProb2 = 0.6;
        layers = [ ...
            sequenceInputLayer(size(sequencesTrain{1},1))
            dropoutLayer(dropoutProb1)
            bilstmLayer(numUnits,"OutputMode","last")
            dropoutLayer(dropoutProb2)
            fullyConnectedLayer(numel(categories(emptyEmotions)))
            softmaxLayer
            classificationLayer];

        % 7. Define training options.
        miniBatchSize       = 512;
        initialLearnRate    = 0.005;
        learnRateDropPeriod = 2;
        maxEpochs           = 3;
        options = trainingOptions("adam", ...
            "MiniBatchSize",miniBatchSize, ...
            "InitialLearnRate",initialLearnRate, ...
            "LearnRateDropPeriod",learnRateDropPeriod, ...
            "LearnRateSchedule","piecewise", ...
            "MaxEpochs",maxEpochs, ...
            "Shuffle","every-epoch", ...
            "Verbose",false);

        % 8. Train the network.
        net = trainNetwork(sequencesTrain,labelsTrain,layers,options);

        % 9. Evaluate the network. Call classify to get the predicted labels
        % for each sequence. Get the mode of the predicted labels of each
        % sequence to get the predicted labels of each file.
        predictedLabelsPerSequence = classify(net,sequencesValidation);
        trueLabels = categorical(adsValidation.Labels);
        predictedLabels = trueLabels;
        idx1 = 1;
        for ii = 1:numel(trueLabels)
            predictedLabels(ii,:) = mode(predictedLabelsPerSequence(idx1:idx1 + sequencePerFileValidation{ii} - 1,:),1);
            idx1 = idx1 + sequencePerFileValidation{ii};
        end
        trueLabelsCrossFold{i} = trueLabels; %#ok<AGROW>
        predictedLabelsCrossFold{i} = predictedLabels; %#ok<AGROW>
    end
end

Ссылки

[1] Burkhardt, F., А. Пэешк, М. Рольфес, В.Ф. Сендлмайер и Б. Вайс, "База данных немецкой эмоциональной речи". В межречи 2005 продолжений. Лиссабон, Португалия: международная речевая коммуникационная ассоциация, 2005.

Смотрите также

| | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте