В этом примере показано, как использовать fitcauto
автоматически попробовать выбор типов модели классификации с различными гиперзначениями параметров, учитывая учебный предиктор и данные об ответе. Функция использует Байесовую оптимизацию, чтобы выбрать модели и их гиперзначения параметров, и вычисляет ошибку классификации перекрестных проверок для каждой модели. После того, как оптимизация завершена, fitcauto
возвращает модель, обученную на целом наборе данных, который, как ожидают, лучше всего классифицирует новые данные. Проверяйте производительность модели на тестовых данных.
Этот пример использует 1 994 данных о переписи, хранимых в census1994.mat
. Набор данных состоит из демографической информации из Бюро переписи США, которое может использоваться, чтобы предсказать, передает ли индивидуум 50 000$ в год.
Загрузите выборочные данные census1994
, который содержит обучающие данные adultdata
и тестовые данные adulttest
. Предварительно просмотрите первые несколько строк обучающего набора данных.
load census1994
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______
39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K
49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K
52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец salary
показывает, есть ли у человека зарплата, меньше чем или равная 50 000$ в год или больше, чем 50 000$ в год.
Тестовые данные adulttest
содержит два ненужных пустых класса. Удалите их при помощи removecats
функция.
adulttest.salary = removecats(adulttest.salary);
Используйте fitcauto
автоматически найти соответствующий классификатор для данных в adultdata
. Установите веса наблюдения и задайте, чтобы запустить Байесовую оптимизацию параллельно, которая требует Parallel Computing Toolbox™. Из-за невоспроизводимости синхронизации параллели, параллельная Байесова оптимизация не обязательно приводит к восстанавливаемым результатам.
Из-за сложности оптимизации этот процесс может занять время, специально для больших наборов данных. Для этого набора данных ожидайте, что оптимизация запустится параллельно на порядке нескольких минут.
options = struct('UseParallel',true); [mdl,results] = fitcauto(adultdata,'salary','Weights','fnlwgt', ... 'HyperparameterOptimizationOptions',options);
Warning: It is recommended that you first standardize all numeric predictors when optimizing the Naive Bayes 'Width' parameter. Ignore this warning if you have done that.
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6). Copying objective function to workers... Done copying objective function to workers.
|==============================================================================================================================| | Iter | Active | Eval | Objective | Objective | BestSoFar | BestSoFar | Learner | Hyperparameter: Value | | | workers | result | | runtime | (observed) | (estim.) | | | |==============================================================================================================================| | 1 | 6 | Best | 0.23856 | 2.8026 | 0.23856 | 0.23856 | tree | MinLeafSize: 10889 |
| 2 | 5 | Best | 0.14224 | 4.5188 | 0.14224 | 0.18415 | tree | MinLeafSize: 32 | | 3 | 5 | Accept | 0.17296 | 7.8309 | 0.14224 | 0.18415 | tree | MinLeafSize: 2 |
| 4 | 5 | Accept | 0.18228 | 0.78592 | 0.14224 | 0.18684 | tree | MinLeafSize: 4990 |
| 5 | 6 | Accept | 0.14922 | 51.758 | 0.14224 | 0.18684 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.41891 |
| 6 | 6 | Accept | 0.16237 | 0.91128 | 0.14224 | 0.15553 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 7 | 6 | Accept | 0.16237 | 0.60173 | 0.14224 | 0.15825 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 8 | 6 | Accept | 0.23856 | 51.033 | 0.14224 | 0.15825 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 225 | | | | | | | | | | MinLeafSize: 3144 |
| 9 | 6 | Accept | 0.23856 | 51.936 | 0.14224 | 0.15825 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 225 | | | | | | | | | | MinLeafSize: 3144 |
| 10 | 5 | Accept | 0.15776 | 58.772 | 0.14154 | 0.15825 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 227 | | | | | | | | | | MinLeafSize: 1 | | 11 | 5 | Best | 0.14154 | 3.8388 | 0.14154 | 0.15825 | tree | MinLeafSize: 40 |
| 12 | 5 | Accept | 0.14181 | 2.3888 | 0.14154 | 0.15825 | tree | MinLeafSize: 103 |
| 13 | 6 | Accept | 0.18202 | 79.14 | 0.14154 | 0.15825 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 239 | | | | | | | | | | MinLeafSize: 161 |
| 14 | 6 | Accept | 0.14574 | 2.2797 | 0.14154 | 0.15825 | tree | MinLeafSize: 229 |
| 15 | 6 | Accept | 0.16049 | 103.63 | 0.14154 | 0.15842 | nb | DistributionNames: kernel | | | | | | | | | | Width: 9.0144 |
| 16 | 6 | Accept | 0.14921 | 50.031 | 0.14154 | 0.15653 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.40347 |
| 17 | 6 | Accept | 0.14921 | 50.812 | 0.14154 | 0.15492 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.40347 |
| 18 | 6 | Accept | 0.17777 | 64.079 | 0.14154 | 0.15492 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 207 | | | | | | | | | | MinLeafSize: 1259 |
| 19 | 6 | Accept | 0.1436 | 3.1919 | 0.14154 | 0.15492 | tree | MinLeafSize: 23 |
| 20 | 6 | Accept | 0.1562 | 73.923 | 0.14154 | 0.15492 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 288 | | | | | | | | | | MinLeafSize: 1774 |
|==============================================================================================================================| | Iter | Active | Eval | Objective | Objective | BestSoFar | BestSoFar | Learner | Hyperparameter: Value | | | workers | result | | runtime | (observed) | (estim.) | | | |==============================================================================================================================| | 21 | 6 | Accept | 0.15849 | 57.26 | 0.14154 | 0.15492 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 212 | | | | | | | | | | MinLeafSize: 376 |
| 22 | 6 | Accept | 0.23856 | 0.43942 | 0.14154 | 0.15492 | tree | MinLeafSize: 8801 |
| 23 | 6 | Best | 0.14084 | 3.3112 | 0.14084 | 0.15492 | tree | MinLeafSize: 39 |
| 24 | 6 | Accept | 0.16237 | 0.51626 | 0.14084 | 0.15679 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 25 | 5 | Accept | 0.19575 | 104.45 | 0.14084 | 0.162 | nb | DistributionNames: kernel | | | | | | | | | | Width: 8643.5 | | 26 | 5 | Accept | 0.16237 | 0.41186 | 0.14084 | 0.162 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 27 | 5 | Accept | 0.16237 | 0.36151 | 0.14084 | 0.16142 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 28 | 5 | Accept | 0.23856 | 46.881 | 0.14084 | 0.16142 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 250 | | | | | | | | | | MinLeafSize: 6630 |
| 29 | 6 | Accept | 0.18915 | 103.22 | 0.14084 | 0.1633 | nb | DistributionNames: kernel | | | | | | | | | | Width: 3883.6 |
| 30 | 6 | Accept | 0.15658 | 64.045 | 0.14084 | 0.1633 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 252 | | | | | | | | | | MinLeafSize: 4 |
| 31 | 5 | Accept | 0.1993 | 101.82 | 0.14084 | 0.1633 | nb | DistributionNames: kernel | | | | | | | | | | Width: 94267 | | 32 | 5 | Accept | 0.15658 | 64.865 | 0.14084 | 0.1633 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 252 | | | | | | | | | | MinLeafSize: 4 |
| 33 | 6 | Accept | 0.18163 | 72.481 | 0.14084 | 0.1633 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 208 | | | | | | | | | | MinLeafSize: 1 |
| 34 | 4 | Accept | 0.1769 | 75.521 | 0.14084 | 0.14895 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 223 | | | | | | | | | | MinLeafSize: 743 | | 35 | 4 | Accept | 0.23856 | 0.63417 | 0.14084 | 0.14895 | tree | MinLeafSize: 9982 | | 36 | 4 | Accept | 0.23856 | 0.43146 | 0.14084 | 0.14895 | tree | MinLeafSize: 9982 |
| 37 | 6 | Accept | 0.23856 | 41.173 | 0.14084 | 0.14895 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 235 | | | | | | | | | | MinLeafSize: 6922 |
| 38 | 6 | Accept | 0.16851 | 1.435 | 0.14084 | 0.15129 | tree | MinLeafSize: 791 |
| 39 | 6 | Accept | 0.23856 | 41.277 | 0.14084 | 0.15129 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 220 | | | | | | | | | | MinLeafSize: 4265 |
| 40 | 6 | Accept | 0.14119 | 2.2295 | 0.14084 | 0.14988 | tree | MinLeafSize: 81 |
|==============================================================================================================================| | Iter | Active | Eval | Objective | Objective | BestSoFar | BestSoFar | Learner | Hyperparameter: Value | | | workers | result | | runtime | (observed) | (estim.) | | | |==============================================================================================================================| | 41 | 6 | Accept | 0.14498 | 2.0529 | 0.14084 | 0.1464 | tree | MinLeafSize: 150 |
| 42 | 6 | Accept | 0.14478 | 1.9602 | 0.14084 | 0.14668 | tree | MinLeafSize: 147 |
| 43 | 6 | Accept | 0.15619 | 67.066 | 0.14084 | 0.14668 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 272 | | | | | | | | | | MinLeafSize: 997 |
| 44 | 4 | Accept | 0.15619 | 67.5 | 0.14084 | 0.14485 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 272 | | | | | | | | | | MinLeafSize: 997 | | 45 | 4 | Accept | 0.15619 | 67.628 | 0.14084 | 0.14485 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 272 | | | | | | | | | | MinLeafSize: 997 | | 46 | 4 | Accept | 0.14398 | 1.8842 | 0.14084 | 0.14485 | tree | MinLeafSize: 158 |
| 47 | 4 | Accept | 0.14536 | 1.6207 | 0.14084 | 0.14544 | tree | MinLeafSize: 234 |
| 48 | 4 | Accept | 0.14258 | 1.7687 | 0.14084 | 0.14503 | tree | MinLeafSize: 191 |
| 49 | 4 | Accept | 0.15204 | 1.439 | 0.14084 | 0.14459 | tree | MinLeafSize: 392 |
| 50 | 3 | Accept | 0.18177 | 94.382 | 0.14084 | 0.14455 | nb | DistributionNames: kernel | | | | | | | | | | Width: 489.63 | | 51 | 3 | Accept | 0.14258 | 1.6702 | 0.14084 | 0.14455 | tree | MinLeafSize: 191 |
| 52 | 5 | Accept | 0.14619 | 1.4677 | 0.14084 | 0.14455 | tree | MinLeafSize: 259 | | 53 | 5 | Accept | 0.15612 | 65.7 | 0.14084 | 0.14455 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 287 | | | | | | | | | | MinLeafSize: 30 |
| 54 | 3 | Accept | 0.14415 | 1.6508 | 0.14084 | 0.1439 | tree | MinLeafSize: 210 | | 55 | 3 | Accept | 0.16237 | 0.34225 | 0.14084 | 0.1439 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN | | 56 | 3 | Accept | 0.16614 | 4.2516 | 0.14084 | 0.1439 | tree | MinLeafSize: 3 |
| 57 | 6 | Accept | 0.14429 | 1.6329 | 0.14084 | 0.14371 | tree | MinLeafSize: 153 |
| 58 | 6 | Accept | 0.16237 | 0.5541 | 0.14084 | 0.14371 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 59 | 6 | Accept | 0.15964 | 1.5347 | 0.14084 | 0.14359 | tree | MinLeafSize: 523 |
| 60 | 6 | Accept | 0.14131 | 2.1059 | 0.14084 | 0.14355 | tree | MinLeafSize: 82 |
|==============================================================================================================================| | Iter | Active | Eval | Objective | Objective | BestSoFar | BestSoFar | Learner | Hyperparameter: Value | | | workers | result | | runtime | (observed) | (estim.) | | | |==============================================================================================================================| | 61 | 6 | Accept | 0.17797 | 0.86666 | 0.14084 | 0.14367 | tree | MinLeafSize: 2174 |
| 62 | 6 | Accept | 0.15641 | 66.356 | 0.14084 | 0.14367 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 266 | | | | | | | | | | MinLeafSize: 1999 |
| 63 | 6 | Accept | 0.15624 | 66.722 | 0.14084 | 0.14367 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 274 | | | | | | | | | | MinLeafSize: 2310 |
| 64 | 6 | Accept | 0.19608 | 99.5 | 0.14084 | 0.14367 | nb | DistributionNames: kernel | | | | | | | | | | Width: 9309.1 |
| 65 | 6 | Accept | 0.18014 | 96.152 | 0.14084 | 0.14367 | nb | DistributionNames: kernel | | | | | | | | | | Width: 250.76 |
| 66 | 5 | Accept | 0.18177 | 98.355 | 0.14084 | 0.14367 | nb | DistributionNames: kernel | | | | | | | | | | Width: 525.18 | | 67 | 5 | Accept | 0.15625 | 67.832 | 0.14084 | 0.14367 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 276 | | | | | | | | | | MinLeafSize: 1823 |
| 68 | 6 | Accept | 0.14279 | 2.4652 | 0.14084 | 0.14412 | tree | MinLeafSize: 57 |
| 69 | 6 | Accept | 0.14279 | 2.4317 | 0.14084 | 0.14243 | tree | MinLeafSize: 57 |
| 70 | 6 | Accept | 0.15768 | 1.4289 | 0.14084 | 0.14233 | tree | MinLeafSize: 449 |
| 71 | 6 | Accept | 0.17003 | 1.1501 | 0.14084 | 0.14227 | tree | MinLeafSize: 979 |
| 72 | 6 | Accept | 0.15644 | 66.103 | 0.14084 | 0.14227 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 270 | | | | | | | | | | MinLeafSize: 2549 |
| 73 | 6 | Accept | 0.14122 | 3.0678 | 0.14084 | 0.14194 | tree | MinLeafSize: 46 |
| 74 | 6 | Accept | 0.15605 | 74.508 | 0.14084 | 0.14194 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 298 | | | | | | | | | | MinLeafSize: 2738 |
| 75 | 5 | Accept | 0.1562 | 71.717 | 0.14084 | 0.14215 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 288 | | | | | | | | | | MinLeafSize: 2376 | | 76 | 5 | Accept | 0.14122 | 2.6641 | 0.14084 | 0.14215 | tree | MinLeafSize: 46 |
| 77 | 5 | Accept | 0.15619 | 68.39 | 0.14084 | 0.14215 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 272 | | | | | | | | | | MinLeafSize: 2521 |
| 78 | 6 | Accept | 0.15612 | 73.254 | 0.14084 | 0.14215 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 287 | | | | | | | | | | MinLeafSize: 2597 |
| 79 | 6 | Accept | 0.1422 | 2.6309 | 0.14084 | 0.14188 | tree | MinLeafSize: 52 |
| 80 | 6 | Accept | 0.14125 | 2.6448 | 0.14084 | 0.1419 | tree | MinLeafSize: 44 |
|==============================================================================================================================| | Iter | Active | Eval | Objective | Objective | BestSoFar | BestSoFar | Learner | Hyperparameter: Value | | | workers | result | | runtime | (observed) | (estim.) | | | |==============================================================================================================================| | 81 | 6 | Accept | 0.1411 | 2.7027 | 0.14084 | 0.14192 | tree | MinLeafSize: 42 |
| 82 | 6 | Accept | 0.14103 | 2.7308 | 0.14084 | 0.14193 | tree | MinLeafSize: 41 |
| 83 | 6 | Accept | 0.1561 | 75.741 | 0.14084 | 0.14193 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 299 | | | | | | | | | | MinLeafSize: 2618 |
| 84 | 6 | Accept | 0.14154 | 3.1338 | 0.14084 | 0.14151 | tree | MinLeafSize: 40 |
| 85 | 6 | Accept | 0.15605 | 75.343 | 0.14084 | 0.14151 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 298 | | | | | | | | | | MinLeafSize: 2825 |
| 86 | 5 | Accept | 0.15626 | 69.271 | 0.14084 | 0.14138 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 277 | | | | | | | | | | MinLeafSize: 2944 | | 87 | 5 | Accept | 0.14103 | 2.7855 | 0.14084 | 0.14138 | tree | MinLeafSize: 41 |
| 88 | 5 | Accept | 0.15622 | 69.815 | 0.14084 | 0.14138 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 280 | | | | | | | | | | MinLeafSize: 2868 |
| 89 | 5 | Accept | 0.15611 | 71.872 | 0.14084 | 0.14138 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 290 | | | | | | | | | | MinLeafSize: 2869 |
| 90 | 6 | Accept | 0.15618 | 68.306 | 0.14084 | 0.14138 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 282 | | | | | | | | | | MinLeafSize: 2862 |
__________________________________________________________ Optimization completed. MaxObjectiveEvaluations of 90 reached. Total function evaluations: 90 Total elapsed time: 672.4994 seconds. Total objective function evaluation time: 3145.1896 Best observed feasible point is a tree model with: MinLeafSize: 39 Observed objective function value = 0.14084 Estimated objective function value = 0.14539 Function evaluation time = 3.3112 Best estimated feasible point (according to models) is a tree model with: MinLeafSize: 44 Estimated objective function value = 0.14138 Estimated function evaluation time = 2.666
Итоговая модель возвращена fitcauto
соответствует лучшей предполагаемой допустимой точке. Прежде, чем возвратить модель, функция переобучает его с помощью целых обучающих данных (adultdata
), перечисленный Learner
(или модель) тип и отображенные гиперзначения параметров.
Оцените производительность возвращенной модели mdl
на наборе тестов adulttest
при помощи матрицы беспорядка и кривой рабочей характеристики получателя (ROC).
Найдите предсказанные метки и значения баллов для набора тестов.
[labels,scores] = predict(mdl,adulttest);
Создайте матрицу беспорядка из результатов набора тестов. Диагональные элементы указывают на количество правильно классифицированных экземпляров данного класса. Недиагональными элементами являются экземпляры неправильно классифицированных наблюдений.
confusionchart(adulttest.salary,labels)
Вычислите точность классификации наборов тестов. accuracy
процент правильно классифицированных наблюдений набора тестов.
accuracy = (1-loss(mdl,adulttest,'salary'))*100
accuracy = 85.5621
Построить кривую ROC для значений баллов, соответствующих метке '<=50K'
, найдите столбец scores
это соответствует той метке. Порядок следования столбцов scores
совпадает с порядком классов в обученной модели.
mdl.ClassNames
ans = 2×1 categorical
<=50K
>50K
Поскольку '<=50K'
перечислен сначала, первый столбец scores
соответствует той метке.
Постройте кривую ROC и вычислите область под кривой (AUC). Кривая ROC показывает истинный положительный уровень по сравнению с ложным положительным уровнем для различных порогов классификатора выход. Для совершенного классификатора, истинный положительный уровень которого всегда 1 независимо от порога, AUC = 1. Для бинарного классификатора, который случайным образом присваивает наблюдения классам, AUC = 0.5. Большое значение AUC (близко к 1) указывает на хорошую производительность классификатора.
[X,Y,~,AUC] = perfcurve(adulttest.salary,scores(:,1),'<=50K'); plot(X,Y) title('ROC Curve') xlabel('False Positive Rate') ylabel('True Positive Rate')
AUC
AUC = 0.8859
На основе точности и значений AUC, классификатор выполняет хорошо на тестовых данных.
BayesianOptimization
| confusionchart
| fitcauto
| perfcurve