Протестируйте качество ансамбля

Вы не можете оценить прогнозирующее качество ансамбля на основе его производительности на данных тренировки. Ансамбли склонны "перетренироваться", подразумевая, что они производят чрезмерно оптимистические оценки своей предсказательной силы. Это означает, что результат resubLoss для классификации (resubLoss для регрессии) обычно показывает на более низкую ошибку, чем вы входите в новые данные.

Чтобы получить лучшую идею качества ансамбля, используйте один из этих методов:

  • Оцените ансамбль на независимом наборе тестов (полезный, когда у вас будет много данных тренировки).

  • Оцените ансамбль перекрестной проверкой (полезный, когда у вас не будет большого количества данных тренировки).

  • Оцените ансамбль на данных из сумки (полезный, когда вы создадите уволенный ансамбль с fitcensemble или fitrensemble).

Этот пример использует уволенный ансамбль, таким образом, он может использовать все три метода оценки качества ансамбля.

Сгенерируйте искусственный набор данных с 20 предикторами. Каждая запись является случайным числом от 0 до 1. Начальная классификация Y=1 если X1+X2+X3+X4+X5>2.5 и Y=0 в противном случае.

rng(1,'twister') % For reproducibility
X = rand(2000,20);
Y = sum(X(:,1:5),2) > 2.5;

Кроме того, чтобы добавить шум в результаты, случайным образом переключите 10% классификаций.

idx = randsample(2000,200);
Y(idx) = ~Y(idx);

Независимый набор тестов

Создайте независимые наборы обучающих данных и наборы тестов данных. Используйте 70% данных для набора обучающих данных путем вызова cvpartition с помощью опции holdout.

cvpart = cvpartition(Y,'holdout',0.3);
Xtrain = X(training(cvpart),:);
Ytrain = Y(training(cvpart),:);
Xtest = X(test(cvpart),:);
Ytest = Y(test(cvpart),:);

Создайте уволенный ансамбль классификации 200 деревьев от данных тренировки.

t = templateTree('Reproducible',true);  % For reproducibility of random predictor selections
bag = fitcensemble(Xtrain,Ytrain,'Method','Bag','NumLearningCycles',200,'Learners',t)
bag = 
  classreg.learning.classif.ClassificationBaggedEnsemble
             ResponseName: 'Y'
    CategoricalPredictors: []
               ClassNames: [0 1]
           ScoreTransform: 'none'
          NumObservations: 1400
               NumTrained: 200
                   Method: 'Bag'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: []
       FitInfoDescription: 'None'
                FResample: 1
                  Replace: 1
         UseObsForLearner: [1400x200 logical]


  Properties, Methods

Постройте потерю (misclassification) тестовых данных как функция количества обученных деревьев в ансамбле.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
xlabel('Number of trees')
ylabel('Test classification error')

Перекрестная проверка

Сгенерируйте пятикратный перекрестный подтвержденный уволенный ансамбль.

cv = fitcensemble(X,Y,'Method','Bag','NumLearningCycles',200,'Kfold',5,'Learners',t)
cv = 
  classreg.learning.partition.ClassificationPartitionedEnsemble
    CrossValidatedModel: 'Bag'
         PredictorNames: {1x20 cell}
           ResponseName: 'Y'
        NumObservations: 2000
                  KFold: 5
              Partition: [1x1 cvpartition]
      NumTrainedPerFold: [200 200 200 200 200]
             ClassNames: [0 1]
         ScoreTransform: 'none'


  Properties, Methods

Исследуйте потерю перекрестной проверки как функцию количества деревьев в ансамбле.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
hold on
plot(kfoldLoss(cv,'mode','cumulative'),'r.')
hold off
xlabel('Number of trees')
ylabel('Classification error')
legend('Test','Cross-validation','Location','NE')

Перекрестная проверка дает сопоставимые оценки тем из независимого набора.

Оценки из сумки

Сгенерируйте кривую потерь для оценок из сумки и постройте ее наряду с другими кривыми.

figure
plot(loss(bag,Xtest,Ytest,'mode','cumulative'))
hold on
plot(kfoldLoss(cv,'mode','cumulative'),'r.')
plot(oobLoss(bag,'mode','cumulative'),'k--')
hold off
xlabel('Number of trees')
ylabel('Classification error')
legend('Test','Cross-validation','Out of bag','Location','NE')

Оценки из сумки снова сопоставимы с теми из других методов.

Смотрите также

| | | | | | | |

Похожие темы