Этот пример показывает, как использовать передачу обучения для переобучения YAMNet, предварительно обученной сверточной нейронной сети, чтобы классифицировать новый набор аудиосигналов. Чтобы начать с глубокого обучения с нуля, смотрите Классификацию Звука Используя Глубокое Обучение.
Передача обучения обычно используется в применениях глубокого обучения. Можно взять предварительно обученную сеть и использовать ее как начальная точка для изучения новой задачи. Подстройка сети с передачей обучения обычно намного быстрее и проще, чем обучение сети со случайным образом инициализированными весами с нуля. Можно быстро передать выученные функции в новую задачу с помощью меньшего количества обучающих сигналов.
Audio Toolbox™ дополнительно обеспечивает classifySound
функция, которая реализует необходимую предварительную обработку для YAMNet и удобную постобработку для интерпретации результатов. Audio Toolbox также предоставляет предварительно обученную сеть VGGish (vggish
) а также vggishFeatures
функция, которая реализует предварительную обработку и постобработку для сети VGGish.
Сгенерируйте 100 сигналов белого шума, 100 сигналов коричневого шума и 100 сигналов розового шума. Каждый сигнал представляет длительность 0,98 секунд, принимая частоту дискретизации 16 кГц.
fs = 16e3; duration = 0.98; N = duration*fs; numSignals = 100; wNoise = 2*rand([N,numSignals]) - 1; wLabels = repelem(categorical("white"),numSignals,1); bNoise = filter(1,[1,-0.999],wNoise); bNoise = bNoise./max(abs(bNoise),[],'all'); bLabels = repelem(categorical("brown"),numSignals,1); pNoise = pinknoise([N,numSignals]); pLabels = repelem(categorical("pink"),numSignals,1);
Разделите данные на обучающие и тестовые наборы. Обычно набор обучающих данных состоит из большинства данных. Однако, чтобы проиллюстрировать силу передачи обучения, вы будете использовать только несколько выборки для обучения и большинство для валидации.
K = 5; trainAudio = [wNoise (:, 1: K), bNoise (:, 1: K), pNoise (:,1:K)]; trainLabels = [wLabels (1: K); bLabels (1: K); pLabels (1:K)]; validationAudio = [wNoise (:, K + 1: end), bNoise (:, K + 1: end), pNoise (:, K + 1: end)]; validationLabels = [wLabels (K + 1: end); bLabels (K + 1: end); pLabels (K + 1: end)]; fprintf ("Number of samples per noise color in train set = %d\n" + ... "Number of samples per noise color in validation set = %d\n", K, numSignals-K);
Number of samples per noise color in train set = 5 Number of samples per noise color in validation set = 95
Использование melSpectrogram
для извлечения логарифмических спектрограмм как из набора обучающих данных, так и из набора валидации с использованием тех же параметров, на которых была обучена модель YAMNet.
FFTLength = 512; numBands = 64; frequencyRange = [125 7500]; windowLength = 0.025*fs; overlapLength = 0.015*fs; trainFeatures = melSpectrogram(trainAudio,fs, ... 'Window',hann(windowLength,'periodic'), ... 'OverlapLength',overlapLength, ... 'FFTLength',FFTLength, ... 'FrequencyRange',frequencyRange, ... 'NumBands',numBands, ... 'FilterBankNormalization','none', ... 'WindowNormalization',false, ... 'SpectrumType','magnitude', ... 'FilterBankDesignDomain','warped'); trainFeatures = log(trainFeatures + single(0.001)); trainFeatures = permute(trainFeatures,[2,1,4,3]); validationFeatures = melSpectrogram(validationAudio,fs, ... 'Window',hann(windowLength,'periodic'), ... 'OverlapLength',overlapLength, ... 'FFTLength',FFTLength, ... 'FrequencyRange',frequencyRange, ... 'NumBands',numBands, ... 'FilterBankNormalization','none', ... 'WindowNormalization',false, ... 'SpectrumType','magnitude', ... 'FilterBankDesignDomain','warped'); validationFeatures = log(validationFeatures + single(0.001)); validationFeatures = permute(validationFeatures,[2,1,4,3]);
Чтобы загрузить предварительно обученную сеть, вызовите yamnet
. Если модель Audio Toolbox для YAMNet не установлена, то функция предоставляет ссылку на расположение весов сети. Чтобы скачать модель, щелкните ссылку. Разархивируйте файл в местоположении по пути MATLAB. Модель YAMNet может классифицировать аудио в одну из 521 категорий звука, включая белый шум и розовый шум (но не коричневый шум).
net = yamnet; net.Layers(end).Classes
ans = 521×1 categorical
Speech
Child speech, kid speaking
Conversation
Narration, monologue
Babbling
Speech synthesizer
Shout
Bellow
Whoop
Yell
Children shouting
Screaming
Whispering
Laughter
Baby laughter
Giggle
Snicker
Belly laugh
Chuckle, chortle
Crying, sobbing
Baby cry, infant cry
Whimper
Wail, moan
Sigh
Singing
Choir
Yodeling
Chant
Mantra
Child singing
⋮
Подготовьте модель для передачи обучения путем предварительного преобразования сети в layerGraph
(Deep Learning Toolbox). Использование replaceLayer
(Deep Learning Toolbox), чтобы заменить полносвязной слой необученным полносвязным слоем. Замените классификационный слой классификационным слоем, который классифицирует вход как «белый», «розовый» или «коричневый». Список слоев глубокого обучения (Deep Learning Toolbox) содержит информацию о слоях глубокого обучения, поддерживаемых в MATLAB ®.
uniqueLabels = unique(trainLabels); numLabels = numel(uniqueLabels); lgraph = layerGraph(net.Layers); lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,"Name","dense")); lgraph = replaceLayer(lgraph,"Sound",classificationLayer("Name","Sounds","Classes",uniqueLabels));
Чтобы определить опции обучения, используйте trainingOptions
(Deep Learning Toolbox).
options = trainingOptions('adam','ValidationData',{single(validationFeatures),validationLabels});
Для обучения сети используйте trainNetwork
(Deep Learning Toolbox). Сеть достигает точности валидации 100%, используя только 5 сигналов на тип шума.
trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU. |======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:02 | 20.00% | 88.77% | 1.1922 | 0.6619 | 0.0010 | | 30 | 30 | 00:00:14 | 100.00% | 100.00% | 9.1076e-06 | 5.0431e-05 | 0.0010 | |======================================================================================================================|