Обучите обобщенную аддитивную модель бинарной классификации

В этом примере показано, как обучить обобщенную аддитивную модель (GAM) оптимальными параметрами и как оценить прогнозирующую эффективность обученной модели. Пример сначала находит оптимальные значения параметров для одномерного GAM (параметры для линейных членов) и затем находит значения для двумерного GAM (параметры в течение многих периодов взаимодействия). Кроме того, пример объясняет, как интерпретировать обученную модель путем исследования локальных эффектов условий на определенном предсказании и путем вычисления частичной зависимости предсказаний на предикторах.

Загрузка демонстрационных данных

Загрузите 1 994 данных о переписи, хранимых в census1994.mat. Набор данных состоит из демографических данных Бюро переписи США, чтобы предсказать, передает ли индивидуум 50 000$ в год. Задача классификации состоит в том, чтобы подобрать модель, которая предсказывает категорию зарплаты людей, учитывая их возраст, рабочий класс, образовательный уровень, семейное положение, гонку, и так далее.

load census1994

census1994 содержит обучающий набор данных adultdata и тестовые данные устанавливают adulttest. Уменьшать время выполнения для этого примера, поддемонстрационных 500 учебных наблюдений и 500 тестовых наблюдений при помощи datasample функция.

rng('default')
NumSamples = 5e2;
adultdata = datasample(adultdata,NumSamples,'Replace',false);
adulttest = datasample(adulttest,NumSamples,'Replace',false);

Найдите оптимальные параметры для одномерного GAM

Оптимизируйте параметры для одномерного GAM относительно перекрестной проверки при помощи bayesopt функция.

Подготовьте optimizableVariable объекты для аргументов значения имени одномерного GAM: MaxNumSplitsPerPredictor, NumTreesPerPredictor, и InitialLearnRateForPredictors.

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

Создайте целевую функцию, которая берет вход z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] и возвращает перекрестное подтвержденное значение потерь в параметрах в z.

minfun1 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

Если вы задаете опцию перекрестной проверки 'CrossVal','on', затем fitcgam функция возвращает перекрестный подтвержденный объект модели ClassificationPartitionedGAM. kfoldLoss функция возвращает потерю классификации, полученную перекрестной подтвержденной моделью. Поэтому указатель на функцию minfun вычисляет потерю перекрестной проверки в параметрах в z.

Ищите лучшие параметры с помощью bayesopt. Для воспроизводимости выберите 'expected-improvement-plus' функция приобретения. Функция приобретения по умолчанию зависит от времени выполнения и, поэтому, может дать различные результаты.

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |     0.18549 |      5.6957 |     0.18549 |     0.18549 |      0.73503 |            7 |           99 |
|    2 | Accept |     0.19145 |      20.383 |     0.18549 |     0.18549 |      0.72917 |           10 |          399 |
|    3 | Best   |     0.17703 |      13.412 |     0.17703 |     0.17703 |     0.079299 |            8 |          267 |
|    4 | Best   |     0.14955 |       0.402 |     0.14955 |     0.14955 |      0.24236 |            4 |            3 |
|    5 | Accept |     0.15999 |      12.363 |     0.14955 |     0.14955 |      0.25509 |            1 |          377 |
|    6 | Accept |     0.15158 |      1.5035 |     0.14955 |     0.14955 |      0.23051 |            7 |           29 |
|    7 | Accept |     0.16181 |     0.18204 |     0.14955 |     0.14955 |      0.34396 |            4 |            1 |
|    8 | Accept |     0.15079 |     0.38418 |     0.14955 |     0.14955 |      0.26669 |           10 |            5 |
|    9 | Accept |     0.16102 |     0.55525 |     0.14955 |     0.14955 |      0.26065 |            2 |           10 |
|   10 | Accept |     0.19259 |      8.6487 |     0.14955 |     0.14955 |      0.24894 |           10 |          182 |
|   11 | Accept |     0.18628 |     0.20681 |     0.14955 |     0.14955 |      0.13389 |            6 |            2 |
|   12 | Accept |     0.15653 |     0.24643 |     0.14955 |     0.14955 |      0.24172 |           10 |            2 |
|   13 | Best   |     0.14699 |     0.82743 |     0.14699 |     0.14699 |      0.26745 |            7 |           12 |
|   14 | Best   |     0.14634 |     0.47528 |     0.14634 |     0.14634 |      0.25025 |            6 |            6 |
|   15 | Best   |     0.14312 |     0.34493 |     0.14312 |     0.14312 |      0.30452 |            9 |            3 |
|   16 | Accept |     0.14334 |     0.51583 |     0.14312 |     0.14312 |      0.33507 |           10 |            7 |
|   17 | Best   |     0.13791 |     0.32248 |     0.13791 |     0.13791 |      0.33179 |            9 |            4 |
|   18 | Accept |     0.14875 |      0.3551 |     0.13791 |     0.13791 |      0.36806 |            8 |            5 |
|   19 | Accept |      0.1651 |      1.3731 |     0.13791 |     0.13791 |      0.32691 |            8 |           27 |
|   20 | Accept |     0.15895 |     0.37324 |     0.13791 |     0.13791 |      0.32985 |            7 |            5 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |     0.13946 |     0.26793 |     0.13791 |     0.13791 |      0.36721 |            9 |            3 |
|   22 | Accept |     0.16719 |      1.1276 |     0.13791 |     0.13791 |      0.25385 |            5 |           23 |
|   23 | Accept |     0.17017 |        1.35 |     0.13791 |     0.13791 |      0.23809 |            9 |           26 |
|   24 | Accept |     0.15519 |     0.46246 |     0.13791 |     0.13791 |      0.34831 |            9 |            7 |
|   25 | Accept |     0.15312 |     0.26445 |     0.13791 |     0.13791 |      0.33416 |           10 |            3 |
|   26 | Accept |     0.15852 |     0.31045 |     0.13791 |     0.13791 |       0.6142 |            9 |            4 |
|   27 | Accept |     0.16691 |     0.50559 |     0.13791 |     0.13791 |      0.31446 |            5 |            7 |
|   28 | Accept |     0.14384 |     0.35136 |     0.13791 |     0.13791 |      0.40215 |            9 |            4 |
|   29 | Accept |     0.14773 |     0.33296 |     0.13791 |     0.13791 |      0.34255 |            9 |            4 |
|   30 | Accept |     0.17604 |     0.85847 |     0.13791 |     0.13791 |      0.36565 |            6 |           15 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 97.6656 seconds
Total objective function evaluation time: 74.4022

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Observed objective function value = 0.13791
Estimated objective function value = 0.13791
Function evaluation time = 0.32248

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Estimated objective function value = 0.13791
Estimated function evaluation time = 0.33084

Получите лучшую точку из results1.

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Обучите одномерный GAM оптимальными параметрами

Обучите оптимизированный GAM с помощью zbest1 значения. Методические рекомендации должны задать имена классов.

Mdl1 = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7383
          NumObservations: 500


  Properties, Methods

Mdl1 ClassificationGAM объект модели. Отображение модели показывает частичный список свойств модели. Чтобы просмотреть полный список свойств модели, дважды кликните имя переменной Mdl1 в Рабочей области. Редактор Переменных открывается для Mdl1. В качестве альтернативы можно отобразить свойства в Командном окне при помощи записи через точку. Например, отобразите ReasonForTermination свойство.

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

PredictorTrees поле значения свойства указывает на тот Mdl1 включает конкретное количество деревьев. NumTreesPerPredictor из fitcgam задает максимальное количество деревьев на предиктор, и функция может остановиться перед обучением требуемое количество деревьев. Можно использовать ReasonForTermination свойство определить, содержит ли обученная модель конкретное количество деревьев.

Если вы задаете, чтобы включать периоды взаимодействия так, чтобы fitcgam обучает деревья им, затем InteractionTrees поле содержит непустое значение.

Найдите оптимальные параметры для двумерного GAM

Найдите параметры в течение многих периодов взаимодействия двумерного GAM при помощи bayesopt функция.

Подготовьте optimizableVariable объекты для аргументов значения имени в течение периодов взаимодействия: InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, NumTreesPerInteraction, и InitialLearnRateForInteractions.

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

Создайте целевую функцию для оптимизации. Используйте оптимальные значения параметров в zbest1 так, чтобы программное обеспечение нашло оптимальные значения параметров в течение многих периодов взаимодействия на основе zbest1 значения.

minfun2 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

Ищите лучшие параметры с помощью bayesopt. Процесс оптимизации обучает многоуровневые модели и отображает предупреждающие сообщения, если модели не включают периодов взаимодействия. Отключите все предупреждения прежде, чем вызвать bayesopt и восстановите состояние предупреждения после выполнения bayesopt. Можно оставить состояние предупреждения без изменений, чтобы просмотреть предупреждающие сообщения.

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |     0.19671 |      10.999 |     0.19671 |     0.19671 |      0.96444 |            8 |          109 |           22 |
|    2 | Best   |       0.189 |       30.57 |       0.189 |       0.189 |      0.98548 |            6 |          457 |           17 |
|    3 | Best   |     0.16538 |      18.643 |     0.16538 |     0.16538 |      0.28678 |            4 |          383 |           13 |
|    4 | Best   |     0.15243 |      0.4285 |     0.15243 |     0.15243 |      0.28044 |            1 |           45 |            3 |
|    5 | Accept |     0.16065 |     0.69005 |     0.15243 |     0.15243 |      0.20151 |            7 |           60 |            1 |
|    6 | Best   |     0.14831 |     0.36629 |     0.14831 |     0.14831 |     0.032423 |            1 |          151 |            1 |
|    7 | Accept |     0.14887 |     0.36443 |     0.14831 |     0.14831 |     0.021093 |            1 |           15 |            1 |
|    8 | Accept |     0.15039 |     0.42139 |     0.14831 |     0.14831 |     0.012128 |            2 |          482 |            1 |
|    9 | Best   |     0.14787 |     0.42482 |     0.14787 |     0.14787 |      0.10119 |            1 |          121 |            6 |
|   10 | Best   |     0.13902 |     0.38822 |     0.13902 |     0.13902 |       0.1233 |            1 |          281 |            3 |
|   11 | Accept |     0.14721 |     0.39532 |     0.13902 |     0.13902 |     0.065618 |            1 |          291 |            3 |
|   12 | Accept |     0.14586 |     0.39205 |     0.13902 |     0.13902 |      0.18711 |            1 |          117 |            1 |
|   13 | Accept |     0.15073 |       0.383 |     0.13902 |     0.13902 |      0.15072 |            1 |           15 |            3 |
|   14 | Accept |     0.14966 |     0.42744 |     0.13902 |     0.13902 |      0.17155 |            1 |          497 |            4 |
|   15 | Best   |     0.13716 |     0.37599 |     0.13716 |     0.13716 |      0.12601 |            1 |          281 |            1 |
|   16 | Accept |     0.15094 |     0.38197 |     0.13716 |     0.13716 |      0.13962 |            2 |          284 |            1 |
|   17 | Accept |     0.13972 |      4.5994 |     0.13716 |     0.13716 |    0.0028545 |            5 |          481 |            2 |
|   18 | Accept |     0.14788 |      31.639 |     0.13716 |     0.13716 |    0.0024433 |            6 |          489 |           15 |
|   19 | Accept |     0.14565 |       1.276 |     0.13716 |     0.13716 |     0.013118 |            5 |          257 |            1 |
|   20 | Accept |     0.16502 |      28.315 |     0.13716 |     0.13716 |    0.0063353 |            4 |          457 |           16 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |     0.15693 |      4.9653 |     0.13716 |     0.13716 |     0.016486 |            6 |          466 |            2 |
|   22 | Accept |     0.16312 |      29.942 |     0.13716 |     0.13716 |     0.019904 |            5 |          488 |           15 |
|   23 | Accept |     0.15719 |      4.7423 |     0.13716 |     0.13716 |     0.020155 |            4 |          456 |            3 |
|   24 | Best   |       0.129 |      6.4419 |       0.129 |       0.129 |     0.090858 |            5 |          478 |            3 |
|   25 | Accept |     0.15118 |      6.6757 |       0.129 |       0.129 |      0.15943 |            5 |          494 |            3 |
|   26 | Accept |     0.15343 |      2.2035 |       0.129 |       0.129 |     0.070349 |            5 |          489 |            1 |
|   27 | Best   |     0.12879 |      6.8017 |     0.12879 |     0.12879 |     0.091985 |            5 |          387 |            4 |
|   28 | Accept |     0.19093 |      5.9262 |     0.12879 |     0.12879 |     0.067405 |            5 |          331 |            4 |
|   29 | Accept |     0.16767 |      6.3779 |     0.12879 |     0.12879 |      0.31419 |            5 |          472 |            3 |
|   30 | Accept |     0.17636 |      11.026 |     0.12879 |     0.12879 |     0.054697 |            5 |          383 |            7 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 239.1035 seconds
Total objective function evaluation time: 216.5833

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Observed objective function value = 0.12879
Estimated objective function value = 0.12879
Function evaluation time = 6.8017

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Estimated objective function value = 0.12879
Estimated function evaluation time = 6.7245
warning(orig_state)

Получите лучшую точку из results2.

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Обучите двумерный GAM оптимальными параметрами

Обучите оптимизированный GAM с помощью zbest1 и zbest2 значения.

Mdl = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7755
             Interactions: [4×2 double]
          NumObservations: 500


  Properties, Methods

В качестве альтернативы можно добавить периоды взаимодействия в одномерный GAM при помощи addInteractions функция.

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

Второй входной параметр задает максимальное количество периодов взаимодействия и NumTreesPerInteraction аргумент значения имени задает максимальное количество деревьев в период взаимодействия. addInteractions функция может включать меньше периодов взаимодействия и остановки перед обучением требуемое количество деревьев. Можно использовать Interactions и ReasonForTermination свойства проверять фактический номер периодов взаимодействия и количество деревьев в обученной модели.

Отобразите периоды взаимодействия в Mdl.

Mdl.Interactions
ans = 4×2

     7    10
     4     7
     7     9
     5    10

Каждая строка Interactions представляет один период взаимодействия и содержит индексы столбца переменных предикторов в течение периода взаимодействия. Можно использовать Interactions свойство проверять периоды взаимодействия в модель и порядок, в который fitcgam добавляет их в модель.

Отобразите периоды взаимодействия в Mdl использование имен предиктора.

Mdl.PredictorNames(Mdl.Interactions)
ans = 4×2 cell
    {'relationship'  }    {'capital_gain'}
    {'education_num' }    {'relationship'}
    {'relationship'  }    {'sex'         }
    {'marital_status'}    {'capital_gain'}

Отобразите причину завершения, чтобы определить, содержит ли модель конкретное количество деревьев для каждого линейного члена и каждый период взаимодействия.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

Оцените прогнозирующую эффективность на новых наблюдениях

Оцените эффективность обученной модели при помощи тестовой выборки adulttest и объект функционирует predict, loss, edge, и margin. Можно использовать полную или компактную модель с этими функциями.

  • predict — Классифицируйте наблюдения

  • loss — Вычислите потерю классификации (misclassification уровень в десятичном числе, по умолчанию)

  • margin — Вычислите поля классификации

  • edge — Вычислите ребро классификации (среднее значение полей классификации)

Если вы хотите оценить эффективность обучающего набора данных, используйте функции объекта перезамены: resubPredict, resubLoss, resubMargin, и resubEdge. Чтобы использовать эти функции, необходимо использовать полную модель, которая содержит обучающие данные.

Создайте компактную модель, чтобы уменьшать размер обученной модели.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size              Bytes  Class                                                 Attributes

  CMdl      1x1             3272176  classreg.learning.classif.CompactClassificationGAM              
  Mdl       1x1             3389515  ClassificationGAM                                               

Предскажите метки и музыку к набору тестовых данных (adulttest), и вычислите статистику модели (потеря, поле и ребро) использование набора тестовых данных.

[labels,scores] = predict(CMdl,adulttest);
L = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt);
M = margin(CMdl,adulttest);
E = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt);

Предскажите метки и баллы и вычислите статистику без включения периодов взаимодействия в обученной модели.

[labels_nointeraction,scores_nointeraction] = predict(CMdl,adulttest,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);
M_nointeractions = margin(CMdl,adulttest,'IncludeInteractions',false);
E_nointeractions = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);

Сравните результаты, полученные включением и линейные члены и периоды взаимодействия к результатам, полученным включением только линейных членов.

Составьте таблицу, содержащую наблюдаемые метки, предсказанные метки и баллы. Отобразите первые восемь строк таблицы.

t = table(adulttest.salary,labels,scores,labels_nointeraction,scores_nointeraction, ...
    'VariableNames',{'True Labels','Predicted Labels','Scores' ...
    'Predicted Labels without interactions','Scores without interactions'});
head(t)
ans=8×5 table
    True Labels    Predicted Labels           Scores            Predicted Labels without interactions    Scores without interactions
    ___________    ________________    _____________________    _____________________________________    ___________________________

       <=50K            <=50K          0.97921      0.020787                    <=50K                       0.98005     0.019951    
       <=50K            <=50K                1     8.258e-17                    <=50K                        0.9713     0.028696    
       <=50K            <=50K                1    1.8297e-19                    <=50K                       0.99449    0.0055054    
       <=50K            <=50K          0.87422       0.12578                    <=50K                       0.87729      0.12271    
       <=50K            <=50K                1    3.5643e-07                    <=50K                       0.99882    0.0011769    
       <=50K            <=50K          0.60371       0.39629                    <=50K                       0.77861      0.22139    
       <=50K            >50K           0.49917       0.50083                    >50K                        0.46877      0.53123    
       >50K             >50K            0.3109        0.6891                    <=50K                       0.53571      0.46429    

Создайте график беспорядка от истины, маркирует adulttest.salary и предсказанные метки.

tiledlayout(1,2);
nexttile
confusionchart(adulttest.salary,labels)
title('Linear and Interaction Terms')
nexttile
confusionchart(adulttest.salary,labels_nointeraction)
title('Linear Terms Only')

Отобразите вычисленную потерю и значения ребра.

table([L; E], [L_nointeractions; E_nointeractions], ...
    'VariableNames',{'Linear and Interaction Terms','Only Linear Terms'}, ...
    'RowNames',{'Loss','Edge'})
ans=2×2 table
            Linear and Interaction Terms    Only Linear Terms
            ____________________________    _________________

    Loss              0.14868                    0.13852     
    Edge              0.63926                    0.58405     

Модель достигает меньшей потери, когда только линейные члены включены, но достигает более высокого значения ребра, когда и линейные члены и периоды взаимодействия включены.

Отобразите распределения полей с помощью диаграмм.

figure
boxplot([M M_nointeractions],'Labels',{'Linear and Interaction Terms','Linear Terms Only'})
title('Box Plots of Test Sample Margins')

Интерпретируйте предсказание

Интерпретируйте предсказание для первого тестового наблюдения при помощи plotLocalEffects функция. Кроме того, создайте частичные графики зависимости для некоторых важных членов в модели при помощи plotPartialDependence функция.

Классифицируйте первое наблюдение за тестовыми данными и постройте локальные эффекты условий в CMdl на предсказании. Чтобы отобразить существующее подчеркивание на любое имя предиктора, измените TickLabelInterpreter значение осей к 'none'.

label = predict(CMdl,adulttest(1,:))
label = categorical
     <=50K 

f1 = figure;
plotLocalEffects(CMdl,adulttest(1,:))
f1.CurrentAxes.TickLabelInterpreter = 'none';

predict функция классифицирует первое наблюдение adulttest(1,:) как '<=50K'. plotLocalEffects функция создает горизонтальный столбчатый график, который показывает локальные эффекты 10 самых важных условий на предсказании. Каждое локальное значение эффекта показывает вклад каждого термина к классификационной оценке для '<=50K', который является логитом апостериорной вероятности, что классификацией является '<=50K' для наблюдения.

Создайте частичный график зависимости для термина age. Задайте и обучение и наборы тестовых данных, чтобы вычислить частичные значения зависимости с помощью обоих наборов.

figure
plotPartialDependence(CMdl,'age',label,[adultdata; adulttest])

Построенная линия представляет усредненные частичные отношения между предиктором age и счет класса <=50K в обученной модели. x- ось незначительные метки деления представляет уникальные значения в предикторе age.

Создайте частичные графики зависимости для условий education_num и relationship.

f2 = figure;
plotPartialDependence(CMdl,["education_num","relationship"],label,[adultdata; adulttest])
f2.CurrentAxes.TickLabelInterpreter = 'none';

График показывает частичную зависимость от education_num, который имеет различный тренд в зависимости от relationship значение.

Смотрите также

| | | | | |

Похожие темы