exponenta event banner

Оценка характеристик классификатора нейронных сетей

Создание прямого классификатора нейронных сетей с полностью соединенными слоями с помощью fitcnet. Используйте данные проверки для ранней остановки процесса обучения, чтобы предотвратить переоснащение модели. Затем используйте объектные функции классификатора для оценки производительности модели на тестовых данных.

Загрузка и предварительная обработка образцов данных

В этом примере используются данные переписи 1994 года, хранящиеся в census1994.mat. Набор данных состоит из демографической информации Бюро переписи населения США, которую можно использовать для прогнозирования того, зарабатывает ли человек более 50 000 долларов в год.

Загрузка данных образца census1994, который содержит данные обучения adultdata и данные испытаний adulttest. Просмотрите первые несколько строк набора данных обучения.

load census1994
head(adultdata)
ans=8×15 table
    age       workClass          fnlwgt      education    education_num       marital_status           occupation        relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
    ___    ________________    __________    _________    _____________    _____________________    _________________    _____________    _____    ______    ____________    ____________    ______________    ______________    ______

    39     State-gov                77516    Bachelors         13          Never-married            Adm-clerical         Not-in-family    White    Male          2174             0                40          United-States     <=50K 
    50     Self-emp-not-inc         83311    Bachelors         13          Married-civ-spouse       Exec-managerial      Husband          White    Male             0             0                13          United-States     <=50K 
    38     Private             2.1565e+05    HS-grad            9          Divorced                 Handlers-cleaners    Not-in-family    White    Male             0             0                40          United-States     <=50K 
    53     Private             2.3472e+05    11th               7          Married-civ-spouse       Handlers-cleaners    Husband          Black    Male             0             0                40          United-States     <=50K 
    28     Private             3.3841e+05    Bachelors         13          Married-civ-spouse       Prof-specialty       Wife             Black    Female           0             0                40          Cuba              <=50K 
    37     Private             2.8458e+05    Masters           14          Married-civ-spouse       Exec-managerial      Wife             White    Female           0             0                40          United-States     <=50K 
    49     Private             1.6019e+05    9th                5          Married-spouse-absent    Other-service        Not-in-family    Black    Female           0             0                16          Jamaica           <=50K 
    52     Self-emp-not-inc    2.0964e+05    HS-grad            9          Married-civ-spouse       Exec-managerial      Husband          White    Male             0             0                45          United-States     >50K  

Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец, salaryпоказывает, имеет ли человек зарплату меньше или равную 50 000 долларов в год или больше 50 000 долларов в год.

Объединить education_num и education переменные в данных обучения и тестирования для создания одной упорядоченной категориальной переменной, которая показывает наивысший уровень образования, которого достиг человек.

edOrder = unique(adultdata.education_num,"stable");
edCats = unique(adultdata.education,"stable");
[~,edIdx] = sort(edOrder);

adultdata.education = categorical(adultdata.education, ...
    edCats(edIdx),"Ordinal",true);
adultdata.education_num = [];

adulttest.education = categorical(adulttest.education, ...
    edCats(edIdx),"Ordinal",true);
adulttest.education_num = [];

Учебные данные по разделам

Разделите учебные данные с помощью стратифицированного раздела удержания. Создайте отдельный набор данных проверки, чтобы остановить процесс обучения модели на ранней стадии. Зарезервируйте приблизительно 30% наблюдений для набора данных проверки и используйте остальные наблюдения для обучения классификатора нейронной сети.

rng("default") % For reproducibility of the partition
c = cvpartition(adultdata.salary,"Holdout",0.30);
trainingIndices = training(c);
validationIndices = test(c);
tblTrain = adultdata(trainingIndices,:);
tblValidation = adultdata(validationIndices,:);

Нейронная сеть поезда

Обучение классификатора нейронной сети с помощью обучающего набора. Укажите salary столбец tblTrain в качестве ответа и fnlwgt столбец в качестве весов наблюдения и стандартизировать числовые предикторы. Оцените модель в каждой итерации с помощью набора проверки. Укажите отображение информации об обучении в каждой итерации с помощью Verbose аргумент «имя-значение». По умолчанию тренировочный процесс заканчивается раньше, если потери при перекрестной энтропии проверки больше или равны минимальным потерям при перекрестной энтропии проверки, вычисленным на данный момент, шесть раз подряд. Чтобы изменить количество раз, когда допустимо, чтобы потеря проверки была больше или равна минимуму, укажите ValidationPatience аргумент «имя-значение».

Mdl = fitcnet(tblTrain,"salary","Weights","fnlwgt", ...
    "Standardize",true,"ValidationData",tblValidation, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|    0.297812|    0.078920|    0.703981|    0.012127|    0.296816|           0|
|           2|    0.281110|    0.054594|    0.149850|    0.012486|    0.280132|           0|
|           3|    0.252648|    0.062041|    1.004181|    0.011339|    0.247863|           0|
|           4|    0.211868|    0.023567|    0.267214|    0.011319|    0.208988|           0|
|           5|    0.207039|    0.057528|    0.320942|    0.010288|    0.206781|           0|
|           6|    0.196838|    0.022492|    0.089842|    0.011197|    0.195583|           0|
|           7|    0.186133|    0.025551|    0.295975|    0.010426|    0.184900|           0|
|           8|    0.178779|    0.023714|    0.244525|    0.010237|    0.179370|           0|
|           9|    0.174531|    0.027149|    0.306182|    0.012911|    0.178175|           0|
|          10|    0.173217|    0.013365|    0.037475|    0.011084|    0.176371|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    0.168160|    0.016506|    0.307350|    0.010016|    0.170415|           0|
|          12|    0.164460|    0.025136|    0.473227|    0.011512|    0.165902|           0|
|          13|    0.162895|    0.014983|    0.473367|    0.010969|    0.164582|           0|
|          14|    0.160791|    0.005187|    0.113760|    0.011720|    0.162947|           0|
|          15|    0.159742|    0.004035|    0.138748|    0.010260|    0.162074|           0|
|          16|    0.159290|    0.005774|    0.108266|    0.010400|    0.161728|           0|
|          17|    0.158593|    0.004977|    0.152142|    0.010603|    0.161272|           0|
|          18|    0.157437|    0.003660|    0.193303|    0.010510|    0.160299|           0|
|          19|    0.156642|    0.007722|    0.430859|    0.010069|    0.160145|           0|
|          20|    0.155954|    0.001908|    0.121039|    0.010041|    0.159066|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    0.155824|    0.001645|    0.025159|    0.010557|    0.158992|           0|
|          22|    0.155486|    0.003232|    0.119915|    0.010829|    0.158731|           0|
|          23|    0.155398|    0.006845|    0.083105|    0.031846|    0.158963|           1|
|          24|    0.155261|    0.004374|    0.065660|    0.010762|    0.158816|           2|
|          25|    0.154955|    0.002505|    0.264106|    0.011437|    0.158687|           0|
|          26|    0.154799|    0.002183|    0.040876|    0.010903|    0.158538|           0|
|          27|    0.154466|    0.002881|    0.219478|    0.012409|    0.158033|           0|
|          28|    0.154250|    0.002724|    0.196190|    0.012062|    0.157980|           0|
|          29|    0.153918|    0.002189|    0.135392|    0.009862|    0.157605|           0|
|          30|    0.153707|    0.001449|    0.111574|    0.010851|    0.157456|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          31|    0.153214|    0.002050|    0.528628|    0.010212|    0.157379|           0|
|          32|    0.152671|    0.002542|    0.488640|    0.010013|    0.156687|           0|
|          33|    0.152303|    0.004554|    0.223206|    0.010334|    0.156778|           1|
|          34|    0.152093|    0.002856|    0.121284|    0.010188|    0.156639|           0|
|          35|    0.151871|    0.003145|    0.135909|    0.010108|    0.156446|           0|
|          36|    0.151741|    0.001441|    0.225342|    0.010452|    0.156517|           1|
|          37|    0.151626|    0.002500|    0.396782|    0.010487|    0.156429|           0|
|          38|    0.151488|    0.005053|    0.148248|    0.010312|    0.156201|           0|
|          39|    0.151250|    0.002552|    0.110278|    0.009895|    0.155968|           0|
|          40|    0.151013|    0.002506|    0.123906|    0.010837|    0.155812|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          41|    0.150821|    0.002536|    0.109515|    0.010627|    0.155742|           0|
|          42|    0.150509|    0.001418|    0.223296|    0.010561|    0.155648|           0|
|          43|    0.150340|    0.003437|    0.185351|    0.010131|    0.155435|           0|
|          44|    0.150280|    0.004746|    0.115075|    0.010432|    0.155797|           1|
|          45|    0.150194|    0.002758|    0.082143|    0.010068|    0.155575|           2|
|          46|    0.150061|    0.001122|    0.094288|    0.011405|    0.155334|           0|
|          47|    0.149978|    0.001259|    0.127677|    0.010628|    0.155305|           0|
|          48|    0.149879|    0.001523|    0.107816|    0.011331|    0.155044|           0|
|          49|    0.149749|    0.004572|    0.156869|    0.009953|    0.155043|           0|
|          50|    0.149617|    0.000965|    0.186502|    0.009702|    0.155106|           1|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          51|    0.149579|    0.001302|    0.062687|    0.010743|    0.155160|           2|
|          52|    0.149519|    0.001407|    0.086000|    0.010335|    0.155216|           3|
|          53|    0.149405|    0.001243|    0.147530|    0.009753|    0.155309|           4|
|          54|    0.149203|    0.002749|    0.186920|    0.010267|    0.155337|           5|
|          55|    0.149040|    0.001217|    0.310011|    0.012444|    0.155460|           6|
|==========================================================================================|

Используйте информацию внутри TrainingHistory свойство объекта Mdl для проверки итерации, соответствующей минимальным потерям перекрестной энтропии проверки. Последняя возвращенная модель Mdl является моделью, обученной этой итерации.

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 49

Оценка производительности тестового набора

Оценка эффективности работы обученного классификатора Mdl на контрольном аппарате adulttest с помощью predict, loss, margin, и edge функции объекта.

Найдите прогнозируемые метки и оценки классификации для наблюдений в тестовом наборе.

[labels,Scores] = predict(Mdl,adulttest);

Создайте матрицу путаницы из результатов набора тестов. Диагональные элементы указывают количество правильно классифицированных экземпляров данного класса. Внедиагональные элементы являются экземплярами неправильно классифицированных наблюдений.

confusionchart(adulttest.salary,labels)

Вычислите точность классификации тестового набора.

error = loss(Mdl,adulttest,"salary");
accuracy = (1-error)*100
accuracy = 85.1306

Нейросетевой классификатор правильно классифицирует приблизительно 85% наблюдений тестового набора.

Вычислите пределы классификации тестового набора для обученной нейронной сети. Отображение гистограммы полей.

Поля классификации представляют собой разницу между показателем классификации для истинного класса и показателем классификации для ложного класса. Поскольку нейросетевые классификаторы возвращают баллы, которые являются задними вероятностями, поля классификации, близкие к 1, указывают на уверенные классификации, а отрицательные значения полей указывают на неправильные классификации.

m = margin(Mdl,adulttest,"salary");
histogram(m)

Используйте край классификации или среднее значение полей классификации для оценки общей производительности классификатора.

meanMargin = edge(Mdl,adulttest,"salary")
meanMargin = 0.5983

Либо вычислите взвешенную границу классификации с помощью весов наблюдения.

weightedMeanMargin = edge(Mdl,adulttest,"salary", ...
    "Weight","fnlwgt")
weightedMeanMargin = 0.6072

Визуализируйте прогнозируемые метки и оценки классификации с помощью графиков рассеяния, на которых каждая точка соответствует наблюдению. Используйте прогнозируемые метки для задания цвета точек, а максимальные оценки - для задания прозрачности точек. Точки с меньшей прозрачностью маркируются с большей уверенностью.

Сначала найдите максимальный показатель классификации для каждого наблюдения тестового набора.

maxScores = max(Scores,[],2);

Создайте график разброса, сравнивая максимальные баллы по количеству рабочих часов в неделю и уровню образования. Поскольку образовательная переменная категорична, случайное дрожание (или пробел) точек вдоль измерения y.

Измените карту цветов таким образом, чтобы максимальные баллы, соответствующие окладам, которые меньше или равны 50 000 долл. США в год, отображались синим цветом, а максимальные баллы, соответствующие окладам, превышающим 50 000 долл. США в год, - красным цветом.

scatter(adulttest.hours_per_week,adulttest.education,[],labels, ...
    "filled","MarkerFaceAlpha","flat","AlphaData",maxScores, ...
    "YJitter","rand");
xlabel("Number of Work Hours Per Week")
ylabel("Education")

Mdl.ClassNames
ans = 2×1 categorical
     <=50K 
     >50K 

colors = lines(2)
colors = 2×3

         0    0.4470    0.7410
    0.8500    0.3250    0.0980

colormap(colors);

Цвета на сюжете рассеяния указывают на то, что в целом нейросеть прогнозирует, что у людей с более низким уровнем образования (12-й класс или ниже) зарплата меньше или равна $50 000 в год. Прозрачность некоторых точек в нижнем правом углу сюжета указывает на то, что модель менее уверена в этом прогнозе для людей, которые работают много часов в неделю (60 часов и более).

См. также

| | | | | | |