Обучите обобщенную аддитивную модель регрессии

В этом примере показано, как обучить Обобщенную аддитивную модель (GAM) Регрессии оптимальными параметрами и как оценить прогнозирующую эффективность обученной модели. Пример сначала находит оптимальные значения параметров для одномерного GAM (параметры для линейных членов) и затем находит значения для двумерного GAM (параметры в течение многих периодов взаимодействия). Кроме того, пример объясняет, как интерпретировать обученную модель путем исследования локальных эффектов терминов на определенном предсказании и путем вычисления частичной зависимости предсказаний на предикторах.

Загрузка демонстрационных данных

Загрузите набор выборочных данных NYCHousing2015.

load NYCHousing2015

Набор данных включает 10 переменных с информацией о продажах свойств в Нью-Йорке в 2 015. Этот пример использует эти переменные, чтобы анализировать отпускные цены (SALEPRICE).

Предварительно обработайте набор данных. Примите что SALEPRICE меньше чем или равный 1 000$ указывает на передачу владения без суммы. Удалите выборки, которые имеют этот SALEPRICE. Кроме того, удалите выбросы, идентифицированные isoutlier функция. Затем преобразуйте datetime массив (SALEDATE) к числам месяца и перемещению переменная отклика (SALEPRICE) к последнему столбцу. Измените нули в LANDSQUAREFEET, GROSSSQUAREFEET, и YEARBUILT к NaNs.

idx1 = NYCHousing2015.SALEPRICE <= 1000;
idx2 = isoutlier(NYCHousing2015.SALEPRICE);
NYCHousing2015(idx1|idx2,:) = [];
NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE);
NYCHousing2015 = movevars(NYCHousing2015,'SALEPRICE','After','SALEDATE');
NYCHousing2015.LANDSQUAREFEET(NYCHousing2015.LANDSQUAREFEET == 0) = NaN; 
NYCHousing2015.GROSSSQUAREFEET(NYCHousing2015.GROSSSQUAREFEET == 0) = NaN; 
NYCHousing2015.YEARBUILT(NYCHousing2015.YEARBUILT == 0) = NaN; 

Отобразите первые три строки таблицы.

head(NYCHousing2015,3)
ans=3×10 table
    BOROUGH    NEIGHBORHOOD       BUILDINGCLASSCATEGORY        RESIDENTIALUNITS    COMMERCIALUNITS    LANDSQUAREFEET    GROSSSQUAREFEET    YEARBUILT    SALEDATE    SALEPRICE
    _______    ____________    ____________________________    ________________    _______________    ______________    _______________    _________    ________    _________

       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   0                1103              1290            1910          2           3e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   1                2500              2452            1910          7           4e+05 
       2       {'BATHGATE'}    {'01  ONE FAMILY DWELLINGS'}           1                   2                1911              4080            1931          1         5.1e+05 

Случайным образом выберите 1 000 выборок при помощи datasample функция и наблюдения раздела в набор обучающих данных и набор тестов при помощи cvpartition функция. Задайте 10%-ю выборку затяжки для тестирования.

rng('default') % For reproducibility
NumSamples = 1e3;
NYCHousing2015 = datasample(NYCHousing2015,NumSamples,'Replace',false);
cv = cvpartition(size(NYCHousing2015,1),'HoldOut',0.10);

Извлеките обучение и протестируйте индексы и составьте таблицы для наборов тестовых данных и обучения.

tbl_training = NYCHousing2015(training(cv),:);
tbl_test = NYCHousing2015(test(cv),:);

Обучите GAM оптимальными гиперпараметрами

Обучите GAM гиперпараметрами, которые минимизируют потерю перекрестной проверки при помощи аргумента значения имени OptimizeHyperparameters.

Можно задать OptimizeHyperparameters как 'auto' или 'all' найти оптимальные гиперзначения параметров и для одномерных и для двумерных параметров. В качестве альтернативы можно найти оптимальные значения для одномерных параметров с помощью 'auto-univariate' или 'all-univariate' опция, и затем находит оптимальные значения для двумерных параметров с помощью 'auto-bivariate' или 'all-bivariate' опция. Этот пример использует 'all-univariate' и 'all-bivariate'.

Обучите одномерный GAM. Задайте FitStandardDeviation как true подбирать модель для стандартного отклонения переменной отклика также. Методические рекомендации должны использовать оптимальные гиперпараметры, когда вы подбираете модель стандартного отклонения для точности оценок стандартного отклонения. Задайте OptimizeHyperparameters как 'all-univariate' так, чтобы fitrgam находит оптимальные значения InitialLearnRateForPredictors, MaxNumSplitsPerPredictor, и NumTreesPerPredictor аргументы name-value. Для воспроизводимости используйте 'expected-improvement-plus' функция захвата. Задайте ShowPlots как false и Verbose как 0, чтобы отключить график и индикаторы сообщения, соответственно.

Mdl_univariate = fitrgam(tbl_training,'SALEPRICE','FitStandardDeviation',true, ...
    'OptimizeHyperparameters','all-univariate', ...
    'HyperparameterOptimizationOptions',struct('AcquisitionFunctionName','expected-improvement-plus', ...
    'ShowPlots',false,'Verbose',0))
Mdl_univariate = 
  RegressionGAM
                       PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
                         ResponseName: 'SALEPRICE'
                CategoricalPredictors: [2 3]
                    ResponseTransform: 'none'
                            Intercept: 5.1868e+05
               IsStandardDeviationFit: 1
                      NumObservations: 900
    HyperparameterOptimizationResults: [1×1 BayesianOptimization]


  Properties, Methods

fitrgam возвращает RegressionGAM объект модели, который использует лучшую предполагаемую допустимую точку. Лучшая предполагаемая допустимая точка указывает на набор гиперпараметров, который минимизирует верхнюю доверительную границу значения целевой функции на основе базовой модели целевой функции Байесового процесса оптимизации. Можно получить лучшую точку из HyperparameterOptimizationResults свойство или при помощи bestPoint функция.

x = Mdl_univariate.HyperparameterOptimizationResults.XAtMinEstimatedObjective
x=1×3 table
    InitialLearnRateForPredictors    MaxNumSplitsPerPredictor    NumTreesPerPredictor
    _____________________________    ________________________    ____________________

              0.063687                          1                         61         

bestPoint(Mdl_univariate.HyperparameterOptimizationResults)
ans=1×3 table
    InitialLearnRateForPredictors    MaxNumSplitsPerPredictor    NumTreesPerPredictor
    _____________________________    ________________________    ____________________

              0.063687                          1                         61         

Для получения дополнительной информации о процессе оптимизации смотрите, Оптимизируют GAM Используя OptimizeHyperparameters.

Обучите двумерный GAM. Задайте OptimizeHyperparameters как 'all-bivariate' так, чтобы fitrgam находит оптимальные значения Interactions, InitialLearnRateForInteractions, MaxNumSplitsPerInteraction, и NumTreesPerInteraction аргументы name-value. Используйте одномерные значения параметров в x так, чтобы программное обеспечение нашло оптимальные значения параметров в течение многих периодов взаимодействия на основе x значений.

Mdl = fitrgam(tbl_training,'SALEPRICE','FitStandardDeviation',true, ...
    'InitialLearnRateForPredictors',x.InitialLearnRateForPredictors, ...
    'MaxNumSplitsPerPredictor',x.MaxNumSplitsPerPredictor, ...
    'NumTreesPerPredictor',x.NumTreesPerPredictor, ...
    'OptimizeHyperparameters','all-bivariate', ...
    'HyperparameterOptimizationOptions',struct('AcquisitionFunctionName','expected-improvement-plus', ...
    'ShowPlots',false,'Verbose',0))
Warning: Model does not include interaction terms because all interaction terms have p-values greater than the 'MaxPValue' value, or the software was unable to improve the model fit.
Mdl = 
  RegressionGAM
                       PredictorNames: {'BOROUGH'  'NEIGHBORHOOD'  'BUILDINGCLASSCATEGORY'  'RESIDENTIALUNITS'  'COMMERCIALUNITS'  'LANDSQUAREFEET'  'GROSSSQUAREFEET'  'YEARBUILT'  'SALEDATE'}
                         ResponseName: 'SALEPRICE'
                CategoricalPredictors: [2 3]
                    ResponseTransform: 'none'
                            Intercept: 5.1868e+05
               IsStandardDeviationFit: 1
                      NumObservations: 900
    HyperparameterOptimizationResults: [1×1 BayesianOptimization]


  Properties, Methods

Отобразите оптимальные двумерные гиперпараметры.

Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans=1×4 table
    Interactions    InitialLearnRateForInteractions    MaxNumSplitsPerInteraction    NumTreesPerInteraction
    ____________    _______________________________    __________________________    ______________________

         3                      0.85621                            1                           14          

Отображение модели Mdl показывает частичный список свойств модели. Чтобы просмотреть полный список свойств модели, дважды кликните имя переменной Mdl в Рабочей области. Редактор Переменных открывается для Mdl. В качестве альтернативы можно отобразить свойства в Командном окне при помощи записи через точку. Например, отобразите ReasonForTermination свойство.

Mdl.ReasonForTermination
ans = struct with fields:
      PredictorTrees: 'Terminated after training the requested number of trees.'
    InteractionTrees: 'Unable to improve the model fit.'

Можно использовать ReasonForTermination свойство определить, содержит ли обученная модель конкретное количество деревьев для каждого линейного члена и каждый период взаимодействия.

Отобразите периоды взаимодействия в Mdl.

Mdl.Interactions
ans =

  0×2 empty double matrix

Каждая строка Interactions представляет один период взаимодействия и содержит индексы столбца переменных предикторов в течение периода взаимодействия. Можно использовать Interactions свойство проверять периоды взаимодействия в модель и порядок, в который fitrgam добавляет их в модель.

Отобразите периоды взаимодействия в Mdl использование имен предиктора.

Mdl.PredictorNames(Mdl.Interactions)
ans =

  0×2 empty cell array

Оцените прогнозирующую эффективность на новых наблюдениях

Оцените эффективность обученной модели при помощи тестовой выборки tbl_test и объект функционирует predict и loss. Можно использовать полную или компактную модель с этими функциями.

  • predict — Предскажите ответы

  • loss — Вычислите потерю регрессии (среднеквадратическая ошибка, по умолчанию)

Если вы хотите оценить эффективность обучающего набора данных, используйте функции объекта перезамены: resubPredict и resubLoss. Чтобы использовать эти функции, необходимо использовать полную модель, которая содержит обучающие данные.

Создайте компактную модель, чтобы уменьшать размер обученной модели.

CMdl = compact(Mdl);
whos('Mdl','CMdl')
  Name      Size              Bytes  Class                                          Attributes

  CMdl      1x1              854780  classreg.learning.regr.CompactRegressionGAM              
  Mdl       1x1             1048879  RegressionGAM                                            

Сравните результаты, полученные включением и линейного к периодам взаимодействия и результатам, полученным включением только линейных членов.

Предскажите ответы и вычислите среднеквадратические ошибки для набора тестовых данных tbl_test.

[yFit,ySD,yInt] = predict(CMdl,tbl_test);
L = loss(CMdl,tbl_test)
L = 1.2531e+11

Найдите предсказанные ответы и ошибки без включения периодов взаимодействия в обученной модели.

[yFit_nointeraction,ySD_nointeraction,yInt__nointeraction] = predict(CMdl,tbl_test,'IncludeInteractions',false);
L_nointeractions = loss(CMdl,tbl_test,'IncludeInteractions',false)
L_nointeractions = 1.2531e+11

Модель достигает меньшей ошибки для набора тестовых данных, когда и линейные члены и периоды взаимодействия включены.

Постройте отсортированные истинные ответы вместе с предсказанными ответами и интервалами предсказания.

yTrue = tbl_test.SALEPRICE;
[sortedYTrue,I] = sort(yTrue);

figure
ax = nexttile;
plot(sortedYTrue,'o')
hold on
plot(yFit(I))
plot(yInt(I,1),'k:')
plot(yInt(I,2),'k:')
legend('True responses','Predicted responses', ...
    '95% Prediction interval limits','Location','best')
title('Linear and interaction terms')
hold off

nexttile
plot(sortedYTrue,'o')
hold on
plot(yFit_nointeraction(I))
plot(yInt__nointeraction(I,1),'k:')
plot(yInt__nointeraction(I,2),'k:')
ylim(ax.YLim)
title('Linear terms only')
hold off

Интервалы предсказания являются более узкими, когда и линейные члены и периоды взаимодействия включены.

Интерпретируйте предсказание

Интерпретируйте предсказание для первого тестового наблюдения при помощи plotLocalEffects функция. Кроме того, создайте частичные графики зависимости для некоторых важных членов в модели при помощи plotPartialDependence функция.

Предскажите значение отклика для первого наблюдения за тестовыми данными и постройте локальные эффекты терминов в CMdl на предсказании. Задайте 'IncludeIntercept',true включать термин точки пересечения в график.

yFit = predict(CMdl,tbl_test(1,:))
yFit = 5.1239e+05
figure
plotLocalEffects(CMdl,tbl_test(1,:),'IncludeIntercept',true)

predict функция возвращает отпускную цену за первое наблюдение tbl_test(1,:). plotLocalEffects функция создает горизонтальный столбчатый график, который показывает локальные эффекты терминов в CMdl на предсказании. Каждое локальное значение эффекта показывает вклад каждого термина к предсказанной отпускной цене за tbl_test(1,:).

Вычислите частичные значения зависимости для BUILDINGCLASSCATEGORY и постройте отсортированные значения. Задайте и обучение и наборы тестовых данных, чтобы вычислить частичные значения зависимости с помощью обоих наборов.

[pd,x,y] = partialDependence(CMdl,'BUILDINGCLASSCATEGORY',[tbl_training; tbl_test]);
[pd_sorted,I] = sort(pd);
x_sorted = x(I);
x_sorted = reordercats(x_sorted,I);
figure
plot(x_sorted,pd_sorted,'o:')
xlabel('BUILDINGCLASSCATEGORY')
ylabel('SALEPRICE')
title('Patial Dependence Plot')

Построенная линия представляет усредненные частичные отношения между предиктором BUILDINGCLASSCATEGORY и ответ SALEPRICE в обученной модели.

Создайте частичный график зависимости для терминов RESIDENTIALUNITS и LANDSQUAREFEET использование тестовых данных установлено.

figure
plotPartialDependence(CMdl,["RESIDENTIALUNITS","LANDSQUAREFEET"],tbl_test)

Незначительные метки деления в оси X (RESIDENTIALUNITS) и ось Y (LANDSQUAREFEET) представляйте уникальные значения предикторов в заданных данных. Значения предиктора включают несколько выбросов и большую часть RESIDENTIALUNITS и LANDSQUAREFEET значения меньше 5 и 5000, соответственно. График показывает что SALEPRICE значения значительно не варьируются когда RESIDENTIALUNITS значение больше 5.

Смотрите также

| | | | | |

Похожие темы