В этом примере показано, как использовать fitcauto
автоматически попробовать выбор типов модели классификации с различными гиперзначениями параметров, учитывая учебный предиктор и данные об ответе. Функция использует Байесовую оптимизацию, чтобы выбрать модели и их гиперзначения параметров, и вычисляет ошибку классификации перекрестных проверок для каждой модели. После того, как оптимизация завершена, fitcauto
возвращает модель, обученную на целом наборе данных, который, как ожидают, лучше всего классифицирует новые данные. Проверяйте производительность модели на тестовых данных.
Этот пример использует 1 994 данных о переписи, хранимых в census1994.mat
. Набор данных состоит из демографической информации из Бюро переписи США, которое может использоваться, чтобы предсказать, передает ли индивидуум 50 000$ в год.
Загрузите выборочные данные census1994
, который содержит обучающие данные adultdata
и тестовые данные adulttest
. Предварительно просмотрите первые несколько строк обучающего набора данных.
load census1994
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______
39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K
49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K
52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Каждая строка содержит демографическую информацию для одного взрослого. Последний столбец salary
показывает, есть ли у человека зарплата, меньше чем или равная 50 000$ в год или больше, чем 50 000$ в год.
Используйте fitcauto
автоматически найти соответствующий классификатор для данных в adultdata
. Установите веса наблюдения и задайте, чтобы запустить Байесовую оптимизацию параллельно, которая требует Parallel Computing Toolbox™. Из-за невоспроизводимости синхронизации параллели, параллельная Байесова оптимизация не обязательно приводит к восстанавливаемым результатам.
Из-за сложности оптимизации этот процесс может занять время, особенно для больших наборов данных. По умолчанию, fitcauto
предоставляет график оптимизации и итеративное отображение результатов оптимизации. Для получения дополнительной информации о том, как интерпретировать эти результаты, смотрите Многословное Отображение.
options = struct('UseParallel',true); [mdl,results] = fitcauto(adultdata,'salary','Weights','fnlwgt', ... 'HyperparameterOptimizationOptions',options);
Warning: It is recommended that you first standardize all numeric predictors when optimizing the Naive Bayes 'Width' parameter. Ignore this warning if you have done that.
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6). Copying objective function to workers... Done copying objective function to workers.
Learner types to explore: ensemble, nb, tree Total iterations (MaxObjectiveEvaluations): 90 Total time (MaxTime): Inf
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 1 | 6 | Best | 0.16287 | 4.3468 | 0.16287 | 0.16287 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 2 | 5 | Accept | 0.14389 | 6.1049 | 0.14162 | 0.14287 | tree | MinLeafSize: 21 | | 3 | 5 | Best | 0.14162 | 5.6195 | 0.14162 | 0.14287 | tree | MinLeafSize: 50 |
| 4 | 6 | Accept | 0.15626 | 74.156 | 0.14162 | 0.14287 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 283 | | | | | | | | | | MinLeafSize: 7330 |
| 5 | 6 | Accept | 0.15603 | 77.293 | 0.14162 | 0.14287 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 295 | | | | | | | | | | MinLeafSize: 3 |
| 6 | 6 | Accept | 0.16027 | 5.6224 | 0.14162 | 0.14842 | tree | MinLeafSize: 5 |
| 7 | 6 | Accept | 0.17343 | 8.6209 | 0.14162 | 0.15576 | tree | MinLeafSize: 2 |
| 8 | 6 | Accept | 0.15103 | 4.8867 | 0.14162 | 0.15392 | tree | MinLeafSize: 8 |
| 9 | 6 | Accept | 0.17642 | 1.1808 | 0.14162 | 0.15449 | tree | MinLeafSize: 1663 |
| 10 | 6 | Accept | 0.15927 | 5.0734 | 0.14162 | 0.15343 | tree | MinLeafSize: 6 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 11 | 6 | Accept | 0.17009 | 1.6504 | 0.14162 | 0.15533 | tree | MinLeafSize: 1272 |
| 12 | 6 | Accept | 0.17869 | 1.0308 | 0.14162 | 0.154 | tree | MinLeafSize: 2744 |
| 13 | 6 | Accept | 0.17961 | 116.64 | 0.14162 | 0.154 | nb | DistributionNames: kernel | | | | | | | | | | Width: 274.23 |
| 14 | 5 | Accept | 0.15128 | 118.36 | 0.14162 | 0.15383 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 241 | | | | | | | | | | MinLeafSize: 23 | | 15 | 5 | Accept | 0.15177 | 115.42 | 0.14162 | 0.15383 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 235 | | | | | | | | | | MinLeafSize: 40 |
| 16 | 5 | Accept | 0.15116 | 115.49 | 0.14162 | 0.15326 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 235 | | | | | | | | | | MinLeafSize: 40 |
| 17 | 6 | Accept | 0.14887 | 63.412 | 0.14162 | 0.15326 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.56014 |
| 18 | 6 | Accept | 0.17869 | 0.89318 | 0.14162 | 0.15219 | tree | MinLeafSize: 2712 |
| 19 | 6 | Accept | 0.17676 | 59.781 | 0.14162 | 0.15219 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 208 | | | | | | | | | | MinLeafSize: 4208 |
| 20 | 6 | Accept | 0.15086 | 81.42 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 2.4778 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 21 | 6 | Accept | 0.16287 | 0.64656 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 22 | 6 | Accept | 0.14943 | 75.578 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 1.6195 |
| 23 | 6 | Accept | 0.16287 | 0.49489 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 24 | 6 | Accept | 0.14926 | 68.642 | 0.14162 | 0.15219 | nb | DistributionNames: kernel | | | | | | | | | | Width: 1.2371 |
| 25 | 6 | Accept | 0.16287 | 0.5124 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 26 | 6 | Accept | 0.15609 | 58.267 | 0.14162 | 0.15219 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 247 | | | | | | | | | | MinLeafSize: 1 |
| 27 | 6 | Accept | 0.16287 | 0.93385 | 0.14162 | 0.15219 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 28 | 6 | Accept | 0.15554 | 4.3668 | 0.14162 | 0.15067 | tree | MinLeafSize: 7 |
| 29 | 6 | Accept | 0.15087 | 127.01 | 0.14162 | 0.15067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 289 | | | | | | | | | | MinLeafSize: 9 |
| 30 | 6 | Accept | 0.15142 | 127.39 | 0.14162 | 0.15067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 289 | | | | | | | | | | MinLeafSize: 9 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 31 | 6 | Accept | 0.14177 | 2.6306 | 0.14162 | 0.14707 | tree | MinLeafSize: 116 |
| 32 | 6 | Accept | 0.16287 | 1.1225 | 0.14162 | 0.14707 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 33 | 6 | Accept | 0.15737 | 56.258 | 0.14162 | 0.14707 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 233 | | | | | | | | | | MinLeafSize: 5308 |
| 34 | 6 | Accept | 0.15158 | 97.559 | 0.14162 | 0.14707 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 214 | | | | | | | | | | MinLeafSize: 133 |
| 35 | 6 | Accept | 0.1719 | 96.392 | 0.14162 | 0.14707 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 223 | | | | | | | | | | MinLeafSize: 1526 |
| 36 | 6 | Accept | 0.16287 | 0.42054 | 0.14162 | 0.14707 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 37 | 6 | Accept | 0.14441 | 3.5932 | 0.14162 | 0.14598 | tree | MinLeafSize: 18 |
| 38 | 6 | Accept | 0.16287 | 0.34693 | 0.14162 | 0.14598 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 39 | 6 | Accept | 0.14432 | 3.4661 | 0.14162 | 0.145 | tree | MinLeafSize: 19 |
| 40 | 6 | Accept | 0.14291 | 2.3121 | 0.14162 | 0.14321 | tree | MinLeafSize: 231 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 41 | 6 | Accept | 0.15278 | 96.086 | 0.14162 | 0.14321 | nb | DistributionNames: kernel | | | | | | | | | | Width: 3.5668 |
| 42 | 6 | Accept | 0.15068 | 1.9847 | 0.14162 | 0.14348 | tree | MinLeafSize: 412 |
| 43 | 6 | Accept | 0.14705 | 2.1122 | 0.14162 | 0.14343 | tree | MinLeafSize: 305 |
| 44 | 6 | Accept | 0.14186 | 2.3835 | 0.14162 | 0.14309 | tree | MinLeafSize: 168 |
| 45 | 6 | Accept | 0.16209 | 1.9821 | 0.14162 | 0.14302 | tree | MinLeafSize: 573 |
| 46 | 5 | Accept | 0.15783 | 53.627 | 0.14135 | 0.14271 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 211 | | | | | | | | | | MinLeafSize: 125 | | 47 | 5 | Best | 0.14135 | 3.1329 | 0.14135 | 0.14271 | tree | MinLeafSize: 63 |
| 48 | 4 | Accept | 0.15637 | 63.578 | 0.14135 | 0.14236 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 252 | | | | | | | | | | MinLeafSize: 485 | | 49 | 4 | Accept | 0.1448 | 2.1012 | 0.14135 | 0.14236 | tree | MinLeafSize: 263 |
| 50 | 3 | Accept | 0.1513 | 114.35 | 0.14135 | 0.14224 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 253 | | | | | | | | | | MinLeafSize: 13 | |===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 51 | 3 | Accept | 0.14271 | 2.2737 | 0.14135 | 0.14224 | tree | MinLeafSize: 133 |
| 52 | 6 | Accept | 0.14349 | 1.9707 | 0.14135 | 0.14224 | tree | MinLeafSize: 199 |
| 53 | 3 | Accept | 0.15337 | 1.6887 | 0.14135 | 0.14235 | tree | MinLeafSize: 441 | | 54 | 3 | Accept | 0.17869 | 1.049 | 0.14135 | 0.14235 | tree | MinLeafSize: 1821 | | 55 | 3 | Accept | 0.1785 | 0.9639 | 0.14135 | 0.14235 | tree | MinLeafSize: 3523 | | 56 | 3 | Accept | 0.18062 | 0.63917 | 0.14135 | 0.14235 | tree | MinLeafSize: 4359 |
| 57 | 6 | Accept | 0.14673 | 3.2067 | 0.14135 | 0.14207 | tree | MinLeafSize: 12 |
| 58 | 6 | Accept | 0.14238 | 2.3081 | 0.14135 | 0.14215 | tree | MinLeafSize: 177 |
| 59 | 5 | Accept | 0.16352 | 125.94 | 0.14135 | 0.1419 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 297 | | | | | | | | | | MinLeafSize: 823 | | 60 | 5 | Accept | 0.14162 | 2.849 | 0.14135 | 0.1419 | tree | MinLeafSize: 50 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 61 | 5 | Best | 0.14113 | 2.6499 | 0.14113 | 0.14173 | tree | MinLeafSize: 83 |
| 62 | 5 | Accept | 0.14178 | 2.9853 | 0.14113 | 0.14153 | tree | MinLeafSize: 40 |
| 63 | 5 | Accept | 0.14157 | 2.8701 | 0.14113 | 0.14153 | tree | MinLeafSize: 42 |
| 64 | 5 | Accept | 0.15886 | 1.7188 | 0.14113 | 0.14161 | tree | MinLeafSize: 532 |
| 65 | 5 | Accept | 0.14529 | 3.6593 | 0.14113 | 0.14151 | tree | MinLeafSize: 14 |
| 66 | 4 | Accept | 0.23856 | 41.472 | 0.14113 | 0.14151 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 209 | | | | | | | | | | MinLeafSize: 8676 | | 67 | 4 | Accept | 0.14702 | 4.0559 | 0.14113 | 0.14151 | tree | MinLeafSize: 10 |
| 68 | 4 | Best | 0.14058 | 2.8472 | 0.14058 | 0.14148 | tree | MinLeafSize: 30 |
| 69 | 4 | Accept | 0.14168 | 2.1868 | 0.14058 | 0.14143 | tree | MinLeafSize: 112 |
| 70 | 4 | Accept | 0.14072 | 2.9698 | 0.14058 | 0.14144 | tree | MinLeafSize: 28 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 71 | 4 | Accept | 0.14117 | 2.8824 | 0.14058 | 0.14114 | tree | MinLeafSize: 29 |
| 72 | 4 | Best | 0.14046 | 2.8853 | 0.14046 | 0.14112 | tree | MinLeafSize: 25 |
| 73 | 4 | Accept | 0.14184 | 2.8532 | 0.14046 | 0.14103 | tree | MinLeafSize: 24 |
| 74 | 4 | Accept | 0.14112 | 2.7998 | 0.14046 | 0.14102 | tree | MinLeafSize: 33 |
| 75 | 4 | Accept | 0.14331 | 3.0835 | 0.14046 | 0.141 | tree | MinLeafSize: 23 |
| 76 | 4 | Accept | 0.14089 | 2.9637 | 0.14046 | 0.14086 | tree | MinLeafSize: 31 |
| 77 | 4 | Accept | 0.14046 | 3.0017 | 0.14046 | 0.14083 | tree | MinLeafSize: 25 |
| 78 | 3 | Accept | 0.15093 | 91.952 | 0.14046 | 0.14085 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 222 | | | | | | | | | | MinLeafSize: 27 | | 79 | 3 | Accept | 0.14046 | 2.9993 | 0.14046 | 0.14085 | tree | MinLeafSize: 25 |
| 80 | 6 | Accept | 0.14046 | 2.7739 | 0.14046 | 0.14073 | tree | MinLeafSize: 25 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 81 | 2 | Accept | 0.18178 | 101.13 | 0.14046 | 0.14068 | nb | DistributionNames: kernel | | | | | | | | | | Width: 868.86 | | 82 | 2 | Accept | 0.14184 | 3.2218 | 0.14046 | 0.14068 | tree | MinLeafSize: 24 | | 83 | 2 | Accept | 0.17807 | 0.82685 | 0.14046 | 0.14068 | tree | MinLeafSize: 3874 | | 84 | 2 | Accept | 0.15989 | 1.8729 | 0.14046 | 0.14068 | tree | MinLeafSize: 540 | | 85 | 2 | Accept | 0.15103 | 3.8835 | 0.14046 | 0.14068 | tree | MinLeafSize: 8 |
| 86 | 6 | Accept | 0.14046 | 2.5909 | 0.14046 | 0.14067 | tree | MinLeafSize: 25 |
| 87 | 6 | Accept | 0.14331 | 3.5433 | 0.14046 | 0.14067 | tree | MinLeafSize: 23 |
| 88 | 6 | Accept | 0.23856 | 47.904 | 0.14046 | 0.14067 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 258 | | | | | | | | | | MinLeafSize: 12543 |
| 89 | 6 | Accept | 0.14914 | 59.665 | 0.14046 | 0.14067 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.37688 |
| 90 | 6 | Accept | 0.15604 | 68.731 | 0.14046 | 0.14067 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 262 | | | | | | | | | | MinLeafSize: 2 |
__________________________________________________________ Optimization completed. Total iterations: 90 Total elapsed time: 577.1419 seconds Total time for training and validation: 2558.1542 seconds Best observed learner is a tree model with: MinLeafSize: 25 Observed validation loss: 0.14046 Time for training and validation: 2.8853 seconds Best estimated learner (returned model) is a tree model with: MinLeafSize: 25 Estimated validation loss: 0.14067 Estimated time for training and validation: 2.8824 seconds Documentation for fitcauto display
Итоговая модель возвращена fitcauto
соответствует лучшему предполагаемому ученику. Прежде, чем возвратить модель, функция переобучает его с помощью целых обучающих данных (adultdata
), перечисленный Learner
(или модель) тип и отображенные гиперзначения параметров.
Оцените эффективность возвращенной модели mdl
на наборе тестов adulttest
при помощи матрицы беспорядка и кривой рабочей характеристики приемника (ROC).
Найдите предсказанные метки и значения баллов для набора тестов.
[labels,scores] = predict(mdl,adulttest);
Создайте матрицу беспорядка из результатов набора тестов. Диагональные элементы указывают на количество правильно классифицированных экземпляров данного класса. Недиагональными элементами являются экземпляры неправильно классифицированных наблюдений.
confusionchart(adulttest.salary,labels)
Вычислите точность классификации наборов тестов. accuracy
процент правильно классифицированных наблюдений набора тестов.
accuracy = (1-loss(mdl,adulttest,'salary'))*100
accuracy = 85.1513
Построить кривую ROC для значений баллов, соответствующих метке '<=50K'
, найдите столбец scores
это соответствует той метке. Порядок следования столбцов scores
совпадает с порядком классов в обученной модели.
mdl.ClassNames
ans = 2×1 categorical
<=50K
>50K
Поскольку '<=50K'
перечислен сначала, первый столбец scores
соответствует той метке.
Постройте кривую ROC и вычислите область под кривой (AUC). Кривая ROC показывает истинный положительный уровень по сравнению с ложным положительным уровнем для различных порогов классификатора выход. Для совершенного классификатора, истинный положительный уровень которого всегда 1 независимо от порога, AUC = 1. Для бинарного классификатора, который случайным образом присваивает наблюдения классам, AUC = 0.5. Большое значение AUC (близко к 1) указывает на хорошую эффективность классификатора.
[X,Y,~,AUC] = perfcurve(adulttest.salary,scores(:,1),'<=50K'); plot(X,Y) title('ROC Curve') xlabel('False Positive Rate') ylabel('True Positive Rate')
AUC
AUC = 0.8947
На основе точности и значений AUC, классификатор выполняет хорошо на тестовых данных.
BayesianOptimization
| confusionchart
| fitcauto
| perfcurve