Обучите сеть с числовыми функциями

В этом примере показано, как создать и обучить простую нейронную сеть для классификации данных о функциях глубокого обучения.

Если у вас есть набор данных числовых функций (для примера набора числовых данных без пространственных или временных размерностей), то можно обучить нейронную сеть для глубокого обучения с помощью функции входа слоя. Для примера, показывающего, как обучить сеть для классификации изображений, смотрите Создать Простую Сеть Глубокого Обучения для Классификации.

Этот пример показывает, как обучить сеть классифицировать условие зуба передачи системы передачи, учитывая смесь числовых показаний датчика, статистики и категориальных меток.

Загрузка данных

Загрузите набор данных кожуха трансмиссии для обучения. Набор данных состоит из 208 синтетических показаний системы передачи, состоящей из 18 числовых показаний и трех категориальных меток:

  1. SigMean - Среднее значение сигнала вибрации

  2. SigMedian - отклонение сигнала вибрации

  3. SigRMS - Сигнал вибрации RMS

  4. SigVar - отклонение сигнала вибрации

  5. SigPeak - Пик сигнала вибрации

  6. SigPeak2Peak - Пик сигнала вибрации в пик

  7. SigSkewness - Искажение сигнала вибрации

  8. SigKurtosis - Куртоз сигнала вибрации

  9. SigCrestFactor - Коэффициент искривления сигнала вибрации

  10. SigMAD - Сигнал вибрации MAD

  11. SigRangeCumSum - Совокупная сумма диапазона сигнала вибрации

  12. SigCorrDimension - размерность корреляции сигнала вибрации

  13. SigApproxEntropy - Сигнал вибрации аппроксимирует энтропию

  14. SigLyapExponent - Сигнал вибрации экспоненты Lyap

  15. PeakFreq - Пиковая частота.

  16. HighFreqPower - Высокая частотная степень

  17. EnvPower - степень окружения

  18. PeakSpecKurtosis - Пиковая частота спектрального куртоза

  19. SensorCondition - Условие датчика, обозначаемое как «Дрейф датчика» или «Нет дрейфу датчика»

  20. ShaftCondition - Условие вала, обозначаемое как «Износ вала» или «Отсутствие износа вала»

  21. GearToothCondidtion - условие зубьев, обозначаемое как «отказ зуба» или «отсутствие отказа зуба»;

Считайте данные обсадной колонны трансмиссии из файла CSV "transmissionCasingData.csv".

filename = "transmissionCasingData.csv";
tbl = readtable(filename,'TextType','String');

Преобразуйте метки для предсказания в категориальные с помощью convertvars функция.

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,'categorical');

Просмотрите первые несколько строк таблицы.

head(tbl)
ans=8×21 table
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition     GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    _______________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  

Чтобы обучить сеть с помощью категориальных признаков, необходимо сначала преобразовать категориальные функции в числа. Во-первых, преобразуйте категориальные предикторы в категориальные с помощью convertvars функция путем определения строковых массивов, содержащего имена всех категориальных входных переменных. В этом наборе данных существуют две категориальные функции с именами "SensorCondition" и "ShaftCondition".

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,'categorical');

Цикл по категориальным входным переменным. Для каждой переменной:

  • Преобразуйте категориальные значения в векторы с одним горячим кодированием с помощью onehotencode функция.

  • Добавьте векторы с одним контактом к таблице с помощью addvars функция. Задайте, чтобы вставить векторы после столбца, содержащего соответствующие категориальные данные.

  • Удалите соответствующий столбец, содержащий категориальные данные.

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,'After',name);
    tbl(:,name) = [];
end

Разделите векторы на отдельные столбцы с помощью splitvars функция.

tbl = splitvars(tbl);

Просмотрите первые несколько строк таблицы. Заметьте, что категориальные предикторы были разделены на несколько столбцов с категориальными значениями в качестве имен переменных.

head(tbl)
ans=8×23 table
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

Просмотрите имена классов набора данных.

classNames = categories(tbl{:,labelName})
classNames = 2×1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

Разделение данных на наборы обучения и валидации

Разделите набор данных на разделы для обучения, валидации и тестирования. Выделите 15% данных для валидации и 15% для проверки.

Просмотрите количество наблюдений в наборе данных.

numObservations = size(tbl,1)
numObservations = 208

Определите количество наблюдений для каждого раздела.

numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32

Создайте массив случайных индексов, соответствующих наблюдениям, и разбейте его с помощью размеров разбиений.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation);
idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);

Разделите таблицу данных на разделы для обучения, валидации и проверки с помощью индексов.

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

Определение сетевой архитектуры

Определите сеть для классификации.

Задайте сеть с входным слоем функций и укажите количество функций. Кроме того, сконфигурируйте слой входа, чтобы нормализовать данные с помощью нормализации Z-балла. Затем включает в себя полностью соединенный слой с выходным размером 50, за которым следует слой нормализации партии . и слой ReLU. Для классификации задайте другой полносвязный слой с выходным размером, соответствующим количеству классов, затем слой softmax и слой классификации.

numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,'Normalization', 'zscore')
    fullyConnectedLayer(50)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Настройка опций обучения

Задайте опции обучения.

  • Обучите сеть с помощью стохастического градиентного спуска с импульсом (SGDM).

  • Обучите с использованием мини-партий размера 16.

  • Перетасовывайте данные каждую эпоху.

  • Отслеживайте точность сети во время обучения путем определения данных валидации.

  • Отображение процесса обучения на графике и подавление выхода подробного командного окна.

Программное обеспечение обучает сеть по обучающим данным и вычисляет точность по данным валидации через регулярные интервалы во время обучения. Данные валидации не используются для обновления весов сети.

miniBatchSize = 16;

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'ValidationData',tblValidation, ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучите сеть

Обучите сеть с помощью архитектуры, заданной layers, обучающих данных и опций обучения. По умолчанию trainNetwork использует графический процессор, если он доступен, в противном случае используется центральный процессор. Для обучения на графическом процессоре требуется Parallel Computing Toolbox™ и поддерживаемое устройство GPU. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Можно также задать окружение выполнения с помощью 'ExecutionEnvironment' Аргумент пары "имя-значение" из trainingOptions.

График процесса обучения показывает мини-потери и точность пакета, а также потери и точность валидации. Для получения дополнительной информации о графике процесса обучения смотрите Монитор Глубокого обучения Процесса обучения.

net = trainNetwork(tblTrain,labelName,layers,options);

Тестирование сети

Спрогнозируйте метки тестовых данных с помощью обученной сети и вычислите точность. Укажите тот же размер мини-пакета, что и для обучения.

YPred = classify(net,tblTest(:,1:end-1),'MiniBatchSize',miniBatchSize);

Вычислите точность классификации. Точность является долей меток, которые сеть предсказывает правильно.

YTest = tblTest{:,labelName};
accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9688

Просмотрите результаты в матрице неточностей.

figure
confusionchart(YTest,YPred)

См. также

| | | |

Похожие примеры

Подробнее о