exponenta event banner

Регуляризация ансамбля

Регуляризация - это процесс выбора меньшего количества слабых учеников для ансамбля таким образом, чтобы не снижать прогностическую работоспособность. В настоящее время можно упорядочить регрессионные ансамбли. (Можно также упорядочить классификатор дискриминантного анализа в контексте, не относящемся к ансамблю; см. «Упорядочить классификатор дискриминантного анализа».)

regularize метод находит оптимальный набор весов учащихся αt, которые минимизируют

∑n=1Nwng ((∑t=1Tαtht (xn)), йн) +λ∑t=1T'αt|.

Здесь

  • λ ≥ 0 - предоставленный параметр, называемый параметром lasso.

  • ht - слабый ученик в ансамбле, обученный наблюдениям N с предикторами xn, ответами yn и весами wn.

  • g (f, y) = (f-y) 2 - квадрат ошибки.

Ансамбль упорядочен на тех же (xn, yn, wn) данных, используемых для обучения, поэтому

∑n=1Nwng ((∑t=1Tαtht (xn)), йн)

ошибка повторного замещения ансамбля. Погрешность измеряется среднеквадратичной погрешностью (MSE).

Если используется λ = 0, regularize находит слабые веса обучающегося путем минимизации MSE повторного замещения. Ансамбли склонны перетренироваться. Другими словами, ошибка повторного замещения обычно меньше, чем истинная ошибка обобщения. Делая ошибку повторного замещения еще меньше, вы, скорее всего, сделаете точность ансамбля хуже, а не улучшите ее. С другой стороны, положительные значения λ толкают величину α t коэффициентов на 0. Это часто улучшает ошибку обобщения. Конечно, если выбрать λ слишком большой, все оптимальные коэффициенты равны 0, и ансамбль не имеет никакой точности. Обычно можно найти оптимальный диапазон для λ, в котором точность регуляризованного ансамбля лучше или сравнима с точностью полного ансамбля без регуляризации.

Хорошей особенностью регуляризации лассо является его способность управлять оптимизированными коэффициентами точно до 0. Если вес ученика αt равен 0, этот ученик может быть исключен из упорядоченного ансамбля. В итоге получается ансамбль с улучшенной точностью и меньшим количеством учеников.

Упорядочить регрессионный ансамбль

В этом примере используются данные для прогнозирования страхового риска автомобиля на основе его многочисленных атрибутов.

Загрузить imports-85 данные в рабочую область MATLAB.

load imports-85;

Посмотрите на описание данных, чтобы найти категориальные переменные и имена предикторов.

Description
Description = 9x79 char array
    '1985 Auto Imports Database from the UCI repository                             '
    'http://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.names'
    'Variables have been reordered to place variables with numeric values (referred '
    'to as "continuous" on the UCI site) to the left and categorical values to the  '
    'right. Specifically, variables 1:16 are: symboling, normalized-losses,         '
    'wheel-base, length, width, height, curb-weight, engine-size, bore, stroke,     '
    'compression-ratio, horsepower, peak-rpm, city-mpg, highway-mpg, and price.     '
    'Variables 17:26 are: make, fuel-type, aspiration, num-of-doors, body-style,    '
    'drive-wheels, engine-location, engine-type, num-of-cylinders, and fuel-system. '

Цель этого процесса состоит в том, чтобы предсказать «символизацию», первую переменную в данных, из других предикторов. «symboling» - целое число от -3 (хороший страховой риск) 3 (низкий страховой риск). Можно использовать классификационный ансамбль для прогнозирования этого риска вместо регрессионного ансамбля. При выборе между регрессией и классификацией сначала следует попробовать регрессию.

Подготовьте данные для монтажа ансамбля.

Y = X(:,1);
X(:,1) = [];
VarNames = {'normalized-losses' 'wheel-base' 'length' 'width' 'height' ...
  'curb-weight' 'engine-size' 'bore' 'stroke' 'compression-ratio' ...
  'horsepower' 'peak-rpm' 'city-mpg' 'highway-mpg' 'price' 'make' ...
  'fuel-type' 'aspiration' 'num-of-doors' 'body-style' 'drive-wheels' ...
  'engine-location' 'engine-type' 'num-of-cylinders' 'fuel-system'};
catidx = 16:25; % indices of categorical predictors

Создайте ансамбль регрессии из данных, используя 300 деревьев.

ls = fitrensemble(X,Y,'Method','LSBoost','NumLearningCycles',300, ...
    'LearnRate',0.1,'PredictorNames',VarNames, ...
    'ResponseName','Symboling','CategoricalPredictors',catidx)
ls = 
  RegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
          NumObservations: 205
               NumTrained: 300
                   Method: 'LSBoost'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: [300x1 double]
       FitInfoDescription: {2x1 cell}
           Regularization: []


  Properties, Methods

Последняя строка, Regularization, пуст ([]). Чтобы упорядочить ансамбль, необходимо использовать regularize способ.

cv = crossval(ls,'KFold',5);
figure;
plot(kfoldLoss(cv,'Mode','Cumulative'));
xlabel('Number of trees');
ylabel('Cross-validated MSE');
ylim([0.2,2])

Figure contains an axes. The axes contains an object of type line.

Кажется, вы можете получить удовлетворительную производительность от меньшего ансамбля, возможно, содержащего от 50 до 100 деревьев.

Позвоните в regularize метод, чтобы попытаться найти деревья, которые можно удалить из ансамбля. По умолчанию regularize исследует 10 значений лассо (Lambda) параметр, разнесенный в геометрической прогрессии.

ls = regularize(ls)
ls = 
  RegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
          NumObservations: 205
               NumTrained: 300
                   Method: 'LSBoost'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: [300x1 double]
       FitInfoDescription: {2x1 cell}
           Regularization: [1x1 struct]


  Properties, Methods

Regularization свойство больше не пусто.

Постройте график среднеквадратичной ошибки (MSE) и количества учеников с ненулевыми весами по отношению к параметру lasso. Отдельно постройте график значения в Lambda = 0. Использовать логарифмическую шкалу, поскольку значения Lambda экспоненциально разнесены.

figure;
semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ...
    'bx-','Markersize',10);
line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ...
     ls.Regularization.ResubstitutionMSE(1)],...
    'Marker','x','Markersize',10,'Color','b');
r0 = resubLoss(ls);
line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],...
     [r0 r0],'Color','r','LineStyle','--');
xlabel('Lambda');
ylabel('Resubstitution MSE');
annotation('textbox',[0.5 0.22 0.5 0.05],'String','unregularized ensemble', ...
    'Color','r','FontSize',14,'LineStyle','none');

Figure contains an axes. The axes contains 3 objects of type line.

figure;
loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1));
line([1e-3 1e-3],...
    [sum(ls.Regularization.TrainedWeights(:,1)>0) ...
    sum(ls.Regularization.TrainedWeights(:,1)>0)],...
    'marker','x','markersize',10,'color','b');
line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],...
    [ls.NTrained ls.NTrained],...
    'color','r','LineStyle','--');
xlabel('Lambda');
ylabel('Number of learners');
annotation('textbox',[0.3 0.8 0.5 0.05],'String','unregularized ensemble',...
    'color','r','FontSize',14,'LineStyle','none');

Figure contains an axes. The axes contains 3 objects of type line.

Значения MSE повторного замещения, вероятно, будут чрезмерно оптимистичными. Для получения более достоверных оценок ошибки, связанной с различными значениями Lambda, перекрестная проверка ансамбля с помощью cvshrink. Постройте график результирующей потери при перекрестной проверке (MSE) и числа учащихся Lambda.

rng(0,'Twister') % for reproducibility
[mse,nlearn] = cvshrink(ls,'Lambda',ls.Regularization.Lambda,'KFold',5);
Warning: Some folds do not have any trained weak learners.
figure;
semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ...
    'bx-','Markersize',10);
hold on;
semilogx(ls.Regularization.Lambda,mse,'ro-','Markersize',10);
hold off;
xlabel('Lambda');
ylabel('Mean squared error');
legend('resubstitution','cross-validation','Location','NW');
line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ...
     ls.Regularization.ResubstitutionMSE(1)],...
    'Marker','x','Markersize',10,'Color','b','HandleVisibility','off');
line([1e-3 1e-3],[mse(1) mse(1)],'Marker','o',...
    'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');

Figure contains an axes. The axes contains 2 objects of type line. These objects represent resubstitution, cross-validation.

figure;
loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1));
hold;
Current plot held
loglog(ls.Regularization.Lambda,nlearn,'r--');
hold off;
xlabel('Lambda');
ylabel('Number of learners');
legend('resubstitution','cross-validation','Location','NE');
line([1e-3 1e-3],...
    [sum(ls.Regularization.TrainedWeights(:,1)>0) ...
    sum(ls.Regularization.TrainedWeights(:,1)>0)],...
    'Marker','x','Markersize',10,'Color','b','HandleVisibility','off');
line([1e-3 1e-3],[nlearn(1) nlearn(1)],'marker','o',...
    'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');

Figure contains an axes. The axes contains 2 objects of type line. These objects represent resubstitution, cross-validation.

Анализ кросс-проверенной ошибки показывает, что MSE перекрестной проверки почти плоский для Lambda до немного больше 1e-2.

Исследовать ls.Regularization.Lambda найти наибольшее значение, которое дает MSE в плоской области (до бита 1e-2).

jj = 1:length(ls.Regularization.Lambda);
[jj;ls.Regularization.Lambda]
ans = 2×10

    1.0000    2.0000    3.0000    4.0000    5.0000    6.0000    7.0000    8.0000    9.0000   10.0000
         0    0.0019    0.0045    0.0107    0.0254    0.0602    0.1428    0.3387    0.8033    1.9048

Элемент 5 из ls.Regularization.Lambda имеет значение 0.0254, самый большой в плоском диапазоне.

Уменьшите размер ансамбля с помощью shrink способ. shrink возвращает компактный ансамбль без обучающих данных. Ошибка обобщения для нового компактного ансамбля уже была оценена перекрестной проверкой в mse(5).

cmp = shrink(ls,'weightcolumn',5)
cmp = 
  CompactRegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
               NumTrained: 8


  Properties, Methods

Количество деревьев в новом ансамбле заметно сократилось с 300 в ls.

Сравните размеры ансамблей.

sz(1) = whos('cmp'); sz(2) = whos('ls');
[sz(1).bytes sz(2).bytes]
ans = 1×2

       91209     3226892

Размер уменьшенного ансамбля - это доля размера оригинала. Обратите внимание, что размер ансамбля может варьироваться в зависимости от операционной системы.

Сравните MSE сокращенного ансамбля с MSE оригинального ансамбля.

figure;
plot(kfoldLoss(cv,'mode','cumulative'));
hold on
plot(cmp.NTrained,mse(5),'ro','MarkerSize',10);
xlabel('Number of trees');
ylabel('Cross-validated MSE');
legend('unregularized ensemble','regularized ensemble',...
    'Location','NE');
hold off

Figure contains an axes. The axes contains 2 objects of type line. These objects represent unregularized ensemble, regularized ensemble.

Уменьшенный ансамбль дает низкие потери при использовании гораздо меньшего количества деревьев.

См. также

| | | | | |

Связанные темы