Оптимизируйте повышенный ансамбль регрессии

В этом примере показано, как оптимизировать гиперпараметры повышенного ансамбля регрессии. Оптимизация минимизирует потерю перекрестной проверки модели.

Проблема состоит в том, чтобы смоделировать КПД в милях на галлон автомобиля, на основе его ускорения, объема двигателя, лошадиной силы и веса. Загрузите carsmall данные, которые содержат эти и другие предикторы.

load carsmall
X = [Acceleration Displacement Horsepower Weight];
Y = MPG;

Соответствуйте ансамблю регрессии к данным с помощью LSBoost алгоритм и использование суррогатных разделений. Оптимизируйте получившуюся модель путем варьирования количества изучения циклов, максимального количества суррогатных разделений и изучить уровня. Кроме того, позвольте оптимизации повторно делить перекрестную проверку между каждой итерацией.

Для воспроизводимости установите случайный seed и используйте 'expected-improvement-plus' функция приобретения.

rng default
Mdl = fitrensemble(X,Y,...
    'Method','LSBoost',...
    'Learner',templateTree('Surrogate','on'),...
    'OptimizeHyperparameters',{'NumLearningCycles','MaxNumSplits','LearnRate'},...
    'HyperparameterOptimizationOptions',struct('Repartition',true,...
    'AcquisitionFunctionName','expected-improvement-plus'))
|====================================================================================================================|
| Iter | Eval   | Objective:  | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result | log(1+loss) | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|    1 | Best   |      3.5891 |      8.8707 |      3.5891 |      3.5891 |          383 |      0.51519 |            4 |
|    2 | Best   |      3.4929 |      0.4762 |      3.4929 |       3.498 |           16 |      0.66503 |            7 |
|    3 | Best   |      3.1712 |     0.97701 |      3.1712 |      3.1713 |           33 |       0.2556 |           92 |
|    4 | Accept |      6.3074 |     0.38797 |      3.1712 |      3.1717 |           13 |    0.0053227 |            5 |
|    5 | Accept |      3.2808 |     0.39457 |      3.1712 |      3.1715 |           13 |      0.53319 |           99 |
|    6 | Best   |       2.974 |     0.32124 |       2.974 |      2.9768 |           10 |      0.30539 |           90 |
|    7 | Accept |      4.6086 |     0.29482 |       2.974 |      2.9757 |           10 |      0.09622 |            2 |
|    8 | Accept |      3.2302 |      0.3267 |       2.974 |      3.1035 |           10 |      0.33326 |           40 |
|    9 | Accept |      3.3755 |      3.1982 |       2.974 |       3.111 |          119 |       0.3092 |           99 |
|   10 | Accept |      2.9805 |     0.28517 |       2.974 |      3.0718 |           10 |      0.31553 |            1 |
|   11 | Accept |      3.0656 |     0.31749 |       2.974 |       3.033 |           10 |      0.31311 |           91 |
|   12 | Best   |      2.9546 |     0.30311 |      2.9546 |      2.9548 |           10 |      0.52213 |            1 |
|   13 | Accept |      3.1134 |     0.31655 |      2.9546 |      3.0514 |           10 |       0.4158 |            1 |
|   14 | Accept |       5.506 |       12.83 |      2.9546 |      3.0524 |          487 |    0.0010022 |           35 |
|   15 | Accept |       3.162 |      13.382 |      2.9546 |      3.0315 |          499 |     0.021297 |           31 |
|   16 | Accept |      5.8944 |     0.30931 |      2.9546 |      3.0558 |           10 |      0.02851 |            1 |
|   17 | Accept |      3.3265 |      13.283 |      2.9546 |      3.0028 |          499 |     0.074578 |           15 |
|   18 | Accept |      3.1752 |      13.219 |      2.9546 |      3.0574 |          494 |      0.04424 |           99 |
|   19 | Accept |      6.4219 |     0.25977 |      2.9546 |      3.0591 |           10 |    0.0010027 |           11 |
|   20 | Accept |       3.358 |      12.422 |      2.9546 |      3.0583 |          498 |    0.0043108 |           95 |
|====================================================================================================================|
| Iter | Eval   | Objective:  | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result | log(1+loss) | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|   21 | Accept |      2.9861 |      11.547 |      2.9546 |      3.0591 |          499 |    0.0092444 |            2 |
|   22 | Accept |      3.0766 |      11.332 |      2.9546 |      3.0581 |          500 |    0.0098187 |           37 |
|   23 | Accept |      3.0247 |      11.945 |      2.9546 |      3.0568 |          500 |     0.010251 |           81 |
|   24 | Accept |      3.1616 |      11.112 |      2.9546 |       3.057 |          500 |    0.0097239 |            3 |
|   25 | Accept |      4.1776 |      1.9193 |      2.9546 |      3.0575 |           85 |     0.015418 |            7 |
|   26 | Accept |      6.1517 |      1.6863 |      2.9546 |      3.0564 |           79 |    0.0019248 |            1 |
|   27 | Accept |      3.8779 |      3.8071 |      2.9546 |      3.0572 |          181 |      0.99678 |            2 |
|   28 | Accept |      3.2471 |      2.1952 |      2.9546 |      3.0449 |           90 |      0.06971 |            3 |
|   29 | Accept |      3.4336 |      4.3663 |      2.9546 |      3.0444 |          159 |      0.10695 |           27 |
|   30 | Accept |      3.5563 |      11.127 |      2.9546 |      3.0433 |          499 |      0.96204 |            3 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 198.6669 seconds.
Total objective function evaluation time: 153.2122

Best observed feasible point:
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           10             0.52213          1      

Observed objective function value = 2.9546
Estimated objective function value = 3.0975
Function evaluation time = 0.30311

Best estimated feasible point (according to models):
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           10             0.4158           1      

Estimated objective function value = 3.0433
Estimated function evaluation time = 0.29822
Mdl = 
  RegressionEnsemble
                         ResponseName: 'Y'
                CategoricalPredictors: []
                    ResponseTransform: 'none'
                      NumObservations: 94
    HyperparameterOptimizationResults: [1×1 BayesianOptimization]
                           NumTrained: 10
                               Method: 'LSBoost'
                         LearnerNames: {'Tree'}
                 ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                              FitInfo: [10×1 double]
                   FitInfoDescription: {2×1 cell}
                       Regularization: []


  Properties, Methods

Сравните потерю для той из повышенной, неоптимизированной модели, и тому из ансамбля по умолчанию.

loss = kfoldLoss(crossval(Mdl,'kfold',10))
loss = 17.8526
Mdl2 = fitrensemble(X,Y,...
    'Method','LSBoost',...
    'Learner',templateTree('Surrogate','on'));
loss2 = kfoldLoss(crossval(Mdl2,'kfold',10))
loss2 = 31.0623
Mdl3 = fitrensemble(X,Y);
loss3 = kfoldLoss(crossval(Mdl3,'kfold',10))
loss3 = 34.0239

Для различного способа оптимизировать этот ансамбль, смотрите, Оптимизируют Ансамбль Регрессии Используя Перекрестную проверку.

Для просмотра документации необходимо авторизоваться на сайте