Объедините неоднородные модели в сложенный ансамбль

В этом примере показано, как создать несколько моделей машинного обучения для данного обучающего набора данных, и затем объединиться, модели с помощью метода вызвали укладку, чтобы улучшить точность относительно набора тестовых данных по сравнению с точностью отдельных моделей.

Укладка является методом, используемым, чтобы объединить несколько неоднородных моделей по образованию дополнительная модель, часто называемая сложенной моделью ансамбля или сложенным учеником, на k-сгибе перекрестные подтвержденные предсказания (классификационные оценки для моделей классификации и предсказанные ответы для моделей регрессии) исходных (основных) моделей. Концепция позади укладки - то, что определенные модели могут правильно классифицировать тестовое наблюдение, в то время как другие могут не сделать так. Алгоритм извлекает уроки из этого разнообразия предсказаний и попыток объединить модели, чтобы улучшить предсказанную точность базовых моделей.

В этом примере вы обучаете несколько неоднородных моделей классификации на наборе данных, и затем комбинируете укладку использования моделей.

Загрузка демонстрационных данных

Этот пример использует 1 994 данных о переписи, хранимых в census1994.mat. Набор данных состоит из демографических данных Бюро переписи США, чтобы предсказать, передает ли индивидуум 50 000$ в год. Задача классификации состоит в том, чтобы подобрать модель, которая предсказывает категорию зарплаты людей, учитывая их возраст, рабочий класс, образовательный уровень, семейное положение, гонку, и так далее.

Загрузите выборочные данные census1994 и отобразите переменные в наборе данных.

load census1994
whos
  Name                 Size              Bytes  Class    Attributes

  Description         20x74               2960  char               
  adultdata        32561x15            1872567  table              
  adulttest        16281x15             944467  table              

census1994 содержит обучающий набор данных adultdata и тестовые данные устанавливают adulttest. В данном примере уменьшать время выполнения, поддемонстрационные 5 000 обучения и тестовых наблюдений каждый, из исходных таблиц adultdata и adulttest, при помощи datasample функция. (Можно пропустить этот шаг, если вы хотите использовать наборы полных данных.)

NumSamples = 5e3;
s = RandStream('mlfg6331_64','seed',0); % For reproducibility
adultdata = datasample(s,adultdata,NumSamples,'Replace',false);
adulttest = datasample(s,adulttest,NumSamples,'Replace',false);

Некоторые модели, такие как машины опорных векторов (SVMs), удаляют наблюдения, содержащие отсутствующие значения, тогда как другие, такие как деревья решений, не удаляют такие наблюдения. Чтобы обеспечить непротиворечивость между моделями, удалите строки, содержащие отсутствующие значения прежде, чем подбирать модели.

adultdata = rmmissing(adultdata);
adulttest = rmmissing(adulttest);

Предварительно просмотрите первые несколько строк обучающего набора данных.

head(adultdata)
ans=8×15 table
    age     workClass       fnlwgt       education      education_num      marital_status         occupation         relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
    ___    ___________    __________    ____________    _____________    __________________    _________________    ______________    _____    ______    ____________    ____________    ______________    ______________    ______

    39     Private          4.91e+05    Bachelors            13          Never-married         Exec-managerial      Other-relative    Black    Male           0               0                45          United-States     <=50K 
    25     Private        2.2022e+05    11th                  7          Never-married         Handlers-cleaners    Own-child         White    Male           0               0                45          United-States     <=50K 
    24     Private        2.2761e+05    10th                  6          Divorced              Handlers-cleaners    Unmarried         White    Female         0               0                58          United-States     <=50K 
    51     Private        1.7329e+05    HS-grad               9          Divorced              Other-service        Not-in-family     White    Female         0               0                40          United-States     <=50K 
    54     Private        2.8029e+05    Some-college         10          Married-civ-spouse    Sales                Husband           White    Male           0               0                32          United-States     <=50K 
    53     Federal-gov         39643    HS-grad               9          Widowed               Exec-managerial      Not-in-family     White    Female         0               0                58          United-States     <=50K 
    52     Private             81859    HS-grad               9          Married-civ-spouse    Machine-op-inspct    Husband           White    Male           0               0                48          United-States     >50K  
    37     Private        1.2429e+05    Some-college         10          Married-civ-spouse    Adm-clerical         Husband           White    Male           0               0                50          United-States     <=50K 

Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и размещение. Последний столбец salary показывает, есть ли у человека зарплата, меньше чем или равная 50 000$ в год или больше, чем 50 000$ в год.

Изучите данные и выберите модели классификации

Statistics and Machine Learning Toolbox™ предоставляет несколько возможностей для классификации, включая деревья классификации, дискриминантный анализ, наивного Бейеса, самых близких соседей, SVMs и ансамбли классификации. Для полного списка алгоритмов смотрите Классификацию.

Прежде, чем выбрать алгоритмы, чтобы использовать для вашей проблемы, смотрите свой набор данных. Данные о переписи имеют несколько примечательных характеристик:

  • Данные являются табличными и содержат и числовые и категориальные переменные.

  • Данные содержат отсутствующие значения.

  • Переменная отклика (salary) имеет два класса (бинарная классификация).

Не делая предположений или с помощью предварительных знаний алгоритмов, что вы ожидаете работать хорошо над своими данными, вы просто обучаете все алгоритмы, которые поддерживают табличные данные и бинарную классификацию. Модели выходных кодов с коррекцией ошибок (ECOC) используются для данных больше чем с двумя классами. Дискриминантный анализ и самые близкие соседние алгоритмы не анализируют данные, которые содержат и числовые и категориальные переменные. Поэтому алгоритмы, подходящие для этого примера, являются SVM, деревом решений, ансамблем деревьев решений и наивной моделью Bayes.

Создайте базовые модели

Подбирайте две модели SVM, один с Гауссовым ядром и один с полиномиальным ядром. Кроме того, соответствуйте дереву решений, наивной модели Bayes и ансамблю деревьев решений.

% SVM with Gaussian kernel
rng('default') % For reproducibility
mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','gaussian', ...
    'Standardize',true,'KernelScale','auto');

% SVM with polynomial kernel
rng('default')
mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','polynomial', ...
    'Standardize',true,'KernelScale','auto');

% Decision tree
rng('default')
mdls{3} = fitctree(adultdata,'salary');

% Naive Bayes
rng('default')
mdls{4} = fitcnb(adultdata,'salary');

% Ensemble of decision trees
rng('default')
mdls{5} = fitcensemble(adultdata,'salary');

Объедините модели Используя укладку

Если вы используете только множество предсказания базовых моделей на обучающих данных, сложенный ансамбль может подвергнуться сверхподбору кривой. Чтобы уменьшать сверхподбор кривой, используйте k-сгиб перекрестные подтвержденные баллы вместо этого. Чтобы гарантировать, что вы обучаете каждую модель с помощью того же разделения данных k-сгиба, создайте cvpartition объект и передача, которые возражают против crossval функция каждой базовой модели. Этим примером является бинарная проблема классификации, таким образом, только необходимо рассмотреть музыку или к положительному или к отрицательному классу.

Получите баллы перекрестной проверки k-сгиба.

rng('default') % For reproducibility
N = numel(mdls);
Scores = zeros(size(adultdata,1),N);
cv = cvpartition(adultdata.salary,"KFold",5);
for ii = 1:N
    m = crossval(mdls{ii},'cvpartition',cv);
    [~,s] = kfoldPredict(m);
    Scores(:,ii) = s(:,m.ClassNames=='<=50K');
end

Создайте сложенный ансамбль по образованию это на перекрестных подтвержденных классификационных оценках Scores с этими опциями:

  • Чтобы получить лучшие результаты для сложенного ансамбля, оптимизируйте его гиперпараметры. Можно соответствовать обучающему набору данных и настройкам параметров легко путем вызывания подходящей функции и установки ее 'OptimizeHyperparameters' аргумент пары "имя-значение" 'auto'.

  • Задайте 'Verbose' как 0, чтобы отключить индикаторы сообщения.

  • Для воспроизводимости установите случайный seed и используйте 'expected-improvement-plus' функция приобретения. Кроме того, для воспроизводимости случайного лесного алгоритма задайте 'Reproducible' аргумент пары "имя-значение" как true для древовидных учеников.

rng('default') % For reproducibility
t = templateTree('Reproducible',true);
stckdMdl = fitcensemble(Scores,adultdata.salary, ...
    'OptimizeHyperparameters','auto', ...
    'Learners',t, ...
    'HyperparameterOptimizationOptions',struct('Verbose',0,'AcquisitionFunctionName','expected-improvement-plus'));

Сравните прогнозирующую точность

Проверяйте эффективность классификатора с набором тестовых данных при помощи матрицы беспорядка и теста гипотезы Макнемэра.

Предскажите метки и баллы на тестовых данных

Найдите предсказанные метки, баллы и значения потерь набора тестовых данных для базовых моделей и сложенного ансамбля.

Во-первых, выполните итерации по базовым моделям к вычислить предсказанным меткам, баллам и значениям потерь.

label = [];
score = zeros(size(adulttest,1),N);
mdlLoss = zeros(1,numel(mdls));
for i = 1:N
    [lbl,s] = predict(mdls{i},adulttest);
    label = [label,lbl];
    score(:,i) = s(:,m.ClassNames=='<=50K');
    mdlLoss(i) = mdls{i}.loss(adulttest);
end

Присоедините предсказания от сложенного ансамбля к label и mdlLoss.

[lbl,s] = predict(stckdMdl,score);
label = [label,lbl];
mdlLoss(end+1) = stckdMdl.loss(score,adulttest.salary);

Конкатенация счета сложенного ансамбля ко множеству базовых моделей.

score = [score,s(:,1)];

Отобразите значения потерь.

names = {'SVM-Gaussian','SVM-Polynomial','Decision Tree','Naive Bayes', ...
    'Ensemble of Decision Trees','Stacked Ensemble'};
array2table(mdlLoss,'VariableNames',names)
ans=1×6 table
    SVM-Gaussian    SVM-Polynomial    Decision Tree    Naive Bayes    Ensemble of Decision Trees    Stacked Ensemble
    ____________    ______________    _____________    ___________    __________________________    ________________

      0.15668          0.17473           0.1975          0.16764               0.15833                  0.14519     

Значение потерь сложенного ансамбля ниже, чем значения потерь базовых моделей.

Матрица беспорядка

Вычислите матрицу беспорядка с предсказанными классами и известными (TRUE) классами набора тестовых данных при помощи confusionchart функция.

figure
c = cell(N+1,1);
for i = 1:numel(c)
    subplot(2,3,i)
    c{i} = confusionchart(adulttest.salary,label(:,i));
    title(names{i})
end

Диагональные элементы указывают на количество правильно классифицированных экземпляров данного класса. Недиагональными элементами являются экземпляры неправильно классифицированных наблюдений.

Тест гипотезы Макнемэра

Чтобы протестировать, является ли улучшение предсказания значительным, используйте testcholdout функция, которая проводит тест гипотезы Макнемэра. Сравните сложенный ансамбль с наивной моделью Bayes.

 [hNB,pNB] = testcholdout(label(:,6),label(:,4),adulttest.salary)
hNB = logical
   1

pNB = 9.7646e-07

Сравните сложенный ансамбль с ансамблем деревьев решений.

 [hE,pE] = testcholdout(label(:,6),label(:,5),adulttest.salary)
hE = logical
   1

pE = 1.9357e-04

В обоих случаях низкое p-значение сложенного ансамбля подтверждает, что его предсказания статистически превосходят те из других моделей.

Для просмотра документации необходимо авторизоваться на сайте