Настройте RobustBoost

RobustBoost алгоритм может сделать хорошие предсказания классификации, даже когда обучающие данные имеют шум. Однако RobustBoost по умолчанию параметры могут произвести ансамбль, который не предсказывает хорошо. Этот пример показывает один способ настроить параметры для лучшей прогнозирующей точности.

Сгенерируйте данные с шумом метки. Этот пример имеет двадцать универсальных случайных чисел на наблюдение и классифицирует наблюдение как 1 если сумма первых пяти чисел превышает 2.5 (так больше, чем среднее значение), и 0 в противном случае:

rng(0,'twister') % for reproducibility
Xtrain = rand(2000,20);
Ytrain = sum(Xtrain(:,1:5),2) > 2.5;

Чтобы добавить шум, случайным образом переключите 10% классификаций:

idx = randsample(2000,200);
Ytrain(idx) = ~Ytrain(idx);

Создайте ансамбль с AdaBoostM1 в целях сравнения:

ada = fitcensemble(Xtrain,Ytrain,'Method','AdaBoostM1', ...
    'NumLearningCycles',300,'Learners','Tree','LearnRate',0.1);

Создайте ансамбль с RobustBoost. Поскольку данные имеют 10%-ю неправильную классификацию, возможно, ошибочная цель 15% разумна.

rb1 = fitcensemble(Xtrain,Ytrain,'Method','RobustBoost', ...
    'NumLearningCycles',300,'Learners','Tree','RobustErrorGoal',0.15, ...
    'RobustMaxMargin',1);

Обратите внимание на то, что, если вы устанавливаете ошибочную цель к достаточно высокому значению, затем программное обеспечение возвращает ошибку.

Создайте ансамбль с очень оптимистической ошибочной целью, 0.01:

rb2 = fitcensemble(Xtrain,Ytrain,'Method','RobustBoost', ...
    'NumLearningCycles',300,'Learners','Tree','RobustErrorGoal',0.01);

Сравните ошибку перезамены этих трех ансамблей:

figure
plot(resubLoss(rb1,'Mode','Cumulative'));
hold on
plot(resubLoss(rb2,'Mode','Cumulative'),'r--');
plot(resubLoss(ada,'Mode','Cumulative'),'g.');
hold off;
xlabel('Number of trees');
ylabel('Resubstitution error');
legend('ErrorGoal=0.15','ErrorGoal=0.01',...
    'AdaBoostM1','Location','NE');

Figure contains an axes object. The axes object contains 3 objects of type line. These objects represent ErrorGoal=0.15, ErrorGoal=0.01, AdaBoostM1.

Весь RobustBoost кривые показывают более низкую ошибку перезамены, чем AdaBoostM1 кривая. Ошибочная цель 0.01 изогнитесь показывает самую низкую ошибку перезамены в большей части области значений.

Xtest = rand(2000,20);
Ytest = sum(Xtest(:,1:5),2) > 2.5;
idx = randsample(2000,200);
Ytest(idx) = ~Ytest(idx);
figure;
plot(loss(rb1,Xtest,Ytest,'Mode','Cumulative'));
hold on
plot(loss(rb2,Xtest,Ytest,'Mode','Cumulative'),'r--');
plot(loss(ada,Xtest,Ytest,'Mode','Cumulative'),'g.');
hold off;
xlabel('Number of trees');
ylabel('Test error');
legend('ErrorGoal=0.15','ErrorGoal=0.01',...
    'AdaBoostM1','Location','NE');

Figure contains an axes object. The axes object contains 3 objects of type line. These objects represent ErrorGoal=0.15, ErrorGoal=0.01, AdaBoostM1.

Кривая ошибок для ошибочной цели 0.15 является самой низкой (лучше всего) в нанесенной на график области значений. AdaBoostM1 имеет более высокую ошибку, чем кривая для ошибочной цели 0.15. Кривая для также оптимистической ошибочной цели 0.01 остается существенно выше (хуже), чем другие алгоритмы для большей части нанесенной на график области значений.

Смотрите также

| |

Похожие темы