exponenta event banner

Сеть поездов с числовыми характеристиками

В этом примере показано, как создать и обучить простую нейронную сеть для классификации данных признаков глубокого обучения.

Если имеется набор данных числовых элементов (например, набор числовых данных без пространственных или временных измерений), можно обучить сеть глубокого обучения с помощью уровня ввода элементов. Пример обучения сети классификации изображений см. в разделе Создание простой сети глубокого обучения для классификации.

Этот пример показывает, как обучить сеть классифицировать заболевание зубов механизма системы передачи, учитывая смесь числовых чтений датчика, статистики и категорических этикеток.

Загрузить данные

Загрузите набор данных корпуса трансмиссии для обучения. Набор данных состоит из 208 синтетических показаний системы передачи, состоящей из 18 цифровых показаний и трех категориальных меток:

  1. SigMean - Среднее значение сигнала вибрации

  2. SigMedian - Отклонение сигнала вибрации

  3. SigRMS - Вибросигнал СРК

  4. SigVar - Отклонение сигнала вибрации

  5. SigPeak - Пик сигнала вибрации

  6. SigPeak2Peak - Пик сигнала вибрации до пика

  7. SigSkewness - Перекос сигнала вибрации

  8. SigKurtosis - Вибросигнал куртоз

  9. SigCrestFactor - Коэффициент гребня сигнала вибрации

  10. SigMAD - Сигнал вибрации MAD

  11. SigRangeCumSum - Суммарная сумма диапазона вибрационных сигналов

  12. SigCorrDimension - Размер корреляции сигнала вибрации

  13. SigApproxEntropy - Приблизительная энтропия сигнала вибрации

  14. SigLyapExponent - Показатель Lyap сигнала вибрации

  15. PeakFreq - Пиковая частота.

  16. HighFreqPower - Высокочастотная мощность

  17. EnvPower - Мощность окружающей среды

  18. PeakSpecKurtosis - Пиковая частота спектрального куртоза

  19. SensorCondition - Состояние датчика, указанное как «Sensor Drift» или «No Sensor Drift»

  20. ShaftCondition - Состояние вала, указанное как «Износ вала» или «Без износа вала»

  21. GearToothCondidtion - состояние зубьев зубьев, указанное как «Неисправность зубьев» или «Отсутствие неисправности зубьев»;

Считывание данных корпуса коробки передач из CSV-файла "transmissionCasingData.csv".

filename = "transmissionCasingData.csv";
tbl = readtable(filename,'TextType','String');

Преобразовать метки для прогнозирования в категориальные с помощью convertvars функция.

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,'categorical');

Просмотрите первые несколько строк таблицы.

head(tbl)
ans=8×21 table
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition     GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    _______________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  

Чтобы обучить сеть с помощью категориальных элементов, необходимо сначала преобразовать категориальные элементы в числовые. Во-первых, преобразуйте категориальные предикторы в категориальные, используя convertvars путем указания строкового массива, содержащего имена всех категориальных входных переменных. В этом наборе данных имеются две категориальные функции с именами "SensorCondition" и "ShaftCondition".

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,'categorical');

Закольцовывать категориальные входные переменные. Для каждой переменной:

  • Преобразование категориальных значений в одноступенчатые кодированные векторы с помощью onehotencode функция.

  • Добавьте векторы с одним горячим сигналом в таблицу с помощью addvars функция. Укажите, следует ли вставлять векторы после столбца, содержащего соответствующие категориальные данные.

  • Удалите соответствующий столбец, содержащий категориальные данные.

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,'After',name);
    tbl(:,name) = [];
end

Разбейте векторы на отдельные столбцы с помощью splitvars функция.

tbl = splitvars(tbl);

Просмотрите первые несколько строк таблицы. Обратите внимание, что категориальные предикторы были разделены на несколько столбцов с категориальными значениями в качестве имен переменных.

head(tbl)
ans=8×23 table
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

Просмотр имен классов набора данных.

classNames = categories(tbl{:,labelName})
classNames = 2×1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

Разделение набора данных на наборы обучения и проверки

Разбиение набора данных на разделы обучения, проверки и тестирования. Отложите 15% данных для проверки и 15% для тестирования.

Просмотр количества наблюдений в наборе данных.

numObservations = size(tbl,1)
numObservations = 208

Определите количество наблюдений для каждого раздела.

numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32

Создайте массив случайных индексов, соответствующих наблюдениям, и разделите его с использованием размеров разделов.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation);
idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);

Разбиение таблицы данных на разделы обучения, проверки и тестирования с использованием индексов.

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

Определение сетевой архитектуры

Определите сетевой график для классификации.

Определите сеть с уровнем ввода элемента и укажите количество элементов. Также настройте входной уровень для нормализации данных с помощью Z-score нормализации. Далее, включают в себя полностью соединенный уровень с размером выхода 50, за которым следуют уровень нормализации партии и уровень ReLU. Для классификации укажите другой полностью подключенный уровень с размером вывода, соответствующим количеству классов, за которым следуют уровень softmax и уровень классификации.

numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,'Normalization', 'zscore')
    fullyConnectedLayer(50)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Укажите параметры обучения

Укажите параметры обучения.

  • Обучение сети с помощью стохастического градиентного спуска с импульсом (SGDM).

  • Поезд с использованием мини-партий размера 16.

  • Тасуйте данные каждую эпоху.

  • Контроль точности сети во время обучения путем указания данных проверки.

  • Отображение хода обучения на графике и подавление вывода подробного окна команд.

Программное обеспечение обучает сеть данным обучения и вычисляет точность данных проверки через регулярные интервалы во время обучения. Данные проверки не используются для обновления весов сети.

miniBatchSize = 16;

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'ValidationData',tblValidation, ...
    'Plots','training-progress', ...
    'Verbose',false);

Железнодорожная сеть

Обучение сети с использованием архитектуры, определенной layers, данные обучения и варианты обучения. По умолчанию trainNetwork использует графический процессор, если он доступен, в противном случае использует центральный процессор. Для обучения графическому процессору требуются параллельные вычислительные Toolbox™ и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox). Можно также указать среду выполнения с помощью 'ExecutionEnvironment' аргумент пары имя-значение trainingOptions.

График хода обучения показывает потери и точность мини-партии, а также потери и точность проверки. Дополнительные сведения о графике хода обучения см. в разделе Мониторинг хода обучения по углубленному обучению.

net = trainNetwork(tblTrain,labelName,layers,options);

Тестовая сеть

Спрогнозировать метки тестовых данных с использованием обученной сети и рассчитать точность. Укажите размер мини-партии, используемый для обучения.

YPred = classify(net,tblTest(:,1:end-1),'MiniBatchSize',miniBatchSize);

Вычислите точность классификации. Точность - это доля меток, которую сеть предсказывает правильно.

YTest = tblTest{:,labelName};
accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9688

Просмотрите результаты в матрице путаницы.

figure
confusionchart(YTest,YPred)

См. также

| | | |

Связанные примеры

Подробнее