exponenta event banner

Распознавание речевых команд с помощью глубокого обучения

В этом примере показано, как обучить модель глубокого обучения, которая обнаруживает наличие речевых команд в звуке. В примере используется набор данных речевых команд [1] для обучения сверточной нейронной сети распознаванию заданного набора команд.

Чтобы обучить сеть с нуля, необходимо сначала загрузить набор данных. Если вы не хотите загружать набор данных или обучать сеть, то вы можете загрузить предварительно обученную сеть, предоставленную в этом примере, и выполнить следующие два раздела примера: Распознавать команды с предварительно обученной сетью и обнаруживать команды с помощью потокового звука из микрофона.

Распознавание команд с предварительно обученной сетью

Перед подробным изучением процесса обучения для идентификации речевых команд используется предварительно обученная сеть распознавания речи.

Загрузите предварительно обученную сеть.

load('commandNet.mat')

Сеть обучена распознаванию следующих речевых команд:

  • «да»

  • «нет»

  • вверх

  • «вниз»

  • «слева»

  • правильно

  • «вкл».

  • «выкл».

  • Стоп

  • «перейти»

Загрузите короткий речевой сигнал, когда человек говорит «стоп».

 [x,fs] = audioread('stop_command.flac');

Послушай команду.

 sound(x,fs)

Предварительно обученная сеть принимает слуховые спектрограммы в качестве входных данных. Сначала вы преобразуете речевой сигнал в слуховую спектрограмму.

Используйте функцию extractAuditoryFeature для вычисления слуховой спектрограммы. Подробные сведения об извлечении элемента будут рассмотрены в примере ниже.

auditorySpect = helperExtractAuditoryFeatures(x,fs);

Классифицируйте команду на основе ее слуховой спектрограммы.

command = classify(trainedNet,auditorySpect)
command = 

  categorical

     stop 

Сеть обучена классифицировать слова, не принадлежащие этому набору, как «неизвестные».

Теперь вы классифицируете слово («play»), которое не было включено в список команд для идентификации.

Загрузите речевой сигнал и прослушайте его.

x = audioread('play_command.flac');
sound(x,fs)

Вычислите слуховую спектрограмму.

auditorySpect = helperExtractAuditoryFeatures(x,fs);

Классифицируйте сигнал.

command = classify(trainedNet,auditorySpect)
command = 

  categorical

     unknown 

Сеть обучена классифицировать фоновый шум как «фоновый».

Создайте односекундный сигнал, состоящий из случайного шума.

x = pinknoise(16e3);

Вычислите слуховую спектрограмму.

auditorySpect = helperExtractAuditoryFeatures(x,fs);

Классифицируйте фоновый шум.

command = classify(trainedNet,auditorySpect)
command = 

  categorical

     background 

Команды обнаружения с помощью потокового звука с микрофона

Протестируйте предварительно обученную сеть обнаружения команд на потоковом аудио с микрофона. Попробуйте произнести одну из команд, например, да, нет, или остановитесь. Затем попробуйте сказать одно из неизвестных слов, таких как Марвин, Шейла, кровать, дом, кошка, птица или любое число от нуля до девяти.

Укажите частоту классификации в Гц и создайте устройство чтения аудиоустройств, которое может считывать аудио с микрофона.

classificationRate = 20;
adr = audioDeviceReader('SampleRate',fs,'SamplesPerFrame',floor(fs/classificationRate));

Инициализируйте буфер для звука. Извлеките метки классификации сети. Инициализируйте буферы по полсекунды для меток и вероятностей классификации потокового звука. Используйте эти буферы для сравнения результатов классификации в течение более длительного периода времени и по этому «соглашению» построения при обнаружении команды. Укажите пороговые значения для логики принятия решения.

audioBuffer = dsp.AsyncBuffer(fs);

labels = trainedNet.Layers(end).Classes;
YBuffer(1:classificationRate/2) = categorical("background");

probBuffer = zeros([numel(labels),classificationRate/2]);

countThreshold = ceil(classificationRate*0.2);
probThreshold = 0.7;

Создайте фигуру и определите команды, пока существует созданная фигура. Для бесконечного запуска цикла установите timeLimit кому Inf. Чтобы остановить обнаружение в реальном времени, просто закройте фигуру.

h = figure('Units','normalized','Position',[0.2 0.1 0.6 0.8]);

timeLimit = 20;

tic
while ishandle(h) && toc < timeLimit

    % Extract audio samples from the audio device and add the samples to
    % the buffer.
    x = adr();
    write(audioBuffer,x);
    y = read(audioBuffer,fs,fs-adr.SamplesPerFrame);

    spec = helperExtractAuditoryFeatures(y,fs);

    % Classify the current spectrogram, save the label to the label buffer,
    % and save the predicted probabilities to the probability buffer.
    [YPredicted,probs] = classify(trainedNet,spec,'ExecutionEnvironment','cpu');
    YBuffer = [YBuffer(2:end),YPredicted];
    probBuffer = [probBuffer(:,2:end),probs(:)];

    % Plot the current waveform and spectrogram.
    subplot(2,1,1)
    plot(y)
    axis tight
    ylim([-1,1])

    subplot(2,1,2)
    pcolor(spec')
    caxis([-4 2.6445])
    shading flat

    % Now do the actual command detection by performing a very simple
    % thresholding operation. Declare a detection and display it in the
    % figure title if all of the following hold: 1) The most common label
    % is not background. 2) At least countThreshold of the latest frame
    % labels agree. 3) The maximum probability of the predicted label is at
    % least probThreshold. Otherwise, do not declare a detection.
    [YMode,count] = mode(YBuffer);

    maxProb = max(probBuffer(labels == YMode,:));
    subplot(2,1,1)
    if YMode == "background" || count < countThreshold || maxProb < probThreshold
        title(" ")
    else
        title(string(YMode),'FontSize',20)
    end

    drawnow
end

Загрузить набор данных речевых команд

В этом примере используется набор данных речевых команд Google [1]. Загрузите набор данных и отмените обработку загруженного файла. Задайте в качестве местоположения данных значение StartToDatabase.

url = 'https://ssd.mathworks.com/supportfiles/audio/google_speech.zip';
downloadFolder = tempdir;
dataFolder = fullfile(downloadFolder,'google_speech');

if ~exist(dataFolder,'dir')
    disp('Downloading data set (1.4 GB) ...')
    unzip(url,downloadFolder)
end

Создание хранилища данных обучения

Создание audioDatastore (Audio Toolbox) указывает на набор обучающих данных.

ads = audioDatastore(fullfile(dataFolder, 'train'), ...
    'IncludeSubfolders',true, ...
    'FileExtensions','.wav', ...
    'LabelSource','foldernames')
ads = 

  audioDatastore with properties:

                       Files: {
                              ' ...\AppData\Local\Temp\google_speech\train\bed\00176480_nohash_0.wav';
                              ' ...\AppData\Local\Temp\google_speech\train\bed\004ae714_nohash_0.wav';
                              ' ...\AppData\Local\Temp\google_speech\train\bed\004ae714_nohash_1.wav'
                               ... and 51085 more
                              }
                     Folders: {
                              'C:\Users\jibrahim\AppData\Local\Temp\google_speech\train'
                              }
                      Labels: [bed; bed; bed ... and 51085 more categorical]
    AlternateFileSystemRoots: {}
              OutputDataType: 'double'
      SupportedOutputFormats: ["wav"    "flac"    "ogg"    "mp4"    "m4a"]
         DefaultOutputFormat: "wav"

Выберите распознаваемые слова

Укажите слова, которые модель должна распознавать как команды. Пометить все слова, которые не являются командами unknown. Маркировка слов, не являющихся командами unknown создает группу слов, которая аппроксимирует распределение всех слов, отличных от команд. Сеть использует эту группу для изучения различий между командами и всеми другими словами.

Чтобы уменьшить дисбаланс классов между известными и неизвестными словами и ускорить обработку, включите в обучающий набор только часть неизвестных слов.

Использовать subset (Audio Toolbox) для создания хранилища данных, содержащего только команды и подмножество неизвестных слов. Подсчитайте количество примеров, относящихся к каждой категории.

commands = categorical(["yes","no","up","down","left","right","on","off","stop","go"]);

isCommand = ismember(ads.Labels,commands);
isUnknown = ~isCommand;

includeFraction = 0.2;
mask = rand(numel(ads.Labels),1) < includeFraction;
isUnknown = isUnknown & mask;
ads.Labels(isUnknown) = categorical("unknown");

adsTrain = subset(ads,isCommand|isUnknown);
countEachLabel(adsTrain)
ans =

  11×2 table

     Label     Count
    _______    _____

    down       1842 
    go         1861 
    left       1839 
    no         1853 
    off        1839 
    on         1864 
    right      1852 
    stop       1885 
    unknown    6483 
    up         1843 
    yes        1860 

Создание хранилища данных проверки

Создание audioDatastore (Audio Toolbox), указывающий на набор данных проверки. Выполните те же действия, что и при создании хранилища данных обучения.

ads = audioDatastore(fullfile(dataFolder, 'validation'), ...
    'IncludeSubfolders',true, ...
    'FileExtensions','.wav', ...
    'LabelSource','foldernames')

isCommand = ismember(ads.Labels,commands);
isUnknown = ~isCommand;

includeFraction = 0.2;
mask = rand(numel(ads.Labels),1) < includeFraction;
isUnknown = isUnknown & mask;
ads.Labels(isUnknown) = categorical("unknown");

adsValidation = subset(ads,isCommand|isUnknown);
countEachLabel(adsValidation)
ads = 

  audioDatastore with properties:

                       Files: {
                              ' ...\AppData\Local\Temp\google_speech\validation\bed\026290a7_nohash_0.wav';
                              ' ...\AppData\Local\Temp\google_speech\validation\bed\060cd039_nohash_0.wav';
                              ' ...\AppData\Local\Temp\google_speech\validation\bed\060cd039_nohash_1.wav'
                               ... and 6795 more
                              }
                     Folders: {
                              'C:\Users\jibrahim\AppData\Local\Temp\google_speech\validation'
                              }
                      Labels: [bed; bed; bed ... and 6795 more categorical]
    AlternateFileSystemRoots: {}
              OutputDataType: 'double'
      SupportedOutputFormats: ["wav"    "flac"    "ogg"    "mp4"    "m4a"]
         DefaultOutputFormat: "wav"


ans =

  11×2 table

     Label     Count
    _______    _____

    down        264 
    go          260 
    left        247 
    no          270 
    off         256 
    on          257 
    right       256 
    stop        246 
    unknown     850 
    up          260 
    yes         261 

Чтобы обучить сеть всему набору данных и достичь максимально возможной точности, установите reduceDataset кому false. Для быстрого выполнения этого примера установите reduceDataset кому true.

reduceDataset = false;
if reduceDataset
    numUniqueLabels = numel(unique(adsTrain.Labels));
    % Reduce the dataset by a factor of 20
    adsTrain = splitEachLabel(adsTrain,round(numel(adsTrain.Files) / numUniqueLabels / 20));
    adsValidation = splitEachLabel(adsValidation,round(numel(adsValidation.Files) / numUniqueLabels / 20));
end

Вычислительные слуховые спектрограммы

Для подготовки данных для эффективной тренировки сверточной нейронной сети преобразуют речевые сигналы в слуховые спектрограммы.

Определите параметры извлечения элемента. segmentDuration - длительность каждого речевого клипа (в секундах). frameDuration - длительность каждого кадра для вычисления спектра. hopDuration - временной шаг между каждым спектром. numBands - количество фильтров в слуховой спектрограмме.

Создание audioFeatureExtractor Объект (Audio Toolbox) для извлечения элемента.

fs = 16e3; % Known sample rate of the data set.

segmentDuration = 1;
frameDuration = 0.025;
hopDuration = 0.010;

segmentSamples = round(segmentDuration*fs);
frameSamples = round(frameDuration*fs);
hopSamples = round(hopDuration*fs);
overlapSamples = frameSamples - hopSamples;

FFTLength = 512;
numBands = 50;

afe = audioFeatureExtractor( ...
    'SampleRate',fs, ...
    'FFTLength',FFTLength, ...
    'Window',hann(frameSamples,'periodic'), ...
    'OverlapLength',overlapSamples, ...
    'barkSpectrum',true);
setExtractorParams(afe,'barkSpectrum','NumBands',numBands,'WindowNormalization',false);

Чтение файла из набора данных. Обучение сверточной нейронной сети требует, чтобы входные данные были согласованного размера. Длина некоторых файлов в наборе данных составляет менее 1 секунды. Применение заполнения нулями к передней и задней частям звукового сигнала таким образом, чтобы он имел длину segmentSamples.

x = read(adsTrain);

numSamples = size(x,1);

numToPadFront = floor( (segmentSamples - numSamples)/2 );
numToPadBack = ceil( (segmentSamples - numSamples)/2 );

xPadded = [zeros(numToPadFront,1,'like',x);x;zeros(numToPadBack,1,'like',x)];

Чтобы извлечь звуковые функции, вызовите extract. Выходной сигнал представляет собой спектр Барка со временем между рядами.

features = extract(afe,xPadded);
[numHops,numFeatures] = size(features)
numHops =

    98


numFeatures =

    50

В этом примере после обработки слуховой спектрограммы применяется логарифм. Взятие журнала малых чисел может привести к ошибке округления.

Чтобы ускорить обработку, можно распределить извлечение элементов между несколькими работниками с помощью parfor.

Сначала определите количество разделов для набора данных. Если у вас нет Toolbox™ Parallel Computing, используйте один раздел.

if ~isempty(ver('parallel')) && ~reduceDataset
    pool = gcp;
    numPar = numpartitions(adsTrain,pool);
else
    numPar = 1;
end

Для каждого раздела прочитайте данные из хранилища данных, установите нулевой уровень сигнала, а затем извлеките элементы.

parfor ii = 1:numPar
    subds = partition(adsTrain,numPar,ii);
    XTrain = zeros(numHops,numBands,1,numel(subds.Files));
    for idx = 1:numel(subds.Files)
        x = read(subds);
        xPadded = [zeros(floor((segmentSamples-size(x,1))/2),1);x;zeros(ceil((segmentSamples-size(x,1))/2),1)];
        XTrain(:,:,:,idx) = extract(afe,xPadded);
    end
    XTrainC{ii} = XTrain;
end

Преобразуйте выходные данные в 4-мерный массив со слуховыми спектрограммами в четвертом измерении.

XTrain = cat(4,XTrainC{:});

[numHops,numBands,numChannels,numSpec] = size(XTrain)
numHops =

    98


numBands =

    50


numChannels =

     1


numSpec =

       25021

Масштабируйте функции с помощью питания окна, а затем возьмите журнал. Чтобы получить данные с более плавным распределением, возьмите логарифм спектрограмм, используя небольшое смещение.

epsil = 1e-6;
XTrain = log10(XTrain + epsil);

Выполните описанные выше шаги извлечения элемента для набора проверки.

if ~isempty(ver('parallel'))
    pool = gcp;
    numPar = numpartitions(adsValidation,pool);
else
    numPar = 1;
end
parfor ii = 1:numPar
    subds = partition(adsValidation,numPar,ii);
    XValidation = zeros(numHops,numBands,1,numel(subds.Files));
    for idx = 1:numel(subds.Files)
        x = read(subds);
        xPadded = [zeros(floor((segmentSamples-size(x,1))/2),1);x;zeros(ceil((segmentSamples-size(x,1))/2),1)];
        XValidation(:,:,:,idx) = extract(afe,xPadded);
    end
    XValidationC{ii} = XValidation;
end
XValidation = cat(4,XValidationC{:});
XValidation = log10(XValidation + epsil);

Изолировать метки поезда и проверки. Удалить пустые категории.

YTrain = removecats(adsTrain.Labels);
YValidation = removecats(adsValidation.Labels);

Визуализация данных

Постройте график сигналов и слуховых спектрограмм нескольких тренировочных образцов. Воспроизведение соответствующих аудиоклипов.

specMin = min(XTrain,[],'all');
specMax = max(XTrain,[],'all');
idx = randperm(numel(adsTrain.Files),3);
figure('Units','normalized','Position',[0.2 0.2 0.6 0.6]);
for i = 1:3
    [x,fs] = audioread(adsTrain.Files{idx(i)});
    subplot(2,3,i)
    plot(x)
    axis tight
    title(string(adsTrain.Labels(idx(i))))

    subplot(2,3,i+3)
    spect = (XTrain(:,:,1,idx(i))');
    pcolor(spect)
    caxis([specMin specMax])
    shading flat

    sound(x,fs)
    pause(2)
end

Добавление данных фонового шума

Сеть должна быть способна не только распознавать различные произносимые слова, но и обнаруживать, содержит ли входной сигнал тишину или фоновый шум.

Используйте аудиофайлы в _background_ папку для создания образцов одной секунды клипов фонового шума. Создайте равное количество фоновых клипов из каждого файла фонового шума. Вы также можете создать собственные записи фонового шума и добавить их в _background_ папка. Перед вычислением спектрограмм функция масштабирует каждый аудиоклип с коэффициентом, дискретизированным из логарифмического распределения в диапазоне, заданном volumeRange.

adsBkg = audioDatastore(fullfile(dataFolder, 'background'))
numBkgClips = 4000;
if reduceDataset
    numBkgClips = numBkgClips/20;
end
volumeRange = log10([1e-4,1]);

numBkgFiles = numel(adsBkg.Files);
numClipsPerFile = histcounts(1:numBkgClips,linspace(1,numBkgClips,numBkgFiles+1));
Xbkg = zeros(size(XTrain,1),size(XTrain,2),1,numBkgClips,'single');
bkgAll = readall(adsBkg);
ind = 1;

for count = 1:numBkgFiles
    bkg = bkgAll{count};
    idxStart = randi(numel(bkg)-fs,numClipsPerFile(count),1);
    idxEnd = idxStart+fs-1;
    gain = 10.^((volumeRange(2)-volumeRange(1))*rand(numClipsPerFile(count),1) + volumeRange(1));
    for j = 1:numClipsPerFile(count)

        x = bkg(idxStart(j):idxEnd(j))*gain(j);

        x = max(min(x,1),-1);

        Xbkg(:,:,:,ind) = extract(afe,x);

        if mod(ind,1000)==0
            disp("Processed " + string(ind) + " background clips out of " + string(numBkgClips))
        end
        ind = ind + 1;
    end
end
Xbkg = log10(Xbkg + epsil);
adsBkg = 

  audioDatastore with properties:

                       Files: {
                              ' ...\AppData\Local\Temp\google_speech\background\doing_the_dishes.wav';
                              ' ...\AppData\Local\Temp\google_speech\background\dude_miaowing.wav';
                              ' ...\AppData\Local\Temp\google_speech\background\exercise_bike.wav'
                               ... and 3 more
                              }
                     Folders: {
                              'C:\Users\jibrahim\AppData\Local\Temp\google_speech\background'
                              }
    AlternateFileSystemRoots: {}
              OutputDataType: 'double'
                      Labels: {}
      SupportedOutputFormats: ["wav"    "flac"    "ogg"    "mp4"    "m4a"]
         DefaultOutputFormat: "wav"

Processed 1000 background clips out of 4000
Processed 2000 background clips out of 4000
Processed 3000 background clips out of 4000
Processed 4000 background clips out of 4000

Разделение спектрограмм фонового шума между обучающими, валидационными и тестовыми наборами. Потому что _background_noise_ папка содержит только около пяти с половиной минут фонового шума, фоновые выборки в различных наборах данных сильно коррелированы. Чтобы увеличить изменение фонового шума, можно создать собственные фоновые файлы и добавить их в папку. Чтобы повысить устойчивость сети к шуму, можно также попробовать смешать фоновый шум в речевые файлы.

numTrainBkg = floor(0.85*numBkgClips);
numValidationBkg = floor(0.15*numBkgClips);

XTrain(:,:,:,end+1:end+numTrainBkg) = Xbkg(:,:,:,1:numTrainBkg);
YTrain(end+1:end+numTrainBkg) = "background";

XValidation(:,:,:,end+1:end+numValidationBkg) = Xbkg(:,:,:,numTrainBkg+1:end);
YValidation(end+1:end+numValidationBkg) = "background";

Постройте график распределения различных меток классов в наборах обучения и проверки.

figure('Units','normalized','Position',[0.2 0.2 0.5 0.5])

subplot(2,1,1)
histogram(YTrain)
title("Training Label Distribution")

subplot(2,1,2)
histogram(YValidation)
title("Validation Label Distribution")

Определение архитектуры нейронной сети

Создание простой сетевой архитектуры в виде массива слоев. Используйте уровни сверточной и пакетной нормализации, а затем «пространственно» (то есть во времени и частоте) уменьшите значения карт элементов, используя уровни максимального объединения. Добавьте окончательный уровень максимального пула, который объединяет карту входных функций глобально во времени. Это обеспечивает (приблизительную) инвариантность временной трансляции во входных спектрограммах, позволяя сети выполнять такую же классификацию независимо от точного положения речи во времени. Глобальное объединение также значительно сокращает количество параметров на конечном полностью подключенном уровне. Чтобы уменьшить возможность запоминания сетью специфических особенностей обучающих данных, добавьте небольшое количество отсева на вход последнего полностью подключенного уровня.

Сеть небольшая, так как имеет всего пять сверточных слоев с небольшим количеством фильтров. numF управляет количеством фильтров в сверточных слоях. Чтобы повысить точность сети, попробуйте увеличить ее глубину, добавив идентичные блоки уровней свертки, пакетной нормализации и ReLU. Можно также попытаться увеличить число сверточных фильтров, увеличив numF.

Используйте взвешенную потерю классификации перекрестной энтропии. weightedClassificationLayer(classWeights) создает пользовательский классификационный слой, который вычисляет потери перекрестной энтропии с наблюдениями, взвешенными по classWeights. Укажите веса классов в том же порядке, в котором они отображаются в categories(YTrain). Чтобы придать каждому классу одинаковый общий вес в потере, используйте веса класса, которые обратно пропорциональны количеству тренировочных примеров в каждом классе. При использовании оптимизатора Адама для обучения сети алгоритм обучения не зависит от общей нормализации весов класса.

classWeights = 1./countcats(YTrain);
classWeights = classWeights'/mean(classWeights);
numClasses = numel(categories(YTrain));

timePoolSize = ceil(numHops/8);

dropoutProb = 0.2;
numF = 12;
layers = [
    imageInputLayer([numHops numBands])

    convolution2dLayer(3,numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,2*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(3,'Stride',2,'Padding','same')

    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,4*numF,'Padding','same')
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer([timePoolSize,1])

    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    weightedClassificationLayer(classWeights)];

Железнодорожная сеть

Укажите параметры обучения. Используйте оптимизатор Adam с размером мини-партии 128. Тренируйтесь в течение 25 эпох и сократите уровень обучения в 10 раз после 20 эпох.

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('adam', ...
    'InitialLearnRate',3e-4, ...
    'MaxEpochs',25, ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',20);

Обучение сети. Если у вас нет графического процессора, то обучение сети может занять время.

trainedNet = trainNetwork(XTrain,YTrain,layers,options);

Оценка обученной сети

Рассчитайте окончательную точность сети на обучающем наборе (без увеличения данных) и валидационном наборе. Сеть очень точна в этом наборе данных. Однако все данные обучения, проверки и тестирования имеют схожие распределения, которые не обязательно отражают реальные условия. Это ограничение особенно применимо к unknown категория, которая содержит высказывания лишь небольшого числа слов.

if reduceDataset
    load('commandNet.mat','trainedNet');
end
YValPred = classify(trainedNet,XValidation);
validationError = mean(YValPred ~= YValidation);
YTrainPred = classify(trainedNet,XTrain);
trainError = mean(YTrainPred ~= YTrain);
disp("Training error: " + trainError*100 + "%")
disp("Validation error: " + validationError*100 + "%")
Training error: 1.907%
Validation error: 5.5376%

Постройте график матрицы путаницы. Отображение точности и отзыва для каждого класса с помощью сводок столбцов и строк. Сортировка классов матрицы путаницы. Самая большая путаница между неизвестными словами и командами, вверх и назад, вниз и нет, и идти и нет.

figure('Units','normalized','Position',[0.2 0.2 0.5 0.5]);
cm = confusionchart(YValidation,YValPred);
cm.Title = 'Confusion Matrix for Validation Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
sortClasses(cm, [commands,"unknown","background"])

При работе с приложениями с ограниченными аппаратными ресурсами, такими как мобильные приложения, учитывайте ограничения на доступную память и вычислительные ресурсы. Вычислите общий размер сети в килобайтах и проверьте ее скорость прогнозирования при использовании ЦП. Время прогнозирования - это время для классификации одного входного изображения. При вводе нескольких изображений в сеть их можно классифицировать одновременно, что приводит к более короткому времени прогнозирования для каждого изображения. Однако при классификации потокового звука наиболее актуальным является время предсказания одного изображения.

info = whos('trainedNet');
disp("Network size: " + info.bytes/1024 + " kB")

for i = 1:100
    x = randn([numHops,numBands]);
    tic
    [YPredicted,probs] = classify(trainedNet,x,"ExecutionEnvironment",'cpu');
    time(i) = toc;
end
disp("Single-image prediction time on CPU: " + mean(time(11:end))*1000 + " ms")
Network size: 286.7402 kB
Single-image prediction time on CPU: 2.5119 ms

Ссылки

[1] Уорден П. «Речевые команды: публичный набор данных для однословного распознавания речи», 2017. Доступно в https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Авторское право Google 2017. Набор данных речевых команд лицензирован по лицензии Creative Commons Attribution 4.0, доступна здесь: https://creativecommons.org/licenses/by/4.0/legalcode.

Ссылки

[1] Уорден П. «Речевые команды: публичный набор данных для однословного распознавания речи», 2017. Доступно в http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz. Авторское право Google 2017. Набор данных речевых команд лицензирован по лицензии Creative Commons Attribution 4.0, доступна здесь: https://creativecommons.org/licenses/by/4.0/legalcode.

См. также

| |

Связанные темы