Обучите разговорную сеть распознавания цифры использование функций из памяти

Этот пример обучает разговорную сеть распознавания цифры на слуховых спектрограммах из памяти с помощью преобразованного datastore. В этом примере вы извлекаете слуховые спектрограммы из аудио с помощью audioDatastore и audioFeatureExtractor, и вы пишете им в диск. Вы затем используете signalDatastore получать доступ к функциям во время обучения. Рабочий процесс полезен, когда учебные функции не умещаются в памяти. В этом рабочем процессе вы только извлекаете функции однажды, который ускоряет ваш рабочий процесс, если вы выполняете итерации на проекте модели глубокого обучения.

Данные

Загрузите Свободный Разговорный Набор данных Цифры (FSDD). FSDD состоит из 2 000 записей четырех динамиков, говорящих числа 0 до 9 на английском языке.

url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip';

downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,'FSDD');

if ~exist(datasetFolder,'dir')
    disp('Downloading FSDD...')
    unzip(url,downloadFolder)
end

Создайте audioDatastore это указывает на набор данных.

pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);

Функция помощника, helperGenerateLabels, создает категориальный массив меток из файлов FSDD. Исходный код для helperGenerateLabels перечислен в приложении. Отобразите классы и количество примеров в каждом классе.

ads.Labels = helperGenerateLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 8).
summary(ads.Labels)
     0      200 
     1      200 
     2      200 
     3      200 
     4      200 
     5      200 
     6      200 
     7      200 
     8      200 
     9      200 

Разделите FSDD в наборы обучающих данных и наборы тестов. Выделите 80% данных к набору обучающих данных и сохраните 20% для набора тестов. Вы используете набор обучающих данных, чтобы обучить модель и набор тестов подтверждать обученную модель.

rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
    Label    Count
    _____    _____

      0       160 
      1       160 
      2       160 
      3       160 
      4       160 
      5       160 
      6       160 
      7       160 
      8       160 
      9       160 

countEachLabel(adsTest)
ans=10×2 table
    Label    Count
    _____    _____

      0       40  
      1       40  
      2       40  
      3       40  
      4       40  
      5       40  
      6       40  
      7       40  
      8       40  
      9       40  

Уменьшайте обучающий набор данных

Чтобы обучить сеть с набором данных в целом и достигнуть максимально возможной точности, установите reduceDataset ко лжи. Чтобы запустить этот пример быстро, установите reduceDataset к истине.

reduceDataset = "false";
if reduceDataset == "true"
    adsTrain = splitEachLabel (adsTrain, 2);
    adsTest = splitEachLabel (adsTest, 2);
end

Настройте Слуховую Экстракцию Спектрограммы

CNN принимает спектрограммы mel-частоты.

Задайте параметры, используемые, чтобы извлечь спектрограммы mel-частоты. Используйте 220 MS Windows с транзитными участками на 10 мс между окнами. Используйте ДПФ с 2048 точками и 40 диапазонов частот.

fs = 8000;
frameDuration = 0.22;
hopDuration = 0.01;
params.segmentLength = 8192;
segmentDuration = params.segmentLength*(1/fs);
params.numHops = ceil((segmentDuration-frameDuration)/hopDuration);
params.numBands = 40;
frameLength = round(frameDuration*fs);
hopLength = round(hopDuration*fs);
fftLength = 2048;

Создайте audioFeatureExtractor объект вычислить спектрограммы mel-частоты из входных звуковых сигналов.

afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs);
afe.Window = hamming(frameLength,'periodic');
afe.OverlapLength = frameLength-hopLength;
afe.FFTLength = fftLength;

Установите параметры для спектрограммы mel-частоты.

setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);

Создайте преобразованный datastore, который вычисляет спектрограммы mel-частоты из аудиоданных. Функция помощника, getSpeechSpectrogram (см. приложение), стандартизирует продолжительность записи и нормирует амплитуду аудиовхода. getSpeechSpectrogram использует audioFeatureExtractor объект afe получить основанные на журнале спектрограммы mel-частоты.

adsSpecTrain = transform(adsTrain,@(x)getSpeechSpectrogram(x,afe,params));

Запишите слуховые спектрограммы в диск

Используйте writeall записать слуховые спектрограммы в диск. Установите UseParallel к истине, чтобы выполнить запись параллельно.

writeall(adsSpecTrain,pwd,'WriteFcn',@myCustomWriter,'UseParallel',true);

Настройте Учебный Datastore Сигнала

Создайте signalDatastore это указывает на функции из памяти. Пользовательская функция чтения возвращает пару спектрограммы/метки.

sds = signalDatastore('recordings','ReadFcn',@myCustomRead);

Данные о валидации

Набор данных валидации помещается в память, и вы предварительно вычисляете функции валидации с помощью функции помощника getValidationSpeechSpectrograms (см. приложение).

XTest = getValidationSpeechSpectrograms(adsTest,afe,params);

Получите метки валидации.

YTest = adsTest.Labels;

Задайте архитектуру CNN

Создайте маленький CNN как массив слоев. Используйте сверточный и слои нормализации партии. и проредите карты функции с помощью макс. слоев объединения. Чтобы уменьшать возможность сети, запоминая определенные функции обучающих данных, добавьте небольшое количество уволенного к входу к последнему полносвязному слою.

sz = size(XTest);
specSize = sz(1:2);
imageSize = [specSize 1];

numClasses = numel(categories(YTest));

dropoutProb = 0.2;
numF = 12;
layers = [
    imageInputLayer(imageSize,'Normalization','none')

    convolution2dLayer(5,numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(2)

    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer('Classes',categories(YTest));
    ];

Установите гиперпараметры использовать в обучении сети. Используйте мини-пакетный размер 50 и скорость обучения 1e-4. Задайте 'adamОптимизация. Чтобы использовать параллельный пул, чтобы считать преобразованный datastore, устанавливает DispatchInBackground к true. Для получения дополнительной информации смотрите trainingOptions (Deep Learning Toolbox).

miniBatchSize = 50;
options = trainingOptions('adam', ...
    'InitialLearnRate',1e-4, ...
    'MaxEpochs',30, ...
    'LearnRateSchedule',"piecewise",...
    'LearnRateDropFactor',.1,...
    'LearnRateDropPeriod',15,...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XTest, YTest},...
    'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),...
    'ExecutionEnvironment','gpu',...
    'DispatchInBackground', true);

Обучите сеть путем передачи учебного datastore trainNetwork.

trainedNet = trainNetwork(sds,layers,options);

Используйте обучивший сеть, чтобы предсказать метки цифры для набора тестов.

[Ypredicted,probs] = classify(trainedNet,XTest);
cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 97

Обобщите эффективность обучившего сеть на наборе тестов с графиком беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца. Таблица в нижней части графика беспорядка показывает значения точности. Таблица справа от графика беспорядка показывает значения отзыва.

figure('Units','normalized','Position',[0.2 0.2 1.5 1.5]);
ccDCNN = confusionchart(YTest,Ypredicted);
ccDCNN.Title = 'Confusion Chart for DCNN';
ccDCNN.ColumnSummary = 'column-normalized';
ccDCNN.RowSummary = 'row-normalized';

Приложение: Функции помощника

function Labels = helperGenerateLabels(ads)
% This function is only for use in this example. It may be changed or
% removed in a future release.
files = ads.Files;
tmp = cell(numel(files),1);
expression = "[0-9]+_";
parfor nf = 1:numel(ads.Files)
    idx = regexp(files{nf},expression);
    tmp{nf} = files{nf}(idx);
end
Labels = categorical(tmp);
end

%------------------------------------------------------------
function X = getValidationSpeechSpectrograms(ads,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for
% the files in the datastore ads using the audioFeatureExtractor afe.

numFiles = length(ads.Files);
X = zeros([params.numBands,params.numHops,1,numFiles],'single');

for i = 1:numFiles
    x = read(ads);    
    spec = getSpeechSpectrogram(x,afe,params);    
    X(:,:,1,i) = spec;
    
end

end

%--------------------------------------------------------------------------
function X = getSpeechSpectrogram(x,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal
% x using the audioFeatureExtractor afe.

X = zeros([params.numBands,params.numHops],'single');

x = normalizeAndResize(single(x),params);

spec = extract(afe,x).';

% If the spectrogram is less wide than numHops, then put spectrogram in
% the middle of X.
w = size(spec,2);
left = floor((params.numHops-w)/2)+1;
ind = left:left+w-1;
X(:,ind) = log10(spec + 1e-6);

end
%--------------------------------------------------------------------------
function x = normalizeAndResize(x,params)
% This function is only for use in this example. It may change or be
% removed in a future release.

L = params.segmentLength;
N = numel(x);
if N > L
    x = x(1:L);
elseif N < L
    pad = L-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = x./max(abs(x));
end
%--------------------------------------------------------------------------
function myCustomWriter(spec,writeInfo,~)
% This function is only for use in this example. It may change or be
% removed in a future release.
% Define custom writing function to write auditory spectrogram/label pair
% to MAT files.
filename = strrep(writeInfo.SuggestedOutputName, '.wav','.mat');
label = writeInfo.ReadInfo.Label;
save(filename,'label','spec');
end
%--------------------------------------------------------------------------
function [data,info] = myCustomRead(filename)
% This function is only for use in this example. It may change or be
% removed in a future release.
load(filename,'spec','label');
data = {spec,label};
info.SampleRate = 8000;
end