Регуляризация ансамбля

Регуляризация - это процесс выбора меньшего количества слабых учащихся для ансамбля таким образом, чтобы это не снижало прогнозирующую эффективность. В настоящее время вы можете регулировать регрессионые ансамбли. (Можно также упорядочить классификатор дискриминантного анализа в контексте без ансамбля; см. «Упорядочение классификатора дискриминантного анализа».)

The regularize метод находит оптимальный набор весов учащихся αt которые минимизируют

n=1Nwng((t=1Tαtht(xn)),yn)+λt=1T|αt|.

Здесь

  • λ ≥ 0 является параметром, который вы предоставляете, называемым параметром lasso.

  • ht является слабым учеником в ансамбле, обученным наблюдениям за N с xn предикторов, yn откликов и wn весов.

  • g (f, y) = (fy)2 является квадратичная невязка.

Ансамбль упорядочен на тех же (xn, yn, wn) данных, используемых для обучения, так

n=1Nwng((t=1Tαtht(xn)),yn)

- ошибка повторной замещения ансамбля. Ошибка измеряется средней квадратичной невязкой (MSE).

Если вы используете λ = 0, regularize находит слабые веса учащихся путем минимизации resubstitution MSE. Ансамбли, как правило, переучиваются. Другими словами, ошибка восстановления обычно меньше, чем истинная ошибка обобщения. Делая ошибку реституции еще меньше, вы, вероятно, сделаете точность ансамбля хуже, чем улучшить ее. С другой стороны, положительные значения λ приводят величину коэффициентов αt к 0. Это часто улучшает ошибку обобщения. Конечно, если вы выбираете λ слишком большие, все оптимальные коэффициенты равны 0, и ансамбль не имеет никакой точности. Обычно можно найти оптимальную область значений для λ, в котором точность регуляризованного ансамбля лучше или сопоставима с точностью полного ансамбля без регуляризации.

Приятной функцией регуляризации лассо является его способность приводить оптимизированные коэффициенты точно в 0. Если αt веса учащегося равен 0, этот ученик может быть исключен из регуляризованного ансамбля. В конце концов, вы получаете ансамбль с улучшенной точностью и меньшим количеством учащихся.

Упорядочение регрессионного ансамбля

Этот пример использует данные для прогнозирования страхового риска автомобиля на основе его многочисленных атрибутов.

Загрузите imports-85 данные в рабочее пространство MATLAB.

load imports-85;

Проверьте описание данных, чтобы найти категориальные переменные и имена предикторов.

Description
Description = 9x79 char array
    '1985 Auto Imports Database from the UCI repository                             '
    'http://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.names'
    'Variables have been reordered to place variables with numeric values (referred '
    'to as "continuous" on the UCI site) to the left and categorical values to the  '
    'right. Specifically, variables 1:16 are: symboling, normalized-losses,         '
    'wheel-base, length, width, height, curb-weight, engine-size, bore, stroke,     '
    'compression-ratio, horsepower, peak-rpm, city-mpg, highway-mpg, and price.     '
    'Variables 17:26 are: make, fuel-type, aspiration, num-of-doors, body-style,    '
    'drive-wheels, engine-location, engine-type, num-of-cylinders, and fuel-system. '

Цель этого процесса состоит в том, чтобы предсказать «символизацию», первую переменную в данных, от других предикторов. «symboling» является целым числом от -3 (хороший страховой риск), чтобы 3 (низкий страховой риск). Можно использовать классификационный ансамбль, чтобы предсказать этот риск вместо регрессионного ансамбля. Когда у вас есть выбор между регрессией и классификацией, вы должны сначала попробовать регрессию.

Подготовим данные для ансамбля подборы кривой.

Y = X(:,1);
X(:,1) = [];
VarNames = {'normalized-losses' 'wheel-base' 'length' 'width' 'height' ...
  'curb-weight' 'engine-size' 'bore' 'stroke' 'compression-ratio' ...
  'horsepower' 'peak-rpm' 'city-mpg' 'highway-mpg' 'price' 'make' ...
  'fuel-type' 'aspiration' 'num-of-doors' 'body-style' 'drive-wheels' ...
  'engine-location' 'engine-type' 'num-of-cylinders' 'fuel-system'};
catidx = 16:25; % indices of categorical predictors

Создайте регрессионный ансамбль из данных с помощью 300 деревьев.

ls = fitrensemble(X,Y,'Method','LSBoost','NumLearningCycles',300, ...
    'LearnRate',0.1,'PredictorNames',VarNames, ...
    'ResponseName','Symboling','CategoricalPredictors',catidx)
ls = 
  RegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
          NumObservations: 205
               NumTrained: 300
                   Method: 'LSBoost'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: [300x1 double]
       FitInfoDescription: {2x1 cell}
           Regularization: []


  Properties, Methods

Конечная линия, Regularization, пуст ([]). Чтобы упорядочить ансамбль, необходимо использовать regularize способ.

cv = crossval(ls,'KFold',5);
figure;
plot(kfoldLoss(cv,'Mode','Cumulative'));
xlabel('Number of trees');
ylabel('Cross-validated MSE');
ylim([0.2,2])

Figure contains an axes. The axes contains an object of type line.

Похоже, вы можете получить удовлетворительную эффективность от меньшего ансамбля, возможно, который содержит от 50 до 100 деревьев.

Вызовите regularize метод, чтобы попытаться найти деревья, которые можно удалить из ансамбля. По умолчанию regularize исследует 10 значений лассо (Lambda) параметр разнесен в геометрической прогрессии.

ls = regularize(ls)
ls = 
  RegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
          NumObservations: 205
               NumTrained: 300
                   Method: 'LSBoost'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: [300x1 double]
       FitInfoDescription: {2x1 cell}
           Regularization: [1x1 struct]


  Properties, Methods

The Regularization свойство больше не пустое.

Постройте график среднеквадратичной ошибки (MSE) и количества учащихся с ненулевыми весами относительно параметра lasso. Отдельно постройте график значения в Lambda = 0. Используйте логарифмическую шкалу, потому что значения Lambda разнесены в геометрической прогрессии.

figure;
semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ...
    'bx-','Markersize',10);
line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ...
     ls.Regularization.ResubstitutionMSE(1)],...
    'Marker','x','Markersize',10,'Color','b');
r0 = resubLoss(ls);
line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],...
     [r0 r0],'Color','r','LineStyle','--');
xlabel('Lambda');
ylabel('Resubstitution MSE');
annotation('textbox',[0.5 0.22 0.5 0.05],'String','unregularized ensemble', ...
    'Color','r','FontSize',14,'LineStyle','none');

Figure contains an axes. The axes contains 3 objects of type line.

figure;
loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1));
line([1e-3 1e-3],...
    [sum(ls.Regularization.TrainedWeights(:,1)>0) ...
    sum(ls.Regularization.TrainedWeights(:,1)>0)],...
    'marker','x','markersize',10,'color','b');
line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],...
    [ls.NTrained ls.NTrained],...
    'color','r','LineStyle','--');
xlabel('Lambda');
ylabel('Number of learners');
annotation('textbox',[0.3 0.8 0.5 0.05],'String','unregularized ensemble',...
    'color','r','FontSize',14,'LineStyle','none');

Figure contains an axes. The axes contains 3 objects of type line.

Значения MSE реституции, вероятно, будут слишком оптимистичными. Чтобы получить более достоверные оценки ошибки, связанной с различными значениями Lambda, перекрестно проверьте ансамбль используя cvshrink. Постройте график полученной потери перекрестной валидации (MSE) и количества учащихся против Lambda.

rng(0,'Twister') % for reproducibility
[mse,nlearn] = cvshrink(ls,'Lambda',ls.Regularization.Lambda,'KFold',5);
Warning: Some folds do not have any trained weak learners.
figure;
semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ...
    'bx-','Markersize',10);
hold on;
semilogx(ls.Regularization.Lambda,mse,'ro-','Markersize',10);
hold off;
xlabel('Lambda');
ylabel('Mean squared error');
legend('resubstitution','cross-validation','Location','NW');
line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ...
     ls.Regularization.ResubstitutionMSE(1)],...
    'Marker','x','Markersize',10,'Color','b','HandleVisibility','off');
line([1e-3 1e-3],[mse(1) mse(1)],'Marker','o',...
    'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');

Figure contains an axes. The axes contains 2 objects of type line. These objects represent resubstitution, cross-validation.

figure;
loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1));
hold;
Current plot held
loglog(ls.Regularization.Lambda,nlearn,'r--');
hold off;
xlabel('Lambda');
ylabel('Number of learners');
legend('resubstitution','cross-validation','Location','NE');
line([1e-3 1e-3],...
    [sum(ls.Regularization.TrainedWeights(:,1)>0) ...
    sum(ls.Regularization.TrainedWeights(:,1)>0)],...
    'Marker','x','Markersize',10,'Color','b','HandleVisibility','off');
line([1e-3 1e-3],[nlearn(1) nlearn(1)],'marker','o',...
    'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');

Figure contains an axes. The axes contains 2 objects of type line. These objects represent resubstitution, cross-validation.

Изучение перекрестной проверенной ошибки показывает, что MSE перекрестной валидации почти плоский для Lambda до бит 1e-2.

Исследуйте ls.Regularization.Lambda чтобы найти наивысшее значение, которое дает MSE в плоской области (до немного выше 1e-2).

jj = 1:length(ls.Regularization.Lambda);
[jj;ls.Regularization.Lambda]
ans = 2×10

    1.0000    2.0000    3.0000    4.0000    5.0000    6.0000    7.0000    8.0000    9.0000   10.0000
         0    0.0019    0.0045    0.0107    0.0254    0.0602    0.1428    0.3387    0.8033    1.9048

Элементный 5 от ls.Regularization.Lambda имеет значение 0.0254, самый большой в плоской области значений.

Уменьшите размер ансамбля, используя shrink способ. shrink возвращает компактный ансамбль без обучающих данных. Ошибка обобщения для нового компактного ансамбля уже была оценена путем перекрестной валидации в mse(5).

cmp = shrink(ls,'weightcolumn',5)
cmp = 
  CompactRegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
               NumTrained: 8


  Properties, Methods

Число деревьев в новом ансамбле значительно сократилось с 300 в ls.

Сравните размеры ансамблей.

sz(1) = whos('cmp'); sz(2) = whos('ls');
[sz(1).bytes sz(2).bytes]
ans = 1×2

       91209     3226892

Размер редуцированного ансамбля составляет часть размера оригинала. Обратите внимание, что размеры ансамбля могут варьироваться в зависимости от операционной системы.

Сравните MSE редуцированного ансамбля с MSE оригинального ансамбля.

figure;
plot(kfoldLoss(cv,'mode','cumulative'));
hold on
plot(cmp.NTrained,mse(5),'ro','MarkerSize',10);
xlabel('Number of trees');
ylabel('Cross-validated MSE');
legend('unregularized ensemble','regularized ensemble',...
    'Location','NE');
hold off

Figure contains an axes. The axes contains 2 objects of type line. These objects represent unregularized ensemble, regularized ensemble.

Уменьшенный ансамбль дает низкие потери при использовании гораздо меньшего количества деревьев.

См. также

| | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте