exponenta event banner

Сеть распознавания речевых цифр Train с использованием функций, не соответствующих памяти

В этом примере обучается сеть распознавания речевых цифр на слуховых спектрограммах, не имеющих памяти, с использованием преобразованного хранилища данных. В этом примере вы извлекаете слуховые спектрограммы из звука с помощью audioDatastore и audioFeatureExtractorи вы записываете их на диск. Затем вы используете signalDatastore для доступа к функциям во время обучения. Рабочий процесс полезен, когда обучающие функции не помещаются в память. В этом рабочем процессе элементы извлекаются только один раз, что ускоряет рабочий процесс, если выполняется итерация в проекте модели глубокого обучения.

Данные

Загрузите бесплатный набор данных речевых цифр (FSDD). FSDD состоит из 2000 записей четырёх говорящих, говорящих цифры от 0 до 9 на английском языке.

url = 'https://ssd.mathworks.com/supportfiles/audio/FSDD.zip';

downloadFolder = tempdir;
datasetFolder = fullfile(downloadFolder,'FSDD');

if ~exist(datasetFolder,'dir')
    disp('Downloading FSDD...')
    unzip(url,downloadFolder)
end
Downloading FSDD...

Создание audioDatastore указывает на набор данных.

pathToRecordingsFolder = fullfile(datasetFolder,'recordings');
location = pathToRecordingsFolder;
ads = audioDatastore(location);

Вспомогательная функция, helperGenerateLabelsсоздает категориальный массив меток из файлов FSDD. Исходный код для helpergenLabels перечислены в приложении. Просмотрите классы и количество примеров в каждом классе.

ads.Labels = helpergenLabels(ads);
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).
summary(ads.Labels)
     0      200 
     1      200 
     2      200 
     3      200 
     4      200 
     5      200 
     6      200 
     7      200 
     8      200 
     9      200 

Разделение FSDD на учебные и тестовые наборы. Выделите 80% данных обучающему набору и сохраните 20% для тестового набора. Обучающий набор используется для обучения модели, а тестовый набор - для проверки обучаемой модели.

rng default
ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
countEachLabel(adsTrain)
ans=10×2 table
    Label    Count
    _____    _____

      0       160 
      1       160 
      2       160 
      3       160 
      4       160 
      5       160 
      6       160 
      7       160 
      8       160 
      9       160 

countEachLabel(adsTest)
ans=10×2 table
    Label    Count
    _____    _____

      0       40  
      1       40  
      2       40  
      3       40  
      4       40  
      5       40  
      6       40  
      7       40  
      8       40  
      9       40  

Сокращение набора данных для обучения

Чтобы обучить сеть всему набору данных и достичь максимально возможной точности, установите reduceDataset в значение false. Для быстрого выполнения этого примера установите reduceDataset к true.

reduceDataset = "false";
if reduceDataset == "true"
    adsTrain = splitEachLabel(adsTrain,2);
    adsTest = splitEachLabel(adsTest,2);
end

Настройка извлечения слуховой спектрограммы

CNN принимает спектрограммы мел-частоты.

Определите параметры, используемые для извлечения мел-частотных спектрограмм. Используйте окна 220 мс с 10 мс прыжками между окнами. Используйте 2048-точечный DFT и 40 диапазонов частот.

fs = 8000;
frameDuration = 0.22;
hopDuration = 0.01;
params.segmentLength = 8192;
segmentDuration = params.segmentLength*(1/fs);
params.numHops = ceil((segmentDuration-frameDuration)/hopDuration);
params.numBands = 40;
frameLength = round(frameDuration*fs);
hopLength = round(hopDuration*fs);
fftLength = 2048;

Создание audioFeatureExtractor Сущность изобретения заключается в том, что на основе входных аудиосигналов вычисляют мель-частотные спектрограммы.

afe = audioFeatureExtractor('melSpectrum',true,'SampleRate',fs);
afe.Window = hamming(frameLength,'periodic');
afe.OverlapLength = frameLength-hopLength;
afe.FFTLength = fftLength;

Задайте параметры для спектрограммы mel-frequency.

setExtractorParams(afe,'melSpectrum','NumBands',params.numBands,'FrequencyRange',[50 fs/2],'WindowNormalization',true);

Создайте преобразованное хранилище данных, которое вычисляет частотные спектрограммы из аудиоданных. Вспомогательная функция, getSpeechSpectrogram (см. приложение), стандартизирует длину записи и нормализует амплитуду звукового входа. getSpeechSpectrogram использует audioFeatureExtractor объект afe для получения логарифмических спектрограмм.

adsSpecTrain = transform(adsTrain,@(x)getSpeechSpectrogram(x,afe,params));

Запись слуховых спектрограмм на диск

Использовать writeall для записи слуховых спектрограмм на диск. Набор UseParallel true для параллельного выполнения записи.

writeall(adsSpecTrain,pwd,'WriteFcn',@myCustomWriter,'UseParallel',true);

Настройка хранилища данных обучающего сигнала

Создать signalDatastore указывает на отсутствие памяти. Пользовательская функция считывания возвращает пару спектрограмма/метка.

sds = signalDatastore('recordings','ReadFcn',@myCustomRead);

Данные проверки

Набор данных проверки помещается в память, и функции проверки предварительно вычисляются с помощью вспомогательной функции. getValidationSpeechSpectrograms (см. приложение).

XTest = getValidationSpeechSpectrograms(adsTest,afe,params);

Получите метки проверки.

YTest = adsTest.Labels;

Определение архитектуры CNN

Создайте небольшой CNN в виде массива слоев. Используйте слои сверточной и пакетной нормализации, а затем выполните понижающую выборку карт элементов с использованием максимального количества слоев пула. Чтобы уменьшить возможность запоминания сетью специфических особенностей обучающих данных, добавьте небольшое количество отсева на вход последнего полностью подключенного уровня.

sz = size(XTest);
specSize = sz(1:2);
imageSize = [specSize 1];

numClasses = numel(categories(YTest));

dropoutProb = 0.2;
numF = 12;
layers = [
    imageInputLayer(imageSize,'Normalization','none')

    convolution2dLayer(5,numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(2)

    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer('Classes',categories(YTest));
    ];

Установите гиперпараметры для использования при обучении сети. Используйте размер мини-партии 50 и скорость обучения 1e-4. Укажите 'adam'оптимизация. Чтобы использовать параллельный пул для чтения преобразованного хранилища данных, установите DispatchInBackground кому true. Дополнительные сведения см. в разделе trainingOptions (инструментарий глубокого обучения).

miniBatchSize = 50;
options = trainingOptions('adam', ...
    'InitialLearnRate',1e-4, ...
    'MaxEpochs',30, ...
    'LearnRateSchedule',"piecewise",...
    'LearnRateDropFactor',.1,...
    'LearnRateDropPeriod',15,...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XTest, YTest},...
    'ValidationFrequency',ceil(numel(adsTrain.Files)/miniBatchSize),...
    'ExecutionEnvironment','gpu',...
    'DispatchInBackground', true);

Обучение сети путем передачи хранилища данных обучения в trainNetwork.

trainedNet = trainNetwork(sds,layers,options);

Используйте обученную сеть для прогнозирования цифровых меток для тестового набора.

[Ypredicted,probs] = classify(trainedNet,XTest);
cnnAccuracy = sum(Ypredicted==YTest)/numel(YTest)*100
cnnAccuracy = 96.5000

Обобщите производительность обученной сети на тестовом наборе с помощью таблицы путаницы. Отображение точности и отзыва для каждого класса с помощью сводок столбцов и строк. В таблице внизу таблицы путаницы показаны значения точности. В таблице справа от таблицы путаницы показаны значения отзыва.

figure('Units','normalized','Position',[0.2 0.2 1.5 1.5]);
ccDCNN = confusionchart(YTest,Ypredicted);
ccDCNN.Title = 'Confusion Chart for DCNN';
ccDCNN.ColumnSummary = 'column-normalized';
ccDCNN.RowSummary = 'row-normalized';

Приложение: Вспомогательные функции

function Labels = helpergenLabels(ads)
% This function is only for use in this example. It may be changed or
% removed in a future release.
files = ads.Files;
tmp = cell(numel(files),1);
expression = "[0-9]+_";
parfor nf = 1:numel(ads.Files)
    idx = regexp(files{nf},expression);
    tmp{nf} = files{nf}(idx);
end
Labels = categorical(tmp);
end

%------------------------------------------------------------
function X = getValidationSpeechSpectrograms(ads,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getValidationSpeechSpectrograms(ads,afe) computes speech spectrograms for
% the files in the datastore ads using the audioFeatureExtractor afe.

numFiles = length(ads.Files);
X = zeros([params.numBands,params.numHops,1,numFiles],'single');

for i = 1:numFiles
    x = read(ads);    
    spec = getSpeechSpectrogram(x,afe,params);    
    X(:,:,1,i) = spec;
    
end

end

%--------------------------------------------------------------------------
function X = getSpeechSpectrogram(x,afe,params)
% This function is only for use in this example. It may changed or be
% removed in a future release.
%
% getSpeechSpectrogram(x,afe) computes a speech spectrogram for the signal
% x using the audioFeatureExtractor afe.

X = zeros([params.numBands,params.numHops],'single');

x = normalizeAndResize(single(x),params);

spec = extract(afe,x).';

% If the spectrogram is less wide than numHops, then put spectrogram in
% the middle of X.
w = size(spec,2);
left = floor((params.numHops-w)/2)+1;
ind = left:left+w-1;
X(:,ind) = log10(spec + 1e-6);

end
%--------------------------------------------------------------------------
function x = normalizeAndResize(x,params)
% This function is only for use in this example. It may change or be
% removed in a future release.

L = params.segmentLength;
N = numel(x);
if N > L
    x = x(1:L);
elseif N < L
    pad = L-N;
    prepad = floor(pad/2);
    postpad = ceil(pad/2);
    x = [zeros(prepad,1) ; x ; zeros(postpad,1)];
end
x = x./max(abs(x));
end
%--------------------------------------------------------------------------
function myCustomWriter(spec,writeInfo,~)
% This function is only for use in this example. It may change or be
% removed in a future release.
% Define custom writing function to write auditory spectrogram/label pair
% to MAT files.
filename = strrep(writeInfo.SuggestedOutputName, '.wav','.mat');
label = writeInfo.ReadInfo.Label;
save(filename,'label','spec');
end
%--------------------------------------------------------------------------
function [data,info] = myCustomRead(filename)
% This function is only for use in this example. It may change or be
% removed in a future release.
load(filename,'spec','label');
data = {spec,label};
info.SampleRate = 8000;
end