В этом примере показано, как построить несколько моделей машинного обучения для данного набора обучающих данных, а затем объединить модели, используя метод, называемый укладкой, для повышения точности набора тестовых данных по сравнению с точностью отдельных моделей.
Стэкинг - это методика, используемая для объединения нескольких гетерогенных моделей путем обучения дополнительной модели, часто называемой сложенной в стопу моделью ансамбля или накопленным учеником, на k-кратных перекрестно проверенных прогнозах (оценки классификации для классификационных моделей и прогнозируемые ответы для регрессионных моделей) исходных (базовых) моделей. Концепция укладки состоит в том, что некоторые модели могут правильно классифицировать тестовое наблюдение, в то время как другие могут не сделать этого. Алгоритм извлекает уроки из этого разнообразия предсказаний и пытается объединить модели, чтобы улучшить предсказанную точность базовых моделей.
В этом примере выполняется обучение нескольких гетерогенных классификационных моделей на наборе данных, а затем выполняется объединение моделей с использованием стека.
В этом примере используются данные переписи 1994 года, хранящиеся в census1994.mat. Набор данных состоит из демографических данных Бюро переписи населения США для прогнозирования того, составляет ли человек более 50 000 долларов в год. Задача классификации состоит в том, чтобы соответствовать модели, которая предсказывает категорию зарплаты людей с учетом их возраста, рабочего класса, уровня образования, семейного положения, расы и так далее.
Загрузка данных образца census1994 и отображение переменных в наборе данных.
load census1994
whosName Size Bytes Class Attributes Description 20x74 2960 char adultdata 32561x15 1872567 table adulttest 16281x15 944467 table
census1994 содержит набор данных обучения adultdata и набор тестовых данных adulttest. Для этого примера, чтобы сократить время работы, выполните выборку 5000 учебных и тестовых наблюдений из исходных таблиц. adultdata и adulttest, с помощью datasample функция. (Этот шаг можно пропустить, если требуется использовать полные наборы данных.)
NumSamples = 5e3; s = RandStream('mlfg6331_64','seed',0); % For reproducibility adultdata = datasample(s,adultdata,NumSamples,'Replace',false); adulttest = datasample(s,adulttest,NumSamples,'Replace',false);
Некоторые модели, такие как вспомогательные векторные машины (SVM), удаляют наблюдения, содержащие отсутствующие значения, в то время как другие, такие как деревья принятия решений, не удаляют такие наблюдения. Чтобы сохранить согласованность между моделями, удалите строки, содержащие отсутствующие значения, перед подгонкой моделей.
adultdata = rmmissing(adultdata); adulttest = rmmissing(adulttest);
Просмотрите первые несколько строк набора данных обучения.
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ___________ __________ ____________ _____________ __________________ _________________ ______________ _____ ______ ____________ ____________ ______________ ______________ ______
39 Private 4.91e+05 Bachelors 13 Never-married Exec-managerial Other-relative Black Male 0 0 45 United-States <=50K
25 Private 2.2022e+05 11th 7 Never-married Handlers-cleaners Own-child White Male 0 0 45 United-States <=50K
24 Private 2.2761e+05 10th 6 Divorced Handlers-cleaners Unmarried White Female 0 0 58 United-States <=50K
51 Private 1.7329e+05 HS-grad 9 Divorced Other-service Not-in-family White Female 0 0 40 United-States <=50K
54 Private 2.8029e+05 Some-college 10 Married-civ-spouse Sales Husband White Male 0 0 32 United-States <=50K
53 Federal-gov 39643 HS-grad 9 Widowed Exec-managerial Not-in-family White Female 0 0 58 United-States <=50K
52 Private 81859 HS-grad 9 Married-civ-spouse Machine-op-inspct Husband White Male 0 0 48 United-States >50K
37 Private 1.2429e+05 Some-college 10 Married-civ-spouse Adm-clerical Husband White Male 0 0 50 United-States <=50K
Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и род занятий. Последний столбец salary показывает, имеет ли человек зарплату меньше или равную 50 000 долл. США в год или больше 50 000 долл. США в год.
Toolbox™ статистики и машинного обучения предоставляет несколько вариантов классификации, включая деревья классификации, дискриминантный анализ, наивный Байес, ближайшие соседи, SVM и классификационные ансамбли. Полный список алгоритмов см. в разделе Классификация.
Прежде чем выбирать алгоритмы для решения проблемы, проверьте набор данных. Данные переписи имеют несколько заслуживающих внимания характеристик:
Данные являются табличными и содержат как числовые, так и категориальные переменные.
Данные содержат отсутствующие значения.
Переменная ответа (salary) имеет два класса (бинарная классификация).
Без каких-либо предположений или использования предварительного знания алгоритмов, которые вы ожидаете хорошо работать с вашими данными, вы просто тренируете все алгоритмы, которые поддерживают табличные данные и двоичную классификацию. Модели выходных кодов с исправлением ошибок (ECOC) используются для данных с более чем двумя классами. Дискриминантный анализ и алгоритмы ближайших соседей не анализируют данные, содержащие как числовые, так и категориальные переменные. Поэтому подходящими для этого примера алгоритмами являются SVM, дерево решений, ансамбль деревьев решений и наивная модель Байеса.
Подходят две модели SVM, одна с гауссовым ядром, а другая с полиномиальным ядром. Также подходят дерево решений, наивная модель Байеса и ансамбль деревьев решений.
% SVM with Gaussian kernel rng('default') % For reproducibility mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','gaussian', ... 'Standardize',true,'KernelScale','auto'); % SVM with polynomial kernel rng('default') mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','polynomial', ... 'Standardize',true,'KernelScale','auto'); % Decision tree rng('default') mdls{3} = fitctree(adultdata,'salary'); % Naive Bayes rng('default') mdls{4} = fitcnb(adultdata,'salary'); % Ensemble of decision trees rng('default') mdls{5} = fitcensemble(adultdata,'salary');
Если вы используете только оценки прогнозирования базовых моделей на данных обучения, сложенный ансамбль может быть переоборудован. Чтобы уменьшить переоборудование, используйте k-кратные перекрестные оценки. Чтобы обеспечить обучение каждой модели с использованием одного и того же k-кратного разделения данных, создайте cvpartition и передать этот объект в crossval функция каждой базовой модели. Этот пример является проблемой двоичной классификации, поэтому необходимо учитывать только оценки для положительного или отрицательного класса.
Получить k-кратные оценки перекрестной проверки.
rng('default') % For reproducibility N = numel(mdls); Scores = zeros(size(adultdata,1),N); cv = cvpartition(adultdata.salary,"KFold",5); for ii = 1:N m = crossval(mdls{ii},'cvpartition',cv); [~,s] = kfoldPredict(m); Scores(:,ii) = s(:,m.ClassNames=='<=50K'); end
Создание составленного ансамбля путем его обучения по перекрестно проверенным классификационным баллам Scores с этими опциями:
Чтобы получить наилучшие результаты для сложенного ансамбля, оптимизируйте его гиперпараметры. Можно легко подобрать набор учебных данных и настроить параметры, вызвав функцию фитинга и установив ее 'OptimizeHyperparameters' аргумент пары имя-значение для 'auto'.
Определить 'Verbose' as 0 для отключения отображения сообщений.
Для воспроизводимости задайте случайное начальное число и используйте 'expected-improvement-plus' функция приобретения. Кроме того, для воспроизводимости алгоритма случайного леса укажите 'Reproducible' аргумент пары имя-значение как true для обучающихся на деревьях.
rng('default') % For reproducibility t = templateTree('Reproducible',true); stckdMdl = fitcensemble(Scores,adultdata.salary, ... 'OptimizeHyperparameters','auto', ... 'Learners',t, ... 'HyperparameterOptimizationOptions',struct('Verbose',0,'AcquisitionFunctionName','expected-improvement-plus'));

Проверьте производительность классификатора с помощью набора тестовых данных с помощью матрицы путаницы и теста гипотез Макнемара.
Найдите прогнозируемые метки, оценки и значения потерь набора тестовых данных для базовых моделей и составленного ансамбля.
Сначала выполните итерацию по базовым моделям для вычисления прогнозируемых меток, оценок и значений потерь.
label = []; score = zeros(size(adulttest,1),N); mdlLoss = zeros(1,numel(mdls)); for i = 1:N [lbl,s] = predict(mdls{i},adulttest); label = [label,lbl]; score(:,i) = s(:,m.ClassNames=='<=50K'); mdlLoss(i) = mdls{i}.loss(adulttest); end
Прикрепить прогнозы от сложенного ансамбля к label и mdlLoss.
[lbl,s] = predict(stckdMdl,score); label = [label,lbl]; mdlLoss(end+1) = stckdMdl.loss(score,adulttest.salary);
Соедините партитуру сложенного ансамбля с партитурой базовых моделей.
score = [score,s(:,1)];
Просмотрите значения потерь.
names = {'SVM-Gaussian','SVM-Polynomial','Decision Tree','Naive Bayes', ...
'Ensemble of Decision Trees','Stacked Ensemble'};
array2table(mdlLoss,'VariableNames',names)ans=1×6 table
SVM-Gaussian SVM-Polynomial Decision Tree Naive Bayes Ensemble of Decision Trees Stacked Ensemble
____________ ______________ _____________ ___________ __________________________ ________________
0.15668 0.17473 0.1975 0.16764 0.15833 0.14519
Значение потерь сложенного ансамбля ниже значений потерь базовых моделей.
Вычислите матрицу путаницы с предсказанными классами и известными (истинными) классами набора тестовых данных, используя confusionchart функция.
figure c = cell(N+1,1); for i = 1:numel(c) subplot(2,3,i) c{i} = confusionchart(adulttest.salary,label(:,i)); title(names{i}) end

Диагональные элементы указывают количество правильно классифицированных экземпляров данного класса. Внедиагональные элементы являются экземплярами неправильно классифицированных наблюдений.
Чтобы проверить, является ли улучшение прогнозирования значительным, используйте testcholdout функция, которая проводит тест гипотезы Макнемара. Сравните сложенный ансамбль с наивной моделью Байеса.
[hNB,pNB] = testcholdout(label(:,6),label(:,4),adulttest.salary)
hNB = logical
1
pNB = 9.7646e-07
Сравните сложенный ансамбль с ансамблем деревьев решений.
[hE,pE] = testcholdout(label(:,6),label(:,5),adulttest.salary)
hE = logical
1
pE = 1.9357e-04
В обоих случаях низкое значение p сложенного ансамбля подтверждает, что его прогнозы статистически превосходят прогнозы других моделей.