В этом примере показано, как создать и обучить простую нейронную сеть для классификации данных о функции глубокого обучения.
Если у вас есть набор данных числовых функций (например, набор числовых данных без пространственных или измерений времени), то можно обучить нейронную сеть для глубокого обучения с помощью входного слоя функции. Для примера, показывающего, как обучить сеть для классификации изображений, смотрите, Создают Простую сеть глубокого обучения для Классификации.
В этом примере показано, как обучить сеть, чтобы классифицировать заболевание зубов механизма системы передачи, учитывая смесь числовых показаний датчика, статистики и категориальных меток.
Загрузите передачу, заключающую набор данных в корпус для обучения. Набор данных состоит из 208 синтетических показаний системы передачи, состоящей из 18 числовых показаний и трех категориальных меток:
SigMean
— Среднее значение сигнала вибрации
SigMedian
— Медиана сигнала вибрации
SigRMS
— RMS сигнала вибрации
SigVar
— Отклонение сигнала вибрации
SigPeak
— Пик сигнала вибрации
SigPeak2Peak
— Сигнал вибрации достигает максимума, чтобы достигнуть максимума
SigSkewness
— Скошенность сигнала вибрации
SigKurtosis
— Эксцесс сигнала вибрации
SigCrestFactor
— Сигнал вибрации увенчивает фактор
SigMAD
— MAD сигнала вибрации
SigRangeCumSum
— Совокупная сумма диапазона сигнала вибрации
SigCorrDimension
— Размерность корреляции сигнала вибрации
SigApproxEntropy
— Сигнал вибрации аппроксимирует энтропию
SigLyapExponent
— Сигнал вибрации экспонента Lyap
PeakFreq
— Пиковая частота.
HighFreqPower
— Высокочастотная степень
EnvPower
— Степень среды
PeakSpecKurtosis
— Пиковая частота спектрального эксцесса
SensorCondition
— Условие датчика в виде "Дрейфа Датчика" или "Никакого Дрейфа Датчика"
ShaftCondition
— Условие вала в виде "Износа Вала" или "Никакого Износа Вала"
GearToothCondition
— Заболевание зубов механизма в виде "Зубного Отказа" или "Никакого Зубного Отказа"
Считайте данные о преобразовании регистра передачи из файла CSV "transmissionCasingData.csv"
.
filename = "transmissionCasingData.csv"; tbl = readtable(filename,'TextType','String');
Преобразуйте метки для предсказания к категориальному использованию convertvars
функция.
labelName = "GearToothCondition"; tbl = convertvars(tbl,labelName,'categorical');
Просмотрите первые несколько строк таблицы.
head(tbl)
ans=8×21 table
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis SensorCondition ShaftCondition GearToothCondition
________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ _______________ __________________
-0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 "Sensor Drift" "No Shaft Wear" No Tooth Fault
-0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 "Sensor Drift" "No Shaft Wear" No Tooth Fault
1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
Чтобы обучить сеть, использующую категориальные функции, необходимо сначала преобразовать категориальные функции в числовой. Во-первых, преобразуйте категориальные предикторы в категориальное использование convertvars
функция путем определения массива строк, содержащего имена всех категориальных входных переменных. В этом наборе данных существует две категориальных функции с именами "SensorCondition"
и "ShaftCondition"
.
categoricalInputNames = ["SensorCondition" "ShaftCondition"]; tbl = convertvars(tbl,categoricalInputNames,'categorical');
Цикл по категориальным входным переменным. Для каждой переменной:
Преобразуйте категориальные значения в одногорячие закодированные векторы с помощью onehotencode
функция.
Добавьте одногорячие векторы в таблицу с помощью addvars
функция. Задайте, чтобы вставить векторы после столбца, содержащего соответствующие категориальные данные.
Удалите соответствующий столбец, содержащий категориальные данные.
for i = 1:numel(categoricalInputNames) name = categoricalInputNames(i); oh = onehotencode(tbl(:,name)); tbl = addvars(tbl,oh,'After',name); tbl(:,name) = []; end
Разделите векторы в отдельные столбцы с помощью splitvars
функция.
tbl = splitvars(tbl);
Просмотрите первые несколько строк таблицы. Заметьте, что категориальные предикторы были разделены в несколько столбцов с категориальными значениями как имена переменных.
head(tbl)
ans=8×23 table
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis No Sensor Drift Sensor Drift No Shaft Wear Shaft Wear GearToothCondition
________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ ____________ _____________ __________ __________________
-0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 0 1 1 0 No Tooth Fault
-0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 0 1 1 0 No Tooth Fault
1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 0 1 0 1 No Tooth Fault
1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 0 1 0 1 No Tooth Fault
1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 0 1 0 1 No Tooth Fault
1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 0 1 0 1 No Tooth Fault
1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 0 1 0 1 No Tooth Fault
1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 0 1 0 1 No Tooth Fault
Просмотрите имена классов набора данных.
classNames = categories(tbl{:,labelName})
classNames = 2x1 cell
{'No Tooth Fault'}
{'Tooth Fault' }
Разделите набор данных в обучение, валидацию, и протестируйте разделы. Отложите 15% данных для валидации и 15% для тестирования.
Просмотрите количество наблюдений в наборе данных.
numObservations = size(tbl,1)
numObservations = 208
Определите количество наблюдений для каждого раздела.
numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32
Создайте массив случайных индексов, соответствующих наблюдениям, и разделите его с помощью размеров раздела.
idx = randperm(numObservations); idxTrain = idx(1:numObservationsTrain); idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation); idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);
Разделите таблицу данных в обучение, валидацию и разделы тестирования с помощью индексов.
tblTrain = tbl(idxTrain,:); tblValidation = tbl(idxValidation,:); tblTest = tbl(idxTest,:);
Задайте сеть для классификации.
Задайте сеть с входным слоем функции и задайте количество функций. Кроме того, сконфигурируйте входной слой, чтобы нормировать данные с помощью нормализации Z-счета. Затем включайте полносвязный слой с выходным размером 50 сопровождаемых слоем нормализации партии. и слоем ReLU. Для классификации задайте другой полносвязный слой с выходным размером, соответствующим количеству классов, сопровождаемых softmax слоем и слоем классификации.
numFeatures = size(tbl,2) - 1; numClasses = numel(classNames); layers = [ featureInputLayer(numFeatures,'Normalization', 'zscore') fullyConnectedLayer(50) batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Задайте опции обучения.
Обучите сеть с помощью Адама.
Обучите мини-пакеты использования размера 16.
Переставьте данные каждая эпоха.
Контролируйте сетевую точность во время обучения путем определения данных о валидации.
Отобразите прогресс обучения в графике и подавите многословное командное окно выход.
Программное обеспечение обучает сеть на обучающих данных и вычисляет точность на данные о валидации равномерно во время обучения. Данные о валидации не используются, чтобы обновить сетевые веса.
miniBatchSize = 16; options = trainingOptions('adam', ... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'ValidationData',tblValidation, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть с помощью архитектуры, заданной layers
, обучающие данные и опции обучения. По умолчанию, trainNetwork
использует графический процессор, если вы доступны, в противном случае, он использует центральный процессор. Обучение на графическом процессоре требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Можно также задать среду выполнения при помощи 'ExecutionEnvironment'
аргумент пары "имя-значение" trainingOptions
.
График процесса обучения показывает мини-пакетную потерю и точность и потерю валидации и точность. Для получения дополнительной информации о графике процесса обучения смотрите Процесс обучения Глубокого обучения Монитора.
net = trainNetwork(tblTrain,labelName,layers,options);
Предскажите метки тестовых данных с помощью обучившего сеть и вычислите точность. Задайте тот же мини-пакетный размер, используемый для обучения.
YPred = classify(net,tblTest(:,1:end-1),'MiniBatchSize',miniBatchSize);
Вычислите точность классификации. Точность является пропорцией меток, которые сеть предсказывает правильно.
YTest = tblTest{:,labelName}; accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9688
Просмотрите результаты в матрице беспорядка.
figure confusionchart(YTest,YPred)
trainNetwork
| trainingOptions
| fullyConnectedLayer
| Deep Network Designer | featureInputLayer