exponenta event banner

Обобщенная аддитивная модель поезда для регрессии

В этом примере показано, как обучить обобщенную аддитивную модель (GAM) оптимальным параметрам и как оценить прогностические характеристики обученной модели. Пример сначала находит оптимальные значения параметров для одномерного GAM (параметры для линейных членов), а затем находит значения для двухмерного GAM (параметры для членов взаимодействия). Кроме того, в примере объясняется, как интерпретировать обученную модель, изучая локальные эффекты терминов на конкретное предсказание и вычисляя частичную зависимость предсказаний от предикторов.

Загрузить данные образца

Загрузить набор данных образца NYCHousing2015.

load NYCHousing2015

Набор данных включает 10 переменных с информацией о продажах недвижимости в Нью-Йорке в 2015 году. В этом примере эти переменные используются для анализа продажных цен (SALEPRICE).

Выполните предварительную обработку набора данных. Предположим, что SALEPRICE не более 1000 долл. США означает передачу права собственности без денежного вознаграждения. Удалить образцы, которые имеют это SALEPRICE. Кроме того, удалите отклонения, определенные isoutlier функция. Затем преобразуйте datetime массив (SALEDATE) на номера месяцев и переместите переменную ответа (SALEPRICE) к последнему столбцу. Изменить нули в LANDSQUAREFEET, GROSSSQUAREFEET, и YEARBUILT кому NaNs.

idx1 = NYCHousing2015.SALEPRICE <= 1000;
idx2 = isoutlier(NYCHousing2015.SALEPRICE);
NYCHousing2015(idx1|idx2,:) = [];
NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE);
NYCHousing2015 = movevars(NYCHousing2015,'SALEPRICE','After','SALEDATE');
NYCHousing2015.LANDSQUAREFEET(NYCHousing2015.LANDSQUAREFEET == 0) = NaN; 
NYCHousing2015.GROSSSQUAREFEET(NYCHousing2015.GROSSSQUAREFEET == 0) = NaN; 
NYCHousing2015.YEARBUILT(NYCHousing2015.YEARBUILT == 0) = NaN; 

Отображение первых трех строк таблицы.

head(NYCHousing2015,3)
ans=3×10 table
    BOROUGH    NEIGHBORHOOD       BUILDINGCLASSCATEGORY        RESIDENTIALUNITS    COMMERCIALUNITS    LANDSQUAREFEET    GROSSSQUAREFEET    YEARBUILT    SALEDATE    SALEPRICE
    _______    ____________    ____________________________    ________________    _______________    ______________    _______________    _________    ________    _________

       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   0                1103              1290            1910          2           3e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   1                2500              2452            1910          7           4e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   2                1911              4080            1931          1         5.1e+05 

Случайный выбор 1000 образцов с помощью datasample функция и секционирование наблюдений в обучающий набор и тестовый набор с помощью cvpartition функция. Укажите пробу удержания 10% для тестирования.

rng('default') % For reproducibility
NumSamples = 1e3;
NYCHousing2015 = datasample(NYCHousing2015,NumSamples,'Replace',false);
cv = cvpartition(size(NYCHousing2015,1),'HoldOut',0.10);

Извлеките показатели обучения и тестирования и создайте таблицы для наборов данных обучения и тестирования.

tbl_training = NYCHousing2015(training(cv),:);
tbl_test = NYCHousing2015(test(cv),:);

Поиск оптимальных параметров для одномерного GAM

Оптимизируйте параметры одномерного GAM в отношении перекрестной проверки с помощью bayesopt функция.

Подготовиться optimizableVariable объекты для аргументов «имя-значение» одномерного GAM: MaxNumSplitsPerPredictor, NumTreesPerPredictor, и InitialLearnRateForPredictors.

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

Создание целевой функции, которая принимает входные данные z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] и возвращает значение перекрестной проверки потерь для параметров в z.

minfun1 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

Если задан параметр перекрестной проверки 'CrossVal','on', то fitrgam функция возвращает объект модели с перекрестной проверкой RegressionPartitionedGAM. kfoldLoss функция возвращает потерю регрессии (среднеквадратичную ошибку), полученную перекрестно проверенной моделью. Поэтому дескриптор функции minfun1 вычисляет потери перекрестной проверки в параметрах в z.

Поиск наилучших параметров с помощью bayesopt. Для воспроизводимости выберите 'expected-improvement-plus' функция приобретения. Функция сбора данных по умолчанию зависит от времени выполнения и, следовательно, может давать различные результаты.

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |  8.4558e+10 |      1.5106 |  8.4558e+10 |  8.4558e+10 |      0.36695 |            2 |           30 |
|    2 | Accept |  8.6891e+10 |       12.01 |  8.4558e+10 |  8.4558e+10 |     0.008213 |            5 |          271 |
|    3 | Accept |  9.6521e+10 |      1.9121 |  8.4558e+10 |  8.4558e+10 |      0.22984 |            9 |           37 |
|    4 | Accept |  1.3402e+11 |      14.388 |  8.4558e+10 |  8.4558e+10 |      0.99932 |            3 |          344 |
|    5 | Accept |  8.7852e+10 |      13.595 |  8.4558e+10 |  8.4558e+10 |      0.16575 |            1 |          456 |
|    6 | Accept |  9.3041e+10 |      11.002 |  8.4558e+10 |  8.4558e+10 |      0.49477 |            1 |          360 |
|    7 | Accept |  1.0558e+11 |      7.7647 |  8.4558e+10 |  8.4558e+10 |      0.24562 |            4 |          175 |
|    8 | Accept |  8.8841e+10 |      1.5763 |  8.4558e+10 |  8.4558e+10 |      0.39298 |            2 |           41 |
|    9 | Accept |  9.9227e+10 |      14.377 |  8.4558e+10 |  8.4558e+10 |     0.091879 |            3 |          358 |
|   10 | Accept |  9.8611e+10 |     0.14914 |  8.4558e+10 |  8.4558e+10 |      0.22487 |            2 |            2 |
|   11 | Accept |  1.2998e+11 |      23.962 |  8.4558e+10 |  8.4558e+10 |      0.25341 |            5 |          500 |
|   12 | Accept |  8.8968e+10 |      5.0028 |  8.4558e+10 |  8.4558e+10 |      0.33109 |            1 |          175 |
|   13 | Accept |  1.2018e+11 |      1.8004 |  8.4558e+10 |  8.4558e+10 |    0.0030413 |            6 |           40 |
|   14 | Accept |  8.7503e+10 |     0.79283 |  8.4558e+10 |  8.4558e+10 |      0.33877 |            1 |           25 |
|   15 | Accept |  9.3798e+10 |      2.9578 |  8.4558e+10 |  8.4558e+10 |      0.32926 |            2 |           80 |
|   16 | Accept |  9.5165e+10 |      8.0635 |  8.4558e+10 |  8.4558e+10 |      0.33878 |            1 |          282 |
|   17 | Best   |  8.3549e+10 |     0.24446 |  8.3549e+10 |  8.3549e+10 |       0.3552 |            2 |            5 |
|   18 | Best   |  8.3104e+10 |      1.4534 |  8.3104e+10 |  8.3104e+10 |       0.2526 |            1 |           49 |
|   19 | Accept |  8.6938e+10 |      3.3234 |  8.3104e+10 |  8.3104e+10 |      0.18293 |            1 |          110 |
|   20 | Accept |  8.7531e+10 |      2.8096 |  8.3104e+10 |  8.3104e+10 |       0.2781 |            1 |           93 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |  9.1613e+10 |      13.347 |  8.3104e+10 |  8.3104e+10 |      0.31722 |            1 |          464 |
|   22 | Accept |   8.678e+10 |      10.358 |  8.3104e+10 |  8.3104e+10 |      0.11269 |            1 |          358 |
|   23 | Accept |  8.3614e+10 |     0.47001 |  8.3104e+10 |  8.3104e+10 |      0.22278 |            1 |           14 |
|   24 | Accept |  1.3203e+11 |       1.069 |  8.3104e+10 |  8.3104e+10 |    0.0021552 |            5 |           23 |
|   25 | Accept |    8.66e+10 |       7.233 |  8.3104e+10 |  8.3104e+10 |      0.11469 |            1 |          236 |
|   26 | Accept |  8.4535e+10 |      8.7657 |  8.3104e+10 |  8.3104e+10 |    0.0090628 |            1 |          292 |
|   27 | Accept |  1.0315e+11 |      12.297 |  8.3104e+10 |  8.3104e+10 |    0.0014094 |            1 |          413 |
|   28 | Accept |  9.6736e+10 |      5.8323 |  8.3104e+10 |  8.3104e+10 |    0.0040429 |            1 |          202 |
|   29 | Accept |  8.3651e+10 |      8.4999 |  8.3104e+10 |  8.3104e+10 |      0.09375 |            1 |          295 |
|   30 | Accept |  8.7977e+10 |      13.521 |  8.3104e+10 |  8.3104e+10 |     0.016448 |            6 |          292 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 245.1541 seconds
Total objective function evaluation time: 210.0881

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Observed objective function value = 83103839919.908
Estimated objective function value = 83103840296.3186
Function evaluation time = 1.4534

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Estimated objective function value = 83103840296.3186
Estimated function evaluation time = 1.803

Получите наилучшую точку из results1.

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.2526                           1                         49         

Одномерный GAM поезда с оптимальными параметрами

Обучение оптимизированного GAM с помощью zbest1 значения.

Mdl1 = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9806e+05
          NumObservations: 900


  Properties, Methods

Mdl1 является RegressionGAM объект модели. На экране модели отображается частичный список свойств модели. Чтобы просмотреть полный список свойств, дважды щелкните имя переменной Mdl1 в рабочей области. Откроется редактор переменных для Mdl1. Кроме того, свойства можно отобразить в окне команд с помощью точечной нотации. Например, отобразите ReasonForTermination собственность.

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

PredictorTrees поле значения свойства указывает, что Mdl1 включает указанное количество деревьев. NumTreesPerPredictor из fitrgam указывает максимальное количество деревьев на один предиктор, и функция может остановиться перед обучением запрошенному количеству деревьев. Вы можете использовать ReasonForTermination , чтобы определить, содержит ли обучаемая модель указанное количество деревьев.

Если указать включить термины взаимодействия так, чтобы fitrgam тренирует для них деревья. InteractionTrees содержит непустое значение.

Поиск оптимальных параметров для Bivariate GAM

Найдите параметры для условий взаимодействия двухмерного GAM с помощью bayesopt функция.

Подготовиться optimizableVariable для аргументов «имя-значение» для терминов взаимодействия: InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, NumTreesPerInteraction, и InitialLearnRateForInteractions.

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

Создайте целевую функцию для оптимизации. Использовать оптимальные значения параметров в zbest1 чтобы программное обеспечение обнаруживало оптимальные значения параметров для терминов взаимодействия на основе zbest1 значения.

minfun2 = @(z)kfoldLoss(fitrgam(tbl_training,'SALEPRICE', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

Поиск наилучших параметров с помощью bayesopt. Процесс оптимизации обучает несколько моделей и отображает предупреждающие сообщения, если модели не содержат терминов взаимодействия. Отключить все предупреждения перед вызовом bayesopt и восстановить состояние предупреждения после запуска bayesopt. Для просмотра предупреждающих сообщений можно оставить состояние предупреждения неизменным.

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |  8.4721e+10 |      1.6996 |  8.4721e+10 |  8.4721e+10 |      0.41774 |            1 |          346 |           28 |
|    2 | Accept |  9.1765e+10 |      8.3313 |  8.4721e+10 |  8.4721e+10 |       0.9565 |            3 |          231 |           14 |
|    3 | Accept |  9.2116e+10 |      2.8341 |  8.4721e+10 |  8.4721e+10 |      0.33578 |            9 |           45 |            5 |
|    4 | Accept |   1.784e+11 |      76.237 |  8.4721e+10 |  8.4721e+10 |      0.91186 |           10 |          479 |           27 |
|    5 | Accept |  8.4906e+10 |      1.8275 |  8.4721e+10 |  8.4721e+10 |        0.296 |            4 |            1 |           27 |
|    6 | Best   |  8.4172e+10 |        1.73 |  8.4172e+10 |  8.4172e+10 |      0.68133 |            1 |           86 |            1 |
|    7 | Best   |   8.234e+10 |      1.7164 |   8.234e+10 |   8.234e+10 |      0.13943 |            1 |          228 |           26 |
|    8 | Accept |  8.3488e+10 |      1.6382 |   8.234e+10 |   8.234e+10 |      0.46764 |            1 |            1 |            5 |
|    9 | Accept |  8.7977e+10 |      1.5655 |   8.234e+10 |   8.234e+10 |       0.8385 |           10 |            1 |            5 |
|   10 | Accept |  8.4431e+10 |      1.5744 |   8.234e+10 |   8.234e+10 |      0.95535 |            1 |          261 |            4 |
|   11 | Accept |  8.5784e+10 |      1.7478 |   8.234e+10 |   8.234e+10 |     0.023058 |            7 |            1 |           14 |
|   12 | Accept |  8.6068e+10 |      1.7304 |   8.234e+10 |   8.234e+10 |      0.77118 |            1 |            5 |           28 |
|   13 | Accept |  8.7004e+10 |      1.5903 |   8.234e+10 |   8.234e+10 |     0.016991 |            1 |          263 |            2 |
|   14 | Accept |  8.3325e+10 |      1.5895 |   8.234e+10 |   8.234e+10 |       0.9468 |            4 |            7 |            1 |
|   15 | Accept |  8.4097e+10 |      1.6357 |   8.234e+10 |   8.234e+10 |      0.97988 |            1 |          250 |           28 |
|   16 | Accept |  8.3106e+10 |      1.6081 |   8.234e+10 |   8.234e+10 |     0.024052 |            1 |          121 |           28 |
|   17 | Accept |   8.469e+10 |      1.6235 |   8.234e+10 |   8.234e+10 |     0.047902 |            3 |            3 |           12 |
|   18 | Best   |  8.1641e+10 |      1.5833 |  8.1641e+10 |  8.1641e+10 |      0.99848 |            6 |            1 |            3 |
|   19 | Accept |  8.5957e+10 |      1.6305 |  8.1641e+10 |  8.1641e+10 |      0.99826 |            6 |            1 |           13 |
|   20 | Accept |  8.2486e+10 |      1.6515 |  8.1641e+10 |  8.1641e+10 |      0.36059 |            7 |            2 |            1 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |  8.6534e+10 |       1.647 |  8.1641e+10 |  8.1641e+10 |    0.0089186 |            1 |          192 |           18 |
|   22 | Accept |  8.5425e+10 |      1.5316 |  8.1641e+10 |  8.1641e+10 |      0.99842 |            1 |          497 |            1 |
|   23 | Accept |   8.515e+10 |      1.5728 |  8.1641e+10 |  8.1641e+10 |      0.99934 |            1 |            3 |            2 |
|   24 | Accept |   8.593e+10 |      1.6086 |  8.1641e+10 |  8.1641e+10 |    0.0099052 |            1 |            2 |           28 |
|   25 | Accept |  8.7394e+10 |       1.577 |  8.1641e+10 |  8.1641e+10 |      0.96502 |            7 |            5 |            2 |
|   26 | Accept |   8.618e+10 |      1.5714 |  8.1641e+10 |  8.1641e+10 |     0.097871 |            5 |            3 |            1 |
|   27 | Accept |  8.5704e+10 |       1.665 |  8.1641e+10 |  8.1641e+10 |     0.056356 |           10 |            6 |            3 |
|   28 | Accept |  9.5451e+10 |      2.8821 |  8.1641e+10 |  8.1641e+10 |      0.91844 |            3 |           12 |           28 |
|   29 | Accept |  8.4013e+10 |      1.5633 |  8.1641e+10 |  8.1641e+10 |      0.68016 |            6 |            1 |            1 |
|   30 | Accept |  8.3928e+10 |      1.7715 |  8.1641e+10 |  8.1641e+10 |      0.07259 |            5 |            5 |           14 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 155.1459 seconds
Total objective function evaluation time: 132.9347

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Observed objective function value = 81640836929.8637
Estimated objective function value = 81640841484.6238
Function evaluation time = 1.5833

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Estimated objective function value = 81640841484.6238
Estimated function evaluation time = 1.5784
warning(orig_state)

Получите наилучшую точку из results2.

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

                0.99848                            6                           1                      3       

Двухмерный GAM поезда с оптимальными параметрами

Обучение оптимизированного GAM с помощью zbest1 и zbest2 значения.

Mdl = fitrgam(tbl_training,'SALEPRICE', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  RegressionGAM
           PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
             ResponseName: 'SALEPRICE'
    CategoricalPredictors: [2 3]
        ResponseTransform: 'none'
                Intercept: 4.9741e+05
             Interactions: [3×2 double]
          NumObservations: 900


  Properties, Methods

Кроме того, можно добавить термины взаимодействия в одномерный GAM с помощью addInteractions функция.

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

Второй входной аргумент указывает максимальное количество членов взаимодействия, а NumTreesPerInteraction аргумент name-value указывает максимальное количество деревьев на член взаимодействия. addInteractions функция может включать меньшее количество терминов взаимодействия и останавливаться перед обучением требуемому количеству деревьев. Вы можете использовать Interactions и ReasonForTermination свойства для проверки фактического количества членов взаимодействия и количества деревьев в обучаемой модели.

Отображение терминов взаимодействия в Mdl.

Mdl.Interactions
ans = 3×2

     3     6
     4     6
     6     8

Каждая строка Interactions представляет один член взаимодействия и содержит индексы столбцов переменных предиктора для члена взаимодействия. Вы можете использовать Interactions свойство для проверки терминов взаимодействия в модели и порядка, в котором fitrgam добавляет их в модель.

Отображение терминов взаимодействия в Mdl используя имена предикторов.

Mdl.PredictorNames(Mdl.Interactions)
ans = 3×2 cell
    {'BUILDINGCLASSCATEGORY'}    {'LANDSQUAREFEET'}
    {'RESIDENTIALUNITS'     }    {'LANDSQUAREFEET'}
    {'LANDSQUAREFEET'       }    {'YEARBUILT'     }

Отображение причины завершения для определения того, содержит ли модель указанное количество деревьев для каждого линейного члена и каждого члена взаимодействия.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

Оценка прогностической эффективности новых наблюдений

Оценка эффективности обучаемой модели с использованием тестового образца tbl_test и функции объекта predict и loss. С помощью этих функций можно использовать полную или компактную модель.

  • predict - Прогнозирование ответов

  • loss - Вычислить потери регрессии (среднеквадратическая ошибка, по умолчанию)

При необходимости оценки производительности набора данных обучения используйте функции объекта повторного замещения: resubPredict и resubLoss. Для использования этих функций необходимо использовать полную модель, содержащую данные обучения.

Создайте компактную модель, чтобы уменьшить размер обучаемой модели.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size             Bytes  Class                                          Attributes

  CMdl      1x1             370211  classreg.learning.regr.CompactRegressionGAM              
  Mdl       1x1             528102  RegressionGAM                                            

Прогнозирование ответов и вычисление среднеквадратичных ошибок для набора тестовых данных tbl_test.

yFit = predict(CMdl,tbl_test);
L = loss(CMdl,tbl_test)
L = 1.2855e+11

Поиск прогнозируемых ответов и ошибок без включения терминов взаимодействия в обучаемую модель.

yFit_nointeraction = predict(CMdl,tbl_test,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,tbl_test,'IncludeInteractions',false)
L_nointeractions = 1.3031e+11

В модели достигается меньшая ошибка для набора тестовых данных, когда включены как линейные, так и интерактивные термины.

Сравните результаты, полученные путем включения как линейных членов во взаимодействие, так и результаты, полученные путем включения только линейных членов. Создайте таблицу, содержащую наблюдаемые и прогнозируемые ответы. Отображение первых восьми строк таблицы.

t = table(tbl_test.SALEPRICE,yFit,yFit_nointeraction, ...
    'VariableNames',{'Observed Value','Predicted Response','Predicted Response Without Interactions'});
head(t)
ans=8×3 table
    Observed Value    Predicted Response    Predicted Response Without Interactions
    ______________    __________________    _______________________________________

         3.6e+05          4.9812e+05                      5.2712e+05               
         1.8e+05          2.7349e+05                      2.7415e+05               
         1.9e+05          3.3682e+05                      3.3748e+05               
        4.26e+05            6.15e+05                      5.6542e+05               
        3.91e+05          3.1262e+05                      3.1328e+05               
         2.3e+05          1.0606e+05                      1.0672e+05               
      4.7333e+05          1.0773e+06                      1.1399e+06               
           2e+05          2.9506e+05                       3.305e+05               

Интерпретировать предсказание

Интерпретировать прогноз для первого тестового наблюдения с помощью plotLocalEffects функция. Кроме того, создайте графики частичной зависимости для некоторых важных терминов в модели с помощью plotPartialDependence функция.

Спрогнозировать значение ответа для первого наблюдения тестовых данных и построить график локальных эффектов терминов в CMdl по прогнозу. Определить 'IncludeIntercept',true включить в сюжет термин перехвата.

yFit = predict(CMdl,tbl_test(1,:))
yFit = 4.9812e+05
plotLocalEffects(CMdl,tbl_test(1,:),'IncludeIntercept',true)

predict функция возвращает продажную цену для первого наблюдения tbl_test(1,:). plotLocalEffects функция создает горизонтальную гистограмму, которая показывает локальные эффекты терминов в CMdl по прогнозу. Каждое значение локального эффекта показывает вклад каждого срока в прогнозируемую цену продажи для tbl_test(1,:).

Вычислить значения частичной зависимости для BUILDINGCLASSCATEGORY и постройте график отсортированных значений. Укажите наборы обучающих и тестовых данных для вычисления значений частичной зависимости с использованием обоих наборов.

[pd,x,y] = partialDependence(CMdl,'BUILDINGCLASSCATEGORY',[tbl_training; tbl_test]);
[pd_sorted,I] = sort(pd);
x_sorted = x(I);
x_sorted = reordercats(x_sorted,I);
figure
plot(x_sorted,pd_sorted,'o:')
xlabel('BUILDINGCLASSCATEGORY')
ylabel('SALEPRICE')
title('Patial Dependence Plot')

Построенная на графике линия представляет усредненные частичные отношения между предиктором BUILDINGCLASSCATEGORY и ответ SALEPRICE в обучаемой модели.

Создание графика частичной зависимости для терминов RESIDENTIALUNITS и LANDSQUAREFEET.

figure
plotPartialDependence(CMdl,["RESIDENTIALUNITS","LANDSQUAREFEET"],[tbl_training; tbl_test])

Минорные засечки на оси X (RESIDENTIALUNITS) и ось y (LANDSQUAREFEET) представляют уникальные значения предикторов в указанных данных. Предикторные значения включают несколько отклонений, и большая часть RESIDENTIALUNITS и LANDSQUAREFEET значения меньше 10 и 50000 соответственно. График показывает, что SALEPRICE значения не изменяются значительно, когда RESIDENTIALUNITS и LANDSQUAREFEET значения больше 10 и 50 000.

См. также

| | | | | |

Связанные темы