В этом примере показано, как использовать fitcauto
чтобы автоматически попробовать выбор типов классификационной модели с различными значениями гиперзначений параметров, учитывая обучающие данные предиктора и отклика. Функция использует байесовскую оптимизацию, чтобы выбрать модели и их значения гиперзначений параметров, и вычисляет ошибку классификации перекрестной валидации для каждой модели. После завершения оптимизации fitcauto
возвращает модель, обученную на целом наборе данных, которая, как ожидается, лучше всего классифицирует новые данные. Проверьте производительность модели по тестовым данным.
Этот пример использует данные переписи 1994 года, хранящиеся в census1994.mat
. Набор данных состоит из демографической информации Бюро переписи населения США, которая может использоваться для прогнозирования того, зарабатывает ли индивидуум более 50 000 долларов в год.
Загрузите выборочные данные census1994
, который содержит обучающие данные adultdata
и тестовые данные adulttest
. Предварительный просмотр первых нескольких строк обучающих данных набора.
load census1994
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______
39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K
49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K
52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец salary
показывает, имеет ли человек зарплату менее 50 000 долл. США в год или более 50 000 долл. США в год.
Использование fitcauto
чтобы автоматически найти подходящий классификатор для данных в adultdata
. Установите веса наблюдений и задайте, чтобы запустить байесовскую оптимизацию параллельно, что требует Toolbox™ Parallel Computing. Из-за непродуктивности параллельной синхронизации параллельная байесовская оптимизация не обязательно приводит к воспроизводимым результатам.
Из-за сложности оптимизации этот процесс может занять некоторое время, особенно для больших наборов данных. По умолчанию fitcauto
предоставляет график оптимизации и итерационное отображение результатов оптимизации. Для получения дополнительной информации о том, как интерпретировать эти результаты, смотрите Подробное отображение.
options = struct('UseParallel',true); [mdl,results] = fitcauto(adultdata,'salary','Weights','fnlwgt', ... 'HyperparameterOptimizationOptions',options);
Warning: It is recommended that you first standardize all numeric predictors when optimizing the Naive Bayes 'Width' parameter. Ignore this warning if you have done that.
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6). Copying objective function to workers... Done copying objective function to workers.
Learner types to explore: ensemble, nb, tree Total iterations (MaxObjectiveEvaluations): 90 Total time (MaxTime): Inf
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 1 | 6 | Best | 0.16287 | 4.3468 | 0.16287 | 0.16287 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 2 | 5 | Accept | 0.14389 | 6.1049 | 0.14162 | 0.14287 | tree | MinLeafSize: 21 | | 3 | 5 | Best | 0.14162 | 5.6195 | 0.14162 | 0.14287 | tree | MinLeafSize: 50 |
| 4 | 6 | Accept | 0.15626 | 74.156 | 0.14162 | 0.14287 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 283 | | | | | | | | | | MinLeafSize: 7330 |
| 5 | 6 | Accept | 0.15603 | 77.293 | 0.14162 | 0.14287 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 295 | | | | | | | | | | MinLeafSize: 3 |
| 6 | 6 | Accept | 0.16027 | 5.6224 | 0.14162 | 0.14842 | tree | MinLeafSize: 5 |
| 7 | 6 | Accept | 0.17343 | 8.6209 | 0.14162 | 0.15576 | tree | MinLeafSize: 2 |
| 8 | 6 | Accept | 0.15103 | 4.8867 | 0.14162 | 0.15392 | tree | MinLeafSize: 8 |
| 9 | 6 | Accept | 0.17642 | 1.1808 | 0.14162 | 0.15449 | tree | MinLeafSize: 1663 |
| 10 | 6 | Accept | 0.15927 | 5.0734 | 0.14162 | 0.15343 | tree | MinLeafSize: 6 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 11 | 6 | Accept | 0.17009 | 1.6504 | 0.14162 | 0.15533 | tree | MinLeafSize: 1272 |
| 12 | 6 | Accept | 0.17869 | 1.0308 | 0.14162 | 0.154 | tree | MinLeafSize: 2744 |
| 13 | 6 | Accept | 0.17961 | 116.64 | 0.14162 | 0.154 | nb | DistributionNames: kernel | | | | | | | | | | Width: 274.23 |
| 14 | 5 | Accept | 0.15128 | 118.36 | 0.14162 | 0.15383 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 241 | | | | | | | | | | MinLeafSize: 23 | | 15 | 5 | Accept | 0.15177 | 115.42 | 0.14162 | 0.15383 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 235 | | | | | | | | | | MinLeafSize: 40 |
| 16 | 5 | Accept | 0.15116 | 115.49 | 0.14162 | 0.15326 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 235 | | | | | | | | | | MinLeafSize: 40 |
| 17 | 6 | Accept | 0.14887 | 63.412 | 0.14162 | 0.15326 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.56014 |
| 18 | 6 | Accept | 0.17869 | 0.89318 | 0.14162 | 0.15219 | tree | MinLeafSize: 2712 |
| 19 | 6 | Accept | 0.17676 | 59.781 | 0.14162 | 0.15219 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 208 | | | | | | | | | | MinLeafSize: 4208 |
| 20 | 6 | Accept | 0.15086 | 81.42 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 2.4778 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 21 | 6 | Accept | 0.16287 | 0.64656 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 22 | 6 | Accept | 0.14943 | 75.578 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 1.6195 |
| 23 | 6 | Accept | 0.16287 | 0.49489 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 24 | 6 | Accept | 0.14926 | 68.642 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 1.2371 |
| 25 | 6 | Accept | 0.16287 | 0.5124 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 26 | 6 | Accept | 0.15609 | 58.267 | 0.14162 | 0.15219 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 247 | | | | | | | | | | MinLeafSize: 1 |
| 27 | 6 | Accept | 0.16287 | 0.93385 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 28 | 6 | Accept | 0.15554 | 4.3668 | 0.14162 | 0.15067 | tree | MinLeafSize: 7 |
| 29 | 6 | Accept | 0.15087 | 127.01 | 0.14162 | 0.15067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 289 | | | | | | | | | | MinLeafSize: 9 |
| 30 | 6 | Accept | 0.15142 | 127.39 | 0.14162 | 0.15067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 289 | | | | | | | | | | MinLeafSize: 9 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 31 | 6 | Accept | 0.14177 | 2.6306 | 0.14162 | 0.14707 | tree | MinLeafSize: 116 |
| 32 | 6 | Accept | 0.16287 | 1.1225 | 0.14162 | 0.14707 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 33 | 6 | Accept | 0.15737 | 56.258 | 0.14162 | 0.14707 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 233 | | | | | | | | | | MinLeafSize: 5308 |
| 34 | 6 | Accept | 0.15158 | 97.559 | 0.14162 | 0.14707 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 214 | | | | | | | | | | MinLeafSize: 133 |
| 35 | 6 | Accept | 0.1719 | 96.392 | 0.14162 | 0.14707 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 223 | | | | | | | | | | MinLeafSize: 1526 |
| 36 | 6 | Accept | 0.16287 | 0.42054 | 0.14162 | 0.14707 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 37 | 6 | Accept | 0.14441 | 3.5932 | 0.14162 | 0.14598 | tree | MinLeafSize: 18 |
| 38 | 6 | Accept | 0.16287 | 0.34693 | 0.14162 | 0.14598 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 39 | 6 | Accept | 0.14432 | 3.4661 | 0.14162 | 0.145 | tree | MinLeafSize: 19 |
| 40 | 6 | Accept | 0.14291 | 2.3121 | 0.14162 | 0.14321 | tree | MinLeafSize: 231 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 41 | 6 | Accept | 0.15278 | 96.086 | 0.14162 | 0.14321 | nb | DistributionNames: kernel | | | | | | | | | | Width: 3.5668 |
| 42 | 6 | Accept | 0.15068 | 1.9847 | 0.14162 | 0.14348 | tree | MinLeafSize: 412 |
| 43 | 6 | Accept | 0.14705 | 2.1122 | 0.14162 | 0.14343 | tree | MinLeafSize: 305 |
| 44 | 6 | Accept | 0.14186 | 2.3835 | 0.14162 | 0.14309 | tree | MinLeafSize: 168 |
| 45 | 6 | Accept | 0.16209 | 1.9821 | 0.14162 | 0.14302 | tree | MinLeafSize: 573 |
| 46 | 5 | Accept | 0.15783 | 53.627 | 0.14135 | 0.14271 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 211 | | | | | | | | | | MinLeafSize: 125 | | 47 | 5 | Best | 0.14135 | 3.1329 | 0.14135 | 0.14271 | tree | MinLeafSize: 63 |
| 48 | 4 | Accept | 0.15637 | 63.578 | 0.14135 | 0.14236 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 252 | | | | | | | | | | MinLeafSize: 485 | | 49 | 4 | Accept | 0.1448 | 2.1012 | 0.14135 | 0.14236 | tree | MinLeafSize: 263 |
| 50 | 3 | Accept | 0.1513 | 114.35 | 0.14135 | 0.14224 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 253 | | | | | | | | | | MinLeafSize: 13 | |===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 51 | 3 | Accept | 0.14271 | 2.2737 | 0.14135 | 0.14224 | tree | MinLeafSize: 133 |
| 52 | 6 | Accept | 0.14349 | 1.9707 | 0.14135 | 0.14224 | tree | MinLeafSize: 199 |
| 53 | 3 | Accept | 0.15337 | 1.6887 | 0.14135 | 0.14235 | tree | MinLeafSize: 441 | | 54 | 3 | Accept | 0.17869 | 1.049 | 0.14135 | 0.14235 | tree | MinLeafSize: 1821 | | 55 | 3 | Accept | 0.1785 | 0.9639 | 0.14135 | 0.14235 | tree | MinLeafSize: 3523 | | 56 | 3 | Accept | 0.18062 | 0.63917 | 0.14135 | 0.14235 | tree | MinLeafSize: 4359 |
| 57 | 6 | Accept | 0.14673 | 3.2067 | 0.14135 | 0.14207 | tree | MinLeafSize: 12 |
| 58 | 6 | Accept | 0.14238 | 2.3081 | 0.14135 | 0.14215 | tree | MinLeafSize: 177 |
| 59 | 5 | Accept | 0.16352 | 125.94 | 0.14135 | 0.1419 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 297 | | | | | | | | | | MinLeafSize: 823 | | 60 | 5 | Accept | 0.14162 | 2.849 | 0.14135 | 0.1419 | tree | MinLeafSize: 50 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 61 | 5 | Best | 0.14113 | 2.6499 | 0.14113 | 0.14173 | tree | MinLeafSize: 83 |
| 62 | 5 | Accept | 0.14178 | 2.9853 | 0.14113 | 0.14153 | tree | MinLeafSize: 40 |
| 63 | 5 | Accept | 0.14157 | 2.8701 | 0.14113 | 0.14153 | tree | MinLeafSize: 42 |
| 64 | 5 | Accept | 0.15886 | 1.7188 | 0.14113 | 0.14161 | tree | MinLeafSize: 532 |
| 65 | 5 | Accept | 0.14529 | 3.6593 | 0.14113 | 0.14151 | tree | MinLeafSize: 14 |
| 66 | 4 | Accept | 0.23856 | 41.472 | 0.14113 | 0.14151 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 209 | | | | | | | | | | MinLeafSize: 8676 | | 67 | 4 | Accept | 0.14702 | 4.0559 | 0.14113 | 0.14151 | tree | MinLeafSize: 10 |
| 68 | 4 | Best | 0.14058 | 2.8472 | 0.14058 | 0.14148 | tree | MinLeafSize: 30 |
| 69 | 4 | Accept | 0.14168 | 2.1868 | 0.14058 | 0.14143 | tree | MinLeafSize: 112 |
| 70 | 4 | Accept | 0.14072 | 2.9698 | 0.14058 | 0.14144 | tree | MinLeafSize: 28 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 71 | 4 | Accept | 0.14117 | 2.8824 | 0.14058 | 0.14114 | tree | MinLeafSize: 29 |
| 72 | 4 | Best | 0.14046 | 2.8853 | 0.14046 | 0.14112 | tree | MinLeafSize: 25 |
| 73 | 4 | Accept | 0.14184 | 2.8532 | 0.14046 | 0.14103 | tree | MinLeafSize: 24 |
| 74 | 4 | Accept | 0.14112 | 2.7998 | 0.14046 | 0.14102 | tree | MinLeafSize: 33 |
| 75 | 4 | Accept | 0.14331 | 3.0835 | 0.14046 | 0.141 | tree | MinLeafSize: 23 |
| 76 | 4 | Accept | 0.14089 | 2.9637 | 0.14046 | 0.14086 | tree | MinLeafSize: 31 |
| 77 | 4 | Accept | 0.14046 | 3.0017 | 0.14046 | 0.14083 | tree | MinLeafSize: 25 |
| 78 | 3 | Accept | 0.15093 | 91.952 | 0.14046 | 0.14085 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 222 | | | | | | | | | | MinLeafSize: 27 | | 79 | 3 | Accept | 0.14046 | 2.9993 | 0.14046 | 0.14085 | tree | MinLeafSize: 25 |
| 80 | 6 | Accept | 0.14046 | 2.7739 | 0.14046 | 0.14073 | tree | MinLeafSize: 25 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 81 | 2 | Accept | 0.18178 | 101.13 | 0.14046 | 0.14068 | nb | DistributionNames: kernel | | | | | | | | | | Width: 868.86 | | 82 | 2 | Accept | 0.14184 | 3.2218 | 0.14046 | 0.14068 | tree | MinLeafSize: 24 | | 83 | 2 | Accept | 0.17807 | 0.82685 | 0.14046 | 0.14068 | tree | MinLeafSize: 3874 | | 84 | 2 | Accept | 0.15989 | 1.8729 | 0.14046 | 0.14068 | tree | MinLeafSize: 540 | | 85 | 2 | Accept | 0.15103 | 3.8835 | 0.14046 | 0.14068 | tree | MinLeafSize: 8 |
| 86 | 6 | Accept | 0.14046 | 2.5909 | 0.14046 | 0.14067 | tree | MinLeafSize: 25 |
| 87 | 6 | Accept | 0.14331 | 3.5433 | 0.14046 | 0.14067 | tree | MinLeafSize: 23 |
| 88 | 6 | Accept | 0.23856 | 47.904 | 0.14046 | 0.14067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 258 | | | | | | | | | | MinLeafSize: 12543 |
| 89 | 6 | Accept | 0.14914 | 59.665 | 0.14046 | 0.14067 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.37688 |
| 90 | 6 | Accept | 0.15604 | 68.731 | 0.14046 | 0.14067 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 262 | | | | | | | | | | MinLeafSize: 2 |
__________________________________________________________ Optimization completed. Total iterations: 90 Total elapsed time: 577.1419 seconds Total time for training and validation: 2558.1542 seconds Best observed learner is a tree model with: MinLeafSize: 25 Observed validation loss: 0.14046 Time for training and validation: 2.8853 seconds Best estimated learner (returned model) is a tree model with: MinLeafSize: 25 Estimated validation loss: 0.14067 Estimated time for training and validation: 2.8824 seconds Documentation for fitcauto display
Конечная модель, возвращенная fitcauto
соответствует лучшему оцененному ученику. Перед возвращением модели функция переобучает ее, используя все обучающие данные (adultdata
), перечисленные Learner
(или модель) тип и отображенные значения гиперзначений параметров.
Оцените эффективность возвращенной модели mdl
на тестовом аппарате adulttest
при помощи матрицы неточностей и приемника кривой рабочей характеристики (ROC).
Найдите предсказанные метки и значения баллов для тестового набора.
[labels,scores] = predict(mdl,adulttest);
Создайте матрицу неточностей из результатов тестового набора. Диагональные элементы указывают количество правильно классифицированных образцов данного класса. Не-диагональные элементы являются образцами неправильно классифицированных наблюдений.
confusionchart(adulttest.salary,labels)
Вычислите точность классификации тестового набора. accuracy
- процент правильно классифицированных наблюдений тестового набора.
accuracy = (1-loss(mdl,adulttest,'salary'))*100
accuracy = 85.1513
Чтобы построить график кривой ROC для значений баллов, соответствующих значению метки '<=50K'
, найти столбец scores
что соответствует этой метке. Порядок столбцов scores
соответствует порядку классов в обученной модели.
mdl.ClassNames
ans = 2×1 categorical
<=50K
>50K
Потому что '<=50K'
первый столбец scores
соответствует этой метке.
Постройте график кривой ROC и вычислите площадь под кривой (AUC). Кривая ROC показывает истинную положительную скорость по сравнению с ложноположительной частотой для различных порогов выхода классификатора. Для идеального классификатора, истинная положительная скорость которого всегда равна 1 независимо от порога, AUC = 1. Для двоичного классификатора, который случайным образом присваивает наблюдения классам, AUC = 0,5. Большое значение AUC (близкое к 1) указывает на хорошую эффективность классификатора.
[X,Y,~,AUC] = perfcurve(adulttest.salary,scores(:,1),'<=50K'); plot(X,Y) title('ROC Curve') xlabel('False Positive Rate') ylabel('True Positive Rate')
AUC
AUC = 0.8947
Основываясь на точности и значениях AUC, классификатор хорошо работает с тестовыми данными.
BayesianOptimization
| confusionchart
| fitcauto
| perfcurve