exponenta event banner

Обобщенная аддитивная модель поезда для двоичной классификации

В этом примере показано, как обучить обобщенную аддитивную модель (GAM) оптимальным параметрам и как оценить прогностические характеристики обученной модели. Пример сначала находит оптимальные значения параметров для одномерного GAM (параметры для линейных членов), а затем находит значения для двухмерного GAM (параметры для членов взаимодействия). Кроме того, в примере объясняется, как интерпретировать обученную модель путем изучения локальных эффектов терминов на конкретное предсказание и путем вычисления частичной зависимости предсказаний от предикторов.

Загрузить данные образца

Загрузка данных переписи 1994 года, хранящихся в census1994.mat. Набор данных состоит из демографических данных Бюро переписи населения США для прогнозирования того, составляет ли человек более 50 000 долларов в год. Задача классификации состоит в том, чтобы соответствовать модели, которая предсказывает категорию зарплаты людей с учетом их возраста, рабочего класса, уровня образования, семейного положения, расы и так далее.

load census1994

census1994 содержит набор данных обучения adultdata и набор тестовых данных adulttest. Чтобы сократить время работы для этого примера, выполните пример 500 учебных наблюдений и 500 тестовых наблюдений с использованием datasample функция.

rng('default')
NumSamples = 5e2;
adultdata = datasample(adultdata,NumSamples,'Replace',false);
adulttest = datasample(adulttest,NumSamples,'Replace',false);

Поиск оптимальных параметров для одномерного GAM

Оптимизируйте параметры одномерного GAM в отношении перекрестной проверки с помощью bayesopt функция.

Подготовиться optimizableVariable объекты для аргументов «имя-значение» одномерного GAM: MaxNumSplitsPerPredictor, NumTreesPerPredictor, и InitialLearnRateForPredictors.

maxNumSplitsPerPredictor = optimizableVariable('maxNumSplitsPerPredictor',[1,10],'Type','integer');
numTreesPerPredictor = optimizableVariable('numTreesPerPredictor',[1,500],'Type','integer');
initialLearnRateForPredictors = optimizableVariable('initialLearnRateForPredictors',[1e-3,1],'Type','real');

Создание целевой функции, которая принимает входные данные z = [maxNumSplitsPerPredictor,numTreesPerPredictor,initialLearnRateForPredictors] и возвращает значение перекрестной проверки потерь для параметров в z.

minfun1 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',z.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',z.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',z.numTreesPerPredictor));

Если задан параметр перекрестной проверки 'CrossVal','on', то fitcgam функция возвращает объект модели с перекрестной проверкой ClassificationPartitionedGAM. kfoldLoss функция возвращает потери классификации, полученные перекрестно проверенной моделью. Поэтому дескриптор функции minfun вычисляет потери перекрестной проверки в параметрах в z.

Поиск наилучших параметров с помощью bayesopt. Для воспроизводимости выберите 'expected-improvement-plus' функция приобретения. Функция сбора данных по умолчанию зависит от времени выполнения и, следовательно, может давать различные результаты.

results1 = bayesopt(minfun1, ...
    [initialLearnRateForPredictors,maxNumSplitsPerPredictor,numTreesPerPredictor], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|    1 | Best   |     0.18549 |      5.6957 |     0.18549 |     0.18549 |      0.73503 |            7 |           99 |
|    2 | Accept |     0.19145 |      20.383 |     0.18549 |     0.18549 |      0.72917 |           10 |          399 |
|    3 | Best   |     0.17703 |      13.412 |     0.17703 |     0.17703 |     0.079299 |            8 |          267 |
|    4 | Best   |     0.14955 |       0.402 |     0.14955 |     0.14955 |      0.24236 |            4 |            3 |
|    5 | Accept |     0.15999 |      12.363 |     0.14955 |     0.14955 |      0.25509 |            1 |          377 |
|    6 | Accept |     0.15158 |      1.5035 |     0.14955 |     0.14955 |      0.23051 |            7 |           29 |
|    7 | Accept |     0.16181 |     0.18204 |     0.14955 |     0.14955 |      0.34396 |            4 |            1 |
|    8 | Accept |     0.15079 |     0.38418 |     0.14955 |     0.14955 |      0.26669 |           10 |            5 |
|    9 | Accept |     0.16102 |     0.55525 |     0.14955 |     0.14955 |      0.26065 |            2 |           10 |
|   10 | Accept |     0.19259 |      8.6487 |     0.14955 |     0.14955 |      0.24894 |           10 |          182 |
|   11 | Accept |     0.18628 |     0.20681 |     0.14955 |     0.14955 |      0.13389 |            6 |            2 |
|   12 | Accept |     0.15653 |     0.24643 |     0.14955 |     0.14955 |      0.24172 |           10 |            2 |
|   13 | Best   |     0.14699 |     0.82743 |     0.14699 |     0.14699 |      0.26745 |            7 |           12 |
|   14 | Best   |     0.14634 |     0.47528 |     0.14634 |     0.14634 |      0.25025 |            6 |            6 |
|   15 | Best   |     0.14312 |     0.34493 |     0.14312 |     0.14312 |      0.30452 |            9 |            3 |
|   16 | Accept |     0.14334 |     0.51583 |     0.14312 |     0.14312 |      0.33507 |           10 |            7 |
|   17 | Best   |     0.13791 |     0.32248 |     0.13791 |     0.13791 |      0.33179 |            9 |            4 |
|   18 | Accept |     0.14875 |      0.3551 |     0.13791 |     0.13791 |      0.36806 |            8 |            5 |
|   19 | Accept |      0.1651 |      1.3731 |     0.13791 |     0.13791 |      0.32691 |            8 |           27 |
|   20 | Accept |     0.15895 |     0.37324 |     0.13791 |     0.13791 |      0.32985 |            7 |            5 |
|====================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerP-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForPredi | PerPredictor | redictor     |
|====================================================================================================================|
|   21 | Accept |     0.13946 |     0.26793 |     0.13791 |     0.13791 |      0.36721 |            9 |            3 |
|   22 | Accept |     0.16719 |      1.1276 |     0.13791 |     0.13791 |      0.25385 |            5 |           23 |
|   23 | Accept |     0.17017 |        1.35 |     0.13791 |     0.13791 |      0.23809 |            9 |           26 |
|   24 | Accept |     0.15519 |     0.46246 |     0.13791 |     0.13791 |      0.34831 |            9 |            7 |
|   25 | Accept |     0.15312 |     0.26445 |     0.13791 |     0.13791 |      0.33416 |           10 |            3 |
|   26 | Accept |     0.15852 |     0.31045 |     0.13791 |     0.13791 |       0.6142 |            9 |            4 |
|   27 | Accept |     0.16691 |     0.50559 |     0.13791 |     0.13791 |      0.31446 |            5 |            7 |
|   28 | Accept |     0.14384 |     0.35136 |     0.13791 |     0.13791 |      0.40215 |            9 |            4 |
|   29 | Accept |     0.14773 |     0.33296 |     0.13791 |     0.13791 |      0.34255 |            9 |            4 |
|   30 | Accept |     0.17604 |     0.85847 |     0.13791 |     0.13791 |      0.36565 |            6 |           15 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 97.6656 seconds
Total objective function evaluation time: 74.4022

Best observed feasible point:
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Observed objective function value = 0.13791
Estimated objective function value = 0.13791
Function evaluation time = 0.32248

Best estimated feasible point (according to models):
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Estimated objective function value = 0.13791
Estimated function evaluation time = 0.33084

Получите наилучшую точку из results1.

zbest1 = bestPoint(results1)
zbest1=1×3 table
    initialLearnRateForPredictors    maxNumSplitsPerPredictor    numTreesPerPredictor
    _____________________________    ________________________    ____________________

               0.33179                          9                         4          

Одномерный GAM поезда с оптимальными параметрами

Обучение оптимизированного GAM с помощью zbest1 значения. Рекомендуется указывать имена классов.

Mdl1 = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor) 
Mdl1 = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7383
          NumObservations: 500


  Properties, Methods

Mdl1 является ClassificationGAM объект модели. На экране модели отображается частичный список свойств модели. Чтобы просмотреть полный список свойств модели, дважды щелкните имя переменной. Mdl1 в рабочей области. Откроется редактор переменных для Mdl1. Кроме того, свойства можно отобразить в окне команд с помощью точечной нотации. Например, отобразите ReasonForTermination собственность.

Mdl1.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: ''

PredictorTrees поле значения свойства указывает, что Mdl1 включает указанное количество деревьев. NumTreesPerPredictor из fitcgam указывает максимальное количество деревьев на один предиктор, и функция может остановиться перед обучением запрошенному количеству деревьев. Вы можете использовать ReasonForTermination , чтобы определить, содержит ли обучаемая модель указанное количество деревьев.

Если указать включить термины взаимодействия так, чтобы fitcgam тренирует деревья для них, затем InteractionTrees содержит непустое значение.

Поиск оптимальных параметров для Bivariate GAM

Найдите параметры для условий взаимодействия двухмерного GAM с помощью bayesopt функция.

Подготовиться optimizableVariable объекты для аргументов «имя-значение» для терминов взаимодействия: InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, NumTreesPerInteraction, и InitialLearnRateForInteractions.

initialLearnRateForInteractions = optimizableVariable('initialLearnRateForInteractions',[1e-3,1],'Type','real');
maxNumSplitsPerInteraction = optimizableVariable('maxNumSplitsPerInteraction',[1,10],'Type','integer');
numTreesPerInteraction = optimizableVariable('numTreesPerInteraction',[1,500],'Type','integer');
numInteractions = optimizableVariable('numInteractions',[1,28],'Type','integer');

Создайте целевую функцию для оптимизации. Использовать оптимальные значения параметров в zbest1 чтобы программное обеспечение обнаруживало оптимальные значения параметров для терминов взаимодействия на основе zbest1 значения.

minfun2 = @(z)kfoldLoss(fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'CrossVal','on', ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',z.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',z.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',z.numTreesPerInteraction, ...
    'Interactions',z.numInteractions));

Поиск наилучших параметров с помощью bayesopt. Процесс оптимизации обучает несколько моделей и отображает предупреждающие сообщения, если модели не содержат терминов взаимодействия. Отключить все предупреждения перед вызовом bayesopt и восстановить состояние предупреждения после запуска bayesopt. Для просмотра предупреждающих сообщений можно оставить состояние предупреждения неизменным.

orig_state = warning('query'); 
warning('off')
results2 = bayesopt(minfun2, ...
    [initialLearnRateForInteractions,maxNumSplitsPerInteraction,numTreesPerInteraction,numInteractions], ...
    'IsObjectiveDeterministic',true, ...
    'AcquisitionFunctionName','expected-improvement-plus');
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|    1 | Best   |     0.19671 |      10.999 |     0.19671 |     0.19671 |      0.96444 |            8 |          109 |           22 |
|    2 | Best   |       0.189 |       30.57 |       0.189 |       0.189 |      0.98548 |            6 |          457 |           17 |
|    3 | Best   |     0.16538 |      18.643 |     0.16538 |     0.16538 |      0.28678 |            4 |          383 |           13 |
|    4 | Best   |     0.15243 |      0.4285 |     0.15243 |     0.15243 |      0.28044 |            1 |           45 |            3 |
|    5 | Accept |     0.16065 |     0.69005 |     0.15243 |     0.15243 |      0.20151 |            7 |           60 |            1 |
|    6 | Best   |     0.14831 |     0.36629 |     0.14831 |     0.14831 |     0.032423 |            1 |          151 |            1 |
|    7 | Accept |     0.14887 |     0.36443 |     0.14831 |     0.14831 |     0.021093 |            1 |           15 |            1 |
|    8 | Accept |     0.15039 |     0.42139 |     0.14831 |     0.14831 |     0.012128 |            2 |          482 |            1 |
|    9 | Best   |     0.14787 |     0.42482 |     0.14787 |     0.14787 |      0.10119 |            1 |          121 |            6 |
|   10 | Best   |     0.13902 |     0.38822 |     0.13902 |     0.13902 |       0.1233 |            1 |          281 |            3 |
|   11 | Accept |     0.14721 |     0.39532 |     0.13902 |     0.13902 |     0.065618 |            1 |          291 |            3 |
|   12 | Accept |     0.14586 |     0.39205 |     0.13902 |     0.13902 |      0.18711 |            1 |          117 |            1 |
|   13 | Accept |     0.15073 |       0.383 |     0.13902 |     0.13902 |      0.15072 |            1 |           15 |            3 |
|   14 | Accept |     0.14966 |     0.42744 |     0.13902 |     0.13902 |      0.17155 |            1 |          497 |            4 |
|   15 | Best   |     0.13716 |     0.37599 |     0.13716 |     0.13716 |      0.12601 |            1 |          281 |            1 |
|   16 | Accept |     0.15094 |     0.38197 |     0.13716 |     0.13716 |      0.13962 |            2 |          284 |            1 |
|   17 | Accept |     0.13972 |      4.5994 |     0.13716 |     0.13716 |    0.0028545 |            5 |          481 |            2 |
|   18 | Accept |     0.14788 |      31.639 |     0.13716 |     0.13716 |    0.0024433 |            6 |          489 |           15 |
|   19 | Accept |     0.14565 |       1.276 |     0.13716 |     0.13716 |     0.013118 |            5 |          257 |            1 |
|   20 | Accept |     0.16502 |      28.315 |     0.13716 |     0.13716 |    0.0063353 |            4 |          457 |           16 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | initialLearn-| maxNumSplits-| numTreesPerI-| numInteracti-|
|      | result |             | runtime     | (observed)  | (estim.)    | RateForInter | PerInteracti | nteraction   | ons          |
|===================================================================================================================================|
|   21 | Accept |     0.15693 |      4.9653 |     0.13716 |     0.13716 |     0.016486 |            6 |          466 |            2 |
|   22 | Accept |     0.16312 |      29.942 |     0.13716 |     0.13716 |     0.019904 |            5 |          488 |           15 |
|   23 | Accept |     0.15719 |      4.7423 |     0.13716 |     0.13716 |     0.020155 |            4 |          456 |            3 |
|   24 | Best   |       0.129 |      6.4419 |       0.129 |       0.129 |     0.090858 |            5 |          478 |            3 |
|   25 | Accept |     0.15118 |      6.6757 |       0.129 |       0.129 |      0.15943 |            5 |          494 |            3 |
|   26 | Accept |     0.15343 |      2.2035 |       0.129 |       0.129 |     0.070349 |            5 |          489 |            1 |
|   27 | Best   |     0.12879 |      6.8017 |     0.12879 |     0.12879 |     0.091985 |            5 |          387 |            4 |
|   28 | Accept |     0.19093 |      5.9262 |     0.12879 |     0.12879 |     0.067405 |            5 |          331 |            4 |
|   29 | Accept |     0.16767 |      6.3779 |     0.12879 |     0.12879 |      0.31419 |            5 |          472 |            3 |
|   30 | Accept |     0.17636 |      11.026 |     0.12879 |     0.12879 |     0.054697 |            5 |          383 |            7 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 239.1035 seconds
Total objective function evaluation time: 216.5833

Best observed feasible point:
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Observed objective function value = 0.12879
Estimated objective function value = 0.12879
Function evaluation time = 6.8017

Best estimated feasible point (according to models):
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Estimated objective function value = 0.12879
Estimated function evaluation time = 6.7245
warning(orig_state)

Получите наилучшую точку из results2.

zbest2 = bestPoint(results2)
zbest2=1×4 table
    initialLearnRateForInteractions    maxNumSplitsPerInteraction    numTreesPerInteraction    numInteractions
    _______________________________    __________________________    ______________________    _______________

               0.091985                            5                          387                     4       

Двухмерный GAM поезда с оптимальными параметрами

Обучение оптимизированного GAM с помощью zbest1 и zbest2 значения.

Mdl = fitcgam(adultdata,'salary','Weights','fnlwgt', ...
    'ClassNames',categorical({'<=50K','>50K'}), ...
    'InitialLearnRateForPredictors',zbest1.initialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',zbest1.maxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',zbest1.numTreesPerPredictor, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction, ...   
    'Interactions',zbest2.numInteractions) 
Mdl = 
  ClassificationGAM
           PredictorNames: {'age'  'workClass'  'education'  'education_num'  'marital_status'  'occupation'  'relationship'  'race'  'sex'  'capital_gain'  'capital_loss'  'hours_per_week'  'native_country'}
             ResponseName: 'salary'
    CategoricalPredictors: [2 3 5 6 7 8 9 13]
               ClassNames: [<=50K    >50K]
           ScoreTransform: 'logit'
                Intercept: -1.7755
             Interactions: [4×2 double]
          NumObservations: 500


  Properties, Methods

Кроме того, можно добавить термины взаимодействия в одномерный GAM с помощью addInteractions функция.

Mdl2 = addInteractions(Mdl1,zbest2.numInteractions, ...
    'InitialLearnRateForInteractions',zbest2.initialLearnRateForInteractions, ...
    'MaxNumSplitsPerInteraction',zbest2.maxNumSplitsPerInteraction, ...
    'NumTreesPerInteraction',zbest2.numTreesPerInteraction); 

Второй входной аргумент указывает максимальное количество членов взаимодействия, а NumTreesPerInteraction аргумент name-value указывает максимальное количество деревьев на член взаимодействия. addInteractions функция может включать меньшее количество терминов взаимодействия и останавливаться перед обучением требуемому количеству деревьев. Вы можете использовать Interactions и ReasonForTermination свойства для проверки фактического количества членов взаимодействия и количества деревьев в обучаемой модели.

Отображение терминов взаимодействия в Mdl.

Mdl.Interactions
ans = 4×2

     7    10
     4     7
     7     9
     5    10

Каждая строка Interactions представляет один член взаимодействия и содержит индексы столбцов переменных предиктора для члена взаимодействия. Вы можете использовать Interactions свойство для проверки терминов взаимодействия в модели и порядка, в котором fitcgam добавляет их в модель.

Отображение терминов взаимодействия в Mdl используя имена предикторов.

Mdl.PredictorNames(Mdl.Interactions)
ans = 4×2 cell
    {'relationship'  }    {'capital_gain'}
    {'education_num' }    {'relationship'}
    {'relationship'  }    {'sex'         }
    {'marital_status'}    {'capital_gain'}

Отображение причины завершения для определения того, содержит ли модель указанное количество деревьев для каждого линейного члена и каждого члена взаимодействия.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Terminated after training the requested number of trees.'

Оценка прогностической эффективности новых наблюдений

Оценка эффективности обучаемой модели с использованием тестового образца adulttest и функции объекта predict, loss, edge, и margin. С помощью этих функций можно использовать полную или компактную модель.

  • predict - Классификация наблюдений

  • loss - Вычислить потерю классификации (коэффициент неправильной классификации в десятичном формате, по умолчанию)

  • margin - Вычислить поля классификации

  • edge - Вычислить границу классификации (среднее значение полей классификации)

При необходимости оценки производительности набора данных обучения используйте функции объекта повторного замещения: resubPredict, resubLoss, resubMargin, и resubEdge. Для использования этих функций необходимо использовать полную модель, содержащую данные обучения.

Создайте компактную модель, чтобы уменьшить размер обучаемой модели.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size              Bytes  Class                                                 Attributes

  CMdl      1x1             3272176  classreg.learning.classif.CompactClassificationGAM              
  Mdl       1x1             3389515  ClassificationGAM                                               

Прогнозировать метки и оценки для набора тестовых данных (adulttest) и вычислить статистику модели (потери, запас и край) с использованием набора тестовых данных.

[labels,scores] = predict(CMdl,adulttest);
L = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt);
M = margin(CMdl,adulttest);
E = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt);

Прогнозирование меток и оценок и вычисление статистики без включения терминов взаимодействия в обучаемую модель.

[labels_nointeraction,scores_nointeraction] = predict(CMdl,adulttest,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);
M_nointeractions = margin(CMdl,adulttest,'IncludeInteractions',false);
E_nointeractions = edge(CMdl,adulttest,'Weights',adulttest.fnlwgt,'IncludeInteractions',false);

Сравните результаты, полученные путем включения как линейных членов, так и членов взаимодействия, с результатами, полученными путем включения только линейных членов.

Создайте таблицу, содержащую наблюдаемые метки, прогнозируемые метки и оценки. Отображение первых восьми строк таблицы.

t = table(adulttest.salary,labels,scores,labels_nointeraction,scores_nointeraction, ...
    'VariableNames',{'True Labels','Predicted Labels','Scores' ...
    'Predicted Labels without interactions','Scores without interactions'});
head(t)
ans=8×5 table
    True Labels    Predicted Labels           Scores            Predicted Labels without interactions    Scores without interactions
    ___________    ________________    _____________________    _____________________________________    ___________________________

       <=50K            <=50K          0.97921      0.020787                    <=50K                       0.98005     0.019951    
       <=50K            <=50K                1     8.258e-17                    <=50K                        0.9713     0.028696    
       <=50K            <=50K                1    1.8297e-19                    <=50K                       0.99449    0.0055054    
       <=50K            <=50K          0.87422       0.12578                    <=50K                       0.87729      0.12271    
       <=50K            <=50K                1    3.5643e-07                    <=50K                       0.99882    0.0011769    
       <=50K            <=50K          0.60371       0.39629                    <=50K                       0.77861      0.22139    
       <=50K            >50K           0.49917       0.50083                    >50K                        0.46877      0.53123    
       >50K             >50K            0.3109        0.6891                    <=50K                       0.53571      0.46429    

Создание таблицы путаницы из истинных меток adulttest.salary и прогнозируемые метки.

tiledlayout(1,2);
nexttile
confusionchart(adulttest.salary,labels)
title('Linear and Interaction Terms')
nexttile
confusionchart(adulttest.salary,labels_nointeraction)
title('Linear Terms Only')

Отображение вычисленных значений потерь и кромок.

table([L; E], [L_nointeractions; E_nointeractions], ...
    'VariableNames',{'Linear and Interaction Terms','Only Linear Terms'}, ...
    'RowNames',{'Loss','Edge'})
ans=2×2 table
            Linear and Interaction Terms    Only Linear Terms
            ____________________________    _________________

    Loss              0.14868                    0.13852     
    Edge              0.63926                    0.58405     

Модель достигает меньших потерь, когда включены только линейные члены, но достигает более высокого значения кромки, когда включены и линейные члены, и элементы взаимодействия.

Отображение распределений полей с помощью оконных графиков.

figure
boxplot([M M_nointeractions],'Labels',{'Linear and Interaction Terms','Linear Terms Only'})
title('Box Plots of Test Sample Margins')

Интерпретировать предсказание

Интерпретировать прогноз для первого тестового наблюдения с помощью plotLocalEffects функция. Кроме того, создайте графики частичной зависимости для некоторых важных терминов в модели с помощью plotPartialDependence функция.

Классифицируйте первое наблюдение тестовых данных и постройте график локальных эффектов терминов в CMdl по прогнозу. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter значение осей для 'none'.

label = predict(CMdl,adulttest(1,:))
label = categorical
     <=50K 

f1 = figure;
plotLocalEffects(CMdl,adulttest(1,:))
f1.CurrentAxes.TickLabelInterpreter = 'none';

predict функция классифицирует первое наблюдение adulttest(1,:) как '<=50K'. plotLocalEffects функция создает горизонтальную гистограмму, которая показывает локальные эффекты 10 наиболее важных терминов на прогноз. Каждое локальное значение эффекта показывает вклад каждого термина в оценку классификации для '<=50K', который является логитом задней вероятности того, что классификация '<=50K' для наблюдения.

Создание графика частичной зависимости для термина age. Укажите наборы обучающих и тестовых данных для вычисления значений частичной зависимости с использованием обоих наборов.

figure
plotPartialDependence(CMdl,'age',label,[adultdata; adulttest])

Построенная на графике линия представляет усредненные частичные отношения между предиктором age и оценка класса <=50K в обучаемой модели. x-axis minor ticks представляет уникальные значения в предикторе age.

Создание графиков частичной зависимости для терминов education_num и relationship.

f2 = figure;
plotPartialDependence(CMdl,["education_num","relationship"],label,[adultdata; adulttest])
f2.CurrentAxes.TickLabelInterpreter = 'none';

На сюжете показана частичная зависимость от education_num, который имеет различную тенденцию в зависимости от relationship значение.

См. также

| | | | | |

Связанные темы