Передача обучения с предварительно обученными аудиосетями

Этот пример показывает, как использовать передачу обучения для переобучения YAMNet, предварительно обученной сверточной нейронной сети, чтобы классифицировать новый набор аудиосигналов. Чтобы начать с глубокого обучения с нуля, смотрите Классификацию Звука Используя Глубокое Обучение.

Передача обучения обычно используется в применениях глубокого обучения. Можно взять предварительно обученную сеть и использовать ее как начальная точка для изучения новой задачи. Подстройка сети с передачей обучения обычно намного быстрее и проще, чем обучение сети со случайным образом инициализированными весами с нуля. Можно быстро передать выученные функции в новую задачу с помощью меньшего количества обучающих сигналов.

Audio Toolbox™ дополнительно обеспечивает classifySound функция, которая реализует необходимую предварительную обработку для YAMNet и удобную постобработку для интерпретации результатов. Audio Toolbox также предоставляет предварительно обученную сеть VGGish (vggish) а также vggishFeatures функция, которая реализует предварительную обработку и постобработку для сети VGGish.

Создание данных

Сгенерируйте 100 сигналов белого шума, 100 сигналов коричневого шума и 100 сигналов розового шума. Каждый сигнал представляет длительность 0,98 секунд, принимая частоту дискретизации 16 кГц.

fs = 16e3;
duration = 0.98;
N = duration*fs;
numSignals = 100;

wNoise = 2*rand([N,numSignals]) - 1;
wLabels = repelem(categorical("white"),numSignals,1);

bNoise = filter(1,[1,-0.999],wNoise);
bNoise = bNoise./max(abs(bNoise),[],'all');
bLabels = repelem(categorical("brown"),numSignals,1);

pNoise = pinknoise([N,numSignals]);
pLabels = repelem(categorical("pink"),numSignals,1);

Разделите данные на обучающие и тестовые наборы. Обычно набор обучающих данных состоит из большинства данных. Однако, чтобы проиллюстрировать силу передачи обучения, вы будете использовать только несколько выборки для обучения и большинство для валидации.

K = 5;

trainAudio = [wNoise (:, 1: K), bNoise (:, 1: K), pNoise (:,1:K)];
trainLabels = [wLabels (1: K); bLabels (1: K); pLabels (1:K)];

validationAudio = [wNoise (:, K + 1: end), bNoise (:, K + 1: end), pNoise (:, K + 1: end)];
validationLabels = [wLabels (K + 1: end); bLabels (K + 1: end); pLabels (K + 1: end)];

fprintf ("Number of samples per noise color in train set = %d\n" + ...
        "Number of samples per noise color in validation set = %d\n", K, numSignals-K);
Number of samples per noise color in train set = 5
Number of samples per noise color in validation set = 95

Извлечение функций

Использование melSpectrogram для извлечения логарифмических спектрограмм как из набора обучающих данных, так и из набора валидации с использованием тех же параметров, на которых была обучена модель YAMNet.

FFTLength = 512;
numBands = 64;
frequencyRange = [125 7500];
windowLength = 0.025*fs;
overlapLength = 0.015*fs;

trainFeatures = melSpectrogram(trainAudio,fs, ...
    'Window',hann(windowLength,'periodic'), ...
    'OverlapLength',overlapLength, ...
    'FFTLength',FFTLength, ...
    'FrequencyRange',frequencyRange, ...
    'NumBands',numBands, ...
    'FilterBankNormalization','none', ...
    'WindowNormalization',false, ...
    'SpectrumType','magnitude', ...
    'FilterBankDesignDomain','warped');

trainFeatures = log(trainFeatures + single(0.001));
trainFeatures = permute(trainFeatures,[2,1,4,3]);

validationFeatures = melSpectrogram(validationAudio,fs, ...
    'Window',hann(windowLength,'periodic'), ...
    'OverlapLength',overlapLength, ...
    'FFTLength',FFTLength, ...
    'FrequencyRange',frequencyRange, ...
    'NumBands',numBands, ...
    'FilterBankNormalization','none', ...
    'WindowNormalization',false, ...
    'SpectrumType','magnitude', ...
    'FilterBankDesignDomain','warped');

validationFeatures = log(validationFeatures + single(0.001));
validationFeatures = permute(validationFeatures,[2,1,4,3]);

Передача обучения

Чтобы загрузить предварительно обученную сеть, вызовите yamnet. Если модель Audio Toolbox для YAMNet не установлена, то функция предоставляет ссылку на расположение весов сети. Чтобы скачать модель, щелкните ссылку. Разархивируйте файл в местоположении по пути MATLAB. Модель YAMNet может классифицировать аудио в одну из 521 категорий звука, включая белый шум и розовый шум (но не коричневый шум).

net = yamnet;
net.Layers(end).Classes
ans = 521×1 categorical
     Speech 
     Child speech, kid speaking 
     Conversation 
     Narration, monologue 
     Babbling 
     Speech synthesizer 
     Shout 
     Bellow 
     Whoop 
     Yell 
     Children shouting 
     Screaming 
     Whispering 
     Laughter 
     Baby laughter 
     Giggle 
     Snicker 
     Belly laugh 
     Chuckle, chortle 
     Crying, sobbing 
     Baby cry, infant cry 
     Whimper 
     Wail, moan 
     Sigh 
     Singing 
     Choir 
     Yodeling 
     Chant 
     Mantra 
     Child singing 
      ⋮

Подготовьте модель для передачи обучения путем предварительного преобразования сети в layerGraph (Deep Learning Toolbox). Использование replaceLayer (Deep Learning Toolbox), чтобы заменить полносвязной слой необученным полносвязным слоем. Замените классификационный слой классификационным слоем, который классифицирует вход как «белый», «розовый» или «коричневый». Список слоев глубокого обучения (Deep Learning Toolbox) содержит информацию о слоях глубокого обучения, поддерживаемых в MATLAB ®.

uniqueLabels = unique(trainLabels);
numLabels = numel(uniqueLabels);

lgraph = layerGraph(net.Layers);

lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,"Name","dense"));
lgraph = replaceLayer(lgraph,"Sound",classificationLayer("Name","Sounds","Classes",uniqueLabels));

Чтобы определить опции обучения, используйте trainingOptions (Deep Learning Toolbox).

options = trainingOptions('adam','ValidationData',{single(validationFeatures),validationLabels});

Для обучения сети используйте trainNetwork (Deep Learning Toolbox). Сеть достигает точности валидации 100%, используя только 5 сигналов на тип шума.

trainNetwork(single(trainFeatures),trainLabels,lgraph,options);
Training on single CPU.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |       20.00% |       88.77% |       1.1922 |       0.6619 |          0.0010 |
|      30 |          30 |       00:00:14 |      100.00% |      100.00% |   9.1076e-06 |   5.0431e-05 |          0.0010 |
|======================================================================================================================|