Обучите обобщенную аддитивную модель для регрессии

Этот пример показывает, как обучить обобщенную аддитивную модель (GAM) с оптимальными параметрами и как оценить прогнозирующую эффективность обученной модели. Пример сначала находит оптимальные значения параметров для одномерного GAM (параметры для линейных членов), а затем находит значения для двухмерного GAM (параметры для членов взаимодействия). Кроме того, пример объясняет, как интерпретировать обученную модель, исследуя локальные эффекты членов на конкретном предсказании и вычисляя частичную зависимость предсказаний от предикторов.

Загрузка выборочных данных

Загрузите набор выборочных данных NYCHousing2015.

load NYCHousing2015

Набор данных включает 10 переменных с информацией о продажах недвижимости в Нью-Йорке в 2015 году. Этот пример использует эти переменные для анализа продажных цен (SALEPRICE).

Предварительно обработайте набор данных. Предположим, что SALEPRICE менее чем или равным 1000 $ означает передачу права собственности без денежного фактора. Удалите выборки, которые имеют это SALEPRICE. Также удалите выбросы, идентифицированные isoutlier функция. Затем преобразуйте datetime массив (SALEDATE) на номера месяцев и переместить переменную отклика (SALEPRICE) до последнего столбца. Измените нули в LANDSQUAREFEET, GROSSSQUAREFEET, и YEARBUILT на NaNс.

idx1 = NYCHousing2015.SALEPRICE <= 1000;
idx2 = isoutlier(NYCHousing2015.SALEPRICE);
NYCHousing2015(idx1|idx2,:) = [];
NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE);
NYCHousing2015 = movevars(NYCHousing2015,'SALEPRICE','After','SALEDATE');
NYCHousing2015.LANDSQUAREFEET(NYCHousing2015.LANDSQUAREFEET == 0) = NaN; 
NYCHousing2015.GROSSSQUAREFEET(NYCHousing2015.GROSSSQUAREFEET == 0) = NaN; 
NYCHousing2015.YEARBUILT(NYCHousing2015.YEARBUILT == 0) = NaN; 

Отобразите первые три строки таблицы.

head(NYCHousing2015,3)
ans=3×10 table
    BOROUGH    NEIGHBORHOOD       BUILDINGCLASSCATEGORY        RESIDENTIALUNITS    COMMERCIALUNITS    LANDSQUAREFEET    GROSSSQUAREFEET    YEARBUILT    SALEDATE    SALEPRICE
    _______    ____________    ____________________________    ________________    _______________    ______________    _______________    _________    ________    _________

       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   0                1103              1290            1910          2           3e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   1                2500              2452            1910          7           4e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   2                1911              4080            1931          1         5.1e+05 

Случайным образом выберите 1000 выборок при помощи datasample функция и наблюдения разбиения на набор обучающих данных и тестовый набор при помощи cvpartition функция. Укажите 10% -ная выборка удержания для проверки.

rng('default') % For reproducibility
NumSamples = 1e3;
NYCHousing2015 = datasample(NYCHousing2015,NumSamples,'Replace',false);
cv = cvpartition(size(NYCHousing2015,1),'HoldOut',0.10);

Извлеките индексы обучения и тестирования и создайте таблицы для наборов обучающих и тестовых данных.

tbl_training = NYCHousing2015(training(cv),:);
tbl_test = NYCHousing2015(test(cv),:);

Нахождение оптимальных параметров для одномерной GAM

Оптимизируйте параметры для одномерной GAM относительно перекрестной валидации с помощью bayesopt функция.

Подготовка optimizableVariable объекты для аргументов имя-значение одномерного GAM: MaxNumSplitsPerPredictor, NumTreesPerPredictor, и InitialLearnRateForPredictors.

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

Создайте целевую функцию, которая принимает вход z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] и возвращает перекрестно проверенное значение потерь в параметрах в z.

minfun1 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

Если вы задаете опцию перекрестной проверки 'CrossVal','on', затем fitrgam функция возвращает перекрестно проверенный объект модели RegressionPartitionedGAM. The kfoldLoss функция возвращает регрессионные потери (средняя квадратичная невязка), полученные перекрестной проверенной моделью. Поэтому указатель на функцию minfun1 вычисляет потери перекрестной валидации у параметров в z.

Поиск оптимальных параметров с помощью bayesopt. Для повторяемости выберите 'expected-improvement-plus' функция сбора. Функция сбора по умолчанию зависит от времени выполнения и, следовательно, может давать различные результаты.

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |  8.4558e+10 |      1.5106 |  8.4558e+10 |  8.4558e+10 |      0.36695 |            2 |           30 |
|    2 | Accept |  8.6891e+10 |       12.01 |  8.4558e+10 |  8.4558e+10 |     0.008213 |            5 |          271 |
|    3 | Accept |  9.6521e+10 |      1.9121 |  8.4558e+10 |  8.4558e+10 |      0.22984 |            9 |           37 |
|    4 | Accept |  1.3402e+11 |      14.388 |  8.4558e+10 |  8.4558e+10 |      0.99932 |            3 |          344 |
|    5 | Accept |  8.7852e+10 |      13.595 |  8.4558e+10 |  8.4558e+10 |      0.16575 |            1 |          456 |
|    6 | Accept |  9.3041e+10 |      11.002 |  8.4558e+10 |  8.4558e+10 |      0.49477 |            1 |          360 |
|    7 | Accept |  1.0558e+11 |      7.7647 |  8.4558e+10 |  8.4558e+10 |      0.24562 |            4 |          175 |
|    8 | Accept |  8.8841e+10 |      1.5763 |  8.4558e+10 |  8.4558e+10 |      0.39298 |            2 |           41 |
|    9 | Accept |  9.9227e+10 |      14.377 |  8.4558e+10 |  8.4558e+10 |     0.091879 |            3 |          358 |
|   10 | Accept |  9.8611e+10 |     0.14914 |  8.4558e+10 |  8.4558e+10 |      0.22487 |            2 |            2 |
|   11 | Accept |  1.2998e+11 |      23.962 |  8.4558e+10 |  8.4558e+10 |      0.25341 |            5 |          500 |
|   12 | Accept |  8.8968e+10 |      5.0028 |  8.4558e+10 |  8.4558e+10 |      0.33109 |            1 |          175 |
|   13 | Accept |  1.2018e+11 |      1.8004 |  8.4558e+10 |  8.4558e+10 |    0.0030413 |            6 |           40 |
|   14 | Accept |  8.7503e+10 |     0.79283 |  8.4558e+10 |  8.4558e+10 |      0.33877 |            1 |           25 |
|   15 | Accept |  9.3798e+10 |      2.9578 |  8.4558e+10 |  8.4558e+10 |      0.32926 |            2 |           80 |
|   16 | Accept |  9.5165e+10 |      8.0635 |  8.4558e+10 |  8.4558e+10 |      0.33878 |            1 |          282 |
|   17 | Best   |  8.3549e+10 |     0.24446 |  8.3549e+10 |  8.3549e+10 |       0.3552 |            2 |            5 |
|   18 | Best   |  8.3104e+10 |      1.4534 |  8.3104e+10 |  8.3104e+10 |       0.2526 |            1 |           49 |
|   19 | Accept |  8.6938e+10 |      3.3234 |  8.3104e+10 |  8.3104e+10 |      0.18293 |            1 |          110 |
|   20 | Accept |  8.7531e+10 |      2.8096 |  8.3104e+10 |  8.3104e+10 |       0.2781 |            1 |           93 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |  9.1613e+10 |      13.347 |  8.3104e+10 |  8.3104e+10 |      0.31722 |            1 |          464 |
|   22 | Accept |   8.678e+10 |      10.358 |  8.3104e+10 |  8.3104e+10 |      0.11269 |            1 |          358 |
|   23 | Accept |  8.3614e+10 |     0.47001 |  8.3104e+10 |  8.3104e+10 |      0.22278 |            1 |           14 |
|   24 | Accept |  1.3203e+11 |       1.069 |  8.3104e+10 |  8.3104e+10 |    0.0021552 |            5 |           23 |
|   25 | Accept |    8.66e+10 |       7.233 |  8.3104e+10 |  8.3104e+10 |      0.11469 |            1 |          236 |
|   26 | Accept |  8.4535e+10 |      8.7657 |  8.3104e+10 |  8.3104e+10 |    0.0090628 |            1 |          292 |
|   27 | Accept |  1.0315e+11 |      12.297 |  8.3104e+10 |  8.3104e+10 |    0.0014094 |            1 |          413 |
|   28 | Accept |  9.6736e+10 |      5.8323 |  8.3104e+10 |  8.3104e+10 |    0.0040429 |            1 |          202 |
|   29 | Accept |  8.3651e+10 |      8.4999 |  8.3104e+10 |  8.3104e+10 |      0.09375 |            1 |          295 |
|   30 | Accept |  8.7977e+10 |      13.521 |  8.3104e+10 |  8.3104e+10 |     0.016448 |            6 |          292 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 245.1541 seconds
Total objective function evaluation time: 210.0881

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Observed objective function value = 83103839919.908
Estimated objective function value = 83103840296.3186
Function evaluation time = 1.4534

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Estimated objective function value = 83103840296.3186
Estimated function evaluation time = 1.803

Получите лучшую точку от results1.

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Обучите одномерную GAM с оптимальными параметрами

Обучите оптимизированную GAM с помощью zbest1 значения.

Mdl1 = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9806e+05
          NumObservations: 900


  Properties, Methods

Mdl1 является RegressionGAM объект модели. Отображение модели показывает частичный список свойств модели. Чтобы просмотреть полный список свойств, дважды кликните имя переменной Mdl1 в Рабочей области. Откроется редактор Переменных для Mdl1. Также можно отобразить свойства в Командном окне с помощью записи через точку. Для примера отобразите ReasonForTermination свойство.

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

The PredictorTrees поле значения свойства указывает, что Mdl1 включает указанное количество деревьев. NumTreesPerPredictor от fitrgam задает максимальное количество деревьев на предиктор, и функция может остановиться перед обучением требуемого количества деревьев. Можно использовать ReasonForTermination свойство, чтобы определить, содержит ли обученная модель заданное количество деревьев.

Если вы задаете, чтобы включить условия взаимодействия, так что fitrgam тренирует для них деревья. а InteractionTrees поле содержит непустое значение.

Нахождение оптимальных параметров для двухмерной GAM

Найдите параметры для условий взаимодействия двухмерной GAM при помощи bayesopt функция.

Подготовка optimizableVariable для аргументов имя-значение для условий взаимодействия: InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, NumTreesPerInteraction, и InitialLearnRateForInteractions.

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

Создайте целевую функцию для оптимизации. Используйте оптимальные значения параметров в zbest1 так, что программное обеспечение находит оптимальные значения параметров для членов взаимодействия на основе zbest1 значения.

minfun2 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

Поиск оптимальных параметров с помощью bayesopt. Процесс оптимизации обучает несколько моделей и отображает предупреждающие сообщения, если модели не включают условия взаимодействия. Отключите все предупреждения перед вызовом bayesopt и восстановите состояние предупреждения после запуска bayesopt. Можно оставить состояние предупреждения без изменений для просмотра предупреждающих сообщений.

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |  8.4721e+10 |      1.6996 |  8.4721e+10 |  8.4721e+10 |      0.41774 |            1 |          346 |           28 |
|    2 | Accept |  9.1765e+10 |      8.3313 |  8.4721e+10 |  8.4721e+10 |       0.9565 |            3 |          231 |           14 |
|    3 | Accept |  9.2116e+10 |      2.8341 |  8.4721e+10 |  8.4721e+10 |      0.33578 |            9 |           45 |            5 |
|    4 | Accept |   1.784e+11 |      76.237 |  8.4721e+10 |  8.4721e+10 |      0.91186 |           10 |          479 |           27 |
|    5 | Accept |  8.4906e+10 |      1.8275 |  8.4721e+10 |  8.4721e+10 |        0.296 |            4 |            1 |           27 |
|    6 | Best   |  8.4172e+10 |        1.73 |  8.4172e+10 |  8.4172e+10 |      0.68133 |            1 |           86 |            1 |
|    7 | Best   |   8.234e+10 |      1.7164 |   8.234e+10 |   8.234e+10 |      0.13943 |            1 |          228 |           26 |
|    8 | Accept |  8.3488e+10 |      1.6382 |   8.234e+10 |   8.234e+10 |      0.46764 |            1 |            1 |            5 |
|    9 | Accept |  8.7977e+10 |      1.5655 |   8.234e+10 |   8.234e+10 |       0.8385 |           10 |            1 |            5 |
|   10 | Accept |  8.4431e+10 |      1.5744 |   8.234e+10 |   8.234e+10 |      0.95535 |            1 |          261 |            4 |
|   11 | Accept |  8.5784e+10 |      1.7478 |   8.234e+10 |   8.234e+10 |     0.023058 |            7 |            1 |           14 |
|   12 | Accept |  8.6068e+10 |      1.7304 |   8.234e+10 |   8.234e+10 |      0.77118 |            1 |            5 |           28 |
|   13 | Accept |  8.7004e+10 |      1.5903 |   8.234e+10 |   8.234e+10 |     0.016991 |            1 |          263 |            2 |
|   14 | Accept |  8.3325e+10 |      1.5895 |   8.234e+10 |   8.234e+10 |       0.9468 |            4 |            7 |            1 |
|   15 | Accept |  8.4097e+10 |      1.6357 |   8.234e+10 |   8.234e+10 |      0.97988 |            1 |          250 |           28 |
|   16 | Accept |  8.3106e+10 |      1.6081 |   8.234e+10 |   8.234e+10 |     0.024052 |            1 |          121 |           28 |
|   17 | Accept |   8.469e+10 |      1.6235 |   8.234e+10 |   8.234e+10 |     0.047902 |            3 |            3 |           12 |
|   18 | Best   |  8.1641e+10 |      1.5833 |  8.1641e+10 |  8.1641e+10 |      0.99848 |            6 |            1 |            3 |
|   19 | Accept |  8.5957e+10 |      1.6305 |  8.1641e+10 |  8.1641e+10 |      0.99826 |            6 |            1 |           13 |
|   20 | Accept |  8.2486e+10 |      1.6515 |  8.1641e+10 |  8.1641e+10 |      0.36059 |            7 |            2 |            1 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |  8.6534e+10 |       1.647 |  8.1641e+10 |  8.1641e+10 |    0.0089186 |            1 |          192 |           18 |
|   22 | Accept |  8.5425e+10 |      1.5316 |  8.1641e+10 |  8.1641e+10 |      0.99842 |            1 |          497 |            1 |
|   23 | Accept |   8.515e+10 |      1.5728 |  8.1641e+10 |  8.1641e+10 |      0.99934 |            1 |            3 |            2 |
|   24 | Accept |   8.593e+10 |      1.6086 |  8.1641e+10 |  8.1641e+10 |    0.0099052 |            1 |            2 |           28 |
|   25 | Accept |  8.7394e+10 |       1.577 |  8.1641e+10 |  8.1641e+10 |      0.96502 |            7 |            5 |            2 |
|   26 | Accept |   8.618e+10 |      1.5714 |  8.1641e+10 |  8.1641e+10 |     0.097871 |            5 |            3 |            1 |
|   27 | Accept |  8.5704e+10 |       1.665 |  8.1641e+10 |  8.1641e+10 |     0.056356 |           10 |            6 |            3 |
|   28 | Accept |  9.5451e+10 |      2.8821 |  8.1641e+10 |  8.1641e+10 |      0.91844 |            3 |           12 |           28 |
|   29 | Accept |  8.4013e+10 |      1.5633 |  8.1641e+10 |  8.1641e+10 |      0.68016 |            6 |            1 |            1 |
|   30 | Accept |  8.3928e+10 |      1.7715 |  8.1641e+10 |  8.1641e+10 |      0.07259 |            5 |            5 |           14 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 155.1459 seconds
Total objective function evaluation time: 132.9347

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Observed objective function value = 81640836929.8637
Estimated objective function value = 81640841484.6238
Function evaluation time = 1.5833

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Estimated objective function value = 81640841484.6238
Estimated function evaluation time = 1.5784
warning(orig_state)

Получите лучшую точку от results2.

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Обучите двухмерную GAM с оптимальными параметрами

Обучите оптимизированную GAM с помощью zbest1 и zbest2 значения.

Mdl = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9741e+05
             Interactions: [3×2 double]
          NumObservations: 900


  Properties, Methods

Кроме того, можно добавить условия взаимодействия к одномерной GAM с помощью addInteractions функция.

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

Второй входной параметр задает максимальное количество членов взаимодействия, и NumTreesPerInteraction аргумент name-value задает максимальное количество деревьев на срок взаимодействия. The addInteractions функция может включать меньше условий взаимодействия и остановку перед обучением требуемого количества деревьев. Можно использовать Interactions и ReasonForTermination свойства для проверки фактического количества членов взаимодействия и количества деревьев в обученной модели.

Отобразите условия взаимодействия в Mdl.

Mdl.Interactions
ans = 3×2

     3     6
     4     6
     6     8

Каждая строка Interactions представляет один термин взаимодействия и содержит индексы столбцов переменных предиктора для термина взаимодействия. Можно использовать Interactions свойство для проверки членов в модели взаимодействия и порядка, в котором fitrgam добавляет их в модель.

Отобразите условия взаимодействия в Mdl использование имен предикторов.

Mdl.PredictorNames(Mdl.Interactions)
ans = 3×2 cell
    {'BUILDINGCLASSCATEGORY'}    {'LANDSQUAREFEET'}
    {'RESIDENTIALUNITS'     }    {'LANDSQUAREFEET'}
    {'LANDSQUAREFEET'       }    {'YEARBUILT'     }

Отобразите причину прекращения, чтобы определить, содержит ли модель заданное количество деревьев для каждого линейного термина и каждого термина взаимодействия.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

Оцените прогнозирующую эффективность при новых наблюдениях

Оцените эффективность обученной модели с помощью тестовой выборки tbl_test и функции объекта predict и loss. Можно использовать полную или компактную модель с этими функциями.

  • predict - Предсказать ответы

  • loss - Вычисление потерь регрессии (средняя квадратичная невязка, по умолчанию)

Если вы хотите оценить эффективность обучающих данных набора, используйте функции объекта реституции: resubPredict и resubLoss. Чтобы использовать эти функции, вы должны использовать полную модель, которая содержит обучающие данные.

Создайте компактную модель, чтобы уменьшить размер обученной модели.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size             Bytes  Class                                          Attributes

  CMdl      1x1             370211  classreg.learning.regr.CompactRegressionGAM              
  Mdl       1x1             528102  RegressionGAM                                            

Спрогнозируйте отклики и вычислите средние квадратичные невязки для тестовых данных набора tbl_test.

yFit = predict(CMdl,tbl_test);
L = loss(CMdl,tbl_test)
L = 1.2855e+11

Найдите предсказанные отклики и ошибки, не включая условия взаимодействия в обученную модель.

yFit_nointeraction = predict(CMdl,tbl_test,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,tbl_test,'IncludeInteractions',false)
L_nointeractions = 1.3031e+11

Модель достигает меньшей ошибки для набора тестовых данных, когда включены как линейные, так и условия взаимодействия.

Сравните результаты, полученные путем включения как линейных, так и условий взаимодействия и результатов, полученных путем включения только линейных терминов. Составьте таблицу, содержащую наблюдаемые отклики и предсказанные отклики. Отображение первых восьми строк таблицы.

t = table(tbl_test.SALEPRICE,yFit,yFit_nointeraction, ...
    'VariableNames',{'Observed Value','Predicted Response','Predicted Response Without Interactions'});
head(t)
ans=8×3 table
    Observed Value    Predicted Response    Predicted Response Without Interactions
    ______________    __________________    _______________________________________

         3.6e+05          4.9812e+05                      5.2712e+05               
         1.8e+05          2.7349e+05                      2.7415e+05               
         1.9e+05          3.3682e+05                      3.3748e+05               
        4.26e+05            6.15e+05                      5.6542e+05               
        3.91e+05          3.1262e+05                      3.1328e+05               
         2.3e+05          1.0606e+05                      1.0672e+05               
      4.7333e+05          1.0773e+06                      1.1399e+06               
           2e+05          2.9506e+05                       3.305e+05               

Интерпретируйте предсказание

Интерпретируйте предсказание для первого тестового наблюдения с помощью plotLocalEffects функция. Кроме того, создайте графики частичной зависимости для некоторых важных членов в модели с помощью plotPartialDependence функция.

Спрогнозируйте значение отклика для первого наблюдения тестовых данных и постройте график локальных эффектов членов CMdl на предсказании. Задайте 'IncludeIntercept',true включить в график термин точки пересечения.

yFit = predict(CMdl,tbl_test(1,:))
yFit = 4.9812e+05
plotLocalEffects(CMdl,tbl_test(1,:),'IncludeIntercept',true)

The predict функция возвращает цену продажи для первого наблюдения tbl_test(1,:). The plotLocalEffects функция создает горизонтальный столбчатый график, которая показывает локальные эффекты членов в CMdl на предсказании. Каждое значение локального эффекта показывает вклад каждого термина в прогнозируемую цену продажи для tbl_test(1,:).

Вычислите значения частичной зависимости для BUILDINGCLASSCATEGORY и постройте график отсортированных значений. Задайте как обучающие, так и тестовые данные, чтобы вычислить значения частичной зависимости с помощью обоих наборов.

[pd,x,y] = partialDependence(CMdl,'BUILDINGCLASSCATEGORY',[tbl_training; tbl_test]);
[pd_sorted,I] = sort(pd);
x_sorted = x(I);
x_sorted = reordercats(x_sorted,I);
figure
plot(x_sorted,pd_sorted,'o:')
xlabel('BUILDINGCLASSCATEGORY')
ylabel('SALEPRICE')
title('Patial Dependence Plot')

Построенная линия представляет усредненные частичные отношения между предиктором BUILDINGCLASSCATEGORY и ответ SALEPRICE в обученной модели.

Создайте график частичной зависимости для членов RESIDENTIALUNITS и LANDSQUAREFEET.

figure
plotPartialDependence(CMdl,["RESIDENTIALUNITS","LANDSQUAREFEET"],[tbl_training; tbl_test])

Незначительные такты на оси X (RESIDENTIALUNITS) и ось Y (LANDSQUAREFEET) представляют уникальные значения предикторов в заданных данных. Значения предиктора включают несколько выбросов и большую часть RESIDENTIALUNITS и LANDSQUAREFEET значения меньше 10 и 50 000, соответственно. График показывает, что SALEPRICE значения не изменяются значительно, когда RESIDENTIALUNITS и LANDSQUAREFEET значения больше 10 и 50 000.

См. также

| | | | | |

Похожие темы