Оптимизируйте повышенный ансамбль регрессии

В этом примере показано, как оптимизировать гиперпараметры повышенного ансамбля регрессии. Оптимизация минимизирует потерю перекрестной проверки модели.

Проблема состоит в том, чтобы смоделировать КПД в милях на галлон автомобиля, на основе его ускорения, объема двигателя, лошадиной силы и веса. Загрузите carsmall данные, которые содержат эти и другие предикторы.

load carsmall
X = [Acceleration Displacement Horsepower Weight];
Y = MPG;

Соответствуйте ансамблю регрессии к данным с помощью LSBoost алгоритм и использование суррогатных разделений. Оптимизируйте получившуюся модель путем варьирования количества изучения циклов, максимального количества суррогатных разделений и изучить уровня. Кроме того, позвольте оптимизации повторно делить перекрестную проверку между каждой итерацией.

Для воспроизводимости, набор случайный seed и использование 'expected-improvement-plus' функция приобретения.

rng default
Mdl = fitrensemble(X,Y,...
    'Method','LSBoost',...
    'Learner',templateTree('Surrogate','on'),...
    'OptimizeHyperparameters',{'NumLearningCycles','MaxNumSplits','LearnRate'},...
    'HyperparameterOptimizationOptions',struct('Repartition',true,...
    'AcquisitionFunctionName','expected-improvement-plus'))

|====================================================================================================================|
| Iter | Eval   | Objective:  | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result | log(1+loss) | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|    1 | Best   |      3.5457 |      8.9246 |      3.5457 |      3.5457 |          383 |      0.51519 |            4 |
|    2 | Best   |      3.4903 |     0.44128 |      3.4903 |      3.4933 |           16 |      0.66503 |            7 |
|    3 | Best   |      3.1763 |     0.82734 |      3.1763 |      3.1764 |           33 |       0.2556 |           92 |
|    4 | Accept |      6.3076 |     0.36184 |      3.1763 |      3.1768 |           13 |    0.0053227 |            5 |
|    5 | Accept |      3.4071 |      2.0604 |      3.1763 |      3.1768 |           78 |      0.47249 |           99 |
|    6 | Accept |      3.7443 |     0.28997 |      3.1763 |      3.1774 |           10 |      0.14669 |           86 |
|    7 | Accept |      3.1772 |     0.29445 |      3.1763 |      3.1756 |           10 |      0.29922 |           20 |
|    8 | Best   |      3.1503 |        10.6 |      3.1503 |      3.1677 |          495 |      0.26141 |            1 |
|    9 | Accept |      3.4226 |     0.90836 |      3.1503 |       3.168 |           37 |      0.99969 |           24 |
|   10 | Accept |      3.2646 |     0.79113 |      3.1503 |       3.193 |           31 |      0.27693 |           10 |
|   11 | Accept |      3.3428 |      7.4946 |      3.1503 |      3.2275 |          314 |      0.27312 |           47 |
|   12 | Accept |      5.9344 |      6.1024 |      3.1503 |      3.2289 |          259 |    0.0010001 |           21 |
|   13 | Best   |      3.0379 |      2.9942 |      3.0379 |      3.0497 |          139 |     0.024969 |            1 |
|   14 | Accept |      3.2336 |      2.2539 |      3.0379 |      3.0571 |           97 |     0.037926 |           24 |
|   15 | Accept |      3.1774 |      10.806 |      3.0379 |      3.1059 |          439 |     0.027258 |           20 |
|   16 | Accept |      4.8699 |     0.71132 |      3.0379 |      3.0764 |           31 |     0.026817 |            2 |
|   17 | Accept |      3.2566 |      5.7023 |      3.0379 |      3.0579 |          244 |     0.058858 |            8 |
|   18 | Best   |      3.0314 |      5.4288 |      3.0314 |      3.0191 |          214 |     0.025201 |           44 |
|   19 | Accept |      3.2227 |      12.301 |      3.0314 |      3.0179 |          500 |      0.12391 |           52 |
|   20 | Accept |      3.0635 |      3.8838 |      3.0314 |      3.0321 |          180 |      0.02991 |            2 |
|====================================================================================================================|
| Iter | Eval   | Objective:  | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result | log(1+loss) | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|   21 | Accept |      3.2057 |      5.7857 |      3.0314 |      3.0363 |          237 |      0.01568 |           97 |
|   22 | Accept |      3.1692 |      1.8666 |      3.0314 |      3.0381 |           76 |      0.14203 |           53 |
|   23 | Accept |      3.1204 |      4.4184 |      3.0314 |      3.0561 |          178 |     0.025816 |           63 |
|   24 | Best   |      2.9564 |      3.9166 |      2.9564 |       3.003 |          185 |     0.028069 |            1 |
|   25 | Best   |      2.8344 |      5.0983 |      2.8344 |      2.9067 |          224 |     0.023927 |            1 |
|   26 | Accept |      2.9268 |      5.7116 |      2.8344 |      2.9136 |          262 |     0.021782 |            1 |
|   27 | Accept |      3.0143 |      5.3932 |      2.8344 |      2.9443 |          249 |     0.022608 |            1 |
|   28 | Accept |      6.4222 |     0.30738 |      2.8344 |      2.9451 |           10 |    0.0010031 |           25 |
|   29 | Accept |       3.513 |      10.871 |      2.8344 |      2.9458 |          499 |    0.0038235 |            1 |
|   30 | Accept |        3.01 |      10.446 |      2.8344 |      2.9461 |          499 |     0.011582 |            1 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 174.0756 seconds.
Total objective function evaluation time: 136.9923

Best observed feasible point:
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           224           0.023927          1      

Observed objective function value = 2.8344
Estimated objective function value = 2.9461
Function evaluation time = 5.0983

Best estimated feasible point (according to models):
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           224           0.023927          1      

Estimated objective function value = 2.9461
Estimated function evaluation time = 4.8629
Mdl = 
  classreg.learning.regr.RegressionEnsemble
                         ResponseName: 'Y'
                CategoricalPredictors: []
                    ResponseTransform: 'none'
                      NumObservations: 94
    HyperparameterOptimizationResults: [1×1 BayesianOptimization]
                           NumTrained: 224
                               Method: 'LSBoost'
                         LearnerNames: {'Tree'}
                 ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                              FitInfo: [224×1 double]
                   FitInfoDescription: {2×1 cell}
                       Regularization: []


  Properties, Methods

Сравните потерю для той из повышенной, неоптимизированной модели, и тому из ансамбля по умолчанию.

loss = kfoldLoss(crossval(Mdl,'kfold',10))
loss = 19.7084
Mdl2 = fitrensemble(X,Y,...
    'Method','LSBoost',...
    'Learner',templateTree('Surrogate','on'));
loss2 = kfoldLoss(crossval(Mdl2,'kfold',10))
loss2 = 30.9569
Mdl3 = fitrensemble(X,Y);
loss3 = kfoldLoss(crossval(Mdl3,'kfold',10))
loss3 = 38.0049

Для различного способа оптимизировать этот ансамбль, смотрите, Оптимизируют Ансамбль Регрессии Используя Перекрестную проверку.

Для просмотра документации необходимо авторизоваться на сайте