LPBoost и TotalBoost для малочисленных ансамблей

В этом примере показано, как получить преимущества LPBoost и TotalBoost алгоритмы. Эти алгоритмы совместно используют две выгодных характеристики:

  • Они самоотключают, что означает, что вы не должны выяснять сколько членов, чтобы включать.

  • Они производят ансамбли с некоторыми очень маленькими весами, позволяя вам безопасно удалить члены ансамбля.

Загрузите данные

Загрузите ionosphere набор данных.

load ionosphere

Создайте ансамбли классификации

Создайте ансамбли для классификации ionosphere данные с помощью LPBoost, TotalBoost, и, для сравнения, AdaBoostM1 алгоритмы. Трудно знать сколько членов включать в ансамбль. Для LPBoost и TotalBoost, попытайтесь использовать 500. Для сравнения также используйте 500 для AdaBoostM1.

Слабые ученики по умолчанию для повышения методов являются деревьями решений с MaxNumSplits набор свойств к 10. Эти деревья имеют тенденцию соответствовать лучше, чем пни (с 1 максимальное разделение), и может сверхсоответствовать больше. Поэтому, чтобы предотвратить сверхподбор кривой, используйте пни в качестве слабых учеников для ансамблей.

rng('default') % For reproducibility
T = 500;
treeStump = templateTree('MaxNumSplits',1);
adaStump = fitcensemble(X,Y,'Method','AdaBoostM1','NumLearningCycles',T,'Learners',treeStump);
totalStump = fitcensemble(X,Y,'Method','TotalBoost','NumLearningCycles',T,'Learners',treeStump);
lpStump = fitcensemble(X,Y,'Method','LPBoost','NumLearningCycles',T,'Learners',treeStump);

figure
plot(resubLoss(adaStump,'Mode','Cumulative'));
hold on
plot(resubLoss(totalStump,'Mode','Cumulative'),'r');
plot(resubLoss(lpStump,'Mode','Cumulative'),'g');
hold off
xlabel('Number of stumps');
ylabel('Training error');
legend('AdaBoost','TotalBoost','LPBoost','Location','NE');

Figure contains an axes object. The axes object contains 3 objects of type line. These objects represent AdaBoost, TotalBoost, LPBoost.

Все три алгоритма достигают совершенного предсказания на обучающих данных через некоторое время.

Исследуйте число членов во всех трех ансамблях.

[adaStump.NTrained totalStump.NTrained lpStump.NTrained]
ans = 1×3

   500    52    77

AdaBoostM1 обученный весь 500 члены. Другие два алгоритма остановили обучение рано.

Крест подтверждает ансамбли

Крест подтверждает ансамбли, чтобы лучше определить точность ансамбля.

cvlp = crossval(lpStump,'KFold',5);
cvtotal = crossval(totalStump,'KFold',5);
cvada = crossval(adaStump,'KFold',5);

figure
plot(kfoldLoss(cvada,'Mode','Cumulative'));
hold on
plot(kfoldLoss(cvtotal,'Mode','Cumulative'),'r');
plot(kfoldLoss(cvlp,'Mode','Cumulative'),'g');
hold off
xlabel('Ensemble size');
ylabel('Cross-validated error');
legend('AdaBoost','TotalBoost','LPBoost','Location','NE');

Figure contains an axes object. The axes object contains 3 objects of type line. These objects represent AdaBoost, TotalBoost, LPBoost.

Результаты показывают, что каждый повышающий алгоритм достигает потери 10% или ниже с 50 членами ансамбля.

Компактный и удаляют члены ансамбля

Чтобы уменьшать размеры ансамбля, уплотните их, и затем используйте removeLearners. Вопрос, сколько учеников необходимо удалить? Перекрестные подтвержденные кривые потерь дают вам одну меру. Для другого исследуйте веса ученика на LPBoost и TotalBoost после уплотнения.

cada = compact(adaStump);
clp = compact(lpStump);
ctotal = compact(totalStump);

figure
subplot(2,1,1)
plot(clp.TrainedWeights)
title('LPBoost weights')
subplot(2,1,2)
plot(ctotal.TrainedWeights)
title('TotalBoost weights')

Figure contains 2 axes objects. Axes object 1 with title LPBoost weights contains an object of type line. Axes object 2 with title TotalBoost weights contains an object of type line.

Оба LPBoost и TotalBoost покажите ясные точки, где веса члена ансамбля становятся незначительными.

Удалите неважные члены ансамбля.

cada = removeLearners(cada,150:cada.NTrained);
clp = removeLearners(clp,60:clp.NTrained);
ctotal = removeLearners(ctotal,40:ctotal.NTrained);

Проверяйте, что удаление этих учеников не влияет на точность ансамбля на обучающих данных.

[loss(cada,X,Y) loss(clp,X,Y) loss(ctotal,X,Y)]
ans = 1×3

     0     0     0

Проверяйте получившиеся компактные размеры ансамбля.

s(1) = whos('cada');
s(2) = whos('clp');
s(3) = whos('ctotal');
s.bytes
ans = 590844
ans = 236030
ans = 157190

Размеры компактных ансамблей приблизительно пропорциональны числу членов в каждом.

Смотрите также

| | | | | |

Похожие темы