Оптимизируйте подгонку классификатора Используя байесовую оптимизацию

В этом примере показано, как оптимизировать классификацию SVM с помощью fitcsvm функционируйте и OptimizeHyperparameters аргумент значения имени.

Сгенерируйте данные

Классификация работает над местоположениями точек от смешанной гауссовской модели. В Элементах Статистического Изучения, Hastie, Тибширэни и Фридмана (2009), страница 17 описывает модель. Модель начинается с генерации 10 базисных точек для "зеленого" класса, распределенного как 2D независимые нормали со средним значением (1,0) и модульное отклонение. Это также генерирует 10 базисных точек для "красного" класса, распределенного как 2D независимые нормали со средним значением (0,1) и модульное отклонение. Для каждого класса (зеленый и красный), сгенерируйте 100 случайных точек можно следующим образом:

  1. Выберите базисную точку m соответствующего цвета однородно наугад.

  2. Сгенерируйте независимую случайную точку с 2D нормальным распределением со средним значением m и отклонением I/5, где я - единичная матрица 2 на 2. В этом примере используйте отклонение I/50, чтобы показать преимущество оптимизации более ясно.

Сгенерируйте эти 10 базисных точек для каждого класса.

rng('default') % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

Просмотрите базисные точки.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line.

Поскольку некоторые красные базисные точки близко к зеленым базисным точкам, это может затруднить, чтобы классифицировать точки данных на основе одного только местоположения.

Сгенерируйте 100 точек данных каждого класса.

redpts = zeros(100,2);
grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

Просмотрите точки данных.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line.

Подготовка данных для классификации

Поместите данные в одну матрицу и сделайте векторный grp это помечает класс каждой точки. 1 указывает, что зеленый класс, и –1 указывает на красный класс.

cdata = [grnpts;redpts];
grp = ones(200,1);
grp(101:200) = -1;

Подготовьте перекрестную проверку

Настройте раздел для перекрестной проверки.

c = cvpartition(200,'KFold',10);

Этот шаг является дополнительным. Если вы задаете раздел для оптимизации, то можно вычислить фактическую потерю перекрестной проверки для возвращенной модели.

Оптимизируйте подгонку

Найти хорошую подгонку, означая один оптимальными гиперпараметрами, которые минимизируют потерю перекрестной проверки, Байесовую оптимизацию использования. Задайте список гиперпараметров, чтобы оптимизировать при помощи OptimizeHyperparameters аргумент значения имени, и задает опции оптимизации при помощи HyperparameterOptimizationOptions аргумент значения имени.

Задайте 'OptimizeHyperparameters' как 'auto'. 'auto' опция включает типичный набор гиперпараметров, чтобы оптимизировать. fitcsvm находит оптимальные значения BoxConstraint и KernelScale. Установите опции гипероптимизации параметров управления использовать раздел перекрестной проверки c и выбрать 'expected-improvement-plus' функция захвата для воспроизводимости. Функция захвата по умолчанию зависит от времени выполнения и, поэтому, может дать различные результаты.

opts = struct('CVPartition',c,'AcquisitionFunctionName','expected-improvement-plus');
Mdl = fitcsvm(cdata,grp,'KernelFunction','rbf', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |       0.345 |     0.22192 |       0.345 |       0.345 |      0.00474 |       306.44 |
|    2 | Best   |       0.115 |     0.16686 |       0.115 |     0.12678 |       430.31 |       1.4864 |
|    3 | Accept |        0.52 |     0.14285 |       0.115 |      0.1152 |     0.028415 |     0.014369 |
|    4 | Accept |        0.61 |     0.14104 |       0.115 |     0.11504 |       133.94 |    0.0031427 |
|    5 | Accept |        0.34 |     0.14998 |       0.115 |     0.11504 |     0.010993 |       5.7742 |
|    6 | Best   |       0.085 |     0.14041 |       0.085 |    0.085039 |       885.63 |      0.68403 |
|    7 | Accept |       0.105 |     0.13736 |       0.085 |    0.085428 |       0.3057 |      0.58118 |
|    8 | Accept |        0.21 |     0.13978 |       0.085 |     0.09566 |      0.16044 |      0.91824 |
|    9 | Accept |       0.085 |     0.16771 |       0.085 |     0.08725 |       972.19 |      0.46259 |
|   10 | Accept |         0.1 |     0.15169 |       0.085 |    0.090952 |       990.29 |        0.491 |
|   11 | Best   |        0.08 |     0.13653 |        0.08 |    0.079362 |       2.5195 |        0.291 |
|   12 | Accept |        0.09 |     0.12055 |        0.08 |     0.08402 |       14.338 |      0.44386 |
|   13 | Accept |         0.1 |     0.12305 |        0.08 |     0.08508 |    0.0022577 |      0.23803 |
|   14 | Accept |        0.11 |     0.12852 |        0.08 |    0.087378 |       0.2115 |      0.32109 |
|   15 | Best   |        0.07 |      0.1381 |        0.07 |    0.081507 |        910.2 |      0.25218 |
|   16 | Best   |       0.065 |      0.1715 |       0.065 |    0.072457 |       953.22 |      0.26253 |
|   17 | Accept |       0.075 |     0.18371 |       0.065 |    0.072554 |       998.74 |      0.23087 |
|   18 | Accept |       0.295 |     0.14336 |       0.065 |    0.072647 |       996.18 |       44.626 |
|   19 | Accept |        0.07 |     0.15007 |       0.065 |     0.06946 |       985.37 |      0.27389 |
|   20 | Accept |       0.165 |     0.13721 |       0.065 |    0.071622 |     0.065103 |      0.13679 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |       0.345 |     0.12409 |       0.065 |    0.071764 |        971.7 |       999.01 |
|   22 | Accept |        0.61 |     0.12674 |       0.065 |    0.071967 |    0.0010168 |    0.0010005 |
|   23 | Accept |       0.345 |     0.12977 |       0.065 |    0.071959 |    0.0011459 |       995.89 |
|   24 | Accept |        0.35 |     0.12457 |       0.065 |    0.071863 |    0.0010003 |       40.628 |
|   25 | Accept |        0.24 |     0.18461 |       0.065 |    0.072124 |       996.55 |       10.423 |
|   26 | Accept |        0.61 |     0.14087 |       0.065 |    0.072067 |       994.71 |    0.0010063 |
|   27 | Accept |        0.47 |     0.13224 |       0.065 |     0.07218 |       993.69 |     0.029723 |
|   28 | Accept |         0.3 |     0.12408 |       0.065 |    0.072291 |       993.15 |       170.01 |
|   29 | Accept |        0.16 |     0.27101 |       0.065 |    0.072103 |       992.81 |       3.8594 |
|   30 | Accept |       0.365 |     0.12737 |       0.065 |    0.072112 |    0.0010017 |     0.044287 |

Figure contains an axes object. The axes object with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

Figure contains an axes object. The axes object with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 24.9814 seconds
Total objective function evaluation time: 4.4776

Best observed feasible point:
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

Observed objective function value = 0.065
Estimated objective function value = 0.073726
Function evaluation time = 0.1715

Best estimated feasible point (according to models):
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

Estimated objective function value = 0.072112
Estimated function evaluation time = 0.15739
Mdl = 
  ClassificationSVM
                         ResponseName: 'Y'
                CategoricalPredictors: []
                           ClassNames: [-1 1]
                       ScoreTransform: 'none'
                      NumObservations: 200
    HyperparameterOptimizationResults: [1x1 BayesianOptimization]
                                Alpha: [77x1 double]
                                 Bias: -0.2352
                     KernelParameters: [1x1 struct]
                       BoxConstraints: [200x1 double]
                      ConvergenceInfo: [1x1 struct]
                      IsSupportVector: [200x1 logical]
                               Solver: 'SMO'


  Properties, Methods

fitcsvm возвращает ClassificationSVM объект модели, который использует лучшую предполагаемую допустимую точку. Лучшая предполагаемая допустимая точка является набором гиперпараметров, который минимизирует верхнюю доверительную границу потери перекрестной проверки на основе базовой Гауссовой модели процесса Байесового процесса оптимизации.

Байесов процесс оптимизации внутренне обеспечивает Гауссову модель процесса целевой функции. Целевая функция является перекрестным подтвержденным misclassification уровнем для классификации. Для каждой итерации процесс оптимизации обновляет Гауссову модель процесса и использует модель, чтобы найти новый набор гиперпараметров. Каждая линия итеративного отображения показывает новый набор гиперпараметров и этих значений столбцов:

  • Objective — Значение целевой функции вычисляется в новом наборе гиперпараметров.

  • Objective runtime — Время оценки целевой функции.

  • Eval result — Отчет результата в виде Accept, Best, или Error. Accept указывает, что целевая функция возвращает конечное значение и Error указывает, что целевая функция возвращает значение, которое не является конечным действительным скаляром. Best указывает, что целевая функция возвращает конечное значение, которое ниже, чем ранее вычисленные значения целевой функции.

  • BestSoFar(observed) — Минимальное значение целевой функции вычисляется до сих пор. Это значение является любой значением целевой функции текущей итерации (если Eval result значением для текущей итерации является Best) или значение предыдущего Best итерация.

  • BestSoFar(estim.) — В каждой итерации программное обеспечение оценивает верхние доверительные границы значений целевой функции, с помощью обновленной Гауссовой модели процесса, во всех наборах гиперпараметров, которые попробовали до сих пор. Затем программное обеспечение выбирает точку с минимальной верхней доверительной границей. BestSoFar(estim.) значение является значением целевой функции, возвращенным predictObjective функция в минимальной точке.

График ниже итеративного отображения показывает BestSoFar(observed) и BestSoFar(estim.) значения синего и зеленого цвета, соответственно.

Возвращенный объект Mdl использует лучшую предполагаемую допустимую точку, то есть, набор гиперпараметров, который производит BestSoFar(estim.) значение в итоговой итерации на основе итоговой Гауссовой модели процесса.

Можно получить лучшую точку из HyperparameterOptimizationResults свойство или при помощи bestPoint функция.

Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

[x,CriterionValue,iteration] = bestPoint(Mdl.HyperparameterOptimizationResults)
x=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

CriterionValue = 0.0888
iteration = 19

По умолчанию, bestPoint функционируйте использует 'min-visited-upper-confidence-interval' критерий. Этот критерий выбирает гиперпараметры, полученные из 19-й итерации как лучшая точка. CriterionValue верхняя граница перекрестной подтвержденной потери, вычисленной итоговой Гауссовой моделью процесса. Вычислите фактическую перекрестную подтвержденную потерю при помощи раздела c.

L_MinEstimated = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf', ...
    'BoxConstraint',x.BoxConstraint,'KernelScale',x.KernelScale))
L_MinEstimated = 0.0700

Фактическая перекрестная подтвержденная потеря близко к ориентировочной стоимости. Estimated objective function value отображен ниже графиков результатов оптимизации.

Можно также извлечь лучшую наблюдаемую допустимую точку (то есть, последний Best укажите в итеративном отображении) от HyperparameterOptimizationResults свойство или путем определения Criterion как 'min-observed'.

Mdl.HyperparameterOptimizationResults.XAtMinObjective
ans=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

[x_observed,CriterionValue_observed,iteration_observed] = bestPoint(Mdl.HyperparameterOptimizationResults,'Criterion','min-observed')
x_observed=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

CriterionValue_observed = 0.0650
iteration_observed = 16

'min-observed' критерий выбирает гиперпараметры, полученные из 16-й итерации как лучшая точка. CriterionValue_observed вычисленное использование фактической перекрестной подтвержденной потери выбранных гиперпараметров. Для получения дополнительной информации смотрите аргумент значения имени Критерия bestPoint.

Визуализируйте оптимизированный классификатор.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(Mdl,xGrid);

figure
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(Mdl.IsSupportVector,1), ...
    cdata(Mdl.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');

Figure contains an axes object. The axes object contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

Оцените точность на новых данных

Сгенерируйте и классифицируйте новые точки тестовых данных.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1); % green = 1
grpData(11:20) = -1; % red = -1

v = predict(Mdl,newData);

Вычислите misclassification уровни на наборе тестовых данных.

L_Test = loss(Mdl,newData,grpData)
L_Test = 0.3500

Определите, какие новые точки данных классифицируются правильно. Отформатируйте правильно классифицированные точки в красных квадратах и неправильно классифицированные точки в черных квадратах.

h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**');

mydiff = (v == grpData); % Classified correctly

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','Support Vectors', ...
    '-1 (classified)','+1 (classified)', ...
    'Correctly Classified','Misclassified'}, ...
    'Location','Southeast');
hold off

Figure contains an axes object. The axes object contains 8 objects of type line, contour. These objects represent -1 (training), +1 (training), Support Vectors, -1 (classified), +1 (classified), Correctly Classified, Misclassified.

Смотрите также

|

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте