vggish

Нейронная сеть VGGish

    Синтаксис

    Описание

    пример

    net = vggish возвращает предварительно обученную модель VGGish.

    Эта функция требует как Audio Toolbox™, так и Deep Learning Toolbox™.

    Примеры

    свернуть все

    Загрузите и разархивируйте модель Audio Toolbox™ для VGGish.

    Тип vggish в Командном окне. Если модель Audio Toolbox для VGGish не установлена, то функция предоставляет ссылку на расположение весов сети. Чтобы скачать модель, щелкните ссылку. Разархивируйте файл в местоположении по пути MATLAB.

    Также выполните эти команды, чтобы загрузить и разархивировать модель VGGish во временную директорию.

    downloadFolder = fullfile(tempdir,'VGGishDownload');
    loc = websave(downloadFolder,'https://ssd.mathworks.com/supportfiles/audio/vggish.zip');
    VGGishLocation = tempdir;
    unzip(loc,VGGishLocation)
    addpath(fullfile(VGGishLocation,'vggish'))

    Проверьте успешность установки путем ввода vggish в Командном окне. Если сеть установлена, то функция возвращает SeriesNetwork (Deep Learning Toolbox) объект.

    vggish
    ans = 
      SeriesNetwork with properties:
    
             Layers: [24×1 nnet.cnn.layer.Layer]
         InputNames: {'InputBatch'}
        OutputNames: {'regressionoutput'}
    
    

    Загрузите предварительно обученную сверточную нейронную сеть VGGish и исследуйте слои и классы.

    Использование vggish для загрузки предварительно обученной сети VGGish. Область выхода net является SeriesNetwork (Deep Learning Toolbox) объект.

    net = vggish
    net = 
      SeriesNetwork with properties:
    
             Layers: [24×1 nnet.cnn.layer.Layer]
         InputNames: {'InputBatch'}
        OutputNames: {'regressionoutput'}
    
    

    Просмотрите сетевую архитектуру с помощью Layers свойство. Сеть имеет 24 слоя. Существует девять слоев с усвояемыми весами, из которых шесть являются сверточными слоями и три являются полносвязными слоями.

    net.Layers
    ans = 
      24×1 Layer array with layers:
    
         1   'InputBatch'         Image Input         96×64×1 images
         2   'conv1'              Convolution         64 3×3×1 convolutions with stride [1  1] and padding 'same'
         3   'relu'               ReLU                ReLU
         4   'pool1'              Max Pooling         2×2 max pooling with stride [2  2] and padding 'same'
         5   'conv2'              Convolution         128 3×3×64 convolutions with stride [1  1] and padding 'same'
         6   'relu2'              ReLU                ReLU
         7   'pool2'              Max Pooling         2×2 max pooling with stride [2  2] and padding 'same'
         8   'conv3_1'            Convolution         256 3×3×128 convolutions with stride [1  1] and padding 'same'
         9   'relu3_1'            ReLU                ReLU
        10   'conv3_2'            Convolution         256 3×3×256 convolutions with stride [1  1] and padding 'same'
        11   'relu3_2'            ReLU                ReLU
        12   'pool3'              Max Pooling         2×2 max pooling with stride [2  2] and padding 'same'
        13   'conv4_1'            Convolution         512 3×3×256 convolutions with stride [1  1] and padding 'same'
        14   'relu4_1'            ReLU                ReLU
        15   'conv4_2'            Convolution         512 3×3×512 convolutions with stride [1  1] and padding 'same'
        16   'relu4_2'            ReLU                ReLU
        17   'pool4'              Max Pooling         2×2 max pooling with stride [2  2] and padding 'same'
        18   'fc1_1'              Fully Connected     4096 fully connected layer
        19   'relu5_1'            ReLU                ReLU
        20   'fc1_2'              Fully Connected     4096 fully connected layer
        21   'relu5_2'            ReLU                ReLU
        22   'fc2'                Fully Connected     128 fully connected layer
        23   'EmbeddingBatch'     ReLU                ReLU
        24   'regressionoutput'   Regression Output   mean-squared-error
    

    Использование analyzeNetwork (Deep Learning Toolbox), чтобы визуально исследовать сеть.

    analyzeNetwork(net)

    Сеть VGGish требует от вас предварительной обработки и извлечения функций из аудиосигналов путем преобразования их в частоту дискретизации, на которой обучалась сеть, а затем извлечения логарифмических спектрограмм. Этот пример проходит необходимую предварительную обработку и редукцию данных, чтобы соответствовать предварительной обработке и редукциям данных, используемым для обучения VGGish. The vggishFeatures функция выполняет эти шаги для вас.

    Считывайте аудиосигнал для классификации. Повторно отобразите аудиосигнал на 16 кГц и затем преобразуйте его в одинарную точность.

    [audioIn,fs0] = audioread('Ambiance-16-44p1-mono-12secs.wav');
    
    fs = 16e3;
    audioIn = resample (audioIn, fs, fs0);
    
    audioIn = single (audioIn);

    Задайте параметры mel spectrogram и затем извлеките функции с помощью melSpectrogram функция.

    FFTLength = 512;
    numBands = 64;
    frequencyRange = [125 7500];
    windowLength = 0.025*fs;
    overlapLength = 0.015*fs;
    
    melSpect = melSpectrogram(audioIn,fs, ...
        'Window',hann(windowLength,'periodic'), ...
        'OverlapLength',overlapLength, ...
        'FFTLength',FFTLength, ...
        'FrequencyRange',frequencyRange, ...
        'NumBands',numBands, ...
        'FilterBankNormalization','none', ...
        'WindowNormalization',false, ...
        'SpectrumType','magnitude', ...
        'FilterBankDesignDomain','warped');

    Преобразуйте спектрограмму mel в шкалу журнала.

    melSpect = log(melSpect + single(0.001));

    Переориентировать mel spectrogram так, чтобы время было вдоль первой размерности как строки.

    melSpect = melSpect.';
    [numSTFTWindows,numBands] = size(melSpect)
    numSTFTWindows = 1222
    
    numBands = 64
    

    Разбейте спектрограмму на системы координат длины 96 с перекрытием 48. Расположите системы координат вдоль четвертой размерности.

    frameWindowLength = 96;
    frameOverlapLength = 48;
    
    hopLength = frameWindowLength - frameOverlapLength;
    numHops = floor((numSTFTWindows - frameWindowLength)/hopLength) + 1;
    
    frames = zeros(frameWindowLength,numBands,1,numHops,'like',melSpect);
    for hop = 1:numHops
        range = 1 + hopLength*(hop-1):hopLength*(hop - 1) + frameWindowLength;
        frames(:,:,1,hop) = melSpect(range,:);
    end

    Создайте сеть VGGish.

    net = vggish;

    Функции predict для извлечения вложений функций из изображений спектрограмм. Вложения функции возвращаются как numFrames-by-128 матрица, где numFrames - количество отдельных спектрограмм, и 128 - количество элементов в каждом векторе функций.

    features = predict(net,frames);
    
    [numFrames,numFeatures] = size(features)
    numFrames = 24
    
    numFeatures = 128
    

    Сравните визуализацию mel spectrogram и встроения функций VGGish.

    melSpectrogram(audioIn,fs, ...
        'Window',hann(windowLength,'periodic'), ...
        'OverlapLength',overlapLength, ...
        'FFTLength',FFTLength, ...
        'FrequencyRange',frequencyRange, ...
        'NumBands',numBands, ...
        'FilterBankNormalization','none', ...
        'WindowNormalization',false, ...
        'SpectrumType','magnitude', ...
        'FilterBankDesignDomain','warped');

    surf(features,'EdgeColor','none')
    view([90,-90])
    axis([1 numFeatures 1 numFrames])
    xlabel('Feature')
    ylabel('Frame')
    title('VGGish Feature Embeddings')

    В этом примере вы переносите обучение в модели регрессии VGGish в задачу классификации аудио.

    Загрузите и разархивируйте набор данных классификации звука окружающей среды. Этот набор данных состоит из записей, помеченных как один из 10 различных классов звука (ESC-10).

    url = 'http://ssd.mathworks.com/supportfiles/audio/ESC-10.zip';
    downloadFolder = fullfile(tempdir,'ESC-10');
    datasetLocation = tempdir;
    
    if ~exist(fullfile(tempdir,'ESC-10'),'dir')
        loc = websave(downloadFolder,url);
        unzip(loc,fullfile(tempdir,'ESC-10'))
    end

    Создайте audioDatastore объект для управления данными и разделения их на train и валидации. Функции countEachLabel отображение распределения классов звука и количества уникальных меток.

    ads = audioDatastore(downloadFolder,'IncludeSubfolders',true,'LabelSource','foldernames');
    labelTable = countEachLabel(ads)
    labelTable=10×2 table
            Label         Count
        ______________    _____
    
        chainsaw           40  
        clock_tick         40  
        crackling_fire     40  
        crying_baby        40  
        dog                40  
        helicopter         40  
        rain               40  
        rooster            38  
        sea_waves          40  
        sneezing           40  
    
    

    Определите общее количество классов.

    numClasses = size(labelTable,1);

    Функции splitEachLabel разделение набора данных на наборы обучения и валидации. Смотрите распределение меток в наборах обучения и валидации.

    [adsTrain, adsValidation] = splitEachLabel(ads,0.8);
    
    countEachLabel(adsTrain)
    ans=10×2 table
            Label         Count
        ______________    _____
    
        chainsaw           32  
        clock_tick         32  
        crackling_fire     32  
        crying_baby        32  
        dog                32  
        helicopter         32  
        rain               32  
        rooster            30  
        sea_waves          32  
        sneezing           32  
    
    
    countEachLabel(adsValidation)
    ans=10×2 table
            Label         Count
        ______________    _____
    
        chainsaw            8  
        clock_tick          8  
        crackling_fire      8  
        crying_baby         8  
        dog                 8  
        helicopter          8  
        rain                8  
        rooster             8  
        sea_waves           8  
        sneezing            8  
    
    

    Сеть VGGish ожидает предварительной обработки аудио в логарифмические спектрограммы. Вспомогательная функция vggishPreprocess принимает audioDatastore объект и процент перекрытия между спектрограммами логарифмического меля в качестве входных данных и возвращает матрицы предикторов и откликов, подходящих в качестве входных данных для сети VGGish.

    overlapPercentage = 75;
    
    [trainFeatures, trainLabels] = vggishPreprocess (adsTrain, overlapPercentage);
    [validationFeatures, validationLabels, segmentsPerFile] = vggishPreprocess (adsValidation, overlapPercent);

    Загрузите модель VGGish и преобразуйте ее в layerGraph (Deep Learning Toolbox) объект.

    net = vggish;
    
    lgraph = layerGraph(net.Layers);

    Использование removeLayers (Deep Learning Toolbox), чтобы удалить конечный выходной слой регрессии из графика. После удаления слоя регрессии новым конечным слоем графика является слой ReLU с именем 'EmbeddingBatch'.

    lgraph = removeLayers(lgraph,'regressionoutput');
    lgraph.Layers(end)
    ans = 
      ReLULayer with properties:
    
        Name: 'EmbeddingBatch'
    
    

    Использование addLayers (Deep Learning Toolbox) для добавления fullyConnectedLayer (Deep Learning Toolbox), a softmaxLayer (Deep Learning Toolbox) и classificationLayer (Deep Learning Toolbox) к графику.

    lgraph = addLayers(lgraph,fullyConnectedLayer(numClasses,'Name','FCFinal'));
    lgraph = addLayers(lgraph,softmaxLayer('Name','softmax'));
    lgraph = addLayers(lgraph,classificationLayer('Name','classOut'));

    Использование connectLayers (Deep Learning Toolbox), чтобы добавить полносвязные, программные и классификационные слои к графику слоев.

    lgraph = connectLayers(lgraph,'EmbeddingBatch','FCFinal');
    lgraph = connectLayers(lgraph,'FCFinal','softmax');
    lgraph = connectLayers(lgraph,'softmax','classOut');

    Чтобы определить опции обучения, используйте trainingOptions (Deep Learning Toolbox).

    miniBatchSize = 128;
    options = trainingOptions('adam', ...
        'MaxEpochs',5, ...
        'MiniBatchSize',miniBatchSize, ...
        'Shuffle','every-epoch', ...
        'ValidationData',{validationFeatures,validationLabels}, ...
        'ValidationFrequency',50, ...
        'LearnRateSchedule','piecewise', ...
        'LearnRateDropFactor',0.5, ...
        'LearnRateDropPeriod',2);

    Для обучения сети используйте trainNetwork (Deep Learning Toolbox).

    [trainedNet, netInfo] = trainNetwork(trainFeatures,trainLabels,lgraph,options);
    Training on single GPU.
    |======================================================================================================================|
    |  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
    |         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
    |======================================================================================================================|
    |       1 |           1 |       00:00:00 |       10.94% |       26.03% |       2.2253 |       2.0317 |          0.0010 |
    |       2 |          50 |       00:00:05 |       93.75% |       83.75% |       0.1884 |       0.7001 |          0.0010 |
    |       3 |         100 |       00:00:10 |       96.88% |       80.07% |       0.1150 |       0.7838 |          0.0005 |
    |       4 |         150 |       00:00:15 |       92.97% |       81.99% |       0.1656 |       0.7612 |          0.0005 |
    |       5 |         200 |       00:00:20 |       92.19% |       79.04% |       0.1738 |       0.8192 |          0.0003 |
    |       5 |         210 |       00:00:21 |       95.31% |       80.15% |       0.1389 |       0.8581 |          0.0003 |
    |======================================================================================================================|
    

    Каждый аудио файла был разделен на несколько сегментов для подачи в сеть VGGish. Объедините предсказания для каждого файла в наборе валидации с помощью решения о правиле большинства.

    validationPredictions = classify(trainedNet,validationFeatures);
    
    idx = 1;
    validationPredictionsPerFile = categorical;
    for ii = 1:numel(adsValidation.Files)
        validationPredictionsPerFile(ii,1) = mode(validationPredictions(idx:idx+segmentsPerFile(ii)-1));
        idx = idx + segmentsPerFile(ii);
    end

    Использование confusionchart (Deep Learning Toolbox), чтобы оценить эффективность сети на наборе валидации.

    figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]);
    cm = confusionchart(adsValidation.Labels,validationPredictionsPerFile);
    cm.Title = sprintf('Confusion Matrix for Validation Data \nAccuracy = %0.2f %%',mean(validationPredictionsPerFile==adsValidation.Labels)*100);
    cm.ColumnSummary = 'column-normalized';
    cm.RowSummary = 'row-normalized';

    Вспомогательные функции

    function [predictor,response,segmentsPerFile] = vggishPreprocess(ads,overlap)
    % This function is for example purposes only and may be changed or removed
    % in a future release.
    
    % Create filter bank
    FFTLength = 512;
    numBands = 64;
    fs0 = 16e3;
    filterBank = designAuditoryFilterBank(fs0, ...
        'FrequencyScale','mel', ...
        'FFTLength',FFTLength, ...
        'FrequencyRange',[125 7500], ...
        'NumBands',numBands, ...
        'Normalization','none', ...
        'FilterBankDesignDomain','warped');
    
    % Define STFT parameters
    windowLength = 0.025 * fs0;
    hopLength = 0.01 * fs0;
    win = hann(windowLength,'periodic');
    
    % Define spectrogram segmentation parameters
    segmentDuration = 0.96; % seconds
    segmentRate = 100; % hertz
    segmentLength = segmentDuration*segmentRate; % Number of spectrums per auditory spectrograms
    segmentHopDuration = (100-overlap) * segmentDuration / 100; % Duration (s) advanced between auditory spectrograms
    segmentHopLength = round(segmentHopDuration * segmentRate); % Number of spectrums advanced between auditory spectrograms
    
    % Preallocate cell arrays for the predictors and responses
    numFiles = numel(ads.Files);
    predictor = cell(numFiles,1);
    response = predictor;
    segmentsPerFile = zeros(numFiles,1);
    
    % Extract predictors and responses for each file
    for ii = 1:numFiles
        [audioIn,info] = read(ads);
    
        x = single(resample(audioIn,fs0,info.SampleRate));
    
        Y = stft(x, ...
            'Window',win, ...
            'OverlapLength',windowLength-hopLength, ...
            'FFTLength',FFTLength, ...
            'FrequencyRange','onesided');
        Y = abs(Y);
    
        logMelSpectrogram = log(filterBank*Y + single(0.01))';
        
        % Segment log-mel spectrogram
        numHops = floor((size(Y,2)-segmentLength)/segmentHopLength) + 1;
        segmentedLogMelSpectrogram = zeros(segmentLength,numBands,1,numHops);
        for hop = 1:numHops
            segmentedLogMelSpectrogram(:,:,1,hop) = logMelSpectrogram(1+segmentHopLength*(hop-1):segmentLength+segmentHopLength*(hop-1),:);
        end
    
        predictor{ii} = segmentedLogMelSpectrogram;
        response{ii} = repelem(info.Label,numHops);
        segmentsPerFile(ii) = numHops;
    end
    
    % Concatenate predictors and responses into arrays
    predictor = cat(4,predictor{:});
    response = cat(2,response{:});
    end

    Выходные аргументы

    свернуть все

    Предварительно обученная нейронная сеть VGGish, возвращенная как SeriesNetwork (Deep Learning Toolbox) объект.

    Ссылки

    [1] Gemmeke, Jort F., Daniel P. W. Ellis, Dylan Freedman, Are Jansen, Wade Lawrence, R. Channing Moore, Manoj Plakal и Marvin Ritter. 2017. Audio Set: An Ontology and Human-Labeled Dataset for Audio Events (неопр.) (недоступная ссылка). В 2017 году IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 776-80. Новый Орлеан, LA: IEEE. https://doi.org/10.1109/ICASSP.2017.7952261.

    [2] Hershey, Shawn, Sourish Chaudhuri, Daniel P. W. Ellis, Jort F. Gemmeke, Are Jansen, R. Channing Moore, Manoj Plakal, et al. 2017. CNN Архитектур для крупномасштабной классификации аудио. В 2017 году IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 131-35. Новый Орлеан, LA: IEEE. https://doi.org/10.1109/ICASSP.2017.7952132.

    Расширенные возможности

    ..
    Введенный в R2020b