Передача обучения с предварительно обученными аудио сетями

В этом примере показано, как использовать передачу обучения, чтобы переобучить YAMNet, предварительно обученную сверточную нейронную сеть, чтобы классифицировать новый набор звуковых сигналов. Чтобы начать с аудио глубоким обучением с нуля, смотрите, Классифицируют Звук Используя Глубокое обучение (Audio Toolbox).

Передача обучения обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве начальной точки, чтобы изучить новую задачу. Подстройка сети с передачей обучения обычно намного быстрее и легче, чем обучение сети со случайным образом инициализированными весами с нуля. Можно быстро передать изученные функции новой задаче с помощью меньшего числа учебных сигналов.

Audio Toolbox™ дополнительно обеспечивает classifySound Функция (Audio Toolbox), которая реализует необходимую предварительную обработку для YAMNet и удобной постобработки, чтобы интерпретировать результаты. Audio Toolbox также обеспечивает предварительно обученную сеть VGGish (vggish (Audio Toolbox)), а также vggishFeatures Функция (Audio Toolbox), которая реализует предварительную обработку и постобработку для сети VGGish.

Создайте данные

Сгенерируйте 100 сигналов белого шума, 100 коричневых шумовых сигналов и 100 розовых шумовых сигналов. Каждый сигнал представляет длительность 0,98 секунд, принимая частоту дискретизации на 16 кГц.

fs = 16e3;
duration = 0.98;
N = duration*fs;
numSignals = 100;

wNoise = 2*rand([N,numSignals]) - 1;
wLabels = repelem(categorical("white"),numSignals,1);

bNoise = filter(1,[1,-0.999],wNoise);
bNoise = bNoise./max(abs(bNoise),[],'all');
bLabels = repelem(categorical("brown"),numSignals,1);

pNoise = pinknoise([N,numSignals]);
pLabels = repelem(categorical("pink"),numSignals,1);

Разделите данные в наборы обучающих данных и наборы тестов. Обычно, набор обучающих данных состоит из большинства данных. Однако, чтобы проиллюстрировать степень передачи обучения, вы будете использовать только несколько выборок для обучения и большинство за валидацию.

K = 5;

trainAudio = [wNoise (: 1:K), bNoise (: 1:K), pNoise (: 1:K)];
trainLabels = [wLabels (1:K); bLabels (1:K); pLabels (1:K)];

validationAudio = [wNoise (: K+1:end), bNoise (: K+1:end), pNoise (: K+1:end)];
validationLabels = [wLabels (K+1:end); bLabels (K+1:end); pLabels (K+1:end)];

fprintf"Number of samples per noise color in train set = %d\n" + ...
        "Number of samples per noise color in validation set = %d\n", K, numSignals-K);
Number of samples per noise color in train set = 5
Number of samples per noise color in validation set = 95

Выделить признаки

Используйте melSpectrogram (Audio Toolbox), чтобы извлечь логарифмические-mel спектрограммы и из набора обучающих данных и из набора валидации с помощью тех же параметров в качестве модели YAMNet был обучен на.

FFTLength = 512;
numBands = 64;
frequencyRange = [125 7500];
windowLength = 0.025*fs;
overlapLength = 0.015*fs;

trainFeatures = melSpectrogram(trainAudio,fs, ...
    'Window',hann(windowLength,'periodic'), ...
    'OverlapLength',overlapLength, ...
    'FFTLength',FFTLength, ...
    'FrequencyRange',frequencyRange, ...
    'NumBands',numBands, ...
    'FilterBankNormalization','none', ...
    'WindowNormalization',false, ...
    'SpectrumType','magnitude', ...
    'FilterBankDesignDomain','warped');

trainFeatures = log(trainFeatures + single(0.001));
trainFeatures = permute(trainFeatures,[2,1,4,3]);

validationFeatures = melSpectrogram(validationAudio,fs, ...
    'Window',hann(windowLength,'periodic'), ...
    'OverlapLength',overlapLength, ...
    'FFTLength',FFTLength, ...
    'FrequencyRange',frequencyRange, ...
    'NumBands',numBands, ...
    'FilterBankNormalization','none', ...
    'WindowNormalization',false, ...
    'SpectrumType','magnitude', ...
    'FilterBankDesignDomain','warped');

validationFeatures = log(validationFeatures + single(0.001));
validationFeatures = permute(validationFeatures,[2,1,4,3]);

Передача обучения

Чтобы загрузить предварительно обученную сеть, вызовите yamnet (Audio Toolbox). Если модель Audio Toolbox для YAMNet не установлена, то функция обеспечивает ссылку на местоположение сетевых весов. Чтобы загрузить модель, щелкните по ссылке. Разархивируйте файл к местоположению на пути MATLAB. Модель YAMNet может классифицировать аудио в одну из 521 звуковой категории, включая белый шум и розовый шум (но не коричневый шум).

net = yamnet;
net.Layers(end).Classes
ans = 521×1 categorical
     Speech 
     Child speech, kid speaking 
     Conversation 
     Narration, monologue 
     Babbling 
     Speech synthesizer 
     Shout 
     Bellow 
     Whoop 
     Yell 
     Children shouting 
     Screaming 
     Whispering 
     Laughter 
     Baby laughter 
     Giggle 
     Snicker 
     Belly laugh 
     Chuckle, chortle 
     Crying, sobbing 
     Baby cry, infant cry 
     Whimper 
     Wail, moan 
     Sigh 
     Singing 
     Choir 
     Yodeling 
     Chant 
     Mantra 
     Child singing 
      ⋮

Подготовьте модель к передаче обучения первым преобразованием сети к layerGraph. Используйте replaceLayer заменять полносвязный слой на нетренированный полносвязный слой. Замените слой классификации на слой классификации, который классифицирует вход как "белый", "розовый", или "коричневый". Смотрите Список слоев глубокого обучения для слоев глубокого обучения, поддержанных в MATLAB®.

uniqueLabels = unique(trainLabels);
numLabels = numel(uniqueLabels);

lgraph = layerGraph(net.Layers);

lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,"Name","dense"));
lgraph = replaceLayer(lgraph,"Sound",classificationLayer("Name","Sounds","Classes",uniqueLabels));

Чтобы задать опции обучения, используйте trainingOptions.

options = trainingOptions('adam','ValidationData',{single(validationFeatures),validationLabels});

Чтобы обучить сеть, используйте trainNetwork. Сеть достигает точности валидации 100% с помощью только 5 сигналов на шумовой тип.

trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |       20.00% |       88.77% |       1.1922 |       0.6619 |          0.0010 |
|      30 |          30 |       00:00:14 |      100.00% |      100.00% |   9.1076e-06 |   5.0431e-05 |          0.0010 |
|======================================================================================================================|