В этом примере показано, как использовать передачу обучения, чтобы переобучить YAMNet, предварительно обученную сверточную нейронную сеть, чтобы классифицировать новый набор звуковых сигналов. Чтобы начать с аудио глубоким обучением с нуля, смотрите, Классифицируют Звук Используя Глубокое обучение (Audio Toolbox).
Передача обучения обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве начальной точки, чтобы изучить новую задачу. Подстройка сети с передачей обучения обычно намного быстрее и легче, чем обучение сети со случайным образом инициализированными весами с нуля. Можно быстро передать изученные функции новой задаче с помощью меньшего числа учебных сигналов.
Audio Toolbox™ дополнительно обеспечивает classifySound
Функция (Audio Toolbox), которая реализует необходимую предварительную обработку для YAMNet и удобной постобработки, чтобы интерпретировать результаты. Audio Toolbox также обеспечивает предварительно обученную сеть VGGish (vggish
(Audio Toolbox)), а также vggishFeatures
Функция (Audio Toolbox), которая реализует предварительную обработку и постобработку для сети VGGish.
Сгенерируйте 100 сигналов белого шума, 100 коричневых шумовых сигналов и 100 розовых шумовых сигналов. Каждый сигнал представляет длительность 0,98 секунд, принимая частоту дискретизации на 16 кГц.
fs = 16e3; duration = 0.98; N = duration*fs; numSignals = 100; wNoise = 2*rand([N,numSignals]) - 1; wLabels = repelem(categorical("white"),numSignals,1); bNoise = filter(1,[1,-0.999],wNoise); bNoise = bNoise./max(abs(bNoise),[],'all'); bLabels = repelem(categorical("brown"),numSignals,1); pNoise = pinknoise([N,numSignals]); pLabels = repelem(categorical("pink"),numSignals,1);
Разделите данные в наборы обучающих данных и наборы тестов. Обычно, набор обучающих данных состоит из большинства данных. Однако, чтобы проиллюстрировать степень передачи обучения, вы будете использовать только несколько выборок для обучения и большинство за валидацию.
K = 5; trainAudio = [wNoise (: 1:K), bNoise (: 1:K), pNoise (: 1:K)]; trainLabels = [wLabels (1:K); bLabels (1:K); pLabels (1:K)]; validationAudio = [wNoise (: K+1:end), bNoise (: K+1:end), pNoise (: K+1:end)]; validationLabels = [wLabels (K+1:end); bLabels (K+1:end); pLabels (K+1:end)]; fprintf"Number of samples per noise color in train set = %d\n" + ... "Number of samples per noise color in validation set = %d\n", K, numSignals-K);
Number of samples per noise color in train set = 5 Number of samples per noise color in validation set = 95
Используйте melSpectrogram
(Audio Toolbox), чтобы извлечь логарифмические-mel спектрограммы и из набора обучающих данных и из набора валидации с помощью тех же параметров в качестве модели YAMNet был обучен на.
FFTLength = 512; numBands = 64; frequencyRange = [125 7500]; windowLength = 0.025*fs; overlapLength = 0.015*fs; trainFeatures = melSpectrogram(trainAudio,fs, ... 'Window',hann(windowLength,'periodic'), ... 'OverlapLength',overlapLength, ... 'FFTLength',FFTLength, ... 'FrequencyRange',frequencyRange, ... 'NumBands',numBands, ... 'FilterBankNormalization','none', ... 'WindowNormalization',false, ... 'SpectrumType','magnitude', ... 'FilterBankDesignDomain','warped'); trainFeatures = log(trainFeatures + single(0.001)); trainFeatures = permute(trainFeatures,[2,1,4,3]); validationFeatures = melSpectrogram(validationAudio,fs, ... 'Window',hann(windowLength,'periodic'), ... 'OverlapLength',overlapLength, ... 'FFTLength',FFTLength, ... 'FrequencyRange',frequencyRange, ... 'NumBands',numBands, ... 'FilterBankNormalization','none', ... 'WindowNormalization',false, ... 'SpectrumType','magnitude', ... 'FilterBankDesignDomain','warped'); validationFeatures = log(validationFeatures + single(0.001)); validationFeatures = permute(validationFeatures,[2,1,4,3]);
Чтобы загрузить предварительно обученную сеть, вызовите yamnet
(Audio Toolbox). Если модель Audio Toolbox для YAMNet не установлена, то функция обеспечивает ссылку на местоположение сетевых весов. Чтобы загрузить модель, щелкните по ссылке. Разархивируйте файл к местоположению на пути MATLAB. Модель YAMNet может классифицировать аудио в одну из 521 звуковой категории, включая белый шум и розовый шум (но не коричневый шум).
net = yamnet; net.Layers(end).Classes
ans = 521×1 categorical
Speech
Child speech, kid speaking
Conversation
Narration, monologue
Babbling
Speech synthesizer
Shout
Bellow
Whoop
Yell
Children shouting
Screaming
Whispering
Laughter
Baby laughter
Giggle
Snicker
Belly laugh
Chuckle, chortle
Crying, sobbing
Baby cry, infant cry
Whimper
Wail, moan
Sigh
Singing
Choir
Yodeling
Chant
Mantra
Child singing
⋮
Подготовьте модель к передаче обучения первым преобразованием сети к layerGraph
. Используйте replaceLayer
заменять полносвязный слой на нетренированный полносвязный слой. Замените слой классификации на слой классификации, который классифицирует вход как "белый", "розовый", или "коричневый". Смотрите Список слоев глубокого обучения для слоев глубокого обучения, поддержанных в MATLAB®.
uniqueLabels = unique(trainLabels); numLabels = numel(uniqueLabels); lgraph = layerGraph(net.Layers); lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,"Name","dense")); lgraph = replaceLayer(lgraph,"Sound",classificationLayer("Name","Sounds","Classes",uniqueLabels));
Чтобы задать опции обучения, используйте trainingOptions
.
options = trainingOptions('adam','ValidationData',{single(validationFeatures),validationLabels});
Чтобы обучить сеть, используйте trainNetwork
. Сеть достигает точности валидации 100% с помощью только 5 сигналов на шумовой тип.
trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU. |======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:02 | 20.00% | 88.77% | 1.1922 | 0.6619 | 0.0010 | | 30 | 30 | 00:00:14 | 100.00% | 100.00% | 9.1076e-06 | 5.0431e-05 | 0.0010 | |======================================================================================================================|