Создание прямого классификатора нейронных сетей с полностью соединенными слоями с помощью fitcnet. Используйте данные проверки для ранней остановки процесса обучения, чтобы предотвратить переоснащение модели. Затем используйте объектные функции классификатора для оценки производительности модели на тестовых данных.
В этом примере используются данные переписи 1994 года, хранящиеся в census1994.mat. Набор данных состоит из демографической информации Бюро переписи населения США, которую можно использовать для прогнозирования того, зарабатывает ли человек более 50 000 долларов в год.
Загрузка данных образца census1994, который содержит данные обучения adultdata и данные испытаний adulttest. Просмотрите первые несколько строк набора данных обучения.
load census1994
head(adultdata)ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______
39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K
49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K
52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец, salaryпоказывает, имеет ли человек зарплату меньше или равную 50 000 долларов в год или больше 50 000 долларов в год.
Объединить education_num и education переменные в данных обучения и тестирования для создания одной упорядоченной категориальной переменной, которая показывает наивысший уровень образования, которого достиг человек.
edOrder = unique(adultdata.education_num,"stable"); edCats = unique(adultdata.education,"stable"); [~,edIdx] = sort(edOrder); adultdata.education = categorical(adultdata.education, ... edCats(edIdx),"Ordinal",true); adultdata.education_num = []; adulttest.education = categorical(adulttest.education, ... edCats(edIdx),"Ordinal",true); adulttest.education_num = [];
Разделите учебные данные с помощью стратифицированного раздела удержания. Создайте отдельный набор данных проверки, чтобы остановить процесс обучения модели на ранней стадии. Зарезервируйте приблизительно 30% наблюдений для набора данных проверки и используйте остальные наблюдения для обучения классификатора нейронной сети.
rng("default") % For reproducibility of the partition c = cvpartition(adultdata.salary,"Holdout",0.30); trainingIndices = training(c); validationIndices = test(c); tblTrain = adultdata(trainingIndices,:); tblValidation = adultdata(validationIndices,:);
Обучение классификатора нейронной сети с помощью обучающего набора. Укажите salary столбец tblTrain в качестве ответа и fnlwgt столбец в качестве весов наблюдения и стандартизировать числовые предикторы. Оцените модель в каждой итерации с помощью набора проверки. Укажите отображение информации об обучении в каждой итерации с помощью Verbose аргумент «имя-значение». По умолчанию тренировочный процесс заканчивается раньше, если потери при перекрестной энтропии проверки больше или равны минимальным потерям при перекрестной энтропии проверки, вычисленным на данный момент, шесть раз подряд. Чтобы изменить количество раз, когда допустимо, чтобы потеря проверки была больше или равна минимуму, укажите ValidationPatience аргумент «имя-значение».
Mdl = fitcnet(tblTrain,"salary","Weights","fnlwgt", ... "Standardize",true,"ValidationData",tblValidation, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 0.297812| 0.078920| 0.703981| 0.012127| 0.296816| 0| | 2| 0.281110| 0.054594| 0.149850| 0.012486| 0.280132| 0| | 3| 0.252648| 0.062041| 1.004181| 0.011339| 0.247863| 0| | 4| 0.211868| 0.023567| 0.267214| 0.011319| 0.208988| 0| | 5| 0.207039| 0.057528| 0.320942| 0.010288| 0.206781| 0| | 6| 0.196838| 0.022492| 0.089842| 0.011197| 0.195583| 0| | 7| 0.186133| 0.025551| 0.295975| 0.010426| 0.184900| 0| | 8| 0.178779| 0.023714| 0.244525| 0.010237| 0.179370| 0| | 9| 0.174531| 0.027149| 0.306182| 0.012911| 0.178175| 0| | 10| 0.173217| 0.013365| 0.037475| 0.011084| 0.176371| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 0.168160| 0.016506| 0.307350| 0.010016| 0.170415| 0| | 12| 0.164460| 0.025136| 0.473227| 0.011512| 0.165902| 0| | 13| 0.162895| 0.014983| 0.473367| 0.010969| 0.164582| 0| | 14| 0.160791| 0.005187| 0.113760| 0.011720| 0.162947| 0| | 15| 0.159742| 0.004035| 0.138748| 0.010260| 0.162074| 0| | 16| 0.159290| 0.005774| 0.108266| 0.010400| 0.161728| 0| | 17| 0.158593| 0.004977| 0.152142| 0.010603| 0.161272| 0| | 18| 0.157437| 0.003660| 0.193303| 0.010510| 0.160299| 0| | 19| 0.156642| 0.007722| 0.430859| 0.010069| 0.160145| 0| | 20| 0.155954| 0.001908| 0.121039| 0.010041| 0.159066| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 0.155824| 0.001645| 0.025159| 0.010557| 0.158992| 0| | 22| 0.155486| 0.003232| 0.119915| 0.010829| 0.158731| 0| | 23| 0.155398| 0.006845| 0.083105| 0.031846| 0.158963| 1| | 24| 0.155261| 0.004374| 0.065660| 0.010762| 0.158816| 2| | 25| 0.154955| 0.002505| 0.264106| 0.011437| 0.158687| 0| | 26| 0.154799| 0.002183| 0.040876| 0.010903| 0.158538| 0| | 27| 0.154466| 0.002881| 0.219478| 0.012409| 0.158033| 0| | 28| 0.154250| 0.002724| 0.196190| 0.012062| 0.157980| 0| | 29| 0.153918| 0.002189| 0.135392| 0.009862| 0.157605| 0| | 30| 0.153707| 0.001449| 0.111574| 0.010851| 0.157456| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 31| 0.153214| 0.002050| 0.528628| 0.010212| 0.157379| 0| | 32| 0.152671| 0.002542| 0.488640| 0.010013| 0.156687| 0| | 33| 0.152303| 0.004554| 0.223206| 0.010334| 0.156778| 1| | 34| 0.152093| 0.002856| 0.121284| 0.010188| 0.156639| 0| | 35| 0.151871| 0.003145| 0.135909| 0.010108| 0.156446| 0| | 36| 0.151741| 0.001441| 0.225342| 0.010452| 0.156517| 1| | 37| 0.151626| 0.002500| 0.396782| 0.010487| 0.156429| 0| | 38| 0.151488| 0.005053| 0.148248| 0.010312| 0.156201| 0| | 39| 0.151250| 0.002552| 0.110278| 0.009895| 0.155968| 0| | 40| 0.151013| 0.002506| 0.123906| 0.010837| 0.155812| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 41| 0.150821| 0.002536| 0.109515| 0.010627| 0.155742| 0| | 42| 0.150509| 0.001418| 0.223296| 0.010561| 0.155648| 0| | 43| 0.150340| 0.003437| 0.185351| 0.010131| 0.155435| 0| | 44| 0.150280| 0.004746| 0.115075| 0.010432| 0.155797| 1| | 45| 0.150194| 0.002758| 0.082143| 0.010068| 0.155575| 2| | 46| 0.150061| 0.001122| 0.094288| 0.011405| 0.155334| 0| | 47| 0.149978| 0.001259| 0.127677| 0.010628| 0.155305| 0| | 48| 0.149879| 0.001523| 0.107816| 0.011331| 0.155044| 0| | 49| 0.149749| 0.004572| 0.156869| 0.009953| 0.155043| 0| | 50| 0.149617| 0.000965| 0.186502| 0.009702| 0.155106| 1| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 51| 0.149579| 0.001302| 0.062687| 0.010743| 0.155160| 2| | 52| 0.149519| 0.001407| 0.086000| 0.010335| 0.155216| 3| | 53| 0.149405| 0.001243| 0.147530| 0.009753| 0.155309| 4| | 54| 0.149203| 0.002749| 0.186920| 0.010267| 0.155337| 5| | 55| 0.149040| 0.001217| 0.310011| 0.012444| 0.155460| 6| |==========================================================================================|
Используйте информацию внутри TrainingHistory свойство объекта Mdl для проверки итерации, соответствующей минимальным потерям перекрестной энтропии проверки. Последняя возвращенная модель Mdl является моделью, обученной этой итерации.
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 49
Оценка эффективности работы обученного классификатора Mdl на контрольном аппарате adulttest с помощью predict, loss, margin, и edge функции объекта.
Найдите прогнозируемые метки и оценки классификации для наблюдений в тестовом наборе.
[labels,Scores] = predict(Mdl,adulttest);
Создайте матрицу путаницы из результатов набора тестов. Диагональные элементы указывают количество правильно классифицированных экземпляров данного класса. Внедиагональные элементы являются экземплярами неправильно классифицированных наблюдений.
confusionchart(adulttest.salary,labels)

Вычислите точность классификации тестового набора.
error = loss(Mdl,adulttest,"salary");
accuracy = (1-error)*100accuracy = 85.1306
Нейросетевой классификатор правильно классифицирует приблизительно 85% наблюдений тестового набора.
Вычислите пределы классификации тестового набора для обученной нейронной сети. Отображение гистограммы полей.
Поля классификации представляют собой разницу между показателем классификации для истинного класса и показателем классификации для ложного класса. Поскольку нейросетевые классификаторы возвращают баллы, которые являются задними вероятностями, поля классификации, близкие к 1, указывают на уверенные классификации, а отрицательные значения полей указывают на неправильные классификации.
m = margin(Mdl,adulttest,"salary");
histogram(m)
Используйте край классификации или среднее значение полей классификации для оценки общей производительности классификатора.
meanMargin = edge(Mdl,adulttest,"salary")meanMargin = 0.5983
Либо вычислите взвешенную границу классификации с помощью весов наблюдения.
weightedMeanMargin = edge(Mdl,adulttest,"salary", ... "Weight","fnlwgt")
weightedMeanMargin = 0.6072
Визуализируйте прогнозируемые метки и оценки классификации с помощью графиков рассеяния, на которых каждая точка соответствует наблюдению. Используйте прогнозируемые метки для задания цвета точек, а максимальные оценки - для задания прозрачности точек. Точки с меньшей прозрачностью маркируются с большей уверенностью.
Сначала найдите максимальный показатель классификации для каждого наблюдения тестового набора.
maxScores = max(Scores,[],2);
Создайте график разброса, сравнивая максимальные баллы по количеству рабочих часов в неделю и уровню образования. Поскольку образовательная переменная категорична, случайное дрожание (или пробел) точек вдоль измерения y.
Измените карту цветов таким образом, чтобы максимальные баллы, соответствующие окладам, которые меньше или равны 50 000 долл. США в год, отображались синим цветом, а максимальные баллы, соответствующие окладам, превышающим 50 000 долл. США в год, - красным цветом.
scatter(adulttest.hours_per_week,adulttest.education,[],labels, ... "filled","MarkerFaceAlpha","flat","AlphaData",maxScores, ... "YJitter","rand"); xlabel("Number of Work Hours Per Week") ylabel("Education") Mdl.ClassNames
ans = 2×1 categorical
<=50K
>50K
colors = lines(2)
colors = 2×3
0 0.4470 0.7410
0.8500 0.3250 0.0980
colormap(colors);

Цвета на сюжете рассеяния указывают на то, что в целом нейросеть прогнозирует, что у людей с более низким уровнем образования (12-й класс или ниже) зарплата меньше или равна $50 000 в год. Прозрачность некоторых точек в нижнем правом углу сюжета указывает на то, что модель менее уверена в этом прогнозе для людей, которые работают много часов в неделю (60 часов и более).
ClassificationNeuralNetwork | confusionchart | edge | fitcnet | loss | margin | predict | scatter