Этот пример показов, как создать несколько моделей машинного обучения для данных обучающих данных набора, а затем объединить модели с помощью метода, называемого stacking, чтобы улучшить точность на тестовые данные наборе по сравнению с точностью отдельных моделей.
Stacking - это метод, используемый для объединения нескольких гетерогенных моделей путем настройки дополнительной модели, часто называемой сложенной моделью ансамбля, или сложенной обучающейся, на k-складной перекрестной проверенных предсказаниях (классификационные оценки для классификационных моделей и предсказанных откликов для регрессионных моделей) исходных (базовых) моделей. Концепция укладки состоит в том, что определенные модели могут правильно классифицировать тестовое наблюдение, в то время как другие могут не сделать этого. Алгоритм учится на этом разнообразии предсказаний и пытается объединить модели, чтобы улучшить предсказанную точность базовых моделей.
В этом примере вы обучаете несколько гетерогенных моделей классификации на наборе данных, а затем комбинируете модели с помощью стека.
Этот пример использует данные переписи 1994 года, хранящиеся в census1994.mat
. Набор данных состоит из демографических данных Бюро переписи населения США, чтобы предсказать, составляет ли индивидуум более 50 000 долларов в год. Задача классификации состоит в том, чтобы соответствовать модели, которая предсказывает категорию заработной платы людей с учетом их возраста, рабочего класса, уровня образования, семейного положения, расы и так далее.
Загрузите выборочные данные census1994
и отобразите переменные в наборе данных.
load census1994
whos
Name Size Bytes Class Attributes Description 20x74 2960 char adultdata 32561x15 1872567 table adulttest 16281x15 944467 table
census1994
содержит обучающий набор обучающих данных adultdata
и набор тестовых данных adulttest
. В данном примере, чтобы уменьшить время работы, выделите 5000 обучений и тестовых наблюдений каждый из исходных таблиц adultdata
и adulttest
, при помощи datasample
функция. (Можно пропустить этот шаг, если необходимо использовать полные наборы данных.)
NumSamples = 5e3; s = RandStream('mlfg6331_64','seed',0); % For reproducibility adultdata = datasample(s,adultdata,NumSamples,'Replace',false); adulttest = datasample(s,adulttest,NumSamples,'Replace',false);
Некоторые модели, такие как машины опорных векторов (SVM), удаляют наблюдения, содержащие отсутствующие значения, в то время как другие, такие как деревья решений, не удаляют такие наблюдения. Чтобы сохранить согласованность между моделями, удалите строки, содержащие отсутствующие значения, перед подгонкой моделей.
adultdata = rmmissing(adultdata); adulttest = rmmissing(adulttest);
Предварительный просмотр первых нескольких строк обучающих данных набора.
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ___________ __________ ____________ _____________ __________________ _________________ ______________ _____ ______ ____________ ____________ ______________ ______________ ______
39 Private 4.91e+05 Bachelors 13 Never-married Exec-managerial Other-relative Black Male 0 0 45 United-States <=50K
25 Private 2.2022e+05 11th 7 Never-married Handlers-cleaners Own-child White Male 0 0 45 United-States <=50K
24 Private 2.2761e+05 10th 6 Divorced Handlers-cleaners Unmarried White Female 0 0 58 United-States <=50K
51 Private 1.7329e+05 HS-grad 9 Divorced Other-service Not-in-family White Female 0 0 40 United-States <=50K
54 Private 2.8029e+05 Some-college 10 Married-civ-spouse Sales Husband White Male 0 0 32 United-States <=50K
53 Federal-gov 39643 HS-grad 9 Widowed Exec-managerial Not-in-family White Female 0 0 58 United-States <=50K
52 Private 81859 HS-grad 9 Married-civ-spouse Machine-op-inspct Husband White Male 0 0 48 United-States >50K
37 Private 1.2429e+05 Some-college 10 Married-civ-spouse Adm-clerical Husband White Male 0 0 50 United-States <=50K
Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и род занятий. Последний столбец salary
показывает, имеет ли человек зарплату менее 50 000 долл. США в год или более 50 000 долл. США в год.
Statistics and Machine Learning Toolbox™ предоставляет несколько опции классификации, включая деревья классификации, дискриминантный анализ, наивный Байес, ближайшие соседи, SVM и классификационные ансамбли. Полный список алгоритмов см. в разделе Классификация.
Прежде чем выбирать алгоритмы, которые будут использоваться для вашей задачи, смотрите свой набор данных. Данные переписи населения имеют несколько примечательных характеристик:
Данные являются табличными и содержат как числовые, так и категориальные переменные.
Данные содержат отсутствующие значения.
Переменная отклика (salary
) имеет два класса (двоичная классификация).
Не делая никаких предположений или используя предыдущие знания алгоритмов, которые вы ожидаете хорошо работать с вашими данными, вы просто обучаете все алгоритмы, которые поддерживают табличные данные и двоичную классификацию. Модели выходных кодов с исправлением ошибок (ECOC) используются для данных с более чем двумя классами. Дискриминантный анализ и ближайшие соседние алгоритмы не анализируют данные, которые содержат как числовые, так и категориальные переменные. Поэтому алгоритмами, подходящими для этого примера, являются SVM, дерево решений, ансамбль деревьев решений и наивная модель Байеса.
Подгонка двух моделей SVM, одной с Гауссовым ядром и одной с полиномиальным ядром. Также подбирается дерево решений, наивная модель Байеса и ансамбль деревьев решений.
% SVM with Gaussian kernel rng('default') % For reproducibility mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','gaussian', ... 'Standardize',true,'KernelScale','auto'); % SVM with polynomial kernel rng('default') mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','polynomial', ... 'Standardize',true,'KernelScale','auto'); % Decision tree rng('default') mdls{3} = fitctree(adultdata,'salary'); % Naive Bayes rng('default') mdls{4} = fitcnb(adultdata,'salary'); % Ensemble of decision trees rng('default') mdls{5} = fitcensemble(adultdata,'salary');
Если вы используете только счета предсказания базовых моделей в обучающих данных, сложенный ансамбль может подвергнуться сверхподбору кривой. Чтобы уменьшить сверхподбор кривой, используйте вместо этого перекрестные счета k-fold. Чтобы убедиться, что вы обучаете каждую модель, используя одно и то же разделение данных k-fold, создайте cvpartition
объект и передать этот объект в crossval
функция каждой базовой модели. Этот пример является двоичной задачей классификации, поэтому вам нужно рассматривать счета только для положительного или отрицательного класса.
Получите k-кратные перекрестные счета валидации.
rng('default') % For reproducibility N = numel(mdls); Scores = zeros(size(adultdata,1),N); cv = cvpartition(adultdata.salary,"KFold",5); for ii = 1:N m = crossval(mdls{ii},'cvpartition',cv); [~,s] = kfoldPredict(m); Scores(:,ii) = s(:,m.ClassNames=='<=50K'); end
Создайте сложенный ансамбль путем настройки его на перекрестных проверенных классификационных оценках Scores
с этими опциями:
Чтобы получить лучшие результаты для сложенного ансамбля, оптимизируйте его гиперпараметры. Вы можете легко подогнать обучающий набор обучающих данных и настроить параметры, вызвав функцию fitting и установив ее 'OptimizeHyperparameters'
аргумент пары "имя-значение" в 'auto'
.
Задайте 'Verbose'
0, чтобы отключить отображения сообщений.
Для воспроизводимости установите случайный seed и используйте 'expected-improvement-plus'
функция сбора. Кроме того, для воспроизводимости алгоритма случайного леса задайте 'Reproducible'
аргумент пары "имя-значение" как true
для учащихся дерева.
rng('default') % For reproducibility t = templateTree('Reproducible',true); stckdMdl = fitcensemble(Scores,adultdata.salary, ... 'OptimizeHyperparameters','auto', ... 'Learners',t, ... 'HyperparameterOptimizationOptions',struct('Verbose',0,'AcquisitionFunctionName','expected-improvement-plus'));
Проверьте эффективность классификатора с набором тестовых данных с помощью матрицы неточностей и теста гипотезы Макнемара.
Найдите предсказанные метки, счета и значения потерь тестовых данных набора для базовых моделей и сложенного ансамбля.
Во-первых, итерация по базовым моделям для вычисления предсказанных меток, счетов и значений потерь.
label = []; score = zeros(size(adulttest,1),N); mdlLoss = zeros(1,numel(mdls)); for i = 1:N [lbl,s] = predict(mdls{i},adulttest); label = [label,lbl]; score(:,i) = s(:,m.ClassNames=='<=50K'); mdlLoss(i) = mdls{i}.loss(adulttest); end
Приложите предсказания от сложенного ансамбля к label
и mdlLoss
.
[lbl,s] = predict(stckdMdl,score); label = [label,lbl]; mdlLoss(end+1) = stckdMdl.loss(score,adulttest.salary);
Конкатенируйте счет сложенного ансамбля с счетами базовых моделей.
score = [score,s(:,1)];
Отобразите значения потерь.
names = {'SVM-Gaussian','SVM-Polynomial','Decision Tree','Naive Bayes', ... 'Ensemble of Decision Trees','Stacked Ensemble'}; array2table(mdlLoss,'VariableNames',names)
ans=1×6 table
SVM-Gaussian SVM-Polynomial Decision Tree Naive Bayes Ensemble of Decision Trees Stacked Ensemble
____________ ______________ _____________ ___________ __________________________ ________________
0.15668 0.17473 0.1975 0.16764 0.15833 0.14519
Значение потерь сложенного ансамбля ниже, чем значения потерь базовых моделей.
Вычислите матрицу неточностей с предсказанными классами и известными (true) классами тестовых данных набора при помощи confusionchart
функция.
figure c = cell(N+1,1); for i = 1:numel(c) subplot(2,3,i) c{i} = confusionchart(adulttest.salary,label(:,i)); title(names{i}) end
Диагональные элементы указывают количество правильно классифицированных образцов данного класса. Не-диагональные элементы являются образцами неправильно классифицированных наблюдений.
Чтобы проверить, является ли улучшение предсказания значительным, используйте testcholdout
функция, которая проводит тест гипотезы Макнемара. Сравните сложенный ансамбль с наивной моделью Байеса.
[hNB,pNB] = testcholdout(label(:,6),label(:,4),adulttest.salary)
hNB = logical
1
pNB = 9.7646e-07
Сравните сложенный ансамбль с ансамблем деревьев решений.
[hE,pE] = testcholdout(label(:,6),label(:,5),adulttest.salary)
hE = logical
1
pE = 1.9357e-04
В обоих случаях низкое p-значение сложенного ансамбля подтверждает, что его предсказания статистически превосходят прогнозы других моделей.