exponenta event banner

Передача обучения с предварительно обученными аудиосетями

Этот пример показывает, как использовать обучение передаче для переподготовки YAMNet, предварительно обученной сверточной нейронной сети, для классификации нового набора аудиосигналов. Чтобы начать с глубокого обучения аудио с нуля, см. раздел Классификация звука с помощью глубокого обучения.

Transfer learning обычно используется в приложениях для глубокого обучения. Предварительно подготовленную сеть можно использовать в качестве отправной точки для изучения новой задачи. Точная настройка сети с обучением переносу обычно намного быстрее и проще, чем обучение сети с произвольно инициализированными весами с нуля. Вы можете быстро перенести изученные функции в новую задачу, используя меньшее количество обучающих сигналов.

Аудио Toolbox™ дополнительно обеспечивает classifySound функция, реализующая необходимую предварительную обработку для YAMNet и удобную постобработку для интерпретации результатов. Audio Toolbox также предоставляет предварительно подготовленную сеть VGGish (vggish), а также vggishFeatures функция, реализующая предварительную и постобработку для сети VGGish.

Создать данные

Создайте 100 сигналов белого шума, 100 сигналов коричневого шума и 100 сигналов розового шума. Каждый сигнал представляет длительность 0,98 секунды, предполагая частоту дискретизации 16 кГц.

fs = 16e3;
duration = 0.98;
N = duration*fs;
numSignals = 100;

wNoise = 2*rand([N,numSignals]) - 1;
wLabels = repelem(categorical("white"),numSignals,1);

bNoise = filter(1,[1,-0.999],wNoise);
bNoise = bNoise./max(abs(bNoise),[],'all');
bLabels = repelem(categorical("brown"),numSignals,1);

pNoise = pinknoise([N,numSignals]);
pLabels = repelem(categorical("pink"),numSignals,1);

Разбейте данные на обучающие и тестовые наборы. Как правило, тренировочный набор состоит из большей части данных. Тем не менее, чтобы проиллюстрировать силу трансферного обучения, вы будете использовать только несколько образцов для обучения и большинство для проверки.

K = 5;

trainAudio = [wNoise(:,1:K),bNoise(:,1:K),pNoise(:,1:K)];
trainLabels = [wLabels(1:K);bLabels(1:K);pLabels(1:K)];

validationAudio = [wNoise(:,K+1:end),bNoise(:,K+1:end),pNoise(:,K+1:end)];
validationLabels = [wLabels(K+1:end);bLabels(K+1:end);pLabels(K+1:end)];

fprintf("Number of samples per noise color in train set = %d\n" + ...
        "Number of samples per noise color in validation set = %d\n",K,numSignals-K);
Number of samples per noise color in train set = 5
Number of samples per noise color in validation set = 95

Извлечь элементы

Использовать melSpectrogram для извлечения логарифмических спектрограмм как из обучающего набора, так и из валидационного набора с использованием тех же параметров, на которых была обучена модель YAMNet.

FFTLength = 512;
numBands = 64;
frequencyRange = [125 7500];
windowLength = 0.025*fs;
overlapLength = 0.015*fs;

trainFeatures = melSpectrogram(trainAudio,fs, ...
    'Window',hann(windowLength,'periodic'), ...
    'OverlapLength',overlapLength, ...
    'FFTLength',FFTLength, ...
    'FrequencyRange',frequencyRange, ...
    'NumBands',numBands, ...
    'FilterBankNormalization','none', ...
    'WindowNormalization',false, ...
    'SpectrumType','magnitude', ...
    'FilterBankDesignDomain','warped');

trainFeatures = log(trainFeatures + single(0.001));
trainFeatures = permute(trainFeatures,[2,1,4,3]);

validationFeatures = melSpectrogram(validationAudio,fs, ...
    'Window',hann(windowLength,'periodic'), ...
    'OverlapLength',overlapLength, ...
    'FFTLength',FFTLength, ...
    'FrequencyRange',frequencyRange, ...
    'NumBands',numBands, ...
    'FilterBankNormalization','none', ...
    'WindowNormalization',false, ...
    'SpectrumType','magnitude', ...
    'FilterBankDesignDomain','warped');

validationFeatures = log(validationFeatures + single(0.001));
validationFeatures = permute(validationFeatures,[2,1,4,3]);

Передача обучения

Для загрузки предварительно обученной сети вызовите yamnet. Если модель Audio Toolbox для YAMNet не установлена, то функция предоставляет ссылку на расположение весов сети. Чтобы загрузить модель, щелкните ссылку. Распакуйте файл в папку по пути MATLAB. Модель YAMNet может классифицировать звук на одну из 521 категорий звука, включая белый шум и розовый шум (но не коричневый шум).

net = yamnet;
net.Layers(end).Classes
ans = 521×1 categorical
     Speech 
     Child speech, kid speaking 
     Conversation 
     Narration, monologue 
     Babbling 
     Speech synthesizer 
     Shout 
     Bellow 
     Whoop 
     Yell 
     Children shouting 
     Screaming 
     Whispering 
     Laughter 
     Baby laughter 
     Giggle 
     Snicker 
     Belly laugh 
     Chuckle, chortle 
     Crying, sobbing 
     Baby cry, infant cry 
     Whimper 
     Wail, moan 
     Sigh 
     Singing 
     Choir 
     Yodeling 
     Chant 
     Mantra 
     Child singing 
      ⋮

Подготовка модели для передачи обучения путем предварительного преобразования сети в layerGraph (инструментарий глубокого обучения). Использовать replaceLayer (Deep Learning Toolbox) для замены полностью подключенного слоя необученным полностью подключенным слоем. Замените классификационный слой классификационным слоем, который классифицирует входные данные как «белый», «розовый» или «коричневый». Список слоев глубокого обучения (панель инструментов глубокого обучения) для уровней глубокого обучения, поддерживаемых в MATLAB ®.

uniqueLabels = unique(trainLabels);
numLabels = numel(uniqueLabels);

lgraph = layerGraph(net.Layers);

lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,"Name","dense"));
lgraph = replaceLayer(lgraph,"Sound",classificationLayer("Name","Sounds","Classes",uniqueLabels));

Для определения вариантов обучения используйте trainingOptions (инструментарий глубокого обучения).

options = trainingOptions('adam','ValidationData',{single(validationFeatures),validationLabels});

Для обучения сети используйте trainNetwork (инструментарий глубокого обучения). Сеть достигает точности проверки 100%, используя только 5 сигналов на один тип шума.

trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |       20.00% |       88.77% |       1.1922 |       0.6619 |          0.0010 |
|      30 |          30 |       00:00:14 |      100.00% |      100.00% |   9.1076e-06 |   5.0431e-05 |          0.0010 |
|======================================================================================================================|