Регуляризация является процессом выбора меньшего количества слабых учеников для ансамбля способом, который не уменьшает прогнозирующую эффективность. В настоящее время можно упорядочить ансамбли регрессии. (Можно также упорядочить классификатор дискриминантного анализа в контексте неансамбля; смотрите Упорядочивают Классификатор Дискриминантного анализа.)
regularize
метод находит оптимальный набор весов ученика αt, которые минимизируют
Здесь
λ ≥ 0 является параметром, который вы обеспечиваете, названный параметром лассо.
ht является слабым учеником в ансамбле, обученном на наблюдениях N с предикторами xn, ответы yn и веса wn.
g (f, y) = (f – y) 2 является квадратичной невязкой.
Ансамбль упорядочен на том же самом (xn, yn, wn) данные, используемые для обучения, таким образом,
ошибка перезамены ансамбля. Ошибка измеряется среднеквадратической ошибкой (MSE).
Если вы используете λ = 0, regularize
находит слабые веса ученика путем минимизации перезамены MSE. Ансамбли склонны перетренироваться. Другими словами, ошибка перезамены обычно меньше, чем истинная ошибка обобщения. Путем совершения еще меньшей ошибки перезамены вы, вероятно, сделаете точность ансамбля хуже вместо того, чтобы улучшить его. С другой стороны, положительные значения λ продвигают величину коэффициентов αt к 0. Это часто улучшает ошибку обобщения. Конечно, если вы выбираете слишком большой λ, все оптимальные коэффициенты 0, и у ансамбля нет точности. Обычно можно найти оптимальную область значений для λ, в котором точность упорядоченного ансамбля лучше или сопоставима с тем из полного ансамбля без регуляризации.
Хорошей функцией регуляризации лассо является своя способность управлять оптимизированными коэффициентами точно к 0. Если вес ученика, который αt 0, этот ученик, может быть исключен из упорядоченного ансамбля. В конце вы получаете ансамбль с улучшенной точностью и меньшим количеством учеников.
Этот пример использует данные для предсказания страхового риска автомобиля на основе его многих атрибутов.
Загрузите imports-85
данные в рабочее пространство MATLAB.
load imports-85;
Посмотрите на описание данных, чтобы найти категориальные переменные и имена предиктора.
Description
Description = 9x79 char array
'1985 Auto Imports Database from the UCI repository '
'http://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.names'
'Variables have been reordered to place variables with numeric values (referred '
'to as "continuous" on the UCI site) to the left and categorical values to the '
'right. Specifically, variables 1:16 are: symboling, normalized-losses, '
'wheel-base, length, width, height, curb-weight, engine-size, bore, stroke, '
'compression-ratio, horsepower, peak-rpm, city-mpg, highway-mpg, and price. '
'Variables 17:26 are: make, fuel-type, aspiration, num-of-doors, body-style, '
'drive-wheels, engine-location, engine-type, num-of-cylinders, and fuel-system. '
Цель этого процесса состоит в том, чтобы предсказать "symboling", первую переменную в данных, от других предикторов. "symboling" является целым числом от -3
(хороший страховой риск) к 3
(плохой страховой риск). Вы могли использовать ансамбль классификации, чтобы предсказать этот риск вместо ансамбля регрессии. Когда у вас есть выбор между регрессией и классификацией, необходимо попробовать регрессию сначала.
Подготовьте данные к подбору кривой ансамбля.
Y = X(:,1); X(:,1) = []; VarNames = {'normalized-losses' 'wheel-base' 'length' 'width' 'height' ... 'curb-weight' 'engine-size' 'bore' 'stroke' 'compression-ratio' ... 'horsepower' 'peak-rpm' 'city-mpg' 'highway-mpg' 'price' 'make' ... 'fuel-type' 'aspiration' 'num-of-doors' 'body-style' 'drive-wheels' ... 'engine-location' 'engine-type' 'num-of-cylinders' 'fuel-system'}; catidx = 16:25; % indices of categorical predictors
Создайте ансамбль регрессии из данных с помощью 300 деревьев.
ls = fitrensemble(X,Y,'Method','LSBoost','NumLearningCycles',300, ... 'LearnRate',0.1,'PredictorNames',VarNames, ... 'ResponseName','Symboling','CategoricalPredictors',catidx)
ls = RegressionEnsemble PredictorNames: {1x25 cell} ResponseName: 'Symboling' CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25] ResponseTransform: 'none' NumObservations: 205 NumTrained: 300 Method: 'LSBoost' LearnerNames: {'Tree'} ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.' FitInfo: [300x1 double] FitInfoDescription: {2x1 cell} Regularization: [] Properties, Methods
Итоговая линия, Regularization
isempty. Чтобы упорядочить ансамбль, необходимо использовать regularize
метод.
cv = crossval(ls,'KFold',5); figure; plot(kfoldLoss(cv,'Mode','Cumulative')); xlabel('Number of trees'); ylabel('Cross-validated MSE'); ylim([0.2,2])
Кажется, что вы можете получить удовлетворительную эффективность из меньшего ансамбля, возможно, один содержащий от 50 до 100 деревьев.
Вызовите regularize
метод, чтобы попытаться найти деревья, которые можно удалить из ансамбля. По умолчанию, regularize
исследует 10 значений лассо (Lambda
) параметр расположен с интервалами экспоненциально.
ls = regularize(ls)
ls = RegressionEnsemble PredictorNames: {1x25 cell} ResponseName: 'Symboling' CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25] ResponseTransform: 'none' NumObservations: 205 NumTrained: 300 Method: 'LSBoost' LearnerNames: {'Tree'} ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.' FitInfo: [300x1 double] FitInfoDescription: {2x1 cell} Regularization: [1x1 struct] Properties, Methods
Regularization
свойство более не не пусто.
Постройте среднеквадратическую ошибку (MSE) перезамены и количество учеников с ненулевыми весами против параметра лассо. Отдельно постройте значение в Lambda = 0
. Используйте логарифмический масштаб потому что значения Lambda
экспоненциально расположены с интервалами.
figure; semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ... 'bx-','Markersize',10); line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ... ls.Regularization.ResubstitutionMSE(1)],... 'Marker','x','Markersize',10,'Color','b'); r0 = resubLoss(ls); line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],... [r0 r0],'Color','r','LineStyle','--'); xlabel('Lambda'); ylabel('Resubstitution MSE'); annotation('textbox',[0.5 0.22 0.5 0.05],'String','unregularized ensemble', ... 'Color','r','FontSize',14,'LineStyle','none');
figure; loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1)); line([1e-3 1e-3],... [sum(ls.Regularization.TrainedWeights(:,1)>0) ... sum(ls.Regularization.TrainedWeights(:,1)>0)],... 'marker','x','markersize',10,'color','b'); line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],... [ls.NTrained ls.NTrained],... 'color','r','LineStyle','--'); xlabel('Lambda'); ylabel('Number of learners'); annotation('textbox',[0.3 0.8 0.5 0.05],'String','unregularized ensemble',... 'color','r','FontSize',14,'LineStyle','none');
Перезамена значения MSE, вероятно, будет чрезмерно оптимистична. Получить более надежные оценки ошибки, сопоставленной с различными значениями Lambda
, крест подтверждает ансамбль, использующий cvshrink
. Постройте получившуюся потерю перекрестной проверки (MSE) и количество учеников против Lambda
.
rng(0,'Twister') % for reproducibility [mse,nlearn] = cvshrink(ls,'Lambda',ls.Regularization.Lambda,'KFold',5);
Warning: Some folds do not have any trained weak learners.
figure; semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ... 'bx-','Markersize',10); hold on; semilogx(ls.Regularization.Lambda,mse,'ro-','Markersize',10); hold off; xlabel('Lambda'); ylabel('Mean squared error'); legend('resubstitution','cross-validation','Location','NW'); line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ... ls.Regularization.ResubstitutionMSE(1)],... 'Marker','x','Markersize',10,'Color','b','HandleVisibility','off'); line([1e-3 1e-3],[mse(1) mse(1)],'Marker','o',... 'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');
figure; loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1)); hold;
Current plot held
loglog(ls.Regularization.Lambda,nlearn,'r--'); hold off; xlabel('Lambda'); ylabel('Number of learners'); legend('resubstitution','cross-validation','Location','NE'); line([1e-3 1e-3],... [sum(ls.Regularization.TrainedWeights(:,1)>0) ... sum(ls.Regularization.TrainedWeights(:,1)>0)],... 'Marker','x','Markersize',10,'Color','b','HandleVisibility','off'); line([1e-3 1e-3],[nlearn(1) nlearn(1)],'marker','o',... 'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');
Исследование перекрестной подтвержденной ошибки показывает, что перекрестная проверка MSE является почти плоской для Lambda
до немногим выше 1e-2
.
Исследуйте ls.Regularization.Lambda
найти самое высокое значение, которое дает MSE в плоской области (до немногим выше 1e-2
).
jj = 1:length(ls.Regularization.Lambda); [jj;ls.Regularization.Lambda]
ans = 2×10
1.0000 2.0000 3.0000 4.0000 5.0000 6.0000 7.0000 8.0000 9.0000 10.0000
0 0.0019 0.0045 0.0107 0.0254 0.0602 0.1428 0.3387 0.8033 1.9048
Элемент 5
из ls.Regularization.Lambda
имеет значение 0.0254
, самое большое в плоской области значений.
Уменьшайте размер ансамбля с помощью shrink
метод. shrink
возвращает компактный ансамбль без обучающих данных. Ошибка обобщения для нового компактного ансамбля была уже оценена перекрестной проверкой в mse(5)
.
cmp = shrink(ls,'weightcolumn',5)
cmp = CompactRegressionEnsemble PredictorNames: {1x25 cell} ResponseName: 'Symboling' CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25] ResponseTransform: 'none' NumTrained: 8 Properties, Methods
Количество деревьев в новом ансамбле особенно уменьшало от 300 в ls
.
Сравните размеры ансамблей.
sz(1) = whos('cmp'); sz(2) = whos('ls'); [sz(1).bytes sz(2).bytes]
ans = 1×2
91001 3225515
Размер уменьшаемого ансамбля является частью размера оригинала. Обратите внимание на то, что ваши размеры ансамбля могут варьироваться в зависимости от вашей операционной системы.
Сравните MSE уменьшаемого ансамбля тому из исходного ансамбля.
figure; plot(kfoldLoss(cv,'mode','cumulative')); hold on plot(cmp.NTrained,mse(5),'ro','MarkerSize',10); xlabel('Number of trees'); ylabel('Cross-validated MSE'); legend('unregularized ensemble','regularized ensemble',... 'Location','NE'); hold off
Уменьшаемый ансамбль дает низкую потерю при использовании многих меньше деревьев.
crossval
| cvshrink
| fitrensemble
| kfoldLoss
| regularize
| resubLoss
| shrink