В этом примере показано, как использовать передачу обучения, чтобы переобучить YAMNet, предварительно обученную сверточную нейронную сеть, чтобы классифицировать новый набор звуковых сигналов. Чтобы начать с аудио глубоким обучением с нуля, смотрите, Классифицируют Звук Используя Глубокое обучение.
Передача обучения обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве начальной точки, чтобы изучить новую задачу. Подстройка сети с передачей обучения обычно намного быстрее и легче, чем обучение сети со случайным образом инициализированными весами с нуля. Можно быстро передать изученные функции новой задаче с помощью меньшего числа учебных сигналов.
Audio Toolbox™ дополнительно обеспечивает classifySound
функция, которая реализует необходимую предварительную обработку для YAMNet и удобной постобработки, чтобы интерпретировать результаты. Audio Toolbox также обеспечивает предварительно обученную сеть VGGish (vggish
) а также vggishFeatures
функция, которая реализует предварительную обработку и постобработку для сети VGGish.
Сгенерируйте 100 белых шумовых сигналов, 100 коричневых шумовых сигналов и 100 розовых шумовых сигналов. Каждый сигнал представляет длительность 0,98 секунд, принимая частоту дискретизации на 16 кГц.
fs = 16e3; duration = 0.98; N = duration*fs; numSignals = 100; wNoise = 2*rand([N,numSignals]) - 1; wLabels = repelem(categorical("white"),numSignals,1); bNoise = filter(1,[1,-0.999],wNoise); bNoise = bNoise./max(abs(bNoise),[],'all'); bLabels = repelem(categorical("brown"),numSignals,1); pNoise = pinknoise([N,numSignals]); pLabels = repelem(categorical("pink"),numSignals,1);
Разделите данные в наборы обучающих данных и наборы тестов. Обычно, набор обучающих данных состоит из большинства данных. Однако, чтобы проиллюстрировать степень передачи обучения, вы будете использовать только несколько выборок для обучения и большинство за валидацию.
K = 5; trainAudio = [wNoise (: 1:K), bNoise (: 1:K), pNoise (: 1:K)]; trainLabels = [wLabels (1:K); bLabels (1:K); pLabels (1:K)]; validationAudio = [wNoise (: K+1:end), bNoise (: K+1:end), pNoise (: K+1:end)]; validationLabels = [wLabels (K+1:end); bLabels (K+1:end); pLabels (K+1:end)]; fprintf"Number of samples per noise color in train set = %d\n" + ... "Number of samples per noise color in validation set = %d\n", K, numSignals-K);
Number of samples per noise color in train set = 5 Number of samples per noise color in validation set = 95
Используйте melSpectrogram
чтобы извлечь логарифмические-mel спектрограммы и из набора обучающих данных и из набора валидации с помощью тех же параметров в качестве, модель YAMNet была обучена на.
FFTLength = 512; numBands = 64; frequencyRange = [125 7500]; windowLength = 0.025*fs; overlapLength = 0.015*fs; trainFeatures = melSpectrogram(trainAudio,fs, ... 'Window',hann(windowLength,'periodic'), ... 'OverlapLength',overlapLength, ... 'FFTLength',FFTLength, ... 'FrequencyRange',frequencyRange, ... 'NumBands',numBands, ... 'FilterBankNormalization','none', ... 'WindowNormalization',false, ... 'SpectrumType','magnitude', ... 'FilterBankDesignDomain','warped'); trainFeatures = log(trainFeatures + single(0.001)); trainFeatures = permute(trainFeatures,[2,1,4,3]); validationFeatures = melSpectrogram(validationAudio,fs, ... 'Window',hann(windowLength,'periodic'), ... 'OverlapLength',overlapLength, ... 'FFTLength',FFTLength, ... 'FrequencyRange',frequencyRange, ... 'NumBands',numBands, ... 'FilterBankNormalization','none', ... 'WindowNormalization',false, ... 'SpectrumType','magnitude', ... 'FilterBankDesignDomain','warped'); validationFeatures = log(validationFeatures + single(0.001)); validationFeatures = permute(validationFeatures,[2,1,4,3]);
Чтобы загрузить предварительно обученную сеть, вызовите yamnet
. Если модель Audio Toolbox для YAMNet не установлена, то функция обеспечивает ссылку на местоположение сетевых весов. Чтобы загрузить модель, щелкните по ссылке. Разархивируйте файл к местоположению на пути MATLAB. Модель YAMNet может классифицировать аудио в одну из 521 звуковой категории, включая белый шумовой и розовый шум (но не коричневый шум).
net = yamnet; net.Layers(end).Classes
ans = 521×1 categorical
Speech
Child speech, kid speaking
Conversation
Narration, monologue
Babbling
Speech synthesizer
Shout
Bellow
Whoop
Yell
Children shouting
Screaming
Whispering
Laughter
Baby laughter
Giggle
Snicker
Belly laugh
Chuckle, chortle
Crying, sobbing
Baby cry, infant cry
Whimper
Wail, moan
Sigh
Singing
Choir
Yodeling
Chant
Mantra
Child singing
⋮
Подготовьте модель к передаче обучения первым преобразованием сети к layerGraph
(Deep Learning Toolbox). Используйте replaceLayer
(Deep Learning Toolbox), чтобы заменить полносвязный слой на нетренированный полносвязный слой. Замените слой классификации на слой классификации, который классифицирует вход как "белый", "розовый", или "коричневый". Смотрите Список слоев глубокого обучения (Deep Learning Toolbox) для слоев глубокого обучения, поддержанных в MATLAB®.
uniqueLabels = unique(trainLabels); numLabels = numel(uniqueLabels); lgraph = layerGraph(net.Layers); lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,"Name","dense")); lgraph = replaceLayer(lgraph,"Sound",classificationLayer("Name","Sounds","Classes",uniqueLabels));
Чтобы задать опции обучения, используйте trainingOptions
(Deep Learning Toolbox).
options = trainingOptions('adam','ValidationData',{single(validationFeatures),validationLabels});
Чтобы обучить сеть, используйте trainNetwork
(Deep Learning Toolbox). Сеть достигает точности валидации 100% с помощью только 5 сигналов на шумовой тип.
trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU. |======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:02 | 20.00% | 88.77% | 1.1922 | 0.6619 | 0.0010 | | 30 | 30 | 00:00:14 | 100.00% | 100.00% | 9.1076e-06 | 5.0431e-05 | 0.0010 | |======================================================================================================================|