Объедините гетерогенные модели в сложенный ансамбль

Этот пример показов, как создать несколько моделей машинного обучения для данных обучающих данных набора, а затем объединить модели с помощью метода, называемого stacking, чтобы улучшить точность на тестовые данные наборе по сравнению с точностью отдельных моделей.

Stacking - это метод, используемый для объединения нескольких гетерогенных моделей путем настройки дополнительной модели, часто называемой сложенной моделью ансамбля, или сложенной обучающейся, на k-складной перекрестной проверенных предсказаниях (классификационные оценки для классификационных моделей и предсказанных откликов для регрессионных моделей) исходных (базовых) моделей. Концепция укладки состоит в том, что определенные модели могут правильно классифицировать тестовое наблюдение, в то время как другие могут не сделать этого. Алгоритм учится на этом разнообразии предсказаний и пытается объединить модели, чтобы улучшить предсказанную точность базовых моделей.

В этом примере вы обучаете несколько гетерогенных моделей классификации на наборе данных, а затем комбинируете модели с помощью стека.

Загрузка выборочных данных

Этот пример использует данные переписи 1994 года, хранящиеся в census1994.mat. Набор данных состоит из демографических данных Бюро переписи населения США, чтобы предсказать, составляет ли индивидуум более 50 000 долларов в год. Задача классификации состоит в том, чтобы соответствовать модели, которая предсказывает категорию заработной платы людей с учетом их возраста, рабочего класса, уровня образования, семейного положения, расы и так далее.

Загрузите выборочные данные census1994 и отобразите переменные в наборе данных.

load census1994
whos
  Name                 Size              Bytes  Class    Attributes

  Description         20x74               2960  char               
  adultdata        32561x15            1872567  table              
  adulttest        16281x15             944467  table              

census1994 содержит обучающий набор обучающих данных adultdata и набор тестовых данных adulttest. В данном примере, чтобы уменьшить время работы, выделите 5000 обучений и тестовых наблюдений каждый из исходных таблиц adultdata и adulttest, при помощи datasample функция. (Можно пропустить этот шаг, если необходимо использовать полные наборы данных.)

NumSamples = 5e3;
s = RandStream('mlfg6331_64','seed',0); % For reproducibility
adultdata = datasample(s,adultdata,NumSamples,'Replace',false);
adulttest = datasample(s,adulttest,NumSamples,'Replace',false);

Некоторые модели, такие как машины опорных векторов (SVM), удаляют наблюдения, содержащие отсутствующие значения, в то время как другие, такие как деревья решений, не удаляют такие наблюдения. Чтобы сохранить согласованность между моделями, удалите строки, содержащие отсутствующие значения, перед подгонкой моделей.

adultdata = rmmissing(adultdata);
adulttest = rmmissing(adulttest);

Предварительный просмотр первых нескольких строк обучающих данных набора.

head(adultdata)
ans=8×15 table
    age     workClass       fnlwgt       education      education_num      marital_status         occupation         relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
    ___    ___________    __________    ____________    _____________    __________________    _________________    ______________    _____    ______    ____________    ____________    ______________    ______________    ______

    39     Private          4.91e+05    Bachelors            13          Never-married         Exec-managerial      Other-relative    Black    Male           0               0                45          United-States     <=50K 
    25     Private        2.2022e+05    11th                  7          Never-married         Handlers-cleaners    Own-child         White    Male           0               0                45          United-States     <=50K 
    24     Private        2.2761e+05    10th                  6          Divorced              Handlers-cleaners    Unmarried         White    Female         0               0                58          United-States     <=50K 
    51     Private        1.7329e+05    HS-grad               9          Divorced              Other-service        Not-in-family     White    Female         0               0                40          United-States     <=50K 
    54     Private        2.8029e+05    Some-college         10          Married-civ-spouse    Sales                Husband           White    Male           0               0                32          United-States     <=50K 
    53     Federal-gov         39643    HS-grad               9          Widowed               Exec-managerial      Not-in-family     White    Female         0               0                58          United-States     <=50K 
    52     Private             81859    HS-grad               9          Married-civ-spouse    Machine-op-inspct    Husband           White    Male           0               0                48          United-States     >50K  
    37     Private        1.2429e+05    Some-college         10          Married-civ-spouse    Adm-clerical         Husband           White    Male           0               0                50          United-States     <=50K 

Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и род занятий. Последний столбец salary показывает, имеет ли человек зарплату менее 50 000 долл. США в год или более 50 000 долл. США в год.

Осмыслите данные и выберите классификационные модели

Statistics and Machine Learning Toolbox™ предоставляет несколько опции классификации, включая деревья классификации, дискриминантный анализ, наивный Байес, ближайшие соседи, SVM и классификационные ансамбли. Полный список алгоритмов см. в разделе Классификация.

Прежде чем выбирать алгоритмы, которые будут использоваться для вашей задачи, смотрите свой набор данных. Данные переписи населения имеют несколько примечательных характеристик:

  • Данные являются табличными и содержат как числовые, так и категориальные переменные.

  • Данные содержат отсутствующие значения.

  • Переменная отклика (salary) имеет два класса (двоичная классификация).

Не делая никаких предположений или используя предыдущие знания алгоритмов, которые вы ожидаете хорошо работать с вашими данными, вы просто обучаете все алгоритмы, которые поддерживают табличные данные и двоичную классификацию. Модели выходных кодов с исправлением ошибок (ECOC) используются для данных с более чем двумя классами. Дискриминантный анализ и ближайшие соседние алгоритмы не анализируют данные, которые содержат как числовые, так и категориальные переменные. Поэтому алгоритмами, подходящими для этого примера, являются SVM, дерево решений, ансамбль деревьев решений и наивная модель Байеса.

Построение базовых моделей

Подгонка двух моделей SVM, одной с Гауссовым ядром и одной с полиномиальным ядром. Также подбирается дерево решений, наивная модель Байеса и ансамбль деревьев решений.

% SVM with Gaussian kernel
rng('default') % For reproducibility
mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','gaussian', ...
    'Standardize',true,'KernelScale','auto');

% SVM with polynomial kernel
rng('default')
mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','polynomial', ...
    'Standardize',true,'KernelScale','auto');

% Decision tree
rng('default')
mdls{3} = fitctree(adultdata,'salary');

% Naive Bayes
rng('default')
mdls{4} = fitcnb(adultdata,'salary');

% Ensemble of decision trees
rng('default')
mdls{5} = fitcensemble(adultdata,'salary');

Объединение моделей с использованием штабелирования

Если вы используете только счета предсказания базовых моделей в обучающих данных, сложенный ансамбль может подвергнуться сверхподбору кривой. Чтобы уменьшить сверхподбор кривой, используйте вместо этого перекрестные счета k-fold. Чтобы убедиться, что вы обучаете каждую модель, используя одно и то же разделение данных k-fold, создайте cvpartition объект и передать этот объект в crossval функция каждой базовой модели. Этот пример является двоичной задачей классификации, поэтому вам нужно рассматривать счета только для положительного или отрицательного класса.

Получите k-кратные перекрестные счета валидации.

rng('default') % For reproducibility
N = numel(mdls);
Scores = zeros(size(adultdata,1),N);
cv = cvpartition(adultdata.salary,"KFold",5);
for ii = 1:N
    m = crossval(mdls{ii},'cvpartition',cv);
    [~,s] = kfoldPredict(m);
    Scores(:,ii) = s(:,m.ClassNames=='<=50K');
end

Создайте сложенный ансамбль путем настройки его на перекрестных проверенных классификационных оценках Scores с этими опциями:

  • Чтобы получить лучшие результаты для сложенного ансамбля, оптимизируйте его гиперпараметры. Вы можете легко подогнать обучающий набор обучающих данных и настроить параметры, вызвав функцию fitting и установив ее 'OptimizeHyperparameters' аргумент пары "имя-значение" в 'auto'.

  • Задайте 'Verbose' 0, чтобы отключить отображения сообщений.

  • Для воспроизводимости установите случайный seed и используйте 'expected-improvement-plus' функция сбора. Кроме того, для воспроизводимости алгоритма случайного леса задайте 'Reproducible' аргумент пары "имя-значение" как true для учащихся дерева.

rng('default') % For reproducibility
t = templateTree('Reproducible',true);
stckdMdl = fitcensemble(Scores,adultdata.salary, ...
    'OptimizeHyperparameters','auto', ...
    'Learners',t, ...
    'HyperparameterOptimizationOptions',struct('Verbose',0,'AcquisitionFunctionName','expected-improvement-plus'));

Сравнение прогнозирующей точности

Проверьте эффективность классификатора с набором тестовых данных с помощью матрицы неточностей и теста гипотезы Макнемара.

Предсказание меток и счетов по тестовым данным

Найдите предсказанные метки, счета и значения потерь тестовых данных набора для базовых моделей и сложенного ансамбля.

Во-первых, итерация по базовым моделям для вычисления предсказанных меток, счетов и значений потерь.

label = [];
score = zeros(size(adulttest,1),N);
mdlLoss = zeros(1,numel(mdls));
for i = 1:N
    [lbl,s] = predict(mdls{i},adulttest);
    label = [label,lbl];
    score(:,i) = s(:,m.ClassNames=='<=50K');
    mdlLoss(i) = mdls{i}.loss(adulttest);
end

Приложите предсказания от сложенного ансамбля к label и mdlLoss.

[lbl,s] = predict(stckdMdl,score);
label = [label,lbl];
mdlLoss(end+1) = stckdMdl.loss(score,adulttest.salary);

Конкатенируйте счет сложенного ансамбля с счетами базовых моделей.

score = [score,s(:,1)];

Отобразите значения потерь.

names = {'SVM-Gaussian','SVM-Polynomial','Decision Tree','Naive Bayes', ...
    'Ensemble of Decision Trees','Stacked Ensemble'};
array2table(mdlLoss,'VariableNames',names)
ans=1×6 table
    SVM-Gaussian    SVM-Polynomial    Decision Tree    Naive Bayes    Ensemble of Decision Trees    Stacked Ensemble
    ____________    ______________    _____________    ___________    __________________________    ________________

      0.15668          0.17473           0.1975          0.16764               0.15833                  0.14519     

Значение потерь сложенного ансамбля ниже, чем значения потерь базовых моделей.

Матрица неточностей

Вычислите матрицу неточностей с предсказанными классами и известными (true) классами тестовых данных набора при помощи confusionchart функция.

figure
c = cell(N+1,1);
for i = 1:numel(c)
    subplot(2,3,i)
    c{i} = confusionchart(adulttest.salary,label(:,i));
    title(names{i})
end

Диагональные элементы указывают количество правильно классифицированных образцов данного класса. Не-диагональные элементы являются образцами неправильно классифицированных наблюдений.

Тест гипотезы Макнемара

Чтобы проверить, является ли улучшение предсказания значительным, используйте testcholdout функция, которая проводит тест гипотезы Макнемара. Сравните сложенный ансамбль с наивной моделью Байеса.

 [hNB,pNB] = testcholdout(label(:,6),label(:,4),adulttest.salary)
hNB = logical
   1

pNB = 9.7646e-07

Сравните сложенный ансамбль с ансамблем деревьев решений.

 [hE,pE] = testcholdout(label(:,6),label(:,5),adulttest.salary)
hE = logical
   1

pE = 1.9357e-04

В обоих случаях низкое p-значение сложенного ансамбля подтверждает, что его предсказания статистически превосходят прогнозы других моделей.

Для просмотра документации необходимо авторизоваться на сайте