Создайте классификатор нейронной сети прямого распространения с полносвязными слоями с помощью fitcnet
. Используйте данные о валидации для ранней остановки учебного процесса, чтобы предотвратить сверхподбор кривой модели. Затем используйте объектные функции классификатора, чтобы оценить эффективность модели на тестовых данных.
Этот пример использует 1 994 данных о переписи, хранимых в census1994.mat
. Набор данных состоит из демографической информации из Бюро переписи США, которое можно использовать, чтобы предсказать, передает ли индивидуум 50 000$ в год.
Загрузите выборочные данные census1994
, который содержит обучающие данные adultdata
и тестовые данные adulttest
. Предварительно просмотрите первые несколько строк обучающего набора данных.
load census1994
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______
39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K
49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K
52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец, salary
, показывает, есть ли у человека зарплата, меньше чем или равная 50 000$ в год или больше, чем 50 000$ в год.
Объедините education_num
и education
переменные и в обучении и в тестовых данных, чтобы создать одну упорядоченную категориальную переменную, которая показывает высший уровень образования человек, достигли.
edOrder = unique(adultdata.education_num,"stable"); edCats = unique(adultdata.education,"stable"); [~,edIdx] = sort(edOrder); adultdata.education = categorical(adultdata.education, ... edCats(edIdx),"Ordinal",true); adultdata.education_num = []; adulttest.education = categorical(adulttest.education, ... edCats(edIdx),"Ordinal",true); adulttest.education_num = [];
Разделите обучающие данные далее с помощью стратифицированного раздела затяжки. Создайте отдельный набор данных валидации, чтобы остановить учебный процесс модели рано. Зарезервируйте приблизительно 30% наблюдений для набора данных валидации и используйте остальную часть наблюдений, чтобы обучить классификатор нейронной сети.
rng("default") % For reproducibility of the partition c = cvpartition(adultdata.salary,"Holdout",0.30); trainingIndices = training(c); validationIndices = test(c); tblTrain = adultdata(trainingIndices,:); tblValidation = adultdata(validationIndices,:);
Обучите классификатор нейронной сети при помощи набора обучающих данных. Задайте salary
столбец tblTrain
как ответ и fnlwgt
столбец как веса наблюдения, и стандартизирует числовые предикторы. Оцените модель в каждой итерации при помощи набора валидации. Задайте, чтобы отобразить учебную информацию в каждой итерации при помощи Verbose
аргумент значения имени. По умолчанию учебный процесс заканчивается рано, если потеря перекрестной энтропии валидации больше или равна минимальной потере перекрестной энтропии валидации, вычисленной до сих пор, шесть раз подряд. Чтобы изменить число раз, потере валидации позволяют быть больше или быть равной минимуму, задать ValidationPatience
аргумент значения имени.
Mdl = fitcnet(tblTrain,"salary","Weights","fnlwgt", ... "Standardize",true,"ValidationData",tblValidation, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 0.297812| 0.078920| 0.703981| 0.012127| 0.296816| 0| | 2| 0.281110| 0.054594| 0.149850| 0.012486| 0.280132| 0| | 3| 0.252648| 0.062041| 1.004181| 0.011339| 0.247863| 0| | 4| 0.211868| 0.023567| 0.267214| 0.011319| 0.208988| 0| | 5| 0.207039| 0.057528| 0.320942| 0.010288| 0.206781| 0| | 6| 0.196838| 0.022492| 0.089842| 0.011197| 0.195583| 0| | 7| 0.186133| 0.025551| 0.295975| 0.010426| 0.184900| 0| | 8| 0.178779| 0.023714| 0.244525| 0.010237| 0.179370| 0| | 9| 0.174531| 0.027149| 0.306182| 0.012911| 0.178175| 0| | 10| 0.173217| 0.013365| 0.037475| 0.011084| 0.176371| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 0.168160| 0.016506| 0.307350| 0.010016| 0.170415| 0| | 12| 0.164460| 0.025136| 0.473227| 0.011512| 0.165902| 0| | 13| 0.162895| 0.014983| 0.473367| 0.010969| 0.164582| 0| | 14| 0.160791| 0.005187| 0.113760| 0.011720| 0.162947| 0| | 15| 0.159742| 0.004035| 0.138748| 0.010260| 0.162074| 0| | 16| 0.159290| 0.005774| 0.108266| 0.010400| 0.161728| 0| | 17| 0.158593| 0.004977| 0.152142| 0.010603| 0.161272| 0| | 18| 0.157437| 0.003660| 0.193303| 0.010510| 0.160299| 0| | 19| 0.156642| 0.007722| 0.430859| 0.010069| 0.160145| 0| | 20| 0.155954| 0.001908| 0.121039| 0.010041| 0.159066| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 0.155824| 0.001645| 0.025159| 0.010557| 0.158992| 0| | 22| 0.155486| 0.003232| 0.119915| 0.010829| 0.158731| 0| | 23| 0.155398| 0.006845| 0.083105| 0.031846| 0.158963| 1| | 24| 0.155261| 0.004374| 0.065660| 0.010762| 0.158816| 2| | 25| 0.154955| 0.002505| 0.264106| 0.011437| 0.158687| 0| | 26| 0.154799| 0.002183| 0.040876| 0.010903| 0.158538| 0| | 27| 0.154466| 0.002881| 0.219478| 0.012409| 0.158033| 0| | 28| 0.154250| 0.002724| 0.196190| 0.012062| 0.157980| 0| | 29| 0.153918| 0.002189| 0.135392| 0.009862| 0.157605| 0| | 30| 0.153707| 0.001449| 0.111574| 0.010851| 0.157456| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 31| 0.153214| 0.002050| 0.528628| 0.010212| 0.157379| 0| | 32| 0.152671| 0.002542| 0.488640| 0.010013| 0.156687| 0| | 33| 0.152303| 0.004554| 0.223206| 0.010334| 0.156778| 1| | 34| 0.152093| 0.002856| 0.121284| 0.010188| 0.156639| 0| | 35| 0.151871| 0.003145| 0.135909| 0.010108| 0.156446| 0| | 36| 0.151741| 0.001441| 0.225342| 0.010452| 0.156517| 1| | 37| 0.151626| 0.002500| 0.396782| 0.010487| 0.156429| 0| | 38| 0.151488| 0.005053| 0.148248| 0.010312| 0.156201| 0| | 39| 0.151250| 0.002552| 0.110278| 0.009895| 0.155968| 0| | 40| 0.151013| 0.002506| 0.123906| 0.010837| 0.155812| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 41| 0.150821| 0.002536| 0.109515| 0.010627| 0.155742| 0| | 42| 0.150509| 0.001418| 0.223296| 0.010561| 0.155648| 0| | 43| 0.150340| 0.003437| 0.185351| 0.010131| 0.155435| 0| | 44| 0.150280| 0.004746| 0.115075| 0.010432| 0.155797| 1| | 45| 0.150194| 0.002758| 0.082143| 0.010068| 0.155575| 2| | 46| 0.150061| 0.001122| 0.094288| 0.011405| 0.155334| 0| | 47| 0.149978| 0.001259| 0.127677| 0.010628| 0.155305| 0| | 48| 0.149879| 0.001523| 0.107816| 0.011331| 0.155044| 0| | 49| 0.149749| 0.004572| 0.156869| 0.009953| 0.155043| 0| | 50| 0.149617| 0.000965| 0.186502| 0.009702| 0.155106| 1| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 51| 0.149579| 0.001302| 0.062687| 0.010743| 0.155160| 2| | 52| 0.149519| 0.001407| 0.086000| 0.010335| 0.155216| 3| | 53| 0.149405| 0.001243| 0.147530| 0.009753| 0.155309| 4| | 54| 0.149203| 0.002749| 0.186920| 0.010267| 0.155337| 5| | 55| 0.149040| 0.001217| 0.310011| 0.012444| 0.155460| 6| |==========================================================================================|
Используйте информацию в TrainingHistory
свойство объекта Mdl
проверять итерацию, которая соответствует минимальной потере перекрестной энтропии валидации. Финал возвратил модель Mdl
модель, обученная в этой итерации.
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 49
Оцените эффективность обученного классификатора Mdl
на наборе тестов adulttest
при помощи predict
, loss
, margin
, и edge
функции объекта.
Найдите предсказанные метки и классификационные оценки для наблюдений в наборе тестов.
[labels,Scores] = predict(Mdl,adulttest);
Создайте матрицу беспорядка из результатов набора тестов. Диагональные элементы указывают на количество правильно классифицированных экземпляров данного класса. Недиагональными элементами являются экземпляры неправильно классифицированных наблюдений.
confusionchart(adulttest.salary,labels)
Вычислите точность классификации наборов тестов.
error = loss(Mdl,adulttest,"salary");
accuracy = (1-error)*100
accuracy = 85.1306
Классификатор нейронной сети правильно классифицирует приблизительно 85% наблюдений набора тестов.
Вычислите поля классификации наборов тестов для обученной нейронной сети. Отобразите гистограмму полей.
Поля классификации являются различием между классификационной оценкой для истинного класса и классификационной оценкой для ложного класса. Поскольку классификаторы нейронной сети возвращают баллы, которые являются апостериорными вероятностями, поля классификации близко к 1 указывают на уверенные классификации, и отрицательные граничные значения указывают на misclassifications.
m = margin(Mdl,adulttest,"salary");
histogram(m)
Используйте ребро классификации или среднее значение полей классификации, чтобы оценить общую производительность классификатора.
meanMargin = edge(Mdl,adulttest,"salary")
meanMargin = 0.5983
В качестве альтернативы вычислите взвешенное ребро классификации при помощи весов наблюдения.
weightedMeanMargin = edge(Mdl,adulttest,"salary", ... "Weight","fnlwgt")
weightedMeanMargin = 0.6072
Визуализируйте предсказанные метки и классификационные оценки с помощью графиков рассеивания, в которых каждая точка соответствует наблюдению. Используйте предсказанные метки, чтобы выбрать цвет точек и использовать максимальные баллы, чтобы установить прозрачность точек. Точки с меньшей прозрачностью помечены большим доверием.
Во-первых, найдите максимальную классификационную оценку для каждого наблюдения набора тестов.
maxScores = max(Scores,[],2);
Создайте график рассеивания, сравнивающий максимальные баллы через номер часов работы в неделю и уровень образования. Поскольку образовательная переменная является категориальной, случайным образом дрожите (или растяните), точки по y-измерению.
Измените палитру так, чтобы максимальные баллы, соответствующие зарплатам, которые меньше чем или равны 50 000$ в год, появились столь же синие, и максимальные баллы, соответствующие зарплатам, больше, чем 50 000$ в год появляются как красный.
scatter(adulttest.hours_per_week,adulttest.education,[],labels, ... "filled","MarkerFaceAlpha","flat","AlphaData",maxScores, ... "YJitter","rand"); xlabel("Number of Work Hours Per Week") ylabel("Education") Mdl.ClassNames
ans = 2×1 categorical
<=50K
>50K
colors = lines(2)
colors = 2×3
0 0.4470 0.7410
0.8500 0.3250 0.0980
colormap(colors);
Цвета в графике рассеивания указывают, что в целом нейронная сеть предсказывает, что у людей с более низкими уровнями образования (12-й класс или ниже) есть зарплаты, меньше чем или равные 50 000$ в год. Прозрачность некоторых точек в нижнем правом углу графика указывает, что модель менее уверена в этом предсказании для людей, которые работают много часов в неделю (60 часов или больше).
fitcnet
| margin
| edge
| loss
| predict
| ClassificationNeuralNetwork
| confusionchart
| scatter