Оптимизируйте повышенный ансамбль регрессии

Этот пример показывает, как оптимизировать гиперпараметры повышенного ансамбля регрессии. Оптимизация минимизирует потерю перекрестной проверки модели.

Проблема состоит в том, чтобы смоделировать эффективность в милях на галлон автомобиля, на основе его ускорения, объема двигателя, лошадиной силы и веса. Загрузите данные carsmall, которые содержат эти и другие предикторы.

load carsmall
X = [Acceleration Displacement Horsepower Weight];
Y = MPG;

Соответствуйте ансамблю регрессии к данным с помощью алгоритма LSBoost, и с помощью суррогатных разделений. Оптимизируйте получившуюся модель путем варьирования количества изучения циклов, максимального количества суррогатных разделений и изучить уровня. Кроме того, позвольте оптимизации повторно делить перекрестную проверку между каждой итерацией.

Для воспроизводимости, набор случайный seed и использование функция приобретения 'expected-improvement-plus'.

rng default
Mdl = fitrensemble(X,Y,...
    'Method','LSBoost',...
    'Learner',templateTree('Surrogate','on'),...
    'OptimizeHyperparameters',{'NumLearningCycles','MaxNumSplits','LearnRate'},...
    'HyperparameterOptimizationOptions',struct('Repartition',true,...
    'AcquisitionFunctionName','expected-improvement-plus'))

|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result |             | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|    1 | Best   |      3.5411 |       11.08 |      3.5411 |      3.5411 |          383 |      0.51519 |            4 |
|    2 | Best   |      3.4755 |     0.58129 |      3.4755 |       3.479 |           16 |      0.66503 |            7 |
|    3 | Best   |      3.1893 |     0.99491 |      3.1893 |      3.1893 |           33 |       0.2556 |           92 |
|    4 | Accept |      6.3077 |     0.41169 |      3.1893 |      3.1898 |           13 |    0.0053227 |            5 |
|    5 | Accept |      3.4482 |      7.2195 |      3.1893 |      3.1897 |          302 |      0.50394 |           99 |
|    6 | Accept |      4.2638 |      0.3518 |      3.1893 |      3.1897 |           10 |      0.11317 |           93 |
|    7 | Accept |      3.2449 |     0.29776 |      3.1893 |      3.1898 |           10 |      0.34912 |           93 |
|    8 | Accept |      3.4495 |     0.36121 |      3.1893 |        3.19 |           14 |      0.99651 |           98 |
|    9 | Accept |      5.8544 |      6.8997 |      3.1893 |      3.1904 |          308 |    0.0010002 |            2 |
|   10 | Accept |      3.1985 |     0.28893 |      3.1893 |      3.1876 |           10 |      0.27825 |           96 |
|   11 | Accept |      3.3339 |       10.76 |      3.1893 |      3.1886 |          447 |      0.28212 |           97 |
|   12 | Best   |      2.9764 |     0.33911 |      2.9764 |      3.1412 |           11 |      0.26217 |           98 |
|   13 | Accept |      3.1958 |     0.28967 |      2.9764 |      3.1537 |           10 |      0.26754 |           12 |
|   14 | Accept |      3.2951 |      12.284 |      2.9764 |      3.1458 |          487 |     0.022491 |           57 |
|   15 | Accept |      5.8041 |     0.29876 |      2.9764 |      3.1653 |           10 |     0.032877 |           11 |
|   16 | Accept |      3.4128 |      12.493 |      2.9764 |      3.1677 |          500 |     0.065337 |           19 |
|   17 | Accept |      3.2357 |     0.63492 |      2.9764 |      3.1653 |           24 |      0.30654 |            7 |
|   18 | Accept |      3.2848 |      3.2373 |      2.9764 |      3.1605 |          129 |      0.22496 |           91 |
|   19 | Accept |      3.1073 |     0.52902 |      2.9764 |      3.1408 |           16 |      0.29063 |           97 |
|   20 | Accept |       6.422 |     0.33719 |      2.9764 |      3.1415 |           10 |    0.0010038 |           76 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result |             | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|   21 | Accept |      3.2146 |     0.53137 |      2.9764 |       3.157 |           18 |      0.27208 |           96 |
|   22 | Accept |      3.0515 |     0.28428 |      2.9764 |      3.1365 |           10 |      0.29884 |           66 |
|   23 | Accept |      3.3721 |      12.212 |      2.9764 |      3.1357 |          500 |    0.0042631 |           84 |
|   24 | Accept |      3.1053 |      12.465 |      2.9764 |       3.136 |          499 |    0.0093964 |           13 |
|   25 | Accept |      3.1303 |      12.322 |      2.9764 |      3.1357 |          499 |    0.0092601 |           73 |
|   26 | Accept |      3.1956 |      11.799 |      2.9764 |      3.1354 |          500 |    0.0074991 |            6 |
|   27 | Accept |      3.2926 |      12.564 |      2.9764 |      3.1366 |          500 |     0.011141 |           69 |
|   28 | Accept |      4.4567 |      1.7648 |      2.9764 |      3.1372 |           74 |     0.015189 |            1 |
|   29 | Accept |      3.4466 |      4.2533 |      2.9764 |      3.1383 |          186 |      0.99992 |            4 |
|   30 | Accept |      6.1348 |      1.6981 |      2.9764 |       3.137 |           68 |    0.0023006 |           12 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 192.9234 seconds.
Total objective function evaluation time: 139.5846

Best observed feasible point:
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           11             0.26217          98     

Observed objective function value = 2.9764
Estimated objective function value = 3.137
Function evaluation time = 0.33911

Best estimated feasible point (according to models):
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           10             0.29884          66     

Estimated objective function value = 3.137
Estimated function evaluation time = 0.31123
Mdl = 
  classreg.learning.regr.RegressionEnsemble
                         ResponseName: 'Y'
                CategoricalPredictors: []
                    ResponseTransform: 'none'
                      NumObservations: 94
    HyperparameterOptimizationResults: [1×1 BayesianOptimization]
                           NumTrained: 10
                               Method: 'LSBoost'
                         LearnerNames: {'Tree'}
                 ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                              FitInfo: [10×1 double]
                   FitInfoDescription: {2×1 cell}
                       Regularization: []


  Properties, Methods

Сравните потерю для той из повышенной, неоптимизированной модели, и тому из ансамбля по умолчанию.

loss = kfoldLoss(crossval(Mdl,'kfold',10))
loss = 23.3445
Mdl2 = fitrensemble(X,Y,...
    'Method','LSBoost',...
    'Learner',templateTree('Surrogate','on'));
loss2 = kfoldLoss(crossval(Mdl2,'kfold',10))
loss2 = 37.0534
Mdl3 = fitrensemble(X,Y);
loss3 = kfoldLoss(crossval(Mdl3,'kfold',10))
loss3 = 38.4890

Для различного способа оптимизировать этот ансамбль, смотрите, Оптимизируют Ансамбль Регрессии Используя Перекрестную проверку.

Для просмотра документации необходимо авторизоваться на сайте