Обучите сеть с числовыми функциями

В этом примере показано, как создать и обучить простую нейронную сеть для классификации данных о функции глубокого обучения.

Если у вас есть набор данных числовых функций (например, набор числовых данных без пространственных или измерений времени), то можно обучить нейронную сеть для глубокого обучения с помощью входного слоя функции. Для примера, показывающего, как обучить сеть для классификации изображений, смотрите, Создают Простую сеть глубокого обучения для Классификации.

В этом примере показано, как обучить сеть, чтобы классифицировать заболевание зубов механизма системы передачи, учитывая смесь числовых показаний датчика, статистики и категориальных меток.

Загрузка данных

Загрузите передачу, заключающую набор данных в корпус для обучения. Набор данных состоит из 208 синтетических показаний системы передачи, состоящей из 18 числовых показаний и трех категориальных меток:

  1. SigMean — Среднее значение сигнала вибрации

  2. SigMedian — Отклонение сигнала вибрации

  3. SigRMS — RMS сигнала вибрации

  4. SigVar — Отклонение сигнала вибрации

  5. SigPeak — Пик сигнала вибрации

  6. SigPeak2Peak — Сигнал вибрации достигает максимума, чтобы достигнуть максимума

  7. SigSkewness — Скошенность сигнала вибрации

  8. SigKurtosis — Эксцесс сигнала вибрации

  9. SigCrestFactor — Сигнал вибрации увенчивает фактор

  10. SigMAD — MAD сигнала вибрации

  11. SigRangeCumSum — Совокупная сумма диапазона сигнала вибрации

  12. SigCorrDimension — Размерность корреляции сигнала вибрации

  13. SigApproxEntropy — Сигнал вибрации аппроксимирует энтропию

  14. SigLyapExponent — Сигнал вибрации экспонента Lyap

  15. PeakFreq — Пиковая частота.

  16. HighFreqPower — Высокочастотная степень

  17. EnvPower — Степень среды

  18. PeakSpecKurtosis — Пиковая частота спектрального эксцесса

  19. SensorCondition — Условие датчика в виде "Дрейфа Датчика" или "Никакого Дрейфа Датчика"

  20. ShaftCondition — Условие вала в виде "Износа Вала" или "Никакого Износа Вала"

  21. GearToothCondidtion — Заболевание зубов механизма в виде "Зубного Отказа" или "Никакого Зубного Отказа"

Считайте данные о преобразовании регистра передачи из файла CSV "transmissionCasingData.csv".

filename = "transmissionCasingData.csv";
tbl = readtable(filename,'TextType','String');

Преобразуйте метки для предсказания к категориальному использованию convertvars функция.

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,'categorical');

Просмотрите первые несколько строк таблицы.

head(tbl)
ans=8×21 table
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition     GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    _______________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  

Чтобы обучить сеть, использующую категориальные функции, необходимо сначала преобразовать категориальные функции в числовой. Во-первых, преобразуйте категориальные предикторы в категориальное использование convertvars функция путем определения массива строк, содержащего имена всех категориальных входных переменных. В этом наборе данных существует две категориальных функции с именами "SensorCondition" и "ShaftCondition".

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,'categorical');

Цикл по категориальным входным переменным. Для каждой переменной:

  • Преобразуйте категориальные значения в одногорячие закодированные векторы с помощью onehotencode функция.

  • Добавьте одногорячие векторы в таблицу с помощью addvars функция. Задайте, чтобы вставить векторы после столбца, содержащего соответствующие категориальные данные.

  • Удалите соответствующий столбец, содержащий категориальные данные.

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,'After',name);
    tbl(:,name) = [];
end

Разделите векторы в отдельные столбцы с помощью splitvars функция.

tbl = splitvars(tbl);

Просмотрите первые несколько строк таблицы. Заметьте, что категориальные предикторы были разделены в несколько столбцов с категориальными значениями как имена переменных.

head(tbl)
ans=8×23 table
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

Просмотрите имена классов набора данных.

classNames = categories(tbl{:,labelName})
classNames = 2×1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

Набор Разделения данных в наборы обучения и валидации

Разделите набор данных в обучение, валидацию, и протестируйте разделы. Отложите 15% данных для валидации и 15% для тестирования.

Просмотрите количество наблюдений в наборе данных.

numObservations = size(tbl,1)
numObservations = 208

Определите количество наблюдений для каждого раздела.

numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32

Создайте массив случайных индексов, соответствующих наблюдениям, и разделите его с помощью размеров раздела.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation);
idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);

Разделите таблицу данных в обучение, валидацию и разделы тестирования с помощью индексов.

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

Архитектура сети Define

Задайте сеть для классификации.

Задайте сеть с входным слоем функции и задайте количество функций. Кроме того, сконфигурируйте входной слой, чтобы нормировать данные с помощью нормализации Z-счета. Затем включайте полносвязный слой с выходным размером 50 сопровождаемых слоем нормализации партии. и слоем ReLU. Для классификации задайте другой полносвязный слой с выходным размером, соответствующим количеству классов, сопровождаемых softmax слоем и слоем классификации.

numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,'Normalization', 'zscore')
    fullyConnectedLayer(50)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Задайте опции обучения

Задайте опции обучения.

  • Обучите сеть с помощью стохастического градиентного спуска с импульсом (SGDM).

  • Обучите мини-пакеты использования размера 16.

  • Переставьте данные каждая эпоха.

  • Контролируйте сетевую точность во время обучения путем определения данных о валидации.

  • Отобразите прогресс обучения в графике и подавите многословное командное окно выход.

Программное обеспечение обучает сеть на обучающих данных и вычисляет точность на данные о валидации равномерно во время обучения. Данные о валидации не используются, чтобы обновить сетевые веса.

miniBatchSize = 16;

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'ValidationData',tblValidation, ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучение сети

Обучите сеть с помощью архитектуры, заданной layers, обучающие данные и опции обучения. По умолчанию, trainNetwork использует графический процессор, если вы доступны (требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше). В противном случае это использует центральный процессор. Можно также задать среду выполнения при помощи 'ExecutionEnvironment' аргумент пары "имя-значение" trainingOptions.

График процесса обучения показывает мини-пакетную потерю и точность и потерю валидации и точность. Для получения дополнительной информации о графике процесса обучения смотрите Процесс обучения Глубокого обучения Монитора.

net = trainNetwork(tblTrain,labelName,layers,options);

Тестирование сети

Предскажите метки тестовых данных с помощью обучившего сеть и вычислите точность. Задайте тот же мини-пакетный размер, используемый для обучения.

YPred = classify(net,tblTest(:,1:end-1),'MiniBatchSize',miniBatchSize);

Вычислите точность классификации. Точность является пропорцией меток, которые сеть предсказывает правильно.

YTest = tblTest{:,labelName};
accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9688

Просмотрите результаты в матрице беспорядка.

figure
confusionchart(YTest,YPred)

Смотрите также

| | | |

Связанные примеры

Больше о