В этом примере показано, как создать несколько моделей машинного обучения для данного обучающего набора данных, и затем объединиться, модели с помощью метода вызвали укладку, чтобы улучшить точность относительно набора тестовых данных по сравнению с точностью отдельных моделей.
Укладка является методом, используемым, чтобы объединить несколько неоднородных моделей по образованию дополнительная модель, часто называемая сложенной моделью ансамбля или сложенным учеником, на k-сгибе перекрестные подтвержденные предсказания (классификационные оценки для моделей классификации и предсказанные ответы для моделей регрессии) исходных (основных) моделей. Концепция позади укладки - то, что определенные модели могут правильно классифицировать тестовое наблюдение, в то время как другие могут не сделать так. Алгоритм извлекает уроки из этого разнообразия предсказаний и попыток объединить модели, чтобы улучшить предсказанную точность базовых моделей.
В этом примере вы обучаете несколько неоднородных моделей классификации на наборе данных, и затем комбинируете укладку использования моделей.
Этот пример использует 1 994 данных о переписи, хранимых в census1994.mat
. Набор данных состоит из демографических данных Бюро переписи США, чтобы предсказать, передает ли индивидуум 50 000$ в год. Задача классификации состоит в том, чтобы подобрать модель, которая предсказывает категорию зарплаты людей, учитывая их возраст, рабочий класс, образовательный уровень, семейное положение, гонку, и так далее.
Загрузите выборочные данные census1994
и отобразите переменные в наборе данных.
load census1994
whos
Name Size Bytes Class Attributes Description 20x74 2960 char adultdata 32561x15 1872567 table adulttest 16281x15 944467 table
census1994
содержит обучающий набор данных adultdata
и тестовые данные устанавливают adulttest
. В данном примере уменьшать время выполнения, поддемонстрационные 5 000 обучения и тестовых наблюдений каждый, из исходных таблиц adultdata
и adulttest
, при помощи datasample
функция. (Можно пропустить этот шаг, если вы хотите использовать наборы полных данных.)
NumSamples = 5e3; s = RandStream('mlfg6331_64','seed',0); % For reproducibility adultdata = datasample(s,adultdata,NumSamples,'Replace',false); adulttest = datasample(s,adulttest,NumSamples,'Replace',false);
Некоторые модели, такие как машины опорных векторов (SVMs), удаляют наблюдения, содержащие отсутствующие значения, тогда как другие, такие как деревья решений, не удаляют такие наблюдения. Чтобы обеспечить непротиворечивость между моделями, удалите строки, содержащие отсутствующие значения прежде, чем подбирать модели.
adultdata = rmmissing(adultdata); adulttest = rmmissing(adulttest);
Предварительно просмотрите первые несколько строк обучающего набора данных.
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ___________ __________ ____________ _____________ __________________ _________________ ______________ _____ ______ ____________ ____________ ______________ ______________ ______
39 Private 4.91e+05 Bachelors 13 Never-married Exec-managerial Other-relative Black Male 0 0 45 United-States <=50K
25 Private 2.2022e+05 11th 7 Never-married Handlers-cleaners Own-child White Male 0 0 45 United-States <=50K
24 Private 2.2761e+05 10th 6 Divorced Handlers-cleaners Unmarried White Female 0 0 58 United-States <=50K
51 Private 1.7329e+05 HS-grad 9 Divorced Other-service Not-in-family White Female 0 0 40 United-States <=50K
54 Private 2.8029e+05 Some-college 10 Married-civ-spouse Sales Husband White Male 0 0 32 United-States <=50K
53 Federal-gov 39643 HS-grad 9 Widowed Exec-managerial Not-in-family White Female 0 0 58 United-States <=50K
52 Private 81859 HS-grad 9 Married-civ-spouse Machine-op-inspct Husband White Male 0 0 48 United-States >50K
37 Private 1.2429e+05 Some-college 10 Married-civ-spouse Adm-clerical Husband White Male 0 0 50 United-States <=50K
Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и размещение. Последний столбец salary
показывает, есть ли у человека зарплата, меньше чем или равная 50 000$ в год или больше, чем 50 000$ в год.
Statistics and Machine Learning Toolbox™ предоставляет несколько возможностей для классификации, включая деревья классификации, дискриминантный анализ, наивного Бейеса, самых близких соседей, SVMs и ансамбли классификации. Для полного списка алгоритмов смотрите Классификацию.
Прежде, чем выбрать алгоритмы, чтобы использовать для вашей проблемы, смотрите свой набор данных. Данные о переписи имеют несколько примечательных характеристик:
Данные являются табличными и содержат и числовые и категориальные переменные.
Данные содержат отсутствующие значения.
Переменная отклика (salary
) имеет два класса (бинарная классификация).
Не делая предположений или с помощью предварительных знаний алгоритмов, что вы ожидаете работать хорошо над своими данными, вы просто обучаете все алгоритмы, которые поддерживают табличные данные и бинарную классификацию. Модели выходных кодов с коррекцией ошибок (ECOC) используются для данных больше чем с двумя классами. Дискриминантный анализ и самые близкие соседние алгоритмы не анализируют данные, которые содержат и числовые и категориальные переменные. Поэтому алгоритмы, подходящие для этого примера, являются SVM, деревом решений, ансамблем деревьев решений и наивной моделью Bayes.
Подбирайте две модели SVM, один с Гауссовым ядром и один с полиномиальным ядром. Кроме того, соответствуйте дереву решений, наивной модели Bayes и ансамблю деревьев решений.
% SVM with Gaussian kernel rng('default') % For reproducibility mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','gaussian', ... 'Standardize',true,'KernelScale','auto'); % SVM with polynomial kernel rng('default') mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','polynomial', ... 'Standardize',true,'KernelScale','auto'); % Decision tree rng('default') mdls{3} = fitctree(adultdata,'salary'); % Naive Bayes rng('default') mdls{4} = fitcnb(adultdata,'salary'); % Ensemble of decision trees rng('default') mdls{5} = fitcensemble(adultdata,'salary');
Если вы используете только множество предсказания базовых моделей на обучающих данных, сложенный ансамбль может подвергнуться сверхподбору кривой. Чтобы уменьшать сверхподбор кривой, используйте k-сгиб перекрестные подтвержденные баллы вместо этого. Чтобы гарантировать, что вы обучаете каждую модель с помощью того же разделения данных k-сгиба, создайте cvpartition
объект и передача, которые возражают против crossval
функция каждой базовой модели. Этим примером является бинарная проблема классификации, таким образом, только необходимо рассмотреть музыку или к положительному или к отрицательному классу.
Получите баллы перекрестной проверки k-сгиба.
rng('default') % For reproducibility N = numel(mdls); Scores = zeros(size(adultdata,1),N); cv = cvpartition(adultdata.salary,"KFold",5); for ii = 1:N m = crossval(mdls{ii},'cvpartition',cv); [~,s] = kfoldPredict(m); Scores(:,ii) = s(:,m.ClassNames=='<=50K'); end
Создайте сложенный ансамбль по образованию это на перекрестных подтвержденных классификационных оценках Scores
с этими опциями:
Чтобы получить лучшие результаты для сложенного ансамбля, оптимизируйте его гиперпараметры. Можно соответствовать обучающему набору данных и настройкам параметров легко путем вызывания подходящей функции и установки ее 'OptimizeHyperparameters'
аргумент пары "имя-значение" 'auto'
.
Задайте 'Verbose'
как 0, чтобы отключить индикаторы сообщения.
Для воспроизводимости установите случайный seed и используйте 'expected-improvement-plus'
функция захвата. Кроме того, для воспроизводимости случайного лесного алгоритма задайте 'Reproducible'
аргумент пары "имя-значение" как true
для древовидных учеников.
rng('default') % For reproducibility t = templateTree('Reproducible',true); stckdMdl = fitcensemble(Scores,adultdata.salary, ... 'OptimizeHyperparameters','auto', ... 'Learners',t, ... 'HyperparameterOptimizationOptions',struct('Verbose',0,'AcquisitionFunctionName','expected-improvement-plus'));
Проверяйте эффективность классификатора с набором тестовых данных при помощи матрицы беспорядка и теста гипотезы Макнемэра.
Найдите предсказанные метки, баллы и значения потерь набора тестовых данных для базовых моделей и сложенного ансамбля.
Во-первых, выполните итерации по базовым моделям к вычислить предсказанным меткам, баллам и значениям потерь.
label = []; score = zeros(size(adulttest,1),N); mdlLoss = zeros(1,numel(mdls)); for i = 1:N [lbl,s] = predict(mdls{i},adulttest); label = [label,lbl]; score(:,i) = s(:,m.ClassNames=='<=50K'); mdlLoss(i) = mdls{i}.loss(adulttest); end
Присоедините предсказания от сложенного ансамбля к label
и mdlLoss
.
[lbl,s] = predict(stckdMdl,score); label = [label,lbl]; mdlLoss(end+1) = stckdMdl.loss(score,adulttest.salary);
Конкатенация счета сложенного ансамбля ко множеству базовых моделей.
score = [score,s(:,1)];
Отобразите значения потерь.
names = {'SVM-Gaussian','SVM-Polynomial','Decision Tree','Naive Bayes', ... 'Ensemble of Decision Trees','Stacked Ensemble'}; array2table(mdlLoss,'VariableNames',names)
ans=1×6 table
SVM-Gaussian SVM-Polynomial Decision Tree Naive Bayes Ensemble of Decision Trees Stacked Ensemble
____________ ______________ _____________ ___________ __________________________ ________________
0.15668 0.17473 0.1975 0.16764 0.15833 0.14519
Значение потерь сложенного ансамбля ниже, чем значения потерь базовых моделей.
Вычислите матрицу беспорядка с предсказанными классами и известными (TRUE) классами набора тестовых данных при помощи confusionchart
функция.
figure c = cell(N+1,1); for i = 1:numel(c) subplot(2,3,i) c{i} = confusionchart(adulttest.salary,label(:,i)); title(names{i}) end
Диагональные элементы указывают на количество правильно классифицированных экземпляров данного класса. Недиагональными элементами являются экземпляры неправильно классифицированных наблюдений.
Чтобы протестировать, является ли улучшение предсказания значительным, используйте testcholdout
функция, которая проводит тест гипотезы Макнемэра. Сравните сложенный ансамбль с наивной моделью Bayes.
[hNB,pNB] = testcholdout(label(:,6),label(:,4),adulttest.salary)
hNB = logical
1
pNB = 9.7646e-07
Сравните сложенный ансамбль с ансамблем деревьев решений.
[hE,pE] = testcholdout(label(:,6),label(:,5),adulttest.salary)
hE = logical
1
pE = 1.9357e-04
В обоих случаях низкое p-значение сложенного ансамбля подтверждает, что его предсказания статистически превосходят те из других моделей.