Оптимизируйте повышенный ансамбль регрессии

В этом примере показано, как оптимизировать гиперпараметры повышенного ансамбля регрессии. Оптимизация минимизирует потерю перекрестной проверки модели.

Проблема состоит в том, чтобы смоделировать КПД в милях на галлон автомобиля, на основе его ускорения, объема двигателя, лошадиной силы и веса. Загрузите carsmall данные, которые содержат эти и другие предикторы.

load carsmall
X = [Acceleration Displacement Horsepower Weight];
Y = MPG;

Соответствуйте ансамблю регрессии к данным с помощью LSBoost алгоритм и использование суррогатных разделений. Оптимизируйте получившуюся модель путем варьирования количества изучения циклов, максимального количества суррогатных разделений и изучить уровня. Кроме того, позвольте оптимизации повторно делить перекрестную проверку между каждой итерацией.

Для воспроизводимости установите случайный seed и используйте 'expected-improvement-plus' функция приобретения.

rng('default')
Mdl = fitrensemble(X,Y, ...
    'Method','LSBoost', ...
    'Learner',templateTree('Surrogate','on'), ...
    'OptimizeHyperparameters',{'NumLearningCycles','MaxNumSplits','LearnRate'}, ...
    'HyperparameterOptimizationOptions',struct('Repartition',true, ...
    'AcquisitionFunctionName','expected-improvement-plus'))
|====================================================================================================================|
| Iter | Eval   | Objective:  | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result | log(1+loss) | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|    1 | Best   |      3.5219 |      12.762 |      3.5219 |      3.5219 |          383 |      0.51519 |            4 |
|    2 | Best   |      3.4752 |      1.0824 |      3.4752 |      3.4777 |           16 |      0.66503 |            7 |
|    3 | Best   |      3.1575 |      1.5703 |      3.1575 |      3.1575 |           33 |       0.2556 |           92 |
|    4 | Accept |      6.3076 |     0.65033 |      3.1575 |      3.1579 |           13 |    0.0053227 |            5 |
|    5 | Accept |      3.4449 |      9.5959 |      3.1575 |      3.1579 |          277 |      0.45891 |           99 |
|    6 | Accept |      3.9806 |     0.44253 |      3.1575 |      3.1584 |           10 |      0.13017 |           33 |
|    7 | Best   |       3.059 |     0.48004 |       3.059 |        3.06 |           10 |      0.30126 |            3 |
|    8 | Accept |      3.1707 |     0.52678 |       3.059 |      3.1144 |           10 |      0.28991 |           15 |
|    9 | Accept |      3.0937 |     0.51983 |       3.059 |      3.1046 |           10 |      0.31488 |           13 |
|   10 | Accept |       3.196 |     0.50347 |       3.059 |      3.1233 |           10 |      0.32005 |           11 |
|   11 | Best   |      3.0495 |     0.40737 |      3.0495 |      3.1083 |           10 |      0.27882 |           85 |
|   12 | Best   |       2.946 |     0.39281 |       2.946 |      3.0774 |           10 |      0.27157 |            7 |
|   13 | Accept |      3.2026 |     0.51172 |       2.946 |      3.0995 |           10 |      0.25734 |           20 |
|   14 | Accept |       5.595 |      15.127 |       2.946 |      3.0996 |          440 |    0.0010008 |           36 |
|   15 | Accept |      3.1976 |      16.499 |       2.946 |      3.0935 |          496 |     0.027133 |           18 |
|   16 | Accept |      3.9809 |      1.2163 |       2.946 |      3.0927 |           34 |     0.041016 |           18 |
|   17 | Accept |      3.0512 |      13.382 |       2.946 |      3.0939 |          428 |     0.019766 |            3 |
|   18 | Accept |      3.4832 |      7.2531 |       2.946 |      3.0946 |          205 |      0.99989 |            8 |
|   19 | Accept |      3.3389 |      3.2927 |       2.946 |      3.0956 |           95 |     0.021453 |            2 |
|   20 | Accept |      3.2818 |      17.258 |       2.946 |      3.0979 |          494 |     0.020773 |           12 |
|====================================================================================================================|
| Iter | Eval   | Objective:  | Objective   | BestSoFar   | BestSoFar   | NumLearningC-|    LearnRate | MaxNumSplits |
|      | result | log(1+loss) | runtime     | (observed)  | (estim.)    | ycles        |              |              |
|====================================================================================================================|
|   21 | Accept |      3.4367 |      16.049 |       2.946 |      3.0962 |          480 |      0.27412 |            7 |
|   22 | Accept |      6.2247 |      0.5334 |       2.946 |      3.0995 |           10 |     0.010965 |           15 |
|   23 | Accept |      3.2847 |      6.7898 |       2.946 |      3.0991 |          181 |     0.057422 |           22 |
|   24 | Accept |       3.142 |      8.8034 |       2.946 |      3.0997 |          222 |     0.025594 |           25 |
|   25 | Accept |      3.2174 |     0.88584 |       2.946 |       3.106 |           18 |      0.32203 |           37 |
|   26 | Accept |       3.064 |      4.7416 |       2.946 |      3.1057 |          108 |      0.18554 |            1 |
|   27 | Accept |      3.4532 |      3.0553 |       2.946 |      3.1038 |           93 |      0.22441 |            3 |
|   28 | Accept |      3.1992 |       9.249 |       2.946 |      3.1038 |          252 |     0.020628 |            3 |
|   29 | Best   |      2.9432 |       0.475 |      2.9432 |      3.0766 |           10 |      0.36141 |           86 |
|   30 | Best   |       2.891 |     0.54495 |       2.891 |           3 |           10 |      0.38339 |            2 |

Figure contains an axes. The axes with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 208.1748 seconds
Total objective function evaluation time: 154.602

Best observed feasible point:
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           10             0.38339          2      

Observed objective function value = 2.891
Estimated objective function value = 2.9674
Function evaluation time = 0.54495

Best estimated feasible point (according to models):
    NumLearningCycles    LearnRate    MaxNumSplits
    _________________    _________    ____________

           10             0.30126          3      

Estimated objective function value = 3
Estimated function evaluation time = 0.49592
Mdl = 
  RegressionEnsemble
                         ResponseName: 'Y'
                CategoricalPredictors: []
                    ResponseTransform: 'none'
                      NumObservations: 94
    HyperparameterOptimizationResults: [1x1 BayesianOptimization]
                           NumTrained: 10
                               Method: 'LSBoost'
                         LearnerNames: {'Tree'}
                 ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                              FitInfo: [10x1 double]
                   FitInfoDescription: {2x1 cell}
                       Regularization: []


  Properties, Methods

Сравните потерю для той из повышенной, неоптимизированной модели, и тому из ансамбля по умолчанию.

loss = kfoldLoss(crossval(Mdl,'kfold',10))
loss = 20.6082
Mdl2 = fitrensemble(X,Y, ...
    'Method','LSBoost', ...
    'Learner',templateTree('Surrogate','on'));
loss2 = kfoldLoss(crossval(Mdl2,'kfold',10))
loss2 = 36.4539
Mdl3 = fitrensemble(X,Y);
loss3 = kfoldLoss(crossval(Mdl3,'kfold',10))
loss3 = 36.6756

Для различного способа оптимизировать этот ансамбль, смотрите, Оптимизируют Ансамбль Регрессии Используя Перекрестную проверку.

Для просмотра документации необходимо авторизоваться на сайте