Качество тестового ансамбля

Вы не можете оценить прогнозирующее качество ансамбля на основе его эффективности на обучающих данных. Ансамбли, как правило, «переобучаются», что означает, что они дают слишком оптимистичные оценки своей прогностической степени. Это означает результат resubLoss для классификации (resubLoss для регрессии) обычно указывает на более низкую ошибку, чем вы получаете на новых данных.

Чтобы получить лучшее представление о качестве ансамбля, используйте один из следующих методов:

  • Оцените ансамбль на независимом тестовом наборе (полезном, когда у вас есть много обучающих данных).

  • Оцените ансамбль по перекрестной валидации (полезно, когда у вас нет много обучающих данных).

  • Оцените ансамбль по данным вне сумки (полезно, когда вы создаете упакованный ансамбль с fitcensemble или fitrensemble).

Этот пример использует упакованный ансамбль, чтобы он мог использовать все три метода оценки качества ансамбля.

Сгенерируйте искусственный набор данных с 20 предикторами. Каждая запись является случайным числом от 0 до 1. Начальная классификация Y=1 если X1+X2+X3+X4+X5>2.5 и Y=0 в противном случае.

rng(1,'twister') % For reproducibility
X = rand(2000,20);
Y = sum(X(:,1:5),2) > 2.5;

В сложение, чтобы добавить шум к результатам, случайным образом переключите 10% классификаций.

idx = randsample(2000,200);
Y(idx) = ~Y(idx);

Независимый тестовый набор

Создайте независимые наборы обучающих и тестовых данных. Используйте 70% данных для набора обучающих данных по телефону cvpartition использование holdout опция.

cvpart = cvpartition(Y,'holdout',0.3);
Xtrain = X(training(cvpart),:);
Ytrain = Y(training(cvpart),:);
Xtest = X(test(cvpart),:);
Ytest = Y(test(cvpart),:);

Создайте пакетированный классификационный ансамбль из 200 деревьев по обучающие данные.

t = templateTree('Reproducible',true);  % For reproducibility of random predictor selections
bag = fitcensemble(Xtrain,Ytrain,'Method','Bag','NumLearningCycles',200,'Learners',t)
bag = 
  ClassificationBaggedEnsemble
             ResponseName: 'Y'
    CategoricalPredictors: []
               ClassNames: [0 1]
           ScoreTransform: 'none'
          NumObservations: 1400
               NumTrained: 200
                   Method: 'Bag'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: []
       FitInfoDescription: 'None'
                FResample: 1
                  Replace: 1
         UseObsForLearner: [1400x200 logical]


  Properties, Methods

Постройте график потерь (неправильной классификации) тестовых данных как функции от количества обученных деревьев в ансамбле.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
xlabel('Number of trees')
ylabel('Test classification error')

Figure contains an axes. The axes contains an object of type line.

Перекрестная валидация

Сгенерируйте пятикратный перекрестно проверенный упакованный ансамбль.

cv = fitcensemble(X,Y,'Method','Bag','NumLearningCycles',200,'Kfold',5,'Learners',t)
cv = 
  ClassificationPartitionedEnsemble
    CrossValidatedModel: 'Bag'
         PredictorNames: {1x20 cell}
           ResponseName: 'Y'
        NumObservations: 2000
                  KFold: 5
              Partition: [1x1 cvpartition]
      NumTrainedPerFold: [200 200 200 200 200]
             ClassNames: [0 1]
         ScoreTransform: 'none'


  Properties, Methods

Исследуйте потери перекрестной валидации как функцию от количества деревьев в ансамбле.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
hold on
plot(kfoldLoss(cv,'mode','cumulative'),'r.')
hold off
xlabel('Number of trees')
ylabel('Classification error')
legend('Test','Cross-validation','Location','NE')

Figure contains an axes. The axes contains 2 objects of type line. These objects represent Test, Cross-validation.

Перекрестная валидация дает сопоставимые оценки с оценками независимого набора.

Оценки вне мешка

Сгенерируйте кривую потерь для оценок вне мешка и постройте ее график вместе с другими кривыми.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
hold on
plot(kfoldLoss(cv,'mode','cumulative'),'r.')
plot(oobLoss(bag,'mode','cumulative'),'k--')
hold off
xlabel('Number of trees')
ylabel('Classification error')
legend('Test','Cross-validation','Out of bag','Location','NE')

Figure contains an axes. The axes contains 3 objects of type line. These objects represent Test, Cross-validation, Out of bag.

Оценки вне мешка снова сопоставимы с оценками других методов.

См. также

| | | | | | |

Похожие темы