exponenta event banner

Качество тестового ансамбля

Невозможно оценить прогностическое качество ансамбля на основе его производительности по данным обучения. Ансамбли, как правило, «перетренируются», что означает, что они дают чрезмерно оптимистичные оценки своей прогностической силы. Это означает результат resubLoss для классификации (resubLoss для регрессии) обычно указывает на более низкую ошибку, чем при получении новых данных.

Чтобы получить лучшее представление о качестве ансамбля, используйте один из следующих методов:

  • Оцените ансамбль на независимом тестовом наборе (полезно при наличии большого количества обучающих данных).

  • Оцените ансамбль путем перекрестной проверки (полезно, когда у вас нет большого количества учебных данных).

  • Оцените ансамбль на основе данных вне мешка (полезно при создании упакованного ансамбля с помощью fitcensemble или fitrensemble).

В этом примере используется пакетированный ансамбль, поэтому можно использовать все три метода оценки качества ансамбля.

Создайте искусственный набор данных с 20 предикторами. Каждая запись является случайным числом от 0 до 1. Начальная классификация: Y = 1, если X1 + X2 + X3 + X4 + X5 > 2,5 и Y = 0 в противном случае.

rng(1,'twister') % For reproducibility
X = rand(2000,20);
Y = sum(X(:,1:5),2) > 2.5;

Кроме того, для добавления шума к результатам случайным образом переключают 10% классификаций.

idx = randsample(2000,200);
Y(idx) = ~Y(idx);

Независимый набор тестов

Создание независимых обучающих и тестовых наборов данных. Использование 70% данных для обучающего аппарата по телефону cvpartition с использованием holdout вариант.

cvpart = cvpartition(Y,'holdout',0.3);
Xtrain = X(training(cvpart),:);
Ytrain = Y(training(cvpart),:);
Xtest = X(test(cvpart),:);
Ytest = Y(test(cvpart),:);

Создайте пакетированный классификационный ансамбль из 200 деревьев на основе данных обучения.

t = templateTree('Reproducible',true);  % For reproducibility of random predictor selections
bag = fitcensemble(Xtrain,Ytrain,'Method','Bag','NumLearningCycles',200,'Learners',t)
bag = 
  ClassificationBaggedEnsemble
             ResponseName: 'Y'
    CategoricalPredictors: []
               ClassNames: [0 1]
           ScoreTransform: 'none'
          NumObservations: 1400
               NumTrained: 200
                   Method: 'Bag'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: []
       FitInfoDescription: 'None'
                FResample: 1
                  Replace: 1
         UseObsForLearner: [1400x200 logical]


  Properties, Methods

Постройте график потери (неправильной классификации) тестовых данных в зависимости от количества обученных деревьев в ансамбле.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
xlabel('Number of trees')
ylabel('Test classification error')

Figure contains an axes. The axes contains an object of type line.

Перекрестная проверка

Создание пятикратного перекрестно проверенного пакетированного ансамбля.

cv = fitcensemble(X,Y,'Method','Bag','NumLearningCycles',200,'Kfold',5,'Learners',t)
cv = 
  ClassificationPartitionedEnsemble
    CrossValidatedModel: 'Bag'
         PredictorNames: {1x20 cell}
           ResponseName: 'Y'
        NumObservations: 2000
                  KFold: 5
              Partition: [1x1 cvpartition]
      NumTrainedPerFold: [200 200 200 200 200]
             ClassNames: [0 1]
         ScoreTransform: 'none'


  Properties, Methods

Проверьте потерю перекрестной проверки как функцию количества деревьев в ансамбле.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
hold on
plot(kfoldLoss(cv,'mode','cumulative'),'r.')
hold off
xlabel('Number of trees')
ylabel('Classification error')
legend('Test','Cross-validation','Location','NE')

Figure contains an axes. The axes contains 2 objects of type line. These objects represent Test, Cross-validation.

Перекрестная проверка дает сопоставимые оценки с оценками независимого набора.

Оценки вне пакета

Создайте кривую потерь для оценок вне пакета и постройте ее график вместе с другими кривыми.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
hold on
plot(kfoldLoss(cv,'mode','cumulative'),'r.')
plot(oobLoss(bag,'mode','cumulative'),'k--')
hold off
xlabel('Number of trees')
ylabel('Classification error')
legend('Test','Cross-validation','Out of bag','Location','NE')

Figure contains an axes. The axes contains 3 objects of type line. These objects represent Test, Cross-validation, Out of bag.

Оценки вне пакета снова сопоставимы с оценками других методов.

См. также

| | | | | | |

Связанные темы