В этом примере показано, как использовать fitcauto автоматически попробовать выбрать типы классификационных моделей с различными значениями гиперпараметров, учитывая обучающий предиктор и данные ответа. Функция использует байесовскую оптимизацию для выбора моделей и их значений гиперпараметров и вычисляет ошибку классификации перекрестной проверки для каждой модели. После завершения оптимизации fitcauto возвращает модель, обученную всему набору данных, которая, как ожидается, наилучшим образом классифицирует новые данные. Проверьте производительность модели на тестовых данных.
В этом примере используются данные переписи 1994 года, хранящиеся в census1994.mat. Набор данных состоит из демографической информации Бюро переписи населения США, которая может использоваться для прогнозирования того, зарабатывает ли человек более 50 000 долларов в год.
Загрузка данных образца census1994, который содержит данные обучения adultdata и данные испытаний adulttest. Просмотрите первые несколько строк набора данных обучения.
load census1994
head(adultdata)ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______
39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K
49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K
52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец salary показывает, имеет ли человек зарплату меньше или равную 50 000 долл. США в год или больше 50 000 долл. США в год.
Использовать fitcauto для автоматического поиска соответствующего классификатора для данных в adultdata. Задайте веса наблюдения и задайте параллельное выполнение байесовской оптимизации, что требует Toolbox™ параллельных вычислений. Из-за непродуктивности параллельной синхронизации параллельная байесовская оптимизация не обязательно дает воспроизводимые результаты.
Из-за сложности оптимизации этот процесс может занять некоторое время, особенно для больших наборов данных. По умолчанию fitcauto предоставляет график оптимизации и итеративное отображение результатов оптимизации. Дополнительные сведения о том, как интерпретировать эти результаты, см. в разделе Подробный просмотр.
options = struct('UseParallel',true); [mdl,results] = fitcauto(adultdata,'salary','Weights','fnlwgt', ... 'HyperparameterOptimizationOptions',options);
Warning: It is recommended that you first standardize all numeric predictors when optimizing the Naive Bayes 'Width' parameter. Ignore this warning if you have done that.
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6). Copying objective function to workers... Done copying objective function to workers.
Learner types to explore: ensemble, nb, tree Total iterations (MaxObjectiveEvaluations): 90 Total time (MaxTime): Inf
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 1 | 6 | Best | 0.16287 | 4.3468 | 0.16287 | 0.16287 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 2 | 5 | Accept | 0.14389 | 6.1049 | 0.14162 | 0.14287 | tree | MinLeafSize: 21 | | 3 | 5 | Best | 0.14162 | 5.6195 | 0.14162 | 0.14287 | tree | MinLeafSize: 50 |
| 4 | 6 | Accept | 0.15626 | 74.156 | 0.14162 | 0.14287 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 283 | | | | | | | | | | MinLeafSize: 7330 |
| 5 | 6 | Accept | 0.15603 | 77.293 | 0.14162 | 0.14287 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 295 | | | | | | | | | | MinLeafSize: 3 |
| 6 | 6 | Accept | 0.16027 | 5.6224 | 0.14162 | 0.14842 | tree | MinLeafSize: 5 |
| 7 | 6 | Accept | 0.17343 | 8.6209 | 0.14162 | 0.15576 | tree | MinLeafSize: 2 |
| 8 | 6 | Accept | 0.15103 | 4.8867 | 0.14162 | 0.15392 | tree | MinLeafSize: 8 |
| 9 | 6 | Accept | 0.17642 | 1.1808 | 0.14162 | 0.15449 | tree | MinLeafSize: 1663 |
| 10 | 6 | Accept | 0.15927 | 5.0734 | 0.14162 | 0.15343 | tree | MinLeafSize: 6 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 11 | 6 | Accept | 0.17009 | 1.6504 | 0.14162 | 0.15533 | tree | MinLeafSize: 1272 |
| 12 | 6 | Accept | 0.17869 | 1.0308 | 0.14162 | 0.154 | tree | MinLeafSize: 2744 |
| 13 | 6 | Accept | 0.17961 | 116.64 | 0.14162 | 0.154 | nb | DistributionNames: kernel | | | | | | | | | | Width: 274.23 |
| 14 | 5 | Accept | 0.15128 | 118.36 | 0.14162 | 0.15383 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 241 | | | | | | | | | | MinLeafSize: 23 | | 15 | 5 | Accept | 0.15177 | 115.42 | 0.14162 | 0.15383 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 235 | | | | | | | | | | MinLeafSize: 40 |
| 16 | 5 | Accept | 0.15116 | 115.49 | 0.14162 | 0.15326 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 235 | | | | | | | | | | MinLeafSize: 40 |
| 17 | 6 | Accept | 0.14887 | 63.412 | 0.14162 | 0.15326 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.56014 |
| 18 | 6 | Accept | 0.17869 | 0.89318 | 0.14162 | 0.15219 | tree | MinLeafSize: 2712 |
| 19 | 6 | Accept | 0.17676 | 59.781 | 0.14162 | 0.15219 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 208 | | | | | | | | | | MinLeafSize: 4208 |
| 20 | 6 | Accept | 0.15086 | 81.42 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 2.4778 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 21 | 6 | Accept | 0.16287 | 0.64656 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 22 | 6 | Accept | 0.14943 | 75.578 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 1.6195 |
| 23 | 6 | Accept | 0.16287 | 0.49489 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 24 | 6 | Accept | 0.14926 | 68.642 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 1.2371 |
| 25 | 6 | Accept | 0.16287 | 0.5124 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 26 | 6 | Accept | 0.15609 | 58.267 | 0.14162 | 0.15219 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 247 | | | | | | | | | | MinLeafSize: 1 |
| 27 | 6 | Accept | 0.16287 | 0.93385 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 28 | 6 | Accept | 0.15554 | 4.3668 | 0.14162 | 0.15067 | tree | MinLeafSize: 7 |
| 29 | 6 | Accept | 0.15087 | 127.01 | 0.14162 | 0.15067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 289 | | | | | | | | | | MinLeafSize: 9 |
| 30 | 6 | Accept | 0.15142 | 127.39 | 0.14162 | 0.15067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 289 | | | | | | | | | | MinLeafSize: 9 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 31 | 6 | Accept | 0.14177 | 2.6306 | 0.14162 | 0.14707 | tree | MinLeafSize: 116 |
| 32 | 6 | Accept | 0.16287 | 1.1225 | 0.14162 | 0.14707 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 33 | 6 | Accept | 0.15737 | 56.258 | 0.14162 | 0.14707 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 233 | | | | | | | | | | MinLeafSize: 5308 |
| 34 | 6 | Accept | 0.15158 | 97.559 | 0.14162 | 0.14707 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 214 | | | | | | | | | | MinLeafSize: 133 |
| 35 | 6 | Accept | 0.1719 | 96.392 | 0.14162 | 0.14707 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 223 | | | | | | | | | | MinLeafSize: 1526 |
| 36 | 6 | Accept | 0.16287 | 0.42054 | 0.14162 | 0.14707 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 37 | 6 | Accept | 0.14441 | 3.5932 | 0.14162 | 0.14598 | tree | MinLeafSize: 18 |
| 38 | 6 | Accept | 0.16287 | 0.34693 | 0.14162 | 0.14598 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 39 | 6 | Accept | 0.14432 | 3.4661 | 0.14162 | 0.145 | tree | MinLeafSize: 19 |
| 40 | 6 | Accept | 0.14291 | 2.3121 | 0.14162 | 0.14321 | tree | MinLeafSize: 231 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 41 | 6 | Accept | 0.15278 | 96.086 | 0.14162 | 0.14321 | nb | DistributionNames: kernel | | | | | | | | | | Width: 3.5668 |
| 42 | 6 | Accept | 0.15068 | 1.9847 | 0.14162 | 0.14348 | tree | MinLeafSize: 412 |
| 43 | 6 | Accept | 0.14705 | 2.1122 | 0.14162 | 0.14343 | tree | MinLeafSize: 305 |
| 44 | 6 | Accept | 0.14186 | 2.3835 | 0.14162 | 0.14309 | tree | MinLeafSize: 168 |
| 45 | 6 | Accept | 0.16209 | 1.9821 | 0.14162 | 0.14302 | tree | MinLeafSize: 573 |
| 46 | 5 | Accept | 0.15783 | 53.627 | 0.14135 | 0.14271 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 211 | | | | | | | | | | MinLeafSize: 125 | | 47 | 5 | Best | 0.14135 | 3.1329 | 0.14135 | 0.14271 | tree | MinLeafSize: 63 |
| 48 | 4 | Accept | 0.15637 | 63.578 | 0.14135 | 0.14236 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 252 | | | | | | | | | | MinLeafSize: 485 | | 49 | 4 | Accept | 0.1448 | 2.1012 | 0.14135 | 0.14236 | tree | MinLeafSize: 263 |
| 50 | 3 | Accept | 0.1513 | 114.35 | 0.14135 | 0.14224 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 253 | | | | | | | | | | MinLeafSize: 13 | |===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 51 | 3 | Accept | 0.14271 | 2.2737 | 0.14135 | 0.14224 | tree | MinLeafSize: 133 |
| 52 | 6 | Accept | 0.14349 | 1.9707 | 0.14135 | 0.14224 | tree | MinLeafSize: 199 |
| 53 | 3 | Accept | 0.15337 | 1.6887 | 0.14135 | 0.14235 | tree | MinLeafSize: 441 | | 54 | 3 | Accept | 0.17869 | 1.049 | 0.14135 | 0.14235 | tree | MinLeafSize: 1821 | | 55 | 3 | Accept | 0.1785 | 0.9639 | 0.14135 | 0.14235 | tree | MinLeafSize: 3523 | | 56 | 3 | Accept | 0.18062 | 0.63917 | 0.14135 | 0.14235 | tree | MinLeafSize: 4359 |
| 57 | 6 | Accept | 0.14673 | 3.2067 | 0.14135 | 0.14207 | tree | MinLeafSize: 12 |
| 58 | 6 | Accept | 0.14238 | 2.3081 | 0.14135 | 0.14215 | tree | MinLeafSize: 177 |
| 59 | 5 | Accept | 0.16352 | 125.94 | 0.14135 | 0.1419 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 297 | | | | | | | | | | MinLeafSize: 823 | | 60 | 5 | Accept | 0.14162 | 2.849 | 0.14135 | 0.1419 | tree | MinLeafSize: 50 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 61 | 5 | Best | 0.14113 | 2.6499 | 0.14113 | 0.14173 | tree | MinLeafSize: 83 |
| 62 | 5 | Accept | 0.14178 | 2.9853 | 0.14113 | 0.14153 | tree | MinLeafSize: 40 |
| 63 | 5 | Accept | 0.14157 | 2.8701 | 0.14113 | 0.14153 | tree | MinLeafSize: 42 |
| 64 | 5 | Accept | 0.15886 | 1.7188 | 0.14113 | 0.14161 | tree | MinLeafSize: 532 |
| 65 | 5 | Accept | 0.14529 | 3.6593 | 0.14113 | 0.14151 | tree | MinLeafSize: 14 |
| 66 | 4 | Accept | 0.23856 | 41.472 | 0.14113 | 0.14151 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 209 | | | | | | | | | | MinLeafSize: 8676 | | 67 | 4 | Accept | 0.14702 | 4.0559 | 0.14113 | 0.14151 | tree | MinLeafSize: 10 |
| 68 | 4 | Best | 0.14058 | 2.8472 | 0.14058 | 0.14148 | tree | MinLeafSize: 30 |
| 69 | 4 | Accept | 0.14168 | 2.1868 | 0.14058 | 0.14143 | tree | MinLeafSize: 112 |
| 70 | 4 | Accept | 0.14072 | 2.9698 | 0.14058 | 0.14144 | tree | MinLeafSize: 28 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 71 | 4 | Accept | 0.14117 | 2.8824 | 0.14058 | 0.14114 | tree | MinLeafSize: 29 |
| 72 | 4 | Best | 0.14046 | 2.8853 | 0.14046 | 0.14112 | tree | MinLeafSize: 25 |
| 73 | 4 | Accept | 0.14184 | 2.8532 | 0.14046 | 0.14103 | tree | MinLeafSize: 24 |
| 74 | 4 | Accept | 0.14112 | 2.7998 | 0.14046 | 0.14102 | tree | MinLeafSize: 33 |
| 75 | 4 | Accept | 0.14331 | 3.0835 | 0.14046 | 0.141 | tree | MinLeafSize: 23 |
| 76 | 4 | Accept | 0.14089 | 2.9637 | 0.14046 | 0.14086 | tree | MinLeafSize: 31 |
| 77 | 4 | Accept | 0.14046 | 3.0017 | 0.14046 | 0.14083 | tree | MinLeafSize: 25 |
| 78 | 3 | Accept | 0.15093 | 91.952 | 0.14046 | 0.14085 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 222 | | | | | | | | | | MinLeafSize: 27 | | 79 | 3 | Accept | 0.14046 | 2.9993 | 0.14046 | 0.14085 | tree | MinLeafSize: 25 |
| 80 | 6 | Accept | 0.14046 | 2.7739 | 0.14046 | 0.14073 | tree | MinLeafSize: 25 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 81 | 2 | Accept | 0.18178 | 101.13 | 0.14046 | 0.14068 | nb | DistributionNames: kernel | | | | | | | | | | Width: 868.86 | | 82 | 2 | Accept | 0.14184 | 3.2218 | 0.14046 | 0.14068 | tree | MinLeafSize: 24 | | 83 | 2 | Accept | 0.17807 | 0.82685 | 0.14046 | 0.14068 | tree | MinLeafSize: 3874 | | 84 | 2 | Accept | 0.15989 | 1.8729 | 0.14046 | 0.14068 | tree | MinLeafSize: 540 | | 85 | 2 | Accept | 0.15103 | 3.8835 | 0.14046 | 0.14068 | tree | MinLeafSize: 8 |
| 86 | 6 | Accept | 0.14046 | 2.5909 | 0.14046 | 0.14067 | tree | MinLeafSize: 25 |
| 87 | 6 | Accept | 0.14331 | 3.5433 | 0.14046 | 0.14067 | tree | MinLeafSize: 23 |
| 88 | 6 | Accept | 0.23856 | 47.904 | 0.14046 | 0.14067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 258 | | | | | | | | | | MinLeafSize: 12543 |
| 89 | 6 | Accept | 0.14914 | 59.665 | 0.14046 | 0.14067 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.37688 |
| 90 | 6 | Accept | 0.15604 | 68.731 | 0.14046 | 0.14067 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 262 | | | | | | | | | | MinLeafSize: 2 |

__________________________________________________________ Optimization completed. Total iterations: 90 Total elapsed time: 577.1419 seconds Total time for training and validation: 2558.1542 seconds Best observed learner is a tree model with: MinLeafSize: 25 Observed validation loss: 0.14046 Time for training and validation: 2.8853 seconds Best estimated learner (returned model) is a tree model with: MinLeafSize: 25 Estimated validation loss: 0.14067 Estimated time for training and validation: 2.8824 seconds Documentation for fitcauto display
Последняя модель, возвращенная fitcauto соответствует лучшему предполагаемому ученику. Перед возвращением модели функция переобучает её, используя все данные обучения (adultdata), перечисленные Learner тип (или модель) и отображаемые значения гиперпараметров.
Оценка производительности возвращаемой модели mdl на контрольном аппарате adulttest используя матрицу путаницы и кривую рабочих характеристик приемника (ROC).
Найдите прогнозируемые метки и значения баллов для тестового набора.
[labels,scores] = predict(mdl,adulttest);
Создайте матрицу путаницы из результатов набора тестов. Диагональные элементы указывают количество правильно классифицированных экземпляров данного класса. Внедиагональные элементы являются экземплярами неправильно классифицированных наблюдений.
confusionchart(adulttest.salary,labels)

Вычислите точность классификации тестового набора. accuracy - процент правильно классифицированных наблюдений тестового набора.
accuracy = (1-loss(mdl,adulttest,'salary'))*100accuracy = 85.1513
Построение кривой ROC для значений оценки, соответствующих метке '<=50K', найдите столбец scores соответствует этой метке. Порядок столбцов scores соответствует порядку занятий в обучаемой модели.
mdl.ClassNames
ans = 2×1 categorical
<=50K
>50K
Поскольку '<=50K' указан первым, первый столбец scores соответствует этой метке.
Постройте график ROC-кривой и вычислите площадь под кривой (AUC). Кривая ROC показывает истинную положительную скорость по сравнению с ложной положительной скоростью для различных пороговых значений выхода классификатора. Для идеального классификатора, чья истинная положительная скорость всегда равна 1 независимо от порога, AUC = 1. Для двоичного классификатора, который случайным образом присваивает наблюдения классам, AUC = 0,5. Большое значение AUC (близкое к 1) указывает на хорошую производительность классификатора.
[X,Y,~,AUC] = perfcurve(adulttest.salary,scores(:,1),'<=50K'); plot(X,Y) title('ROC Curve') xlabel('False Positive Rate') ylabel('True Positive Rate')

AUC
AUC = 0.8947
На основе точности и значений AUC классификатор хорошо работает с тестовыми данными.
BayesianOptimization | confusionchart | fitcauto | perfcurve