Регуляризация - это процесс выбора меньшего количества слабых учащихся для ансамбля таким образом, чтобы это не снижало прогнозирующую эффективность. В настоящее время вы можете регулировать регрессионые ансамбли. (Можно также упорядочить классификатор дискриминантного анализа в контексте без ансамбля; см. «Упорядочение классификатора дискриминантного анализа».)
The regularize
метод находит оптимальный набор весов учащихся αt которые минимизируют
Здесь
λ ≥ 0 является параметром, который вы предоставляете, называемым параметром lasso.
ht является слабым учеником в ансамбле, обученным наблюдениям за N с xn предикторов, yn откликов и wn весов.
g (f, y) = (f – y)2 является квадратичная невязка.
Ансамбль упорядочен на тех же (xn, yn, wn) данных, используемых для обучения, так
- ошибка повторной замещения ансамбля. Ошибка измеряется средней квадратичной невязкой (MSE).
Если вы используете λ = 0, regularize
находит слабые веса учащихся путем минимизации resubstitution MSE. Ансамбли, как правило, переучиваются. Другими словами, ошибка восстановления обычно меньше, чем истинная ошибка обобщения. Делая ошибку реституции еще меньше, вы, вероятно, сделаете точность ансамбля хуже, чем улучшить ее. С другой стороны, положительные значения λ приводят величину коэффициентов αt к 0. Это часто улучшает ошибку обобщения. Конечно, если вы выбираете λ слишком большие, все оптимальные коэффициенты равны 0, и ансамбль не имеет никакой точности. Обычно можно найти оптимальную область значений для λ, в котором точность регуляризованного ансамбля лучше или сопоставима с точностью полного ансамбля без регуляризации.
Приятной функцией регуляризации лассо является его способность приводить оптимизированные коэффициенты точно в 0. Если αt веса учащегося равен 0, этот ученик может быть исключен из регуляризованного ансамбля. В конце концов, вы получаете ансамбль с улучшенной точностью и меньшим количеством учащихся.
Этот пример использует данные для прогнозирования страхового риска автомобиля на основе его многочисленных атрибутов.
Загрузите imports-85
данные в рабочее пространство MATLAB.
load imports-85;
Проверьте описание данных, чтобы найти категориальные переменные и имена предикторов.
Description
Description = 9x79 char array
'1985 Auto Imports Database from the UCI repository '
'http://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.names'
'Variables have been reordered to place variables with numeric values (referred '
'to as "continuous" on the UCI site) to the left and categorical values to the '
'right. Specifically, variables 1:16 are: symboling, normalized-losses, '
'wheel-base, length, width, height, curb-weight, engine-size, bore, stroke, '
'compression-ratio, horsepower, peak-rpm, city-mpg, highway-mpg, and price. '
'Variables 17:26 are: make, fuel-type, aspiration, num-of-doors, body-style, '
'drive-wheels, engine-location, engine-type, num-of-cylinders, and fuel-system. '
Цель этого процесса состоит в том, чтобы предсказать «символизацию», первую переменную в данных, от других предикторов. «symboling» является целым числом от -3
(хороший страховой риск), чтобы 3
(низкий страховой риск). Можно использовать классификационный ансамбль, чтобы предсказать этот риск вместо регрессионного ансамбля. Когда у вас есть выбор между регрессией и классификацией, вы должны сначала попробовать регрессию.
Подготовим данные для ансамбля подборы кривой.
Y = X(:,1); X(:,1) = []; VarNames = {'normalized-losses' 'wheel-base' 'length' 'width' 'height' ... 'curb-weight' 'engine-size' 'bore' 'stroke' 'compression-ratio' ... 'horsepower' 'peak-rpm' 'city-mpg' 'highway-mpg' 'price' 'make' ... 'fuel-type' 'aspiration' 'num-of-doors' 'body-style' 'drive-wheels' ... 'engine-location' 'engine-type' 'num-of-cylinders' 'fuel-system'}; catidx = 16:25; % indices of categorical predictors
Создайте регрессионный ансамбль из данных с помощью 300 деревьев.
ls = fitrensemble(X,Y,'Method','LSBoost','NumLearningCycles',300, ... 'LearnRate',0.1,'PredictorNames',VarNames, ... 'ResponseName','Symboling','CategoricalPredictors',catidx)
ls = RegressionEnsemble PredictorNames: {1x25 cell} ResponseName: 'Symboling' CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25] ResponseTransform: 'none' NumObservations: 205 NumTrained: 300 Method: 'LSBoost' LearnerNames: {'Tree'} ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.' FitInfo: [300x1 double] FitInfoDescription: {2x1 cell} Regularization: [] Properties, Methods
Конечная линия, Regularization
, пуст ([]). Чтобы упорядочить ансамбль, необходимо использовать regularize
способ.
cv = crossval(ls,'KFold',5); figure; plot(kfoldLoss(cv,'Mode','Cumulative')); xlabel('Number of trees'); ylabel('Cross-validated MSE'); ylim([0.2,2])
Похоже, вы можете получить удовлетворительную эффективность от меньшего ансамбля, возможно, который содержит от 50 до 100 деревьев.
Вызовите regularize
метод, чтобы попытаться найти деревья, которые можно удалить из ансамбля. По умолчанию regularize
исследует 10 значений лассо (Lambda
) параметр разнесен в геометрической прогрессии.
ls = regularize(ls)
ls = RegressionEnsemble PredictorNames: {1x25 cell} ResponseName: 'Symboling' CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25] ResponseTransform: 'none' NumObservations: 205 NumTrained: 300 Method: 'LSBoost' LearnerNames: {'Tree'} ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.' FitInfo: [300x1 double] FitInfoDescription: {2x1 cell} Regularization: [1x1 struct] Properties, Methods
The Regularization
свойство больше не пустое.
Постройте график среднеквадратичной ошибки (MSE) и количества учащихся с ненулевыми весами относительно параметра lasso. Отдельно постройте график значения в Lambda = 0
. Используйте логарифмическую шкалу, потому что значения Lambda
разнесены в геометрической прогрессии.
figure; semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ... 'bx-','Markersize',10); line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ... ls.Regularization.ResubstitutionMSE(1)],... 'Marker','x','Markersize',10,'Color','b'); r0 = resubLoss(ls); line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],... [r0 r0],'Color','r','LineStyle','--'); xlabel('Lambda'); ylabel('Resubstitution MSE'); annotation('textbox',[0.5 0.22 0.5 0.05],'String','unregularized ensemble', ... 'Color','r','FontSize',14,'LineStyle','none');
figure; loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1)); line([1e-3 1e-3],... [sum(ls.Regularization.TrainedWeights(:,1)>0) ... sum(ls.Regularization.TrainedWeights(:,1)>0)],... 'marker','x','markersize',10,'color','b'); line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],... [ls.NTrained ls.NTrained],... 'color','r','LineStyle','--'); xlabel('Lambda'); ylabel('Number of learners'); annotation('textbox',[0.3 0.8 0.5 0.05],'String','unregularized ensemble',... 'color','r','FontSize',14,'LineStyle','none');
Значения MSE реституции, вероятно, будут слишком оптимистичными. Чтобы получить более достоверные оценки ошибки, связанной с различными значениями Lambda
, перекрестно проверьте ансамбль используя cvshrink
. Постройте график полученной потери перекрестной валидации (MSE) и количества учащихся против Lambda
.
rng(0,'Twister') % for reproducibility [mse,nlearn] = cvshrink(ls,'Lambda',ls.Regularization.Lambda,'KFold',5);
Warning: Some folds do not have any trained weak learners.
figure; semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ... 'bx-','Markersize',10); hold on; semilogx(ls.Regularization.Lambda,mse,'ro-','Markersize',10); hold off; xlabel('Lambda'); ylabel('Mean squared error'); legend('resubstitution','cross-validation','Location','NW'); line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ... ls.Regularization.ResubstitutionMSE(1)],... 'Marker','x','Markersize',10,'Color','b','HandleVisibility','off'); line([1e-3 1e-3],[mse(1) mse(1)],'Marker','o',... 'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');
figure; loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1)); hold;
Current plot held
loglog(ls.Regularization.Lambda,nlearn,'r--'); hold off; xlabel('Lambda'); ylabel('Number of learners'); legend('resubstitution','cross-validation','Location','NE'); line([1e-3 1e-3],... [sum(ls.Regularization.TrainedWeights(:,1)>0) ... sum(ls.Regularization.TrainedWeights(:,1)>0)],... 'Marker','x','Markersize',10,'Color','b','HandleVisibility','off'); line([1e-3 1e-3],[nlearn(1) nlearn(1)],'marker','o',... 'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');
Изучение перекрестной проверенной ошибки показывает, что MSE перекрестной валидации почти плоский для Lambda
до бит 1e-2
.
Исследуйте ls.Regularization.Lambda
чтобы найти наивысшее значение, которое дает MSE в плоской области (до немного выше 1e-2
).
jj = 1:length(ls.Regularization.Lambda); [jj;ls.Regularization.Lambda]
ans = 2×10
1.0000 2.0000 3.0000 4.0000 5.0000 6.0000 7.0000 8.0000 9.0000 10.0000
0 0.0019 0.0045 0.0107 0.0254 0.0602 0.1428 0.3387 0.8033 1.9048
Элементный 5
от ls.Regularization.Lambda
имеет значение 0.0254
, самый большой в плоской области значений.
Уменьшите размер ансамбля, используя shrink
способ. shrink
возвращает компактный ансамбль без обучающих данных. Ошибка обобщения для нового компактного ансамбля уже была оценена путем перекрестной валидации в mse(5)
.
cmp = shrink(ls,'weightcolumn',5)
cmp = CompactRegressionEnsemble PredictorNames: {1x25 cell} ResponseName: 'Symboling' CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25] ResponseTransform: 'none' NumTrained: 8 Properties, Methods
Число деревьев в новом ансамбле значительно сократилось с 300 в ls
.
Сравните размеры ансамблей.
sz(1) = whos('cmp'); sz(2) = whos('ls'); [sz(1).bytes sz(2).bytes]
ans = 1×2
91209 3226892
Размер редуцированного ансамбля составляет часть размера оригинала. Обратите внимание, что размеры ансамбля могут варьироваться в зависимости от операционной системы.
Сравните MSE редуцированного ансамбля с MSE оригинального ансамбля.
figure; plot(kfoldLoss(cv,'mode','cumulative')); hold on plot(cmp.NTrained,mse(5),'ro','MarkerSize',10); xlabel('Number of trees'); ylabel('Cross-validated MSE'); legend('unregularized ensemble','regularized ensemble',... 'Location','NE'); hold off
Уменьшенный ансамбль дает низкие потери при использовании гораздо меньшего количества деревьев.
crossval
| cvshrink
| fitrensemble
| kfoldLoss
| regularize
| resubLoss
| shrink