Оцените эффективность классификатора нейронной сети

Создайте классификатор нейронной сети прямого распространения с полносвязными слоями с помощью fitcnet. Используйте данные о валидации для ранней остановки учебного процесса, чтобы предотвратить сверхподбор кривой модели. Затем используйте объектные функции классификатора, чтобы оценить эффективность модели на тестовых данных.

Загрузите и предварительно обработайте выборочные данные

Этот пример использует 1 994 данных о переписи, хранимых в census1994.mat. Набор данных состоит из демографической информации из Бюро переписи США, которое можно использовать, чтобы предсказать, передает ли индивидуум 50 000$ в год.

Загрузите выборочные данные census1994, который содержит обучающие данные adultdata и тестовые данные adulttest. Предварительно просмотрите первые несколько строк обучающего набора данных.

load census1994
head(adultdata)
ans=8×15 table
    age       workClass          fnlwgt      education    education_num       marital_status           occupation        relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
    ___    ________________    __________    _________    _____________    _____________________    _________________    _____________    _____    ______    ____________    ____________    ______________    ______________    ______

    39     State-gov                77516    Bachelors         13          Never-married            Adm-clerical         Not-in-family    White    Male          2174             0                40          United-States     <=50K 
    50     Self-emp-not-inc         83311    Bachelors         13          Married-civ-spouse       Exec-managerial      Husband          White    Male             0             0                13          United-States     <=50K 
    38     Private             2.1565e+05    HS-grad            9          Divorced                 Handlers-cleaners    Not-in-family    White    Male             0             0                40          United-States     <=50K 
    53     Private             2.3472e+05    11th               7          Married-civ-spouse       Handlers-cleaners    Husband          Black    Male             0             0                40          United-States     <=50K 
    28     Private             3.3841e+05    Bachelors         13          Married-civ-spouse       Prof-specialty       Wife             Black    Female           0             0                40          Cuba              <=50K 
    37     Private             2.8458e+05    Masters           14          Married-civ-spouse       Exec-managerial      Wife             White    Female           0             0                40          United-States     <=50K 
    49     Private             1.6019e+05    9th                5          Married-spouse-absent    Other-service        Not-in-family    Black    Female           0             0                16          Jamaica           <=50K 
    52     Self-emp-not-inc    2.0964e+05    HS-grad            9          Married-civ-spouse       Exec-managerial      Husband          White    Male             0             0                45          United-States     >50K  

Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец, salary, показывает, есть ли у человека зарплата, меньше чем или равная 50 000$ в год или больше, чем 50 000$ в год.

Объедините education_num и education переменные и в обучении и в тестовых данных, чтобы создать одну упорядоченную категориальную переменную, которая показывает высший уровень образования человек, достигли.

edOrder = unique(adultdata.education_num,"stable");
edCats = unique(adultdata.education,"stable");
[~,edIdx] = sort(edOrder);

adultdata.education = categorical(adultdata.education, ...
    edCats(edIdx),"Ordinal",true);
adultdata.education_num = [];

adulttest.education = categorical(adulttest.education, ...
    edCats(edIdx),"Ordinal",true);
adulttest.education_num = [];

Обучающие данные раздела

Разделите обучающие данные далее с помощью стратифицированного раздела затяжки. Создайте отдельный набор данных валидации, чтобы остановить учебный процесс модели рано. Зарезервируйте приблизительно 30% наблюдений для набора данных валидации и используйте остальную часть наблюдений, чтобы обучить классификатор нейронной сети.

rng("default") % For reproducibility of the partition
c = cvpartition(adultdata.salary,"Holdout",0.30);
trainingIndices = training(c);
validationIndices = test(c);
tblTrain = adultdata(trainingIndices,:);
tblValidation = adultdata(validationIndices,:);

Обучите нейронную сеть

Обучите классификатор нейронной сети при помощи набора обучающих данных. Задайте salary столбец tblTrain как ответ и fnlwgt столбец как веса наблюдения, и стандартизирует числовые предикторы. Оцените модель в каждой итерации при помощи набора валидации. Задайте, чтобы отобразить учебную информацию в каждой итерации при помощи Verbose аргумент значения имени. По умолчанию учебный процесс заканчивается рано, если потеря перекрестной энтропии валидации больше или равна минимальной потере перекрестной энтропии валидации, вычисленной до сих пор, шесть раз подряд. Чтобы изменить число раз, потере валидации позволяют быть больше или быть равной минимуму, задать ValidationPatience аргумент значения имени.

Mdl = fitcnet(tblTrain,"salary","Weights","fnlwgt", ...
    "Standardize",true,"ValidationData",tblValidation, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|    0.297812|    0.078920|    0.703981|    0.012127|    0.296816|           0|
|           2|    0.281110|    0.054594|    0.149850|    0.012486|    0.280132|           0|
|           3|    0.252648|    0.062041|    1.004181|    0.011339|    0.247863|           0|
|           4|    0.211868|    0.023567|    0.267214|    0.011319|    0.208988|           0|
|           5|    0.207039|    0.057528|    0.320942|    0.010288|    0.206781|           0|
|           6|    0.196838|    0.022492|    0.089842|    0.011197|    0.195583|           0|
|           7|    0.186133|    0.025551|    0.295975|    0.010426|    0.184900|           0|
|           8|    0.178779|    0.023714|    0.244525|    0.010237|    0.179370|           0|
|           9|    0.174531|    0.027149|    0.306182|    0.012911|    0.178175|           0|
|          10|    0.173217|    0.013365|    0.037475|    0.011084|    0.176371|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    0.168160|    0.016506|    0.307350|    0.010016|    0.170415|           0|
|          12|    0.164460|    0.025136|    0.473227|    0.011512|    0.165902|           0|
|          13|    0.162895|    0.014983|    0.473367|    0.010969|    0.164582|           0|
|          14|    0.160791|    0.005187|    0.113760|    0.011720|    0.162947|           0|
|          15|    0.159742|    0.004035|    0.138748|    0.010260|    0.162074|           0|
|          16|    0.159290|    0.005774|    0.108266|    0.010400|    0.161728|           0|
|          17|    0.158593|    0.004977|    0.152142|    0.010603|    0.161272|           0|
|          18|    0.157437|    0.003660|    0.193303|    0.010510|    0.160299|           0|
|          19|    0.156642|    0.007722|    0.430859|    0.010069|    0.160145|           0|
|          20|    0.155954|    0.001908|    0.121039|    0.010041|    0.159066|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    0.155824|    0.001645|    0.025159|    0.010557|    0.158992|           0|
|          22|    0.155486|    0.003232|    0.119915|    0.010829|    0.158731|           0|
|          23|    0.155398|    0.006845|    0.083105|    0.031846|    0.158963|           1|
|          24|    0.155261|    0.004374|    0.065660|    0.010762|    0.158816|           2|
|          25|    0.154955|    0.002505|    0.264106|    0.011437|    0.158687|           0|
|          26|    0.154799|    0.002183|    0.040876|    0.010903|    0.158538|           0|
|          27|    0.154466|    0.002881|    0.219478|    0.012409|    0.158033|           0|
|          28|    0.154250|    0.002724|    0.196190|    0.012062|    0.157980|           0|
|          29|    0.153918|    0.002189|    0.135392|    0.009862|    0.157605|           0|
|          30|    0.153707|    0.001449|    0.111574|    0.010851|    0.157456|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          31|    0.153214|    0.002050|    0.528628|    0.010212|    0.157379|           0|
|          32|    0.152671|    0.002542|    0.488640|    0.010013|    0.156687|           0|
|          33|    0.152303|    0.004554|    0.223206|    0.010334|    0.156778|           1|
|          34|    0.152093|    0.002856|    0.121284|    0.010188|    0.156639|           0|
|          35|    0.151871|    0.003145|    0.135909|    0.010108|    0.156446|           0|
|          36|    0.151741|    0.001441|    0.225342|    0.010452|    0.156517|           1|
|          37|    0.151626|    0.002500|    0.396782|    0.010487|    0.156429|           0|
|          38|    0.151488|    0.005053|    0.148248|    0.010312|    0.156201|           0|
|          39|    0.151250|    0.002552|    0.110278|    0.009895|    0.155968|           0|
|          40|    0.151013|    0.002506|    0.123906|    0.010837|    0.155812|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          41|    0.150821|    0.002536|    0.109515|    0.010627|    0.155742|           0|
|          42|    0.150509|    0.001418|    0.223296|    0.010561|    0.155648|           0|
|          43|    0.150340|    0.003437|    0.185351|    0.010131|    0.155435|           0|
|          44|    0.150280|    0.004746|    0.115075|    0.010432|    0.155797|           1|
|          45|    0.150194|    0.002758|    0.082143|    0.010068|    0.155575|           2|
|          46|    0.150061|    0.001122|    0.094288|    0.011405|    0.155334|           0|
|          47|    0.149978|    0.001259|    0.127677|    0.010628|    0.155305|           0|
|          48|    0.149879|    0.001523|    0.107816|    0.011331|    0.155044|           0|
|          49|    0.149749|    0.004572|    0.156869|    0.009953|    0.155043|           0|
|          50|    0.149617|    0.000965|    0.186502|    0.009702|    0.155106|           1|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          51|    0.149579|    0.001302|    0.062687|    0.010743|    0.155160|           2|
|          52|    0.149519|    0.001407|    0.086000|    0.010335|    0.155216|           3|
|          53|    0.149405|    0.001243|    0.147530|    0.009753|    0.155309|           4|
|          54|    0.149203|    0.002749|    0.186920|    0.010267|    0.155337|           5|
|          55|    0.149040|    0.001217|    0.310011|    0.012444|    0.155460|           6|
|==========================================================================================|

Используйте информацию в TrainingHistory свойство объекта Mdl проверять итерацию, которая соответствует минимальной потере перекрестной энтропии валидации. Финал возвратил модель Mdl модель, обученная в этой итерации.

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 49

Оцените эффективность набора тестов

Оцените эффективность обученного классификатора Mdl на наборе тестов adulttest при помощи predict, loss, margin, и edge функции объекта.

Найдите предсказанные метки и классификационные оценки для наблюдений в наборе тестов.

[labels,Scores] = predict(Mdl,adulttest);

Создайте матрицу беспорядка из результатов набора тестов. Диагональные элементы указывают на количество правильно классифицированных экземпляров данного класса. Недиагональными элементами являются экземпляры неправильно классифицированных наблюдений.

confusionchart(adulttest.salary,labels)

Вычислите точность классификации наборов тестов.

error = loss(Mdl,adulttest,"salary");
accuracy = (1-error)*100
accuracy = 85.1306

Классификатор нейронной сети правильно классифицирует приблизительно 85% наблюдений набора тестов.

Вычислите поля классификации наборов тестов для обученной нейронной сети. Отобразите гистограмму полей.

Поля классификации являются различием между классификационной оценкой для истинного класса и классификационной оценкой для ложного класса. Поскольку классификаторы нейронной сети возвращают баллы, которые являются апостериорными вероятностями, поля классификации близко к 1 указывают на уверенные классификации, и отрицательные граничные значения указывают на misclassifications.

m = margin(Mdl,adulttest,"salary");
histogram(m)

Используйте ребро классификации или среднее значение полей классификации, чтобы оценить общую производительность классификатора.

meanMargin = edge(Mdl,adulttest,"salary")
meanMargin = 0.5983

В качестве альтернативы вычислите взвешенное ребро классификации при помощи весов наблюдения.

weightedMeanMargin = edge(Mdl,adulttest,"salary", ...
    "Weight","fnlwgt")
weightedMeanMargin = 0.6072

Визуализируйте предсказанные метки и классификационные оценки с помощью графиков рассеивания, в которых каждая точка соответствует наблюдению. Используйте предсказанные метки, чтобы выбрать цвет точек и использовать максимальные баллы, чтобы установить прозрачность точек. Точки с меньшей прозрачностью помечены большим доверием.

Во-первых, найдите максимальную классификационную оценку для каждого наблюдения набора тестов.

maxScores = max(Scores,[],2);

Создайте график рассеивания, сравнивающий максимальные баллы через номер часов работы в неделю и уровень образования. Поскольку образовательная переменная является категориальной, случайным образом дрожите (или растяните), точки по y-измерению.

Измените палитру так, чтобы максимальные баллы, соответствующие зарплатам, которые меньше чем или равны 50 000$ в год, появились столь же синие, и максимальные баллы, соответствующие зарплатам, больше, чем 50 000$ в год появляются как красный.

scatter(adulttest.hours_per_week,adulttest.education,[],labels, ...
    "filled","MarkerFaceAlpha","flat","AlphaData",maxScores, ...
    "YJitter","rand");
xlabel("Number of Work Hours Per Week")
ylabel("Education")

Mdl.ClassNames
ans = 2×1 categorical
     <=50K 
     >50K 

colors = lines(2)
colors = 2×3

         0    0.4470    0.7410
    0.8500    0.3250    0.0980

colormap(colors);

Цвета в графике рассеивания указывают, что в целом нейронная сеть предсказывает, что у людей с более низкими уровнями образования (12-й класс или ниже) есть зарплаты, меньше чем или равные 50 000$ в год. Прозрачность некоторых точек в нижнем правом углу графика указывает, что модель менее уверена в этом предсказании для людей, которые работают много часов в неделю (60 часов или больше).

Смотрите также

| | | | | | |