exponenta event banner

Объединение разнородных моделей в сложенный ансамбль

В этом примере показано, как построить несколько моделей машинного обучения для данного набора обучающих данных, а затем объединить модели, используя метод, называемый укладкой, для повышения точности набора тестовых данных по сравнению с точностью отдельных моделей.

Стэкинг - это методика, используемая для объединения нескольких гетерогенных моделей путем обучения дополнительной модели, часто называемой сложенной в стопу моделью ансамбля или накопленным учеником, на k-кратных перекрестно проверенных прогнозах (оценки классификации для классификационных моделей и прогнозируемые ответы для регрессионных моделей) исходных (базовых) моделей. Концепция укладки состоит в том, что некоторые модели могут правильно классифицировать тестовое наблюдение, в то время как другие могут не сделать этого. Алгоритм извлекает уроки из этого разнообразия предсказаний и пытается объединить модели, чтобы улучшить предсказанную точность базовых моделей.

В этом примере выполняется обучение нескольких гетерогенных классификационных моделей на наборе данных, а затем выполняется объединение моделей с использованием стека.

Загрузить данные образца

В этом примере используются данные переписи 1994 года, хранящиеся в census1994.mat. Набор данных состоит из демографических данных Бюро переписи населения США для прогнозирования того, составляет ли человек более 50 000 долларов в год. Задача классификации состоит в том, чтобы соответствовать модели, которая предсказывает категорию зарплаты людей с учетом их возраста, рабочего класса, уровня образования, семейного положения, расы и так далее.

Загрузка данных образца census1994 и отображение переменных в наборе данных.

load census1994
whos
  Name                 Size              Bytes  Class    Attributes

  Description         20x74               2960  char               
  adultdata        32561x15            1872567  table              
  adulttest        16281x15             944467  table              

census1994 содержит набор данных обучения adultdata и набор тестовых данных adulttest. Для этого примера, чтобы сократить время работы, выполните выборку 5000 учебных и тестовых наблюдений из исходных таблиц. adultdata и adulttest, с помощью datasample функция. (Этот шаг можно пропустить, если требуется использовать полные наборы данных.)

NumSamples = 5e3;
s = RandStream('mlfg6331_64','seed',0); % For reproducibility
adultdata = datasample(s,adultdata,NumSamples,'Replace',false);
adulttest = datasample(s,adulttest,NumSamples,'Replace',false);

Некоторые модели, такие как вспомогательные векторные машины (SVM), удаляют наблюдения, содержащие отсутствующие значения, в то время как другие, такие как деревья принятия решений, не удаляют такие наблюдения. Чтобы сохранить согласованность между моделями, удалите строки, содержащие отсутствующие значения, перед подгонкой моделей.

adultdata = rmmissing(adultdata);
adulttest = rmmissing(adulttest);

Просмотрите первые несколько строк набора данных обучения.

head(adultdata)
ans=8×15 table
    age     workClass       fnlwgt       education      education_num      marital_status         occupation         relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
    ___    ___________    __________    ____________    _____________    __________________    _________________    ______________    _____    ______    ____________    ____________    ______________    ______________    ______

    39     Private          4.91e+05    Bachelors            13          Never-married         Exec-managerial      Other-relative    Black    Male           0               0                45          United-States     <=50K 
    25     Private        2.2022e+05    11th                  7          Never-married         Handlers-cleaners    Own-child         White    Male           0               0                45          United-States     <=50K 
    24     Private        2.2761e+05    10th                  6          Divorced              Handlers-cleaners    Unmarried         White    Female         0               0                58          United-States     <=50K 
    51     Private        1.7329e+05    HS-grad               9          Divorced              Other-service        Not-in-family     White    Female         0               0                40          United-States     <=50K 
    54     Private        2.8029e+05    Some-college         10          Married-civ-spouse    Sales                Husband           White    Male           0               0                32          United-States     <=50K 
    53     Federal-gov         39643    HS-grad               9          Widowed               Exec-managerial      Not-in-family     White    Female         0               0                58          United-States     <=50K 
    52     Private             81859    HS-grad               9          Married-civ-spouse    Machine-op-inspct    Husband           White    Male           0               0                48          United-States     >50K  
    37     Private        1.2429e+05    Some-college         10          Married-civ-spouse    Adm-clerical         Husband           White    Male           0               0                50          United-States     <=50K 

Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и род занятий. Последний столбец salary показывает, имеет ли человек зарплату меньше или равную 50 000 долл. США в год или больше 50 000 долл. США в год.

Понимание данных и выбор классификационных моделей

Toolbox™ статистики и машинного обучения предоставляет несколько вариантов классификации, включая деревья классификации, дискриминантный анализ, наивный Байес, ближайшие соседи, SVM и классификационные ансамбли. Полный список алгоритмов см. в разделе Классификация.

Прежде чем выбирать алгоритмы для решения проблемы, проверьте набор данных. Данные переписи имеют несколько заслуживающих внимания характеристик:

  • Данные являются табличными и содержат как числовые, так и категориальные переменные.

  • Данные содержат отсутствующие значения.

  • Переменная ответа (salary) имеет два класса (бинарная классификация).

Без каких-либо предположений или использования предварительного знания алгоритмов, которые вы ожидаете хорошо работать с вашими данными, вы просто тренируете все алгоритмы, которые поддерживают табличные данные и двоичную классификацию. Модели выходных кодов с исправлением ошибок (ECOC) используются для данных с более чем двумя классами. Дискриминантный анализ и алгоритмы ближайших соседей не анализируют данные, содержащие как числовые, так и категориальные переменные. Поэтому подходящими для этого примера алгоритмами являются SVM, дерево решений, ансамбль деревьев решений и наивная модель Байеса.

Создание базовых моделей

Подходят две модели SVM, одна с гауссовым ядром, а другая с полиномиальным ядром. Также подходят дерево решений, наивная модель Байеса и ансамбль деревьев решений.

% SVM with Gaussian kernel
rng('default') % For reproducibility
mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','gaussian', ...
    'Standardize',true,'KernelScale','auto');

% SVM with polynomial kernel
rng('default')
mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','polynomial', ...
    'Standardize',true,'KernelScale','auto');

% Decision tree
rng('default')
mdls{3} = fitctree(adultdata,'salary');

% Naive Bayes
rng('default')
mdls{4} = fitcnb(adultdata,'salary');

% Ensemble of decision trees
rng('default')
mdls{5} = fitcensemble(adultdata,'salary');

Объединение моделей с использованием штабелирования

Если вы используете только оценки прогнозирования базовых моделей на данных обучения, сложенный ансамбль может быть переоборудован. Чтобы уменьшить переоборудование, используйте k-кратные перекрестные оценки. Чтобы обеспечить обучение каждой модели с использованием одного и того же k-кратного разделения данных, создайте cvpartition и передать этот объект в crossval функция каждой базовой модели. Этот пример является проблемой двоичной классификации, поэтому необходимо учитывать только оценки для положительного или отрицательного класса.

Получить k-кратные оценки перекрестной проверки.

rng('default') % For reproducibility
N = numel(mdls);
Scores = zeros(size(adultdata,1),N);
cv = cvpartition(adultdata.salary,"KFold",5);
for ii = 1:N
    m = crossval(mdls{ii},'cvpartition',cv);
    [~,s] = kfoldPredict(m);
    Scores(:,ii) = s(:,m.ClassNames=='<=50K');
end

Создание составленного ансамбля путем его обучения по перекрестно проверенным классификационным баллам Scores с этими опциями:

  • Чтобы получить наилучшие результаты для сложенного ансамбля, оптимизируйте его гиперпараметры. Можно легко подобрать набор учебных данных и настроить параметры, вызвав функцию фитинга и установив ее 'OptimizeHyperparameters' аргумент пары имя-значение для 'auto'.

  • Определить 'Verbose' as 0 для отключения отображения сообщений.

  • Для воспроизводимости задайте случайное начальное число и используйте 'expected-improvement-plus' функция приобретения. Кроме того, для воспроизводимости алгоритма случайного леса укажите 'Reproducible' аргумент пары имя-значение как true для обучающихся на деревьях.

rng('default') % For reproducibility
t = templateTree('Reproducible',true);
stckdMdl = fitcensemble(Scores,adultdata.salary, ...
    'OptimizeHyperparameters','auto', ...
    'Learners',t, ...
    'HyperparameterOptimizationOptions',struct('Verbose',0,'AcquisitionFunctionName','expected-improvement-plus'));

Сравнение предиктивной точности

Проверьте производительность классификатора с помощью набора тестовых данных с помощью матрицы путаницы и теста гипотез Макнемара.

Прогнозирование меток и оценок на тестовых данных

Найдите прогнозируемые метки, оценки и значения потерь набора тестовых данных для базовых моделей и составленного ансамбля.

Сначала выполните итерацию по базовым моделям для вычисления прогнозируемых меток, оценок и значений потерь.

label = [];
score = zeros(size(adulttest,1),N);
mdlLoss = zeros(1,numel(mdls));
for i = 1:N
    [lbl,s] = predict(mdls{i},adulttest);
    label = [label,lbl];
    score(:,i) = s(:,m.ClassNames=='<=50K');
    mdlLoss(i) = mdls{i}.loss(adulttest);
end

Прикрепить прогнозы от сложенного ансамбля к label и mdlLoss.

[lbl,s] = predict(stckdMdl,score);
label = [label,lbl];
mdlLoss(end+1) = stckdMdl.loss(score,adulttest.salary);

Соедините партитуру сложенного ансамбля с партитурой базовых моделей.

score = [score,s(:,1)];

Просмотрите значения потерь.

names = {'SVM-Gaussian','SVM-Polynomial','Decision Tree','Naive Bayes', ...
    'Ensemble of Decision Trees','Stacked Ensemble'};
array2table(mdlLoss,'VariableNames',names)
ans=1×6 table
    SVM-Gaussian    SVM-Polynomial    Decision Tree    Naive Bayes    Ensemble of Decision Trees    Stacked Ensemble
    ____________    ______________    _____________    ___________    __________________________    ________________

      0.15668          0.17473           0.1975          0.16764               0.15833                  0.14519     

Значение потерь сложенного ансамбля ниже значений потерь базовых моделей.

Матрица путаницы

Вычислите матрицу путаницы с предсказанными классами и известными (истинными) классами набора тестовых данных, используя confusionchart функция.

figure
c = cell(N+1,1);
for i = 1:numel(c)
    subplot(2,3,i)
    c{i} = confusionchart(adulttest.salary,label(:,i));
    title(names{i})
end

Диагональные элементы указывают количество правильно классифицированных экземпляров данного класса. Внедиагональные элементы являются экземплярами неправильно классифицированных наблюдений.

Тест гипотезы Макнемара

Чтобы проверить, является ли улучшение прогнозирования значительным, используйте testcholdout функция, которая проводит тест гипотезы Макнемара. Сравните сложенный ансамбль с наивной моделью Байеса.

 [hNB,pNB] = testcholdout(label(:,6),label(:,4),adulttest.salary)
hNB = logical
   1

pNB = 9.7646e-07

Сравните сложенный ансамбль с ансамблем деревьев решений.

 [hE,pE] = testcholdout(label(:,6),label(:,5),adulttest.salary)
hE = logical
   1

pE = 1.9357e-04

В обоих случаях низкое значение p сложенного ансамбля подтверждает, что его прогнозы статистически превосходят прогнозы других моделей.