Создайте классификатор нейронной сети с feedforward с полносвязными слоями с помощью fitcnet
. Используйте данные валидации для раннего остановки процесса обучения, чтобы предотвратить сверхподбор кривой модели. Затем используйте функции объекта классификатора, чтобы оценить эффективность модели на тестовых данных.
Этот пример использует данные переписи 1994 года, хранящиеся в census1994.mat
. Набор данных состоит из демографической информации Бюро переписи населения США, которую можно использовать, чтобы предсказать, зарабатывает ли индивидуум более 50 000 долларов в год.
Загрузите выборочные данные census1994
, который содержит обучающие данные adultdata
и тестовые данные adulttest
. Предварительный просмотр первых нескольких строк обучающих данных набора.
load census1994
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______
39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K
49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K
52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец, salary
, показывает, имеет ли человек оклад менее 50 000 долл. США в год или более 50 000 долл. США в год.
Объедините education_num
и education
переменные как в обучающих, так и в тестовых данных для создания одной упорядоченной категориальной переменной, которая показывает самый высокий уровень образования, которого достиг человек.
edOrder = unique(adultdata.education_num,"stable"); edCats = unique(adultdata.education,"stable"); [~,edIdx] = sort(edOrder); adultdata.education = categorical(adultdata.education, ... edCats(edIdx),"Ordinal",true); adultdata.education_num = []; adulttest.education = categorical(adulttest.education, ... edCats(edIdx),"Ordinal",true); adulttest.education_num = [];
Разделите обучающие данные далее с помощью стратифицированного раздела удержания. Создайте отдельный набор данных валидации, чтобы остановить процесс обучения модели раньше. Резервируйте приблизительно 30% наблюдений для набора данных валидации и используйте остальную часть наблюдений для обучения классификатора нейронной сети.
rng("default") % For reproducibility of the partition c = cvpartition(adultdata.salary,"Holdout",0.30); trainingIndices = training(c); validationIndices = test(c); tblTrain = adultdata(trainingIndices,:); tblValidation = adultdata(validationIndices,:);
Обучите классификатор нейронной сети при помощи набора обучающих данных. Задайте salary
столбец tblTrain
как ответ и fnlwgt
столбец как веса наблюдений и стандартизируйте числовые предикторы. Оцените модель при каждой итерации с помощью набора валидации. Задайте, чтобы отобразить информацию о обучении при каждой итерации с помощью Verbose
аргумент имя-значение. По умолчанию процесс обучения заканчивается раньше, если потеря перекрестной энтропии валидации больше или равна минимальной потере перекрестной энтропии валидации, вычисленной до сих пор, шесть раз подряд. Чтобы изменить количество раз, в течение которого потеря валидации может быть больше или равной минимуму, задайте ValidationPatience
аргумент имя-значение.
Mdl = fitcnet(tblTrain,"salary","Weights","fnlwgt", ... "Standardize",true,"ValidationData",tblValidation, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 0.297812| 0.078920| 0.703981| 0.012127| 0.296816| 0| | 2| 0.281110| 0.054594| 0.149850| 0.012486| 0.280132| 0| | 3| 0.252648| 0.062041| 1.004181| 0.011339| 0.247863| 0| | 4| 0.211868| 0.023567| 0.267214| 0.011319| 0.208988| 0| | 5| 0.207039| 0.057528| 0.320942| 0.010288| 0.206781| 0| | 6| 0.196838| 0.022492| 0.089842| 0.011197| 0.195583| 0| | 7| 0.186133| 0.025551| 0.295975| 0.010426| 0.184900| 0| | 8| 0.178779| 0.023714| 0.244525| 0.010237| 0.179370| 0| | 9| 0.174531| 0.027149| 0.306182| 0.012911| 0.178175| 0| | 10| 0.173217| 0.013365| 0.037475| 0.011084| 0.176371| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 0.168160| 0.016506| 0.307350| 0.010016| 0.170415| 0| | 12| 0.164460| 0.025136| 0.473227| 0.011512| 0.165902| 0| | 13| 0.162895| 0.014983| 0.473367| 0.010969| 0.164582| 0| | 14| 0.160791| 0.005187| 0.113760| 0.011720| 0.162947| 0| | 15| 0.159742| 0.004035| 0.138748| 0.010260| 0.162074| 0| | 16| 0.159290| 0.005774| 0.108266| 0.010400| 0.161728| 0| | 17| 0.158593| 0.004977| 0.152142| 0.010603| 0.161272| 0| | 18| 0.157437| 0.003660| 0.193303| 0.010510| 0.160299| 0| | 19| 0.156642| 0.007722| 0.430859| 0.010069| 0.160145| 0| | 20| 0.155954| 0.001908| 0.121039| 0.010041| 0.159066| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 0.155824| 0.001645| 0.025159| 0.010557| 0.158992| 0| | 22| 0.155486| 0.003232| 0.119915| 0.010829| 0.158731| 0| | 23| 0.155398| 0.006845| 0.083105| 0.031846| 0.158963| 1| | 24| 0.155261| 0.004374| 0.065660| 0.010762| 0.158816| 2| | 25| 0.154955| 0.002505| 0.264106| 0.011437| 0.158687| 0| | 26| 0.154799| 0.002183| 0.040876| 0.010903| 0.158538| 0| | 27| 0.154466| 0.002881| 0.219478| 0.012409| 0.158033| 0| | 28| 0.154250| 0.002724| 0.196190| 0.012062| 0.157980| 0| | 29| 0.153918| 0.002189| 0.135392| 0.009862| 0.157605| 0| | 30| 0.153707| 0.001449| 0.111574| 0.010851| 0.157456| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 31| 0.153214| 0.002050| 0.528628| 0.010212| 0.157379| 0| | 32| 0.152671| 0.002542| 0.488640| 0.010013| 0.156687| 0| | 33| 0.152303| 0.004554| 0.223206| 0.010334| 0.156778| 1| | 34| 0.152093| 0.002856| 0.121284| 0.010188| 0.156639| 0| | 35| 0.151871| 0.003145| 0.135909| 0.010108| 0.156446| 0| | 36| 0.151741| 0.001441| 0.225342| 0.010452| 0.156517| 1| | 37| 0.151626| 0.002500| 0.396782| 0.010487| 0.156429| 0| | 38| 0.151488| 0.005053| 0.148248| 0.010312| 0.156201| 0| | 39| 0.151250| 0.002552| 0.110278| 0.009895| 0.155968| 0| | 40| 0.151013| 0.002506| 0.123906| 0.010837| 0.155812| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 41| 0.150821| 0.002536| 0.109515| 0.010627| 0.155742| 0| | 42| 0.150509| 0.001418| 0.223296| 0.010561| 0.155648| 0| | 43| 0.150340| 0.003437| 0.185351| 0.010131| 0.155435| 0| | 44| 0.150280| 0.004746| 0.115075| 0.010432| 0.155797| 1| | 45| 0.150194| 0.002758| 0.082143| 0.010068| 0.155575| 2| | 46| 0.150061| 0.001122| 0.094288| 0.011405| 0.155334| 0| | 47| 0.149978| 0.001259| 0.127677| 0.010628| 0.155305| 0| | 48| 0.149879| 0.001523| 0.107816| 0.011331| 0.155044| 0| | 49| 0.149749| 0.004572| 0.156869| 0.009953| 0.155043| 0| | 50| 0.149617| 0.000965| 0.186502| 0.009702| 0.155106| 1| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 51| 0.149579| 0.001302| 0.062687| 0.010743| 0.155160| 2| | 52| 0.149519| 0.001407| 0.086000| 0.010335| 0.155216| 3| | 53| 0.149405| 0.001243| 0.147530| 0.009753| 0.155309| 4| | 54| 0.149203| 0.002749| 0.186920| 0.010267| 0.155337| 5| | 55| 0.149040| 0.001217| 0.310011| 0.012444| 0.155460| 6| |==========================================================================================|
Используйте информацию внутри TrainingHistory
свойство объекта Mdl
чтобы проверить итерацию, которая соответствует минимальным потерям перекрестной энтропии валидации. Конечная возвращенная модель Mdl
- модель, обученная на этой итерации.
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 49
Оцените эффективность обученного классификатора Mdl
на тестовом аппарате adulttest
при помощи predict
, loss
, margin
, и edge
функции объекта.
Найдите предсказанные метки и классификационные оценки для наблюдений в тестовом наборе.
[labels,Scores] = predict(Mdl,adulttest);
Создайте матрицу неточностей из результатов тестового набора. Диагональные элементы указывают количество правильно классифицированных образцов данного класса. Не-диагональные элементы являются образцами неправильно классифицированных наблюдений.
confusionchart(adulttest.salary,labels)
Вычислите точность классификации тестового набора.
error = loss(Mdl,adulttest,"salary");
accuracy = (1-error)*100
accuracy = 85.1306
Классификатор нейронной сети правильно классифицирует приблизительно 85% наблюдений тестового набора.
Вычислите поля классификации тестового набора для обученной нейронной сети. Отображение гистограммы полей.
Классификационные поля являются различием между классификационной оценкой для истинного класса и классификационной оценкой для ложного класса. Поскольку классификаторы нейронных сетей возвращают счета, которые являются апостериорными вероятностями, классификационные поля, близкие к 1, указывают на уверенные классификации, а отрицательные значения полей указывают на неправильную классификацию.
m = margin(Mdl,adulttest,"salary");
histogram(m)
Используйте классификационное ребро или среднее значение классификационных полей для оценки общей эффективности классификатора.
meanMargin = edge(Mdl,adulttest,"salary")
meanMargin = 0.5983
Кроме того, вычислите взвешенное ребро классификации с помощью весов наблюдений.
weightedMeanMargin = edge(Mdl,adulttest,"salary", ... "Weight","fnlwgt")
weightedMeanMargin = 0.6072
Визуализируйте предсказанные метки и классификационные оценки с помощью графиков поля точек, на которых каждая точка соответствует наблюдению. Используйте предсказанные метки, чтобы задать цвет точек, и используйте максимальные счета, чтобы задать прозрачность точек. Точки с меньшей прозрачностью маркируются с большей доверие.
Сначала найдите максимальную классификационную оценку для каждого наблюдения тестового набора.
maxScores = max(Scores,[],2);
Создайте график поля точек, сравнивающий максимальные счета по количеству рабочих часов в неделю и уровню образования. Потому что переменная образования категориальна, случайным образом дрожит (или пробелирует) точки вдоль y-размерности.
Измените палитру так, чтобы максимальные счета, соответствующие окладам, которые меньше или равны 50 000 долларов США в год, отображались как синие, а максимальные счета, соответствующие окладам, превышающим 50 000 долларов США в год, отображались как красные.
scatter(adulttest.hours_per_week,adulttest.education,[],labels, ... "filled","MarkerFaceAlpha","flat","AlphaData",maxScores, ... "YJitter","rand"); xlabel("Number of Work Hours Per Week") ylabel("Education") Mdl.ClassNames
ans = 2×1 categorical
<=50K
>50K
colors = lines(2)
colors = 2×3
0 0.4470 0.7410
0.8500 0.3250 0.0980
colormap(colors);
Цвета на графике поля точек указывают, что в целом нейронная сеть предсказывает, что у людей с более низкими уровнями образования (12-й класс или ниже) зарплаты меньше или равны 50 000 долларов в год. Прозрачность некоторых точек в правом нижнем углу графика указывает, что модель менее уверена в этом предсказании для людей, которые работают много часов в неделю (60 часов и более).
ClassificationNeuralNetwork
| confusionchart
| edge
| fitcnet
| loss
| margin
| predict
| scatter