В этом примере показано, как создать и обучить простую нейронную сеть для классификации данных о функциях глубокого обучения.
Если у вас есть набор данных числовых функций (для примера набора числовых данных без пространственных или временных размерностей), то можно обучить нейронную сеть для глубокого обучения с помощью функции входа слоя. Для примера, показывающего, как обучить сеть для классификации изображений, смотрите Создать Простую Сеть Глубокого Обучения для Классификации.
Этот пример показывает, как обучить сеть классифицировать условие зуба передачи системы передачи, учитывая смесь числовых показаний датчика, статистики и категориальных меток.
Загрузите набор данных кожуха трансмиссии для обучения. Набор данных состоит из 208 синтетических показаний системы передачи, состоящей из 18 числовых показаний и трех категориальных меток:
SigMean
- Среднее значение сигнала вибрации
SigMedian
- отклонение сигнала вибрации
SigRMS
- Сигнал вибрации RMS
SigVar
- отклонение сигнала вибрации
SigPeak
- Пик сигнала вибрации
SigPeak2Peak
- Пик сигнала вибрации в пик
SigSkewness
- Искажение сигнала вибрации
SigKurtosis
- Куртоз сигнала вибрации
SigCrestFactor
- Коэффициент искривления сигнала вибрации
SigMAD
- Сигнал вибрации MAD
SigRangeCumSum
- Совокупная сумма диапазона сигнала вибрации
SigCorrDimension
- размерность корреляции сигнала вибрации
SigApproxEntropy
- Сигнал вибрации аппроксимирует энтропию
SigLyapExponent
- Сигнал вибрации экспоненты Lyap
PeakFreq
- Пиковая частота.
HighFreqPower
- Высокая частотная степень
EnvPower
- степень окружения
PeakSpecKurtosis
- Пиковая частота спектрального куртоза
SensorCondition
- Условие датчика, обозначаемое как «Дрейф датчика» или «Нет дрейфу датчика»
ShaftCondition
- Условие вала, обозначаемое как «Износ вала» или «Отсутствие износа вала»
GearToothCondidtion
- условие зубьев, обозначаемое как «отказ зуба» или «отсутствие отказа зуба»;
Считайте данные обсадной колонны трансмиссии из файла CSV "transmissionCasingData.csv"
.
filename = "transmissionCasingData.csv"; tbl = readtable(filename,'TextType','String');
Преобразуйте метки для предсказания в категориальные с помощью convertvars
функция.
labelName = "GearToothCondition"; tbl = convertvars(tbl,labelName,'categorical');
Просмотрите первые несколько строк таблицы.
head(tbl)
ans=8×21 table
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis SensorCondition ShaftCondition GearToothCondition
________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ _______________ __________________
-0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 "Sensor Drift" "No Shaft Wear" No Tooth Fault
-0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 "Sensor Drift" "No Shaft Wear" No Tooth Fault
1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
Чтобы обучить сеть с помощью категориальных признаков, необходимо сначала преобразовать категориальные функции в числа. Во-первых, преобразуйте категориальные предикторы в категориальные с помощью convertvars
функция путем определения строковых массивов, содержащего имена всех категориальных входных переменных. В этом наборе данных существуют две категориальные функции с именами "SensorCondition"
и "ShaftCondition"
.
categoricalInputNames = ["SensorCondition" "ShaftCondition"]; tbl = convertvars(tbl,categoricalInputNames,'categorical');
Цикл по категориальным входным переменным. Для каждой переменной:
Преобразуйте категориальные значения в векторы с одним горячим кодированием с помощью onehotencode
функция.
Добавьте векторы с одним контактом к таблице с помощью addvars
функция. Задайте, чтобы вставить векторы после столбца, содержащего соответствующие категориальные данные.
Удалите соответствующий столбец, содержащий категориальные данные.
for i = 1:numel(categoricalInputNames) name = categoricalInputNames(i); oh = onehotencode(tbl(:,name)); tbl = addvars(tbl,oh,'After',name); tbl(:,name) = []; end
Разделите векторы на отдельные столбцы с помощью splitvars
функция.
tbl = splitvars(tbl);
Просмотрите первые несколько строк таблицы. Заметьте, что категориальные предикторы были разделены на несколько столбцов с категориальными значениями в качестве имен переменных.
head(tbl)
ans=8×23 table
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis No Sensor Drift Sensor Drift No Shaft Wear Shaft Wear GearToothCondition
________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ ____________ _____________ __________ __________________
-0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 0 1 1 0 No Tooth Fault
-0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 0 1 1 0 No Tooth Fault
1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 0 1 0 1 No Tooth Fault
1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 0 1 0 1 No Tooth Fault
1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 0 1 0 1 No Tooth Fault
1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 0 1 0 1 No Tooth Fault
1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 0 1 0 1 No Tooth Fault
1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 0 1 0 1 No Tooth Fault
Просмотрите имена классов набора данных.
classNames = categories(tbl{:,labelName})
classNames = 2×1 cell
{'No Tooth Fault'}
{'Tooth Fault' }
Разделите набор данных на разделы для обучения, валидации и тестирования. Выделите 15% данных для валидации и 15% для проверки.
Просмотрите количество наблюдений в наборе данных.
numObservations = size(tbl,1)
numObservations = 208
Определите количество наблюдений для каждого раздела.
numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32
Создайте массив случайных индексов, соответствующих наблюдениям, и разбейте его с помощью размеров разбиений.
idx = randperm(numObservations); idxTrain = idx(1:numObservationsTrain); idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation); idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);
Разделите таблицу данных на разделы для обучения, валидации и проверки с помощью индексов.
tblTrain = tbl(idxTrain,:); tblValidation = tbl(idxValidation,:); tblTest = tbl(idxTest,:);
Определите сеть для классификации.
Задайте сеть с входным слоем функций и укажите количество функций. Кроме того, сконфигурируйте слой входа, чтобы нормализовать данные с помощью нормализации Z-балла. Затем включает в себя полностью соединенный слой с выходным размером 50, за которым следует слой нормализации партии . и слой ReLU. Для классификации задайте другой полносвязный слой с выходным размером, соответствующим количеству классов, затем слой softmax и слой классификации.
numFeatures = size(tbl,2) - 1; numClasses = numel(classNames); layers = [ featureInputLayer(numFeatures,'Normalization', 'zscore') fullyConnectedLayer(50) batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Задайте опции обучения.
Обучите сеть с помощью стохастического градиентного спуска с импульсом (SGDM).
Обучите с использованием мини-партий размера 16.
Перетасовывайте данные каждую эпоху.
Отслеживайте точность сети во время обучения путем определения данных валидации.
Отображение процесса обучения на графике и подавление выхода подробного командного окна.
Программное обеспечение обучает сеть по обучающим данным и вычисляет точность по данным валидации через регулярные интервалы во время обучения. Данные валидации не используются для обновления весов сети.
miniBatchSize = 16; options = trainingOptions('adam', ... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'ValidationData',tblValidation, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть с помощью архитектуры, заданной layers
, обучающих данных и опций обучения. По умолчанию trainNetwork
использует графический процессор, если он доступен, в противном случае используется центральный процессор. Для обучения на графическом процессоре требуется Parallel Computing Toolbox™ и поддерживаемое устройство GPU. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Можно также задать окружение выполнения с помощью 'ExecutionEnvironment'
Аргумент пары "имя-значение" из trainingOptions
.
График процесса обучения показывает мини-потери и точность пакета, а также потери и точность валидации. Для получения дополнительной информации о графике процесса обучения смотрите Монитор Глубокого обучения Процесса обучения.
net = trainNetwork(tblTrain,labelName,layers,options);
Спрогнозируйте метки тестовых данных с помощью обученной сети и вычислите точность. Укажите тот же размер мини-пакета, что и для обучения.
YPred = classify(net,tblTest(:,1:end-1),'MiniBatchSize',miniBatchSize);
Вычислите точность классификации. Точность является долей меток, которые сеть предсказывает правильно.
YTest = tblTest{:,labelName}; accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9688
Просмотрите результаты в матрице неточностей.
figure confusionchart(YTest,YPred)
Deep Network Designer | featureInputLayer
| fullyConnectedLayer
| trainingOptions
| trainNetwork