В этом примере показано, как создать и обучить простую нейронную сеть для классификации данных признаков глубокого обучения.
Если имеется набор данных числовых элементов (например, набор числовых данных без пространственных или временных измерений), можно обучить сеть глубокого обучения с помощью уровня ввода элементов. Пример обучения сети классификации изображений см. в разделе Создание простой сети глубокого обучения для классификации.
Этот пример показывает, как обучить сеть классифицировать заболевание зубов механизма системы передачи, учитывая смесь числовых чтений датчика, статистики и категорических этикеток.
Загрузите набор данных корпуса трансмиссии для обучения. Набор данных состоит из 208 синтетических показаний системы передачи, состоящей из 18 цифровых показаний и трех категориальных меток:
SigMean - Среднее значение сигнала вибрации
SigMedian - Отклонение сигнала вибрации
SigRMS - Вибросигнал СРК
SigVar - Отклонение сигнала вибрации
SigPeak - Пик сигнала вибрации
SigPeak2Peak - Пик сигнала вибрации до пика
SigSkewness - Перекос сигнала вибрации
SigKurtosis - Вибросигнал куртоз
SigCrestFactor - Коэффициент гребня сигнала вибрации
SigMAD - Сигнал вибрации MAD
SigRangeCumSum - Суммарная сумма диапазона вибрационных сигналов
SigCorrDimension - Размер корреляции сигнала вибрации
SigApproxEntropy - Приблизительная энтропия сигнала вибрации
SigLyapExponent - Показатель Lyap сигнала вибрации
PeakFreq - Пиковая частота.
HighFreqPower - Высокочастотная мощность
EnvPower - Мощность окружающей среды
PeakSpecKurtosis - Пиковая частота спектрального куртоза
SensorCondition - Состояние датчика, указанное как «Sensor Drift» или «No Sensor Drift»
ShaftCondition - Состояние вала, указанное как «Износ вала» или «Без износа вала»
GearToothCondidtion - состояние зубьев зубьев, указанное как «Неисправность зубьев» или «Отсутствие неисправности зубьев»;
Считывание данных корпуса коробки передач из CSV-файла "transmissionCasingData.csv".
filename = "transmissionCasingData.csv"; tbl = readtable(filename,'TextType','String');
Преобразовать метки для прогнозирования в категориальные с помощью convertvars функция.
labelName = "GearToothCondition"; tbl = convertvars(tbl,labelName,'categorical');
Просмотрите первые несколько строк таблицы.
head(tbl)
ans=8×21 table
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis SensorCondition ShaftCondition GearToothCondition
________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ _______________ __________________
-0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 "Sensor Drift" "No Shaft Wear" No Tooth Fault
-0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 "Sensor Drift" "No Shaft Wear" No Tooth Fault
1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
Чтобы обучить сеть с помощью категориальных элементов, необходимо сначала преобразовать категориальные элементы в числовые. Во-первых, преобразуйте категориальные предикторы в категориальные, используя convertvars путем указания строкового массива, содержащего имена всех категориальных входных переменных. В этом наборе данных имеются две категориальные функции с именами "SensorCondition" и "ShaftCondition".
categoricalInputNames = ["SensorCondition" "ShaftCondition"]; tbl = convertvars(tbl,categoricalInputNames,'categorical');
Закольцовывать категориальные входные переменные. Для каждой переменной:
Преобразование категориальных значений в одноступенчатые кодированные векторы с помощью onehotencode функция.
Добавьте векторы с одним горячим сигналом в таблицу с помощью addvars функция. Укажите, следует ли вставлять векторы после столбца, содержащего соответствующие категориальные данные.
Удалите соответствующий столбец, содержащий категориальные данные.
for i = 1:numel(categoricalInputNames) name = categoricalInputNames(i); oh = onehotencode(tbl(:,name)); tbl = addvars(tbl,oh,'After',name); tbl(:,name) = []; end
Разбейте векторы на отдельные столбцы с помощью splitvars функция.
tbl = splitvars(tbl);
Просмотрите первые несколько строк таблицы. Обратите внимание, что категориальные предикторы были разделены на несколько столбцов с категориальными значениями в качестве имен переменных.
head(tbl)
ans=8×23 table
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis No Sensor Drift Sensor Drift No Shaft Wear Shaft Wear GearToothCondition
________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ ____________ _____________ __________ __________________
-0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 0 1 1 0 No Tooth Fault
-0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 0 1 1 0 No Tooth Fault
1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 0 1 0 1 No Tooth Fault
1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 0 1 0 1 No Tooth Fault
1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 0 1 0 1 No Tooth Fault
1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 0 1 0 1 No Tooth Fault
1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 0 1 0 1 No Tooth Fault
1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 0 1 0 1 No Tooth Fault
Просмотр имен классов набора данных.
classNames = categories(tbl{:,labelName})classNames = 2×1 cell
{'No Tooth Fault'}
{'Tooth Fault' }
Разбиение набора данных на разделы обучения, проверки и тестирования. Отложите 15% данных для проверки и 15% для тестирования.
Просмотр количества наблюдений в наборе данных.
numObservations = size(tbl,1)
numObservations = 208
Определите количество наблюдений для каждого раздела.
numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32
Создайте массив случайных индексов, соответствующих наблюдениям, и разделите его с использованием размеров разделов.
idx = randperm(numObservations); idxTrain = idx(1:numObservationsTrain); idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation); idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);
Разбиение таблицы данных на разделы обучения, проверки и тестирования с использованием индексов.
tblTrain = tbl(idxTrain,:); tblValidation = tbl(idxValidation,:); tblTest = tbl(idxTest,:);
Определите сетевой график для классификации.
Определите сеть с уровнем ввода элемента и укажите количество элементов. Также настройте входной уровень для нормализации данных с помощью Z-score нормализации. Далее, включают в себя полностью соединенный уровень с размером выхода 50, за которым следуют уровень нормализации партии и уровень ReLU. Для классификации укажите другой полностью подключенный уровень с размером вывода, соответствующим количеству классов, за которым следуют уровень softmax и уровень классификации.
numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
layers = [
featureInputLayer(numFeatures,'Normalization', 'zscore')
fullyConnectedLayer(50)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];Укажите параметры обучения.
Обучение сети с помощью стохастического градиентного спуска с импульсом (SGDM).
Поезд с использованием мини-партий размера 16.
Тасуйте данные каждую эпоху.
Контроль точности сети во время обучения путем указания данных проверки.
Отображение хода обучения на графике и подавление вывода подробного окна команд.
Программное обеспечение обучает сеть данным обучения и вычисляет точность данных проверки через регулярные интервалы во время обучения. Данные проверки не используются для обновления весов сети.
miniBatchSize = 16; options = trainingOptions('adam', ... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'ValidationData',tblValidation, ... 'Plots','training-progress', ... 'Verbose',false);
Обучение сети с использованием архитектуры, определенной layers, данные обучения и варианты обучения. По умолчанию trainNetwork использует графический процессор, если он доступен, в противном случае использует центральный процессор. Для обучения графическому процессору требуются параллельные вычислительные Toolbox™ и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox). Можно также указать среду выполнения с помощью 'ExecutionEnvironment' аргумент пары имя-значение trainingOptions.
График хода обучения показывает потери и точность мини-партии, а также потери и точность проверки. Дополнительные сведения о графике хода обучения см. в разделе Мониторинг хода обучения по углубленному обучению.
net = trainNetwork(tblTrain,labelName,layers,options);

Спрогнозировать метки тестовых данных с использованием обученной сети и рассчитать точность. Укажите размер мини-партии, используемый для обучения.
YPred = classify(net,tblTest(:,1:end-1),'MiniBatchSize',miniBatchSize);Вычислите точность классификации. Точность - это доля меток, которую сеть предсказывает правильно.
YTest = tblTest{:,labelName};
accuracy = sum(YPred == YTest)/numel(YTest)accuracy = 0.9688
Просмотрите результаты в матрице путаницы.
figure confusionchart(YTest,YPred)

Конструктор глубоких сетей | featureInputLayer | fullyConnectedLayer | trainingOptions | trainNetwork