Регуляризация ансамбля

Регуляризация является процессом выбора меньшего количества слабых учеников для ансамбля способом, который не уменьшает прогнозирующую эффективность. В настоящее время можно упорядочить ансамбли регрессии. (Можно также упорядочить классификатор дискриминантного анализа в контексте неансамбля; смотрите Упорядочивают Классификатор Дискриминантного анализа.)

regularize метод находит оптимальный набор весов ученика αt, которые минимизируют

n=1Nwng((t=1Tαtht(xn)),yn)+λt=1T|αt|.

Здесь

  • λ ≥ 0 является параметром, который вы обеспечиваете, названный параметром лассо.

  • ht является слабым учеником в ансамбле, обученном на наблюдениях N с предикторами xn, ответы yn и веса wn.

  • g (f, y) = (fy) 2 является квадратичной невязкой.

Ансамбль упорядочен на том же самом (xn, yn, wn) данные, используемые для обучения, таким образом,

n=1Nwng((t=1Tαtht(xn)),yn)

ошибка перезамены ансамбля. Ошибка измеряется среднеквадратической ошибкой (MSE).

Если вы используете λ = 0, regularize находит слабые веса ученика путем минимизации перезамены MSE. Ансамбли склонны перетренироваться. Другими словами, ошибка перезамены обычно меньше, чем истинная ошибка обобщения. Путем совершения еще меньшей ошибки перезамены вы, вероятно, сделаете точность ансамбля хуже вместо того, чтобы улучшить его. С другой стороны, положительные значения λ продвигают величину коэффициентов αt к 0. Это часто улучшает ошибку обобщения. Конечно, если вы выбираете слишком большой λ, все оптимальные коэффициенты 0, и у ансамбля нет точности. Обычно можно найти оптимальную область значений для λ, в котором точность упорядоченного ансамбля лучше или сопоставима с тем из полного ансамбля без регуляризации.

Хорошей функцией регуляризации лассо является своя способность управлять оптимизированными коэффициентами точно к 0. Если вес ученика, который αt 0, этот ученик, может быть исключен из упорядоченного ансамбля. В конце вы получаете ансамбль с улучшенной точностью и меньшим количеством учеников.

Упорядочите ансамбль регрессии

Этот пример использует данные для предсказания страхового риска автомобиля на основе его многих атрибутов.

Загрузите imports-85 данные в рабочее пространство MATLAB.

load imports-85;

Посмотрите на описание данных, чтобы найти категориальные переменные и имена предиктора.

Description
Description = 9x79 char array
    '1985 Auto Imports Database from the UCI repository                             '
    'http://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.names'
    'Variables have been reordered to place variables with numeric values (referred '
    'to as "continuous" on the UCI site) to the left and categorical values to the  '
    'right. Specifically, variables 1:16 are: symboling, normalized-losses,         '
    'wheel-base, length, width, height, curb-weight, engine-size, bore, stroke,     '
    'compression-ratio, horsepower, peak-rpm, city-mpg, highway-mpg, and price.     '
    'Variables 17:26 are: make, fuel-type, aspiration, num-of-doors, body-style,    '
    'drive-wheels, engine-location, engine-type, num-of-cylinders, and fuel-system. '

Цель этого процесса состоит в том, чтобы предсказать "symboling", первую переменную в данных, от других предикторов. "symboling" является целым числом от -3 (хороший страховой риск) к 3 (плохой страховой риск). Вы могли использовать ансамбль классификации, чтобы предсказать этот риск вместо ансамбля регрессии. Когда у вас есть выбор между регрессией и классификацией, необходимо попробовать регрессию сначала.

Подготовьте данные к подбору кривой ансамбля.

Y = X(:,1);
X(:,1) = [];
VarNames = {'normalized-losses' 'wheel-base' 'length' 'width' 'height' ...
  'curb-weight' 'engine-size' 'bore' 'stroke' 'compression-ratio' ...
  'horsepower' 'peak-rpm' 'city-mpg' 'highway-mpg' 'price' 'make' ...
  'fuel-type' 'aspiration' 'num-of-doors' 'body-style' 'drive-wheels' ...
  'engine-location' 'engine-type' 'num-of-cylinders' 'fuel-system'};
catidx = 16:25; % indices of categorical predictors

Создайте ансамбль регрессии из данных с помощью 300 деревьев.

ls = fitrensemble(X,Y,'Method','LSBoost','NumLearningCycles',300, ...
    'LearnRate',0.1,'PredictorNames',VarNames, ...
    'ResponseName','Symboling','CategoricalPredictors',catidx)
ls = 
  RegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
          NumObservations: 205
               NumTrained: 300
                   Method: 'LSBoost'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: [300x1 double]
       FitInfoDescription: {2x1 cell}
           Regularization: []


  Properties, Methods

Итоговая линия, Regularizationisempty. Чтобы упорядочить ансамбль, необходимо использовать regularize метод.

cv = crossval(ls,'KFold',5);
figure;
plot(kfoldLoss(cv,'Mode','Cumulative'));
xlabel('Number of trees');
ylabel('Cross-validated MSE');
ylim([0.2,2])

Figure contains an axes. The axes contains an object of type line.

Кажется, что вы можете получить удовлетворительную эффективность из меньшего ансамбля, возможно, один содержащий от 50 до 100 деревьев.

Вызовите regularize метод, чтобы попытаться найти деревья, которые можно удалить из ансамбля. По умолчанию, regularize исследует 10 значений лассо (Lambda) параметр расположен с интервалами экспоненциально.

ls = regularize(ls)
ls = 
  RegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
          NumObservations: 205
               NumTrained: 300
                   Method: 'LSBoost'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: [300x1 double]
       FitInfoDescription: {2x1 cell}
           Regularization: [1x1 struct]


  Properties, Methods

Regularization свойство более не не пусто.

Постройте среднеквадратическую ошибку (MSE) перезамены и количество учеников с ненулевыми весами против параметра лассо. Отдельно постройте значение в Lambda = 0. Используйте логарифмический масштаб потому что значения Lambda экспоненциально расположены с интервалами.

figure;
semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ...
    'bx-','Markersize',10);
line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ...
     ls.Regularization.ResubstitutionMSE(1)],...
    'Marker','x','Markersize',10,'Color','b');
r0 = resubLoss(ls);
line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],...
     [r0 r0],'Color','r','LineStyle','--');
xlabel('Lambda');
ylabel('Resubstitution MSE');
annotation('textbox',[0.5 0.22 0.5 0.05],'String','unregularized ensemble', ...
    'Color','r','FontSize',14,'LineStyle','none');

Figure contains an axes. The axes contains 3 objects of type line.

figure;
loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1));
line([1e-3 1e-3],...
    [sum(ls.Regularization.TrainedWeights(:,1)>0) ...
    sum(ls.Regularization.TrainedWeights(:,1)>0)],...
    'marker','x','markersize',10,'color','b');
line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],...
    [ls.NTrained ls.NTrained],...
    'color','r','LineStyle','--');
xlabel('Lambda');
ylabel('Number of learners');
annotation('textbox',[0.3 0.8 0.5 0.05],'String','unregularized ensemble',...
    'color','r','FontSize',14,'LineStyle','none');

Figure contains an axes. The axes contains 3 objects of type line.

Перезамена значения MSE, вероятно, будет чрезмерно оптимистична. Получить более надежные оценки ошибки, сопоставленной с различными значениями Lambda, крест подтверждает ансамбль, использующий cvshrink. Постройте получившуюся потерю перекрестной проверки (MSE) и количество учеников против Lambda.

rng(0,'Twister') % for reproducibility
[mse,nlearn] = cvshrink(ls,'Lambda',ls.Regularization.Lambda,'KFold',5);
Warning: Some folds do not have any trained weak learners.
figure;
semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ...
    'bx-','Markersize',10);
hold on;
semilogx(ls.Regularization.Lambda,mse,'ro-','Markersize',10);
hold off;
xlabel('Lambda');
ylabel('Mean squared error');
legend('resubstitution','cross-validation','Location','NW');
line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ...
     ls.Regularization.ResubstitutionMSE(1)],...
    'Marker','x','Markersize',10,'Color','b','HandleVisibility','off');
line([1e-3 1e-3],[mse(1) mse(1)],'Marker','o',...
    'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');

Figure contains an axes. The axes contains 2 objects of type line. These objects represent resubstitution, cross-validation.

figure;
loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1));
hold;
Current plot held
loglog(ls.Regularization.Lambda,nlearn,'r--');
hold off;
xlabel('Lambda');
ylabel('Number of learners');
legend('resubstitution','cross-validation','Location','NE');
line([1e-3 1e-3],...
    [sum(ls.Regularization.TrainedWeights(:,1)>0) ...
    sum(ls.Regularization.TrainedWeights(:,1)>0)],...
    'Marker','x','Markersize',10,'Color','b','HandleVisibility','off');
line([1e-3 1e-3],[nlearn(1) nlearn(1)],'marker','o',...
    'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');

Figure contains an axes. The axes contains 2 objects of type line. These objects represent resubstitution, cross-validation.

Исследование перекрестной подтвержденной ошибки показывает, что перекрестная проверка MSE является почти плоской для Lambda до немногим выше 1e-2.

Исследуйте ls.Regularization.Lambda найти самое высокое значение, которое дает MSE в плоской области (до немногим выше 1e-2).

jj = 1:length(ls.Regularization.Lambda);
[jj;ls.Regularization.Lambda]
ans = 2×10

    1.0000    2.0000    3.0000    4.0000    5.0000    6.0000    7.0000    8.0000    9.0000   10.0000
         0    0.0019    0.0045    0.0107    0.0254    0.0602    0.1428    0.3387    0.8033    1.9048

Элемент 5 из ls.Regularization.Lambda имеет значение 0.0254, самое большое в плоской области значений.

Уменьшайте размер ансамбля с помощью shrink метод. shrink возвращает компактный ансамбль без обучающих данных. Ошибка обобщения для нового компактного ансамбля была уже оценена перекрестной проверкой в mse(5).

cmp = shrink(ls,'weightcolumn',5)
cmp = 
  CompactRegressionEnsemble
           PredictorNames: {1x25 cell}
             ResponseName: 'Symboling'
    CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]
        ResponseTransform: 'none'
               NumTrained: 8


  Properties, Methods

Количество деревьев в новом ансамбле особенно уменьшало от 300 в ls.

Сравните размеры ансамблей.

sz(1) = whos('cmp'); sz(2) = whos('ls');
[sz(1).bytes sz(2).bytes]
ans = 1×2

       91209     3226892

Размер уменьшаемого ансамбля является частью размера оригинала. Обратите внимание на то, что ваши размеры ансамбля могут варьироваться в зависимости от вашей операционной системы.

Сравните MSE уменьшаемого ансамбля тому из исходного ансамбля.

figure;
plot(kfoldLoss(cv,'mode','cumulative'));
hold on
plot(cmp.NTrained,mse(5),'ro','MarkerSize',10);
xlabel('Number of trees');
ylabel('Cross-validated MSE');
legend('unregularized ensemble','regularized ensemble',...
    'Location','NE');
hold off

Figure contains an axes. The axes contains 2 objects of type line. These objects represent unregularized ensemble, regularized ensemble.

Уменьшаемый ансамбль дает низкую потерю при использовании многих меньше деревьев.

Смотрите также

| | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте