Обучите обобщенную аддитивную модель для двоичной классификации

Этот пример показывает, как обучить обобщенную аддитивную модель (GAM) с оптимальными параметрами и как оценить прогнозирующую эффективность обученной модели. Пример сначала находит оптимальные значения параметров для одномерного GAM (параметры для линейных членов), а затем находит значения для двухмерного GAM (параметры для членов взаимодействия). Кроме того, пример объясняет, как интерпретировать обученную модель, исследуя локальные эффекты членов на конкретном предсказании и вычисляя частичную зависимость предсказаний от предикторов.

Загрузка выборочных данных

Загрузка данных переписи 1994 года, хранящихся в census1994.mat. Набор данных состоит из демографических данных Бюро переписи населения США, чтобы предсказать, составляет ли индивидуум более 50 000 долларов в год. Задача классификации состоит в том, чтобы соответствовать модели, которая предсказывает категорию заработной платы людей с учетом их возраста, рабочего класса, уровня образования, семейного положения, расы и так далее.

load census1994

census1994 содержит обучающий набор обучающих данных adultdata и набор тестовых данных adulttest. Чтобы уменьшить время работы для этого примера, выделите 500 обучающих наблюдений и 500 тестовых наблюдений при помощи datasample функция.

rng('default')
NumSamples = 5e2;
adultdata = datasample(adultdata,NumSamples,'Replace',false);
adulttest = datasample(adulttest,NumSamples,'Replace',false);

Нахождение оптимальных параметров для одномерной GAM

Оптимизируйте параметры для одномерной GAM относительно перекрестной валидации с помощью bayesopt функция.

Подготовка optimizableVariable объекты для аргументов имя-значение одномерного GAM: MaxNumSplitsPerPredictor, NumTreesPerPredictor, и InitialLearnRateForPredictors.

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

Создайте целевую функцию, которая принимает вход z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] и возвращает перекрестно проверенное значение потерь в параметрах в z.

minfun1 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

Если вы задаете опцию перекрестной проверки 'CrossVal','on', затем fitcgam функция возвращает перекрестно проверенный объект модели ClassificationPartitionedGAM. The kfoldLoss функция возвращает классификационные потери, полученные перекрестной проверенной моделью. Поэтому указатель на функцию minfun вычисляет потери перекрестной валидации у параметров в z.

Поиск оптимальных параметров с помощью bayesopt. Для повторяемости выберите 'expected-improvement-plus' функция сбора. Функция сбора по умолчанию зависит от времени выполнения и, следовательно, может давать различные результаты.

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |     0.18549 |      5.6957 |     0.18549 |     0.18549 |      0.73503 |            7 |           99 |
|    2 | Accept |     0.19145 |      20.383 |     0.18549 |     0.18549 |      0.72917 |           10 |          399 |
|    3 | Best   |     0.17703 |      13.412 |     0.17703 |     0.17703 |     0.079299 |            8 |          267 |
|    4 | Best   |     0.14955 |       0.402 |     0.14955 |     0.14955 |      0.24236 |            4 |            3 |
|    5 | Accept |     0.15999 |      12.363 |     0.14955 |     0.14955 |      0.25509 |            1 |          377 |
|    6 | Accept |     0.15158 |      1.5035 |     0.14955 |     0.14955 |      0.23051 |            7 |           29 |
|    7 | Accept |     0.16181 |     0.18204 |     0.14955 |     0.14955 |      0.34396 |            4 |            1 |
|    8 | Accept |     0.15079 |     0.38418 |     0.14955 |     0.14955 |      0.26669 |           10 |            5 |
|    9 | Accept |     0.16102 |     0.55525 |     0.14955 |     0.14955 |      0.26065 |            2 |           10 |
|   10 | Accept |     0.19259 |      8.6487 |     0.14955 |     0.14955 |      0.24894 |           10 |          182 |
|   11 | Accept |     0.18628 |     0.20681 |     0.14955 |     0.14955 |      0.13389 |            6 |            2 |
|   12 | Accept |     0.15653 |     0.24643 |     0.14955 |     0.14955 |      0.24172 |           10 |            2 |
|   13 | Best   |     0.14699 |     0.82743 |     0.14699 |     0.14699 |      0.26745 |            7 |           12 |
|   14 | Best   |     0.14634 |     0.47528 |     0.14634 |     0.14634 |      0.25025 |            6 |            6 |
|   15 | Best   |     0.14312 |     0.34493 |     0.14312 |     0.14312 |      0.30452 |            9 |            3 |
|   16 | Accept |     0.14334 |     0.51583 |     0.14312 |     0.14312 |      0.33507 |           10 |            7 |
|   17 | Best   |     0.13791 |     0.32248 |     0.13791 |     0.13791 |      0.33179 |            9 |            4 |
|   18 | Accept |     0.14875 |      0.3551 |     0.13791 |     0.13791 |      0.36806 |            8 |            5 |
|   19 | Accept |      0.1651 |      1.3731 |     0.13791 |     0.13791 |      0.32691 |            8 |           27 |
|   20 | Accept |     0.15895 |     0.37324 |     0.13791 |     0.13791 |      0.32985 |            7 |            5 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |     0.13946 |     0.26793 |     0.13791 |     0.13791 |      0.36721 |            9 |            3 |
|   22 | Accept |     0.16719 |      1.1276 |     0.13791 |     0.13791 |      0.25385 |            5 |           23 |
|   23 | Accept |     0.17017 |        1.35 |     0.13791 |     0.13791 |      0.23809 |            9 |           26 |
|   24 | Accept |     0.15519 |     0.46246 |     0.13791 |     0.13791 |      0.34831 |            9 |            7 |
|   25 | Accept |     0.15312 |     0.26445 |     0.13791 |     0.13791 |      0.33416 |           10 |            3 |
|   26 | Accept |     0.15852 |     0.31045 |     0.13791 |     0.13791 |       0.6142 |            9 |            4 |
|   27 | Accept |     0.16691 |     0.50559 |     0.13791 |     0.13791 |      0.31446 |            5 |            7 |
|   28 | Accept |     0.14384 |     0.35136 |     0.13791 |     0.13791 |      0.40215 |            9 |            4 |
|   29 | Accept |     0.14773 |     0.33296 |     0.13791 |     0.13791 |      0.34255 |            9 |            4 |
|   30 | Accept |     0.17604 |     0.85847 |     0.13791 |     0.13791 |      0.36565 |            6 |           15 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 97.6656 seconds
Total objective function evaluation time: 74.4022

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Observed objective function value = 0.13791
Estimated objective function value = 0.13791
Function evaluation time = 0.32248

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Estimated objective function value = 0.13791
Estimated function evaluation time = 0.33084

Получите лучшую точку от results1.

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Обучите одномерную GAM с оптимальными параметрами

Обучите оптимизированную GAM с помощью zbest1 значения. Рекомендуемая практика состоит в том, чтобы задать имена классов.

Mdl1 = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7383
          NumObservations: 500


  Properties, Methods

Mdl1 является ClassificationGAM объект модели. Отображение модели показывает частичный список свойств модели. Чтобы просмотреть полный список свойств модели, дважды кликните имя переменной Mdl1 в Рабочей области. Откроется редактор Переменных для Mdl1. Также можно отобразить свойства в Командном окне с помощью записи через точку. Для примера отобразите ReasonForTermination свойство.

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

The PredictorTrees поле значения свойства указывает, что Mdl1 включает указанное количество деревьев. NumTreesPerPredictor от fitcgam задает максимальное количество деревьев на предиктор, и функция может остановиться перед обучением требуемого количества деревьев. Можно использовать ReasonForTermination свойство, чтобы определить, содержит ли обученная модель заданное количество деревьев.

Если вы задаете, чтобы включить условия взаимодействия, так что fitcgam обучает для них деревья, затем InteractionTrees поле содержит непустое значение.

Нахождение оптимальных параметров для двухмерной GAM

Найдите параметры для условий взаимодействия двухмерной GAM при помощи bayesopt функция.

Подготовка optimizableVariable объекты для аргументов имя-значение для условий взаимодействия: InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, NumTreesPerInteraction, и InitialLearnRateForInteractions.

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

Создайте целевую функцию для оптимизации. Используйте оптимальные значения параметров в zbest1 так, что программное обеспечение находит оптимальные значения параметров для членов взаимодействия на основе zbest1 значения.

minfun2 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

Поиск оптимальных параметров с помощью bayesopt. Процесс оптимизации обучает несколько моделей и отображает предупреждающие сообщения, если модели не включают условия взаимодействия. Отключите все предупреждения перед вызовом bayesopt и восстановите состояние предупреждения после запуска bayesopt. Можно оставить состояние предупреждения без изменений для просмотра предупреждающих сообщений.

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |     0.19671 |      10.999 |     0.19671 |     0.19671 |      0.96444 |            8 |          109 |           22 |
|    2 | Best   |       0.189 |       30.57 |       0.189 |       0.189 |      0.98548 |            6 |          457 |           17 |
|    3 | Best   |     0.16538 |      18.643 |     0.16538 |     0.16538 |      0.28678 |            4 |          383 |           13 |
|    4 | Best   |     0.15243 |      0.4285 |     0.15243 |     0.15243 |      0.28044 |            1 |           45 |            3 |
|    5 | Accept |     0.16065 |     0.69005 |     0.15243 |     0.15243 |      0.20151 |            7 |           60 |            1 |
|    6 | Best   |     0.14831 |     0.36629 |     0.14831 |     0.14831 |     0.032423 |            1 |          151 |            1 |
|    7 | Accept |     0.14887 |     0.36443 |     0.14831 |     0.14831 |     0.021093 |            1 |           15 |            1 |
|    8 | Accept |     0.15039 |     0.42139 |     0.14831 |     0.14831 |     0.012128 |            2 |          482 |            1 |
|    9 | Best   |     0.14787 |     0.42482 |     0.14787 |     0.14787 |      0.10119 |            1 |          121 |            6 |
|   10 | Best   |     0.13902 |     0.38822 |     0.13902 |     0.13902 |       0.1233 |            1 |          281 |            3 |
|   11 | Accept |     0.14721 |     0.39532 |     0.13902 |     0.13902 |     0.065618 |            1 |          291 |            3 |
|   12 | Accept |     0.14586 |     0.39205 |     0.13902 |     0.13902 |      0.18711 |            1 |          117 |            1 |
|   13 | Accept |     0.15073 |       0.383 |     0.13902 |     0.13902 |      0.15072 |            1 |           15 |            3 |
|   14 | Accept |     0.14966 |     0.42744 |     0.13902 |     0.13902 |      0.17155 |            1 |          497 |            4 |
|   15 | Best   |     0.13716 |     0.37599 |     0.13716 |     0.13716 |      0.12601 |            1 |          281 |            1 |
|   16 | Accept |     0.15094 |     0.38197 |     0.13716 |     0.13716 |      0.13962 |            2 |          284 |            1 |
|   17 | Accept |     0.13972 |      4.5994 |     0.13716 |     0.13716 |    0.0028545 |            5 |          481 |            2 |
|   18 | Accept |     0.14788 |      31.639 |     0.13716 |     0.13716 |    0.0024433 |            6 |          489 |           15 |
|   19 | Accept |     0.14565 |       1.276 |     0.13716 |     0.13716 |     0.013118 |            5 |          257 |            1 |
|   20 | Accept |     0.16502 |      28.315 |     0.13716 |     0.13716 |    0.0063353 |            4 |          457 |           16 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |     0.15693 |      4.9653 |     0.13716 |     0.13716 |     0.016486 |            6 |          466 |            2 |
|   22 | Accept |     0.16312 |      29.942 |     0.13716 |     0.13716 |     0.019904 |            5 |          488 |           15 |
|   23 | Accept |     0.15719 |      4.7423 |     0.13716 |     0.13716 |     0.020155 |            4 |          456 |            3 |
|   24 | Best   |       0.129 |      6.4419 |       0.129 |       0.129 |     0.090858 |            5 |          478 |            3 |
|   25 | Accept |     0.15118 |      6.6757 |       0.129 |       0.129 |      0.15943 |            5 |          494 |            3 |
|   26 | Accept |     0.15343 |      2.2035 |       0.129 |       0.129 |     0.070349 |            5 |          489 |            1 |
|   27 | Best   |     0.12879 |      6.8017 |     0.12879 |     0.12879 |     0.091985 |            5 |          387 |            4 |
|   28 | Accept |     0.19093 |      5.9262 |     0.12879 |     0.12879 |     0.067405 |            5 |          331 |            4 |
|   29 | Accept |     0.16767 |      6.3779 |     0.12879 |     0.12879 |      0.31419 |            5 |          472 |            3 |
|   30 | Accept |     0.17636 |      11.026 |     0.12879 |     0.12879 |     0.054697 |            5 |          383 |            7 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 239.1035 seconds
Total objective function evaluation time: 216.5833

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Observed objective function value = 0.12879
Estimated objective function value = 0.12879
Function evaluation time = 6.8017

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Estimated objective function value = 0.12879
Estimated function evaluation time = 6.7245
warning(orig_state)

Получите лучшую точку от results2.

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Обучите двухмерную GAM с оптимальными параметрами

Обучите оптимизированную GAM с помощью zbest1 и zbest2 значения.

Mdl = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7755
             Interactions: [4×2 double]
          NumObservations: 500


  Properties, Methods

Кроме того, можно добавить условия взаимодействия к одномерной GAM с помощью addInteractions функция.

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

Второй входной параметр задает максимальное количество членов взаимодействия, и NumTreesPerInteraction аргумент name-value задает максимальное количество деревьев на срок взаимодействия. The addInteractions функция может включать меньше условий взаимодействия и остановку перед обучением требуемого количества деревьев. Можно использовать Interactions и ReasonForTermination свойства для проверки фактического количества членов взаимодействия и количества деревьев в обученной модели.

Отобразите условия взаимодействия в Mdl.

Mdl.Interactions
ans = 4×2

     7    10
     4     7
     7     9
     5    10

Каждая строка Interactions представляет один термин взаимодействия и содержит индексы столбцов переменных предиктора для термина взаимодействия. Можно использовать Interactions свойство для проверки членов в модели взаимодействия и порядка, в котором fitcgam добавляет их в модель.

Отобразите условия взаимодействия в Mdl использование имен предикторов.

Mdl.PredictorNames(Mdl.Interactions)
ans = 4×2 cell
    {'relationship'  }    {'capital_gain'}
    {'education_num' }    {'relationship'}
    {'relationship'  }    {'sex'         }
    {'marital_status'}    {'capital_gain'}

Отобразите причину прекращения, чтобы определить, содержит ли модель заданное количество деревьев для каждого линейного термина и каждого термина взаимодействия.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

Оцените прогнозирующую эффективность при новых наблюдениях

Оцените эффективность обученной модели с помощью тестовой выборки adulttest и функции объекта predict, loss, edge, и margin. Можно использовать полную или компактную модель с этими функциями.

  • predict - Классификация наблюдений

  • loss - Вычислите классификационные потери (коэффициент неправильной классификации в десятичном числе, по умолчанию)

  • margin - Вычисление классификационных полей

  • edge - Вычислите классификационное ребро (среднее значение классификационных полей)

Если вы хотите оценить эффективность обучающих данных набора, используйте функции объекта реституции: resubPredict, resubLoss, resubMargin, и resubEdge. Чтобы использовать эти функции, вы должны использовать полную модель, которая содержит обучающие данные.

Создайте компактную модель, чтобы уменьшить размер обученной модели.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size              Bytes  Class                                                 Attributes

  CMdl      1x1             3272176  classreg.learning.classif.CompactClassificationGAM              
  Mdl       1x1             3389515  ClassificationGAM                                               

Спрогнозируйте метки и счета для тестовых данных набора (adulttest) и вычислите статистику модели (потеря, маржа и ребро) с помощью тестовых данных набора.

[labels,scores] = predict(CMdl,adulttest);
L = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt);
M = margin(CMdl,adulttest);
E = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt);

Прогнозируйте метки и счета и вычисляйте статистику, не включая условия взаимодействия в обученную модель.

[labels_nointeraction,scores_nointeraction] = predict(CMdl,adulttest,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);
M_nointeractions = margin(CMdl,adulttest,'IncludeInteractions',false);
E_nointeractions = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);

Сравните результаты, полученные путем включения как линейных, так и условий взаимодействия с результатами, полученными путем включения только линейных терминов.

Составьте таблицу, содержащую наблюдаемые метки, предсказанные метки и счета. Отображение первых восьми строк таблицы.

t = table(adulttest.salary,labels,scores,labels_nointeraction,scores_nointeraction, ...
    'VariableNames',{'True Labels','Predicted Labels','Scores' ...
    'Predicted Labels without interactions','Scores without interactions'});
head(t)
ans=8×5 table
    True Labels    Predicted Labels           Scores            Predicted Labels without interactions    Scores without interactions
    ___________    ________________    _____________________    _____________________________________    ___________________________

       <=50K            <=50K          0.97921      0.020787                    <=50K                       0.98005     0.019951    
       <=50K            <=50K                1     8.258e-17                    <=50K                        0.9713     0.028696    
       <=50K            <=50K                1    1.8297e-19                    <=50K                       0.99449    0.0055054    
       <=50K            <=50K          0.87422       0.12578                    <=50K                       0.87729      0.12271    
       <=50K            <=50K                1    3.5643e-07                    <=50K                       0.99882    0.0011769    
       <=50K            <=50K          0.60371       0.39629                    <=50K                       0.77861      0.22139    
       <=50K            >50K           0.49917       0.50083                    >50K                        0.46877      0.53123    
       >50K             >50K            0.3109        0.6891                    <=50K                       0.53571      0.46429    

Создайте график неточностей из истинных меток adulttest.salary и предсказанные метки.

tiledlayout(1,2);
nexttile
confusionchart(adulttest.salary,labels)
title('Linear and Interaction Terms')
nexttile
confusionchart(adulttest.salary,labels_nointeraction)
title('Linear Terms Only')

Отобразите вычисленные значения потерь и ребер.

table([L; E], [L_nointeractions; E_nointeractions], ...
    'VariableNames',{'Linear and Interaction Terms','Only Linear Terms'}, ...
    'RowNames',{'Loss','Edge'})
ans=2×2 table
            Linear and Interaction Terms    Only Linear Terms
            ____________________________    _________________

    Loss              0.14868                    0.13852     
    Edge              0.63926                    0.58405     

Модель достигает меньших потерь, когда включены только линейные условия, но достигает более высокого значения ребра, когда включены как линейные, так и условия взаимодействия.

Отображение распределений полей с помощью прямоугольных графиков.

figure
boxplot([M M_nointeractions],'Labels',{'Linear and Interaction Terms','Linear Terms Only'})
title('Box Plots of Test Sample Margins')

Интерпретируйте предсказание

Интерпретируйте предсказание для первого тестового наблюдения с помощью plotLocalEffects функция. Кроме того, создайте графики частичной зависимости для некоторых важных членов в модели с помощью plotPartialDependence функция.

Классифицируйте первое наблюдение тестовых данных и постройте график локальных эффектов терминов в CMdl на предсказании. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter значение осей для 'none'.

label = predict(CMdl,adulttest(1,:))
label = categorical
     <=50K 

f1 = figure;
plotLocalEffects(CMdl,adulttest(1,:))
f1.CurrentAxes.TickLabelInterpreter = 'none';

The predict функция классифицирует первые adulttest(1,:) наблюдения как '<=50K'. The plotLocalEffects функция создает горизонтальный столбчатый график, которая показывает локальные эффекты 10 наиболее важных членов на предсказание. Каждое значение локального эффекта показывает вклад каждого члена в классификационную оценку для '<=50K', который является логитом апостериорной вероятности того, что классификация '<=50K' для наблюдения.

Создайте график частичной зависимости для термина age. Задайте как обучающие, так и тестовые данные, чтобы вычислить значения частичной зависимости с помощью обоих наборов.

figure
plotPartialDependence(CMdl,'age',label,[adultdata; adulttest])

Построенная линия представляет усредненные частичные отношения между предиктором age и счет класса <=50K в обученной модели. The x-sis minor ticks представляют уникальные значения в предикторе age.

Создайте графики частичной зависимости для терминов education_num и relationship.

f2 = figure;
plotPartialDependence(CMdl,["education_num","relationship"],label,[adultdata; adulttest])
f2.CurrentAxes.TickLabelInterpreter = 'none';

График показывает частичную зависимость от education_num, который имеет разный тренд в зависимости от relationship значение.

См. также

| | | | | |

Похожие темы