Обучите сеть распознавания разговорных цифр с помощью функций нехватки памяти

Этот пример обучает сеть распознавания разговорных цифр на слуховых спектрограммах за пределами памяти, используя преобразованный datastore. В этом примере вы извлекаете слуховые спектрограммы из аудио с помощью audioDatastore и audioFeatureExtractorи вы записываете их на диск. Затем вы используете signalDatastore для доступа к функциям во время обучения. Рабочий процесс полезен, когда функции обучения не помещаются в памяти. В этом рабочем процессе вы извлекаете функции только один раз, что ускоряет рабочий процесс, если вы итератируетесь по проекту модели глубокого обучения.

Данные

Загрузите бесплатный набор данных (FSDD). FSDD состоит из 2000 записей четырех ораторов, говорящих цифры от 0 до 9 на английском языке.

url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip';

downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,'FSDD');

if ~exist(datasetFolder,'dir')
    disp('Downloading FSDD...')
    unzip(url,downloadFolder)
end
Downloading FSDD...

Создайте audioDatastore это указывает на набор данных.

pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);

Функция помощника, helperGenerateLabelsсоздает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels приведено в приложении. Отображение классов и количества примеров в каждом классе.

ads.Labels = helpergenLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).
summary(ads.Labels)
     0      200 
     1      200 
     2      200 
     3      200 
     4      200 
     5      200 
     6      200 
     7      200 
     8      200 
     9      200 

Разделите FSDD на обучающие и тестовые наборы. Выделите 80% данных наборов обучающих данных и сохраните 20% для тестового набора. Вы используете набор обучающих данных для обучения модели и тестовый набор для проверки обученной модели.

rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
    Label    Count
    _____    _____

      0       160 
      1       160 
      2       160 
      3       160 
      4       160 
      5       160 
      6       160 
      7       160 
      8       160 
      9       160 

countEachLabel(adsTest)
ans=10×2 table
    Label    Count
    _____    _____

      0       40  
      1       40  
      2       40  
      3       40  
      4       40  
      5       40  
      6       40  
      7       40  
      8       40  
      9       40  

Сокращение набора обучающих наборов данных

Чтобы обучить сеть со набором данных в целом и достичь максимально возможной точности, установите reduceDataset на ложь. Чтобы запустить этот пример быстро, установите reduceDataset к true.

reduceDataset = "false";
if reduceDataset =  ="true"
    adsTrain = splitEachLabel (adsTrain, 2);
    adsTest = splitEachLabel (adsTest, 2);
end

Настройка извлечения слуховой спектрограммы

CNN принимает мел-частотные спектрограммы.

Задайте параметры, используемые для извлечения мел-частотных спектрограмм. Используйте 220 мс окна с 10 мс хмеля между окнами. Используйте ДПФ с 2048 точками и 40 полосы частот.

fs = 8000;
frameDuration = 0.22;
hopDuration = 0.01;
params.segmentLength = 8192;
segmentDuration = params.segmentLength*(1/fs);
params.numHops = ceil((segmentDuration-frameDuration)/hopDuration);
params.numBands = 40;
frameLength = round(frameDuration*fs);
hopLength = round(hopDuration*fs);
fftLength = 2048;

Создайте audioFeatureExtractor объект для вычисления спектрограмм mel-частоты из входных аудиосигналов.

afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs);
afe.Window = hamming(frameLength,'periodic');
afe.OverlapLength = frameLength-hopLength;
afe.FFTLength = fftLength;

Установите параметры для спектрограммы мел-частоты.

setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);

Создайте преобразованный datastore, который вычисляет спектрограммы mel-частоты из аудио данных. Функция помощника, getSpeechSpectrogram (см. приложение), стандартизирует длину записи и нормализует амплитуду аудиовхода. getSpeechSpectrogram использует audioFeatureExtractor afe объекта для получения логарифмически основанных мел-частотных спектрограмм.

adsSpecTrain = transform(adsTrain,@(x)getSpeechSpectrogram(x,afe,params));

Запись слуховых спектрограмм на диск

Использование writeall запись слуховых спектрограмм на диск. Задайте UseParallel true, чтобы выполнять запись параллельно.

writeall(adsSpecTrain,pwd,'WriteFcn',@myCustomWriter,'UseParallel',true);

Настройте Training Signal Datastore

Создайте signalDatastore это указывает на функции нехватки памяти. Пользовательская функция чтения возвращает пару спектрограмма/метка.

sds = signalDatastore('recordings','ReadFcn',@myCustomRead);

Данные валидации

Набор данных валидации помещается в память, и вы предварительно вычисляете функции валидации с помощью функции helper getValidationSpeechSpectrograms (см. приложение).

XTest = getValidationSpeechSpectrograms(adsTest,afe,params);

Получите метки валидации.

YTest = adsTest.Labels;

Определение архитектуры CNN

Создайте небольшой CNN как массив слоев. Используйте сверточные и пакетные слои нормализации и понижающее отображение карт признаков с помощью максимальных слоев объединения. Чтобы уменьшить возможность запоминания сетью специфических функций обучающих данных, добавьте небольшое количество отсева на вход к последнему полностью подключенному слою.

sz = size(XTest);
specSize = sz(1:2);
imageSize = [specSize 1];

numClasses = numel(categories(YTest));

dropoutProb = 0.2;
numF = 12;
layers = [
    imageInputLayer(imageSize,'Normalization','none')

    convolution2dLayer(5,numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(2)

    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer('Classes',categories(YTest));
    ];

Установите гиперпараметры, которые будут использоваться при обучении сети. Используйте мини-пакет размером 50 и скоростью обучения 1e-4. Задайте 'adam'оптимизация. Чтобы использовать параллельный пул для чтения преобразованного datastore, установите DispatchInBackground на true. Для получения дополнительной информации смотрите trainingOptions (Deep Learning Toolbox).

miniBatchSize = 50;
options = trainingOptions('adam', ...
    'InitialLearnRate',1e-4, ...
    'MaxEpochs',30, ...
    'LearnRateSchedule',"piecewise",...
    'LearnRateDropFactor',.1,...
    'LearnRateDropPeriod',15,...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XTest, YTest},...
    'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),...
    'ExecutionEnvironment','gpu',...
    'DispatchInBackground', true);

Обучите сеть, передав обучающий datastore, чтобы trainNetwork.

trainedNet = trainNetwork(sds,layers,options);

Используйте обученную сеть, чтобы предсказать метки цифр для тестового набора.

[Ypredicted,probs] = classify(trainedNet,XTest);
cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 96.5000

Суммируйте эффективность обученной сети на тестовом наборе с графиком неточностей. Отображение точности и отзыва для каждого класса с помощью сводных данных по столбцам и строкам. Таблица в нижней части графика неточностей показывает значения точности. Таблица справа от графика неточностей показывает значения отзыва.

figure('Units','normalized','Position',[0.2 0.2 1.5 1.5]);
ccDCNN = confusionchart(YTest,Ypredicted);
ccDCNN.Title = 'Confusion Chart for DCNN';
ccDCNN.ColumnSummary = 'column-normalized';
ccDCNN.RowSummary = 'row-normalized';

Приложение: Функции помощника

function Labels = helpergenLabels(ads)
% This function is only for use in this example. It may be changed or
% removed in a future release.
files = ads.Files;
tmp = cell(numel(files),1);
expression = "[0-9]+_";
parfor nf = 1:numel(ads.Files)
    idx = regexp(files{nf},expression);
    tmp{nf} = files{nf}(idx);
end
Labels = categorical(tmp);
end

%------------------------------------------------------------
function X = getValidationSpeechSpectrograms(ads,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for
% the files in the datastore ads using the audioFeatureExtractor afe.

numFiles = length(ads.Files);
X = zeros([params.numBands,params.numHops,1,numFiles],'single');

for i = 1:numFiles
    x = read(ads);    
    spec = getSpeechSpectrogram(x,afe,params);    
    X(:,:,1,i) = spec;
    
end

end

%--------------------------------------------------------------------------
function X = getSpeechSpectrogram(x,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal
% x using the audioFeatureExtractor afe.

X = zeros([params.numBands,params.numHops],'single');

x = normalizeAndResize(single(x),params);

spec = extract(afe,x).';

% If the spectrogram is less wide than numHops, then put spectrogram in
% the middle of X.
w = size(spec,2);
left = floor((params.numHops-w)/2)+1;
ind = left:left+w-1;
X(:,ind) = log10(spec + 1e-6);

end
%--------------------------------------------------------------------------
function x = normalizeAndResize(x,params)
% This function is only for use in this example. It may change or be
% removed in a future release.

L = params.segmentLength;
N = numel(x);
if N > L
    x = x(1:L);
elseif N < L
    pad = L-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = x./max(abs(x));
end
%--------------------------------------------------------------------------
function myCustomWriter(spec,writeInfo,~)
% This function is only for use in this example. It may change or be
% removed in a future release.
% Define custom writing function to write auditory spectrogram/label pair
% to MAT files.
filename = strrep(writeInfo.SuggestedOutputName, '.wav','.mat');
label = writeInfo.ReadInfo.Label;
save(filename,'label','spec');
end
%--------------------------------------------------------------------------
function [data,info] = myCustomRead(filename)
% This function is only for use in this example. It may change or be
% removed in a future release.
load(filename,'spec','label');
data = {spec,label};
info.SampleRate = 8000;
end