Обучите обобщенную аддитивную модель регрессии

В этом примере показано, как обучить обобщенную аддитивную модель (GAM) оптимальными параметрами и как оценить прогнозирующую эффективность обученной модели. Пример сначала находит оптимальные значения параметров для одномерного GAM (параметры для линейных членов) и затем находит значения для двумерного GAM (параметры в течение многих периодов взаимодействия). Кроме того, пример объясняет, как интерпретировать обученную модель путем исследования локальных эффектов условий на определенном предсказании и путем вычисления частичной зависимости предсказаний на предикторах.

Загрузка демонстрационных данных

Загрузите набор выборочных данных NYCHousing2015.

load NYCHousing2015

Набор данных включает 10 переменных с информацией о продажах свойств в Нью-Йорке в 2 015. Этот пример использует эти переменные, чтобы анализировать отпускные цены (SALEPRICE).

Предварительно обработайте набор данных. Примите что SALEPRICE меньше чем или равный 1 000$ указывает на передачу владения без суммы. Удалите выборки, которые имеют этот SALEPRICE. Кроме того, удалите выбросы, идентифицированные isoutlier функция. Затем преобразуйте datetime массив (SALEDATE) к числам месяца и перемещению переменная отклика (SALEPRICE) к последнему столбцу. Измените нули в LANDSQUAREFEET, GROSSSQUAREFEET, и YEARBUILT к NaNs.

idx1 = NYCHousing2015.SALEPRICE <= 1000;
idx2 = isoutlier(NYCHousing2015.SALEPRICE);
NYCHousing2015(idx1|idx2,:) = [];
NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE);
NYCHousing2015 = movevars(NYCHousing2015,'SALEPRICE','After','SALEDATE');
NYCHousing2015.LANDSQUAREFEET(NYCHousing2015.LANDSQUAREFEET == 0) = NaN; 
NYCHousing2015.GROSSSQUAREFEET(NYCHousing2015.GROSSSQUAREFEET == 0) = NaN; 
NYCHousing2015.YEARBUILT(NYCHousing2015.YEARBUILT == 0) = NaN; 

Отобразите первые три строки таблицы.

head(NYCHousing2015,3)
ans=3×10 table
    BOROUGH    NEIGHBORHOOD       BUILDINGCLASSCATEGORY        RESIDENTIALUNITS    COMMERCIALUNITS    LANDSQUAREFEET    GROSSSQUAREFEET    YEARBUILT    SALEDATE    SALEPRICE
    _______    ____________    ____________________________    ________________    _______________    ______________    _______________    _________    ________    _________

       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   0                1103              1290            1910          2           3e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   1                2500              2452            1910          7           4e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   2                1911              4080            1931          1         5.1e+05 

Случайным образом выберите 1 000 выборок при помощи datasample функция и наблюдения раздела в набор обучающих данных и набор тестов при помощи cvpartition функция. Задайте 10%-ю выборку затяжки для тестирования.

rng('default') % For reproducibility
NumSamples = 1e3;
NYCHousing2015 = datasample(NYCHousing2015,NumSamples,'Replace',false);
cv = cvpartition(size(NYCHousing2015,1),'HoldOut',0.10);

Извлеките обучение и протестируйте индексы и составьте таблицы для наборов тестовых данных и обучения.

tbl_training = NYCHousing2015(training(cv),:);
tbl_test = NYCHousing2015(test(cv),:);

Найдите оптимальные параметры для одномерного GAM

Оптимизируйте параметры для одномерного GAM относительно перекрестной проверки при помощи bayesopt функция.

Подготовьте optimizableVariable объекты для аргументов значения имени одномерного GAM: MaxNumSplitsPerPredictor, NumTreesPerPredictor, и InitialLearnRateForPredictors.

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

Создайте целевую функцию, которая берет вход z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] и возвращает перекрестное подтвержденное значение потерь в параметрах в z.

minfun1 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

Если вы задаете опцию перекрестной проверки 'CrossVal','on', затем fitrgam функция возвращает перекрестный подтвержденный объект модели RegressionPartitionedGAM. kfoldLoss функция возвращает потерю регрессии (среднеквадратическая ошибка), полученная перекрестной подтвержденной моделью. Поэтому указатель на функцию minfun1 вычисляет потерю перекрестной проверки в параметрах в z.

Ищите лучшие параметры с помощью bayesopt. Для воспроизводимости выберите 'expected-improvement-plus' функция приобретения. Функция приобретения по умолчанию зависит от времени выполнения и, поэтому, может дать различные результаты.

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |  8.4558e+10 |      1.5106 |  8.4558e+10 |  8.4558e+10 |      0.36695 |            2 |           30 |
|    2 | Accept |  8.6891e+10 |       12.01 |  8.4558e+10 |  8.4558e+10 |     0.008213 |            5 |          271 |
|    3 | Accept |  9.6521e+10 |      1.9121 |  8.4558e+10 |  8.4558e+10 |      0.22984 |            9 |           37 |
|    4 | Accept |  1.3402e+11 |      14.388 |  8.4558e+10 |  8.4558e+10 |      0.99932 |            3 |          344 |
|    5 | Accept |  8.7852e+10 |      13.595 |  8.4558e+10 |  8.4558e+10 |      0.16575 |            1 |          456 |
|    6 | Accept |  9.3041e+10 |      11.002 |  8.4558e+10 |  8.4558e+10 |      0.49477 |            1 |          360 |
|    7 | Accept |  1.0558e+11 |      7.7647 |  8.4558e+10 |  8.4558e+10 |      0.24562 |            4 |          175 |
|    8 | Accept |  8.8841e+10 |      1.5763 |  8.4558e+10 |  8.4558e+10 |      0.39298 |            2 |           41 |
|    9 | Accept |  9.9227e+10 |      14.377 |  8.4558e+10 |  8.4558e+10 |     0.091879 |            3 |          358 |
|   10 | Accept |  9.8611e+10 |     0.14914 |  8.4558e+10 |  8.4558e+10 |      0.22487 |            2 |            2 |
|   11 | Accept |  1.2998e+11 |      23.962 |  8.4558e+10 |  8.4558e+10 |      0.25341 |            5 |          500 |
|   12 | Accept |  8.8968e+10 |      5.0028 |  8.4558e+10 |  8.4558e+10 |      0.33109 |            1 |          175 |
|   13 | Accept |  1.2018e+11 |      1.8004 |  8.4558e+10 |  8.4558e+10 |    0.0030413 |            6 |           40 |
|   14 | Accept |  8.7503e+10 |     0.79283 |  8.4558e+10 |  8.4558e+10 |      0.33877 |            1 |           25 |
|   15 | Accept |  9.3798e+10 |      2.9578 |  8.4558e+10 |  8.4558e+10 |      0.32926 |            2 |           80 |
|   16 | Accept |  9.5165e+10 |      8.0635 |  8.4558e+10 |  8.4558e+10 |      0.33878 |            1 |          282 |
|   17 | Best   |  8.3549e+10 |     0.24446 |  8.3549e+10 |  8.3549e+10 |       0.3552 |            2 |            5 |
|   18 | Best   |  8.3104e+10 |      1.4534 |  8.3104e+10 |  8.3104e+10 |       0.2526 |            1 |           49 |
|   19 | Accept |  8.6938e+10 |      3.3234 |  8.3104e+10 |  8.3104e+10 |      0.18293 |            1 |          110 |
|   20 | Accept |  8.7531e+10 |      2.8096 |  8.3104e+10 |  8.3104e+10 |       0.2781 |            1 |           93 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |  9.1613e+10 |      13.347 |  8.3104e+10 |  8.3104e+10 |      0.31722 |            1 |          464 |
|   22 | Accept |   8.678e+10 |      10.358 |  8.3104e+10 |  8.3104e+10 |      0.11269 |            1 |          358 |
|   23 | Accept |  8.3614e+10 |     0.47001 |  8.3104e+10 |  8.3104e+10 |      0.22278 |            1 |           14 |
|   24 | Accept |  1.3203e+11 |       1.069 |  8.3104e+10 |  8.3104e+10 |    0.0021552 |            5 |           23 |
|   25 | Accept |    8.66e+10 |       7.233 |  8.3104e+10 |  8.3104e+10 |      0.11469 |            1 |          236 |
|   26 | Accept |  8.4535e+10 |      8.7657 |  8.3104e+10 |  8.3104e+10 |    0.0090628 |            1 |          292 |
|   27 | Accept |  1.0315e+11 |      12.297 |  8.3104e+10 |  8.3104e+10 |    0.0014094 |            1 |          413 |
|   28 | Accept |  9.6736e+10 |      5.8323 |  8.3104e+10 |  8.3104e+10 |    0.0040429 |            1 |          202 |
|   29 | Accept |  8.3651e+10 |      8.4999 |  8.3104e+10 |  8.3104e+10 |      0.09375 |            1 |          295 |
|   30 | Accept |  8.7977e+10 |      13.521 |  8.3104e+10 |  8.3104e+10 |     0.016448 |            6 |          292 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 245.1541 seconds
Total objective function evaluation time: 210.0881

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Observed objective function value = 83103839919.908
Estimated objective function value = 83103840296.3186
Function evaluation time = 1.4534

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Estimated objective function value = 83103840296.3186
Estimated function evaluation time = 1.803

Получите лучшую точку из results1.

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Обучите одномерный GAM оптимальными параметрами

Обучите оптимизированный GAM с помощью zbest1 значения.

Mdl1 = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9806e+05
          NumObservations: 900


  Properties, Methods

Mdl1 RegressionGAM объект модели. Отображение модели показывает частичный список свойств модели. Чтобы просмотреть полный список свойств, дважды кликните имя переменной Mdl1 в Рабочей области. Редактор Переменных открывается для Mdl1. В качестве альтернативы можно отобразить свойства в Командном окне при помощи записи через точку. Например, отобразите ReasonForTermination свойство.

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

PredictorTrees поле значения свойства указывает на тот Mdl1 включает конкретное количество деревьев. NumTreesPerPredictor из fitrgam задает максимальное количество деревьев на предиктор, и функция может остановиться перед обучением требуемое количество деревьев. Можно использовать ReasonForTermination свойство определить, содержит ли обученная модель конкретное количество деревьев.

Если вы задаете, чтобы включать периоды взаимодействия так, чтобы fitrgam обучает деревья им. InteractionTrees поле содержит непустое значение.

Найдите оптимальные параметры для двумерного GAM

Найдите параметры в течение многих периодов взаимодействия двумерного GAM при помощи bayesopt функция.

Подготовьте optimizableVariable для аргументов значения имени в течение периодов взаимодействия: InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, NumTreesPerInteraction, и InitialLearnRateForInteractions.

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

Создайте целевую функцию для оптимизации. Используйте оптимальные значения параметров в zbest1 так, чтобы программное обеспечение нашло оптимальные значения параметров в течение многих периодов взаимодействия на основе zbest1 значения.

minfun2 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

Ищите лучшие параметры с помощью bayesopt. Процесс оптимизации обучает многоуровневые модели и отображает предупреждающие сообщения, если модели не включают периодов взаимодействия. Отключите все предупреждения прежде, чем вызвать bayesopt и восстановите состояние предупреждения после выполнения bayesopt. Можно оставить состояние предупреждения без изменений, чтобы просмотреть предупреждающие сообщения.

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |  8.4721e+10 |      1.6996 |  8.4721e+10 |  8.4721e+10 |      0.41774 |            1 |          346 |           28 |
|    2 | Accept |  9.1765e+10 |      8.3313 |  8.4721e+10 |  8.4721e+10 |       0.9565 |            3 |          231 |           14 |
|    3 | Accept |  9.2116e+10 |      2.8341 |  8.4721e+10 |  8.4721e+10 |      0.33578 |            9 |           45 |            5 |
|    4 | Accept |   1.784e+11 |      76.237 |  8.4721e+10 |  8.4721e+10 |      0.91186 |           10 |          479 |           27 |
|    5 | Accept |  8.4906e+10 |      1.8275 |  8.4721e+10 |  8.4721e+10 |        0.296 |            4 |            1 |           27 |
|    6 | Best   |  8.4172e+10 |        1.73 |  8.4172e+10 |  8.4172e+10 |      0.68133 |            1 |           86 |            1 |
|    7 | Best   |   8.234e+10 |      1.7164 |   8.234e+10 |   8.234e+10 |      0.13943 |            1 |          228 |           26 |
|    8 | Accept |  8.3488e+10 |      1.6382 |   8.234e+10 |   8.234e+10 |      0.46764 |            1 |            1 |            5 |
|    9 | Accept |  8.7977e+10 |      1.5655 |   8.234e+10 |   8.234e+10 |       0.8385 |           10 |            1 |            5 |
|   10 | Accept |  8.4431e+10 |      1.5744 |   8.234e+10 |   8.234e+10 |      0.95535 |            1 |          261 |            4 |
|   11 | Accept |  8.5784e+10 |      1.7478 |   8.234e+10 |   8.234e+10 |     0.023058 |            7 |            1 |           14 |
|   12 | Accept |  8.6068e+10 |      1.7304 |   8.234e+10 |   8.234e+10 |      0.77118 |            1 |            5 |           28 |
|   13 | Accept |  8.7004e+10 |      1.5903 |   8.234e+10 |   8.234e+10 |     0.016991 |            1 |          263 |            2 |
|   14 | Accept |  8.3325e+10 |      1.5895 |   8.234e+10 |   8.234e+10 |       0.9468 |            4 |            7 |            1 |
|   15 | Accept |  8.4097e+10 |      1.6357 |   8.234e+10 |   8.234e+10 |      0.97988 |            1 |          250 |           28 |
|   16 | Accept |  8.3106e+10 |      1.6081 |   8.234e+10 |   8.234e+10 |     0.024052 |            1 |          121 |           28 |
|   17 | Accept |   8.469e+10 |      1.6235 |   8.234e+10 |   8.234e+10 |     0.047902 |            3 |            3 |           12 |
|   18 | Best   |  8.1641e+10 |      1.5833 |  8.1641e+10 |  8.1641e+10 |      0.99848 |            6 |            1 |            3 |
|   19 | Accept |  8.5957e+10 |      1.6305 |  8.1641e+10 |  8.1641e+10 |      0.99826 |            6 |            1 |           13 |
|   20 | Accept |  8.2486e+10 |      1.6515 |  8.1641e+10 |  8.1641e+10 |      0.36059 |            7 |            2 |            1 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |  8.6534e+10 |       1.647 |  8.1641e+10 |  8.1641e+10 |    0.0089186 |            1 |          192 |           18 |
|   22 | Accept |  8.5425e+10 |      1.5316 |  8.1641e+10 |  8.1641e+10 |      0.99842 |            1 |          497 |            1 |
|   23 | Accept |   8.515e+10 |      1.5728 |  8.1641e+10 |  8.1641e+10 |      0.99934 |            1 |            3 |            2 |
|   24 | Accept |   8.593e+10 |      1.6086 |  8.1641e+10 |  8.1641e+10 |    0.0099052 |            1 |            2 |           28 |
|   25 | Accept |  8.7394e+10 |       1.577 |  8.1641e+10 |  8.1641e+10 |      0.96502 |            7 |            5 |            2 |
|   26 | Accept |   8.618e+10 |      1.5714 |  8.1641e+10 |  8.1641e+10 |     0.097871 |            5 |            3 |            1 |
|   27 | Accept |  8.5704e+10 |       1.665 |  8.1641e+10 |  8.1641e+10 |     0.056356 |           10 |            6 |            3 |
|   28 | Accept |  9.5451e+10 |      2.8821 |  8.1641e+10 |  8.1641e+10 |      0.91844 |            3 |           12 |           28 |
|   29 | Accept |  8.4013e+10 |      1.5633 |  8.1641e+10 |  8.1641e+10 |      0.68016 |            6 |            1 |            1 |
|   30 | Accept |  8.3928e+10 |      1.7715 |  8.1641e+10 |  8.1641e+10 |      0.07259 |            5 |            5 |           14 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 155.1459 seconds
Total objective function evaluation time: 132.9347

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Observed objective function value = 81640836929.8637
Estimated objective function value = 81640841484.6238
Function evaluation time = 1.5833

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Estimated objective function value = 81640841484.6238
Estimated function evaluation time = 1.5784
warning(orig_state)

Получите лучшую точку из results2.

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Обучите двумерный GAM оптимальными параметрами

Обучите оптимизированный GAM с помощью zbest1 и zbest2 значения.

Mdl = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9741e+05
             Interactions: [3×2 double]
          NumObservations: 900


  Properties, Methods

В качестве альтернативы можно добавить периоды взаимодействия в одномерный GAM при помощи addInteractions функция.

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

Второй входной параметр задает максимальное количество периодов взаимодействия и NumTreesPerInteraction аргумент значения имени задает максимальное количество деревьев в период взаимодействия. addInteractions функция может включать меньше периодов взаимодействия и остановки перед обучением требуемое количество деревьев. Можно использовать Interactions и ReasonForTermination свойства проверять фактический номер периодов взаимодействия и количество деревьев в обученной модели.

Отобразите периоды взаимодействия в Mdl.

Mdl.Interactions
ans = 3×2

     3     6
     4     6
     6     8

Каждая строка Interactions представляет один период взаимодействия и содержит индексы столбца переменных предикторов в течение периода взаимодействия. Можно использовать Interactions свойство проверять периоды взаимодействия в модель и порядок, в который fitrgam добавляет их в модель.

Отобразите периоды взаимодействия в Mdl использование имен предиктора.

Mdl.PredictorNames(Mdl.Interactions)
ans = 3×2 cell
    {'BUILDINGCLASSCATEGORY'}    {'LANDSQUAREFEET'}
    {'RESIDENTIALUNITS'     }    {'LANDSQUAREFEET'}
    {'LANDSQUAREFEET'       }    {'YEARBUILT'     }

Отобразите причину завершения, чтобы определить, содержит ли модель конкретное количество деревьев для каждого линейного члена и каждый период взаимодействия.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

Оцените прогнозирующую эффективность на новых наблюдениях

Оцените эффективность обученной модели при помощи тестовой выборки tbl_test и объект функционирует predict и loss. Можно использовать полную или компактную модель с этими функциями.

  • predict — Предскажите ответы

  • loss — Вычислите потерю регрессии (среднеквадратическая ошибка, по умолчанию)

Если вы хотите оценить эффективность обучающего набора данных, используйте функции объекта перезамены: resubPredict и resubLoss. Чтобы использовать эти функции, необходимо использовать полную модель, которая содержит обучающие данные.

Создайте компактную модель, чтобы уменьшать размер обученной модели.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size             Bytes  Class                                          Attributes

  CMdl      1x1             370211  classreg.learning.regr.CompactRegressionGAM              
  Mdl       1x1             528102  RegressionGAM                                            

Предскажите ответы и вычислите среднеквадратические ошибки для набора тестовых данных tbl_test.

yFit = predict(CMdl,tbl_test);
L = loss(CMdl,tbl_test)
L = 1.2855e+11

Найдите предсказанные ответы и ошибки без включения периодов взаимодействия в обученной модели.

yFit_nointeraction = predict(CMdl,tbl_test,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,tbl_test,'IncludeInteractions',false)
L_nointeractions = 1.3031e+11

Модель достигает меньшей ошибки для набора тестовых данных, когда и линейные члены и периоды взаимодействия включены.

Сравните результаты, полученные включением и линейного к периодам взаимодействия и результатам, полученным включением только линейных членов. Составьте таблицу, содержащую наблюдаемые ответы и предсказанные ответы. Отобразите первые восемь строк таблицы.

t = table(tbl_test.SALEPRICE,yFit,yFit_nointeraction, ...
    'VariableNames',{'Observed Value','Predicted Response','Predicted Response Without Interactions'});
head(t)
ans=8×3 table
    Observed Value    Predicted Response    Predicted Response Without Interactions
    ______________    __________________    _______________________________________

         3.6e+05          4.9812e+05                      5.2712e+05               
         1.8e+05          2.7349e+05                      2.7415e+05               
         1.9e+05          3.3682e+05                      3.3748e+05               
        4.26e+05            6.15e+05                      5.6542e+05               
        3.91e+05          3.1262e+05                      3.1328e+05               
         2.3e+05          1.0606e+05                      1.0672e+05               
      4.7333e+05          1.0773e+06                      1.1399e+06               
           2e+05          2.9506e+05                       3.305e+05               

Интерпретируйте предсказание

Интерпретируйте предсказание для первого тестового наблюдения при помощи plotLocalEffects функция. Кроме того, создайте частичные графики зависимости для некоторых важных членов в модели при помощи plotPartialDependence функция.

Предскажите значение отклика для первого наблюдения за тестовыми данными и постройте локальные эффекты условий в CMdl на предсказании. Задайте 'IncludeIntercept',true включать термин точки пересечения в графике.

yFit = predict(CMdl,tbl_test(1,:))
yFit = 4.9812e+05
plotLocalEffects(CMdl,tbl_test(1,:),'IncludeIntercept',true)

predict функция возвращает отпускную цену за первое наблюдение tbl_test(1,:). plotLocalEffects функция создает горизонтальный столбчатый график, который показывает локальные эффекты условий в CMdl на предсказании. Каждое локальное значение эффекта показывает вклад каждого термина к предсказанной отпускной цене за tbl_test(1,:).

Вычислите частичные значения зависимости для BUILDINGCLASSCATEGORY и постройте отсортированные значения. Задайте и обучение и наборы тестовых данных, чтобы вычислить частичные значения зависимости с помощью обоих наборов.

[pd,x,y] = partialDependence(CMdl,'BUILDINGCLASSCATEGORY',[tbl_training; tbl_test]);
[pd_sorted,I] = sort(pd);
x_sorted = x(I);
x_sorted = reordercats(x_sorted,I);
figure
plot(x_sorted,pd_sorted,'o:')
xlabel('BUILDINGCLASSCATEGORY')
ylabel('SALEPRICE')
title('Patial Dependence Plot')

Построенная линия представляет усредненные частичные отношения между предиктором BUILDINGCLASSCATEGORY и ответ SALEPRICE в обученной модели.

Создайте частичный график зависимости для условий RESIDENTIALUNITS и LANDSQUAREFEET.

figure
plotPartialDependence(CMdl,["RESIDENTIALUNITS","LANDSQUAREFEET"],[tbl_training; tbl_test])

Незначительные метки деления в оси X (RESIDENTIALUNITS) и ось Y (LANDSQUAREFEET) представляйте уникальные значения предикторов в заданных данных. Значения предиктора включают несколько выбросов и большую часть RESIDENTIALUNITS и LANDSQUAREFEET значения меньше 10 и 50,000, соответственно. График показывает что SALEPRICE значения значительно не варьируются когда RESIDENTIALUNITS и LANDSQUAREFEET значения больше 10 и 50,000.

Смотрите также

| | | | | |

Похожие темы