Оптимизируйте повышенный ансамбль регрессии

В этом примере показано, как оптимизировать гиперпараметры повышенного ансамбля регрессии. Оптимизация минимизирует потерю перекрестной проверки модели.

Проблема состоит в том, чтобы смоделировать КПД в милях на галлон автомобиля, на основе его ускорения, объема двигателя, лошадиной силы и веса. Загрузите carsmall данные, которые содержат эти и другие предикторы.

load carsmall
X = [Acceleration Displacement Horsepower Weight];
Y = MPG;

Соответствуйте ансамблю регрессии к данным с помощью LSBoost алгоритм и использование суррогатных разделений. Оптимизируйте получившуюся модель путем варьирования количества изучения циклов, максимального количества суррогатных разделений и изучить уровня. Кроме того, позвольте оптимизации повторно делить перекрестную проверку между каждой итерацией.

Для воспроизводимости установите случайный seed и используйте 'expected-improvement-plus' функция приобретения.

rng default
Mdl = fitrensemble(X,Y,...
    'Method','LSBoost',...
    'Learner',templateTree('Surrogate','on'),...
    'OptimizeHyperparameters',{'NumLearningCycles','MaxNumSplits','LearnRate'},...
    'HyperparameterOptimizationOptions',struct('Repartition',true,...
    'AcquisitionFunctionName','expected-improvement-plus'))
|====================================================================================================================|
| Iter | Eval   | Objective:  | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result | log(1+loss) | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|    1 | Best   |      3.5457 |      8.2128 |      3.5457 |      3.5457 |          383 |      0.51519 |            4 |
|    2 | Best   |      3.4903 |     0.41109 |      3.4903 |      3.4933 |           16 |      0.66503 |            7 |
|    3 | Best   |      3.1763 |     0.78207 |      3.1763 |      3.1764 |           33 |       0.2556 |           92 |
|    4 | Accept |      6.3076 |     0.34868 |      3.1763 |      3.1768 |           13 |    0.0053227 |            5 |
|    5 | Accept |      3.4071 |      1.7753 |      3.1763 |      3.1768 |           78 |      0.47249 |           99 |
|    6 | Accept |      3.7443 |     0.26439 |      3.1763 |      3.1774 |           10 |      0.14669 |           86 |
|    7 | Accept |      3.1772 |     0.28254 |      3.1763 |      3.1756 |           10 |      0.29922 |           20 |
|    8 | Best   |      3.1503 |      9.8543 |      3.1503 |      3.1677 |          495 |      0.26141 |            1 |
|    9 | Accept |      3.4226 |     0.91397 |      3.1503 |       3.168 |           37 |      0.99969 |           24 |
|   10 | Accept |      3.2646 |     0.72349 |      3.1503 |       3.193 |           31 |      0.27693 |           10 |
|   11 | Accept |      3.3428 |      7.0584 |      3.1503 |      3.2275 |          314 |      0.27312 |           47 |
|   12 | Accept |      5.9344 |      5.6579 |      3.1503 |      3.2289 |          259 |    0.0010001 |           21 |
|   13 | Best   |      3.0379 |      2.7762 |      3.0379 |      3.0497 |          139 |     0.024969 |            1 |
|   14 | Accept |      3.2336 |      2.1551 |      3.0379 |      3.0571 |           97 |     0.037926 |           24 |
|   15 | Accept |      3.1774 |      9.9179 |      3.0379 |      3.1059 |          439 |     0.027258 |           20 |
|   16 | Accept |      4.8699 |     0.66609 |      3.0379 |      3.0764 |           31 |     0.026817 |            2 |
|   17 | Accept |      3.2566 |      5.4116 |      3.0379 |      3.0579 |          244 |     0.058858 |            8 |
|   18 | Best   |      3.0314 |      4.8355 |      3.0314 |      3.0191 |          214 |     0.025201 |           44 |
|   19 | Accept |      3.2227 |      11.275 |      3.0314 |      3.0179 |          500 |      0.12391 |           52 |
|   20 | Accept |      3.0635 |      3.6356 |      3.0314 |      3.0321 |          180 |      0.02991 |            2 |
|====================================================================================================================|
| Iter | Eval   | Objective:  | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result | log(1+loss) | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|   21 | Accept |      3.2057 |      5.2988 |      3.0314 |      3.0363 |          237 |      0.01568 |           97 |
|   22 | Accept |      3.1692 |      1.7442 |      3.0314 |      3.0381 |           76 |      0.14203 |           53 |
|   23 | Accept |      3.1204 |      4.0391 |      3.0314 |      3.0561 |          178 |     0.025816 |           63 |
|   24 | Best   |      2.9564 |      3.6824 |      2.9564 |       3.003 |          185 |     0.028069 |            1 |
|   25 | Best   |      2.8344 |       4.448 |      2.8344 |      2.9067 |          224 |     0.023927 |            1 |
|   26 | Accept |      2.9268 |      5.2045 |      2.8344 |      2.9136 |          262 |     0.021782 |            1 |
|   27 | Accept |      3.0143 |      4.8904 |      2.8344 |      2.9443 |          249 |     0.022608 |            1 |
|   28 | Accept |      6.4222 |     0.26882 |      2.8344 |      2.9451 |           10 |    0.0010031 |           25 |
|   29 | Accept |       3.513 |      9.9166 |      2.8344 |      2.9458 |          499 |    0.0038235 |            1 |
|   30 | Accept |        3.01 |       9.915 |      2.8344 |      2.9461 |          499 |     0.011582 |            1 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 159.6818 seconds.
Total objective function evaluation time: 126.3662

Best observed feasible point:
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           224           0.023927          1      

Observed objective function value = 2.8344
Estimated objective function value = 2.9461
Function evaluation time = 4.448

Best estimated feasible point (according to models):
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           224           0.023927          1      

Estimated objective function value = 2.9461
Estimated function evaluation time = 4.4346
Mdl = 
  classreg.learning.regr.RegressionEnsemble
                         ResponseName: 'Y'
                CategoricalPredictors: []
                    ResponseTransform: 'none'
                      NumObservations: 94
    HyperparameterOptimizationResults: [1×1 BayesianOptimization]
                           NumTrained: 224
                               Method: 'LSBoost'
                         LearnerNames: {'Tree'}
                 ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                              FitInfo: [224×1 double]
                   FitInfoDescription: {2×1 cell}
                       Regularization: []


  Properties, Methods

Сравните потерю для той из повышенной, неоптимизированной модели, и тому из ансамбля по умолчанию.

loss = kfoldLoss(crossval(Mdl,'kfold',10))
loss = 19.7084
Mdl2 = fitrensemble(X,Y,...
    'Method','LSBoost',...
    'Learner',templateTree('Surrogate','on'));
loss2 = kfoldLoss(crossval(Mdl2,'kfold',10))
loss2 = 30.9569
Mdl3 = fitrensemble(X,Y);
loss3 = kfoldLoss(crossval(Mdl3,'kfold',10))
loss3 = 38.0049

Для различного способа оптимизировать этот ансамбль, смотрите, Оптимизируют Ансамбль Регрессии Используя Перекрестную проверку.

Для просмотра документации необходимо авторизоваться на сайте