Этот пример показывает, как использовать обучение передаче для переподготовки YAMNet, предварительно обученной сверточной нейронной сети, для классификации нового набора аудиосигналов. Чтобы начать с глубокого обучения аудио с нуля, см. раздел Классификация звука с помощью глубокого обучения.
Transfer learning обычно используется в приложениях для глубокого обучения. Предварительно подготовленную сеть можно использовать в качестве отправной точки для изучения новой задачи. Точная настройка сети с обучением переносу обычно намного быстрее и проще, чем обучение сети с произвольно инициализированными весами с нуля. Вы можете быстро перенести изученные функции в новую задачу, используя меньшее количество обучающих сигналов.

Аудио Toolbox™ дополнительно обеспечивает classifySound функция, реализующая необходимую предварительную обработку для YAMNet и удобную постобработку для интерпретации результатов. Audio Toolbox также предоставляет предварительно подготовленную сеть VGGish (vggish), а также vggishFeatures функция, реализующая предварительную и постобработку для сети VGGish.
Создайте 100 сигналов белого шума, 100 сигналов коричневого шума и 100 сигналов розового шума. Каждый сигнал представляет длительность 0,98 секунды, предполагая частоту дискретизации 16 кГц.
fs = 16e3; duration = 0.98; N = duration*fs; numSignals = 100; wNoise = 2*rand([N,numSignals]) - 1; wLabels = repelem(categorical("white"),numSignals,1); bNoise = filter(1,[1,-0.999],wNoise); bNoise = bNoise./max(abs(bNoise),[],'all'); bLabels = repelem(categorical("brown"),numSignals,1); pNoise = pinknoise([N,numSignals]); pLabels = repelem(categorical("pink"),numSignals,1);
Разбейте данные на обучающие и тестовые наборы. Как правило, тренировочный набор состоит из большей части данных. Тем не менее, чтобы проиллюстрировать силу трансферного обучения, вы будете использовать только несколько образцов для обучения и большинство для проверки.
K =5; trainAudio = [wNoise(:,1:K),bNoise(:,1:K),pNoise(:,1:K)]; trainLabels = [wLabels(1:K);bLabels(1:K);pLabels(1:K)]; validationAudio = [wNoise(:,K+1:end),bNoise(:,K+1:end),pNoise(:,K+1:end)]; validationLabels = [wLabels(K+1:end);bLabels(K+1:end);pLabels(K+1:end)]; fprintf("Number of samples per noise color in train set = %d\n" + ... "Number of samples per noise color in validation set = %d\n",K,numSignals-K);
Number of samples per noise color in train set = 5 Number of samples per noise color in validation set = 95
Использовать melSpectrogram для извлечения логарифмических спектрограмм как из обучающего набора, так и из валидационного набора с использованием тех же параметров, на которых была обучена модель YAMNet.
FFTLength = 512; numBands = 64; frequencyRange = [125 7500]; windowLength = 0.025*fs; overlapLength = 0.015*fs; trainFeatures = melSpectrogram(trainAudio,fs, ... 'Window',hann(windowLength,'periodic'), ... 'OverlapLength',overlapLength, ... 'FFTLength',FFTLength, ... 'FrequencyRange',frequencyRange, ... 'NumBands',numBands, ... 'FilterBankNormalization','none', ... 'WindowNormalization',false, ... 'SpectrumType','magnitude', ... 'FilterBankDesignDomain','warped'); trainFeatures = log(trainFeatures + single(0.001)); trainFeatures = permute(trainFeatures,[2,1,4,3]); validationFeatures = melSpectrogram(validationAudio,fs, ... 'Window',hann(windowLength,'periodic'), ... 'OverlapLength',overlapLength, ... 'FFTLength',FFTLength, ... 'FrequencyRange',frequencyRange, ... 'NumBands',numBands, ... 'FilterBankNormalization','none', ... 'WindowNormalization',false, ... 'SpectrumType','magnitude', ... 'FilterBankDesignDomain','warped'); validationFeatures = log(validationFeatures + single(0.001)); validationFeatures = permute(validationFeatures,[2,1,4,3]);
Для загрузки предварительно обученной сети вызовите yamnet. Если модель Audio Toolbox для YAMNet не установлена, то функция предоставляет ссылку на расположение весов сети. Чтобы загрузить модель, щелкните ссылку. Распакуйте файл в папку по пути MATLAB. Модель YAMNet может классифицировать звук на одну из 521 категорий звука, включая белый шум и розовый шум (но не коричневый шум).
net = yamnet; net.Layers(end).Classes
ans = 521×1 categorical
Speech
Child speech, kid speaking
Conversation
Narration, monologue
Babbling
Speech synthesizer
Shout
Bellow
Whoop
Yell
Children shouting
Screaming
Whispering
Laughter
Baby laughter
Giggle
Snicker
Belly laugh
Chuckle, chortle
Crying, sobbing
Baby cry, infant cry
Whimper
Wail, moan
Sigh
Singing
Choir
Yodeling
Chant
Mantra
Child singing
⋮
Подготовка модели для передачи обучения путем предварительного преобразования сети в layerGraph (инструментарий глубокого обучения). Использовать replaceLayer (Deep Learning Toolbox) для замены полностью подключенного слоя необученным полностью подключенным слоем. Замените классификационный слой классификационным слоем, который классифицирует входные данные как «белый», «розовый» или «коричневый». Список слоев глубокого обучения (панель инструментов глубокого обучения) для уровней глубокого обучения, поддерживаемых в MATLAB ®.
uniqueLabels = unique(trainLabels); numLabels = numel(uniqueLabels); lgraph = layerGraph(net.Layers); lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,"Name","dense")); lgraph = replaceLayer(lgraph,"Sound",classificationLayer("Name","Sounds","Classes",uniqueLabels));
Для определения вариантов обучения используйте trainingOptions (инструментарий глубокого обучения).
options = trainingOptions('adam','ValidationData',{single(validationFeatures),validationLabels});
Для обучения сети используйте trainNetwork (инструментарий глубокого обучения). Сеть достигает точности проверки 100%, используя только 5 сигналов на один тип шума.
trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU. |======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:02 | 20.00% | 88.77% | 1.1922 | 0.6619 | 0.0010 | | 30 | 30 | 00:00:14 | 100.00% | 100.00% | 9.1076e-06 | 5.0431e-05 | 0.0010 | |======================================================================================================================|