Оптимизируйте классификатор SVM Подгонки используя байесовскую оптимизацию

В этом примере показано, как оптимизировать классификацию SVM с помощью fitcsvm функции и OptimizeHyperparameters Пара "имя-значение". Классификация работает с местоположениями точек из смешанной гауссовской модели. В The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009), страница 17 описывает модель. Модель начинается с генерации 10 базовых точек для «зеленого» класса, распределенных как 2-D независимых нормалей со средним значением (1,0) и единичным отклонением. Это также генерирует 10 базовых точек для «красного» класса, распределенных как 2-D независимых нормалей со средней (0,1) и единичным отклонением. Для каждого класса (зеленого и красного) сгенерируйте 100 случайных точек следующим образом:

  1. Выберите базовую точку m соответствующего цвета случайным образом.

  2. Сгенерируйте независимую случайную точку с 2-D нормальным распределением со средним m и I/5 отклонения, где I является единичной матрицей 2 на 2. В этом примере используйте I/50 отклонений, чтобы более четко показать преимущество оптимизации.

Сгенерируйте точки и классификатор

Сгенерируйте 10 базовых точек для каждого класса.

rng default % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

Просмотрите базовые точки.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

Поскольку некоторые красные базовые точки близки к зеленым базовым точкам, может быть трудно классифицировать точки данных на основе одного местоположения.

Сгенерируйте 100 точек данных каждого класса.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

Просмотрите точки данных.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

Подготовка данных к классификации

Поместите данные в одну матрицу и сделайте вектор grp который помечает класс каждой точки.

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

Подготовка перекрестной проверки

Настройте раздел для перекрестной проверки. Этот шаг исправляет train и тестовые наборы, которые оптимизация использует на каждом шаге.

c = cvpartition(200,'KFold',10);

Оптимизация подгонки

Чтобы найти хорошую подгонку, то есть с низкой потерей перекрестной валидации, установите опции, чтобы использовать байесовскую оптимизацию. Используйте тот же раздел перекрестной проверки c во всех оптимизациях.

Для воспроизводимости используйте 'expected-improvement-plus' функция сбора.

opts = struct('Optimizer','bayesopt','ShowPlots',true,'CVPartition',c,...
    'AcquisitionFunctionName','expected-improvement-plus');
svmmod = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |       0.345 |     0.20756 |       0.345 |       0.345 |      0.00474 |       306.44 |
|    2 | Best   |       0.115 |     0.14872 |       0.115 |     0.12678 |       430.31 |       1.4864 |
|    3 | Accept |        0.52 |     0.12556 |       0.115 |      0.1152 |     0.028415 |     0.014369 |
|    4 | Accept |        0.61 |     0.16264 |       0.115 |     0.11504 |       133.94 |    0.0031427 |
|    5 | Accept |        0.34 |     0.14384 |       0.115 |     0.11504 |     0.010993 |       5.7742 |
|    6 | Best   |       0.085 |     0.15049 |       0.085 |    0.085039 |       885.63 |      0.68403 |
|    7 | Accept |       0.105 |     0.12919 |       0.085 |    0.085428 |       0.3057 |      0.58118 |
|    8 | Accept |        0.21 |     0.15841 |       0.085 |     0.09566 |      0.16044 |      0.91824 |
|    9 | Accept |       0.085 |     0.18577 |       0.085 |     0.08725 |       972.19 |      0.46259 |
|   10 | Accept |         0.1 |      0.1428 |       0.085 |    0.090952 |       990.29 |        0.491 |
|   11 | Best   |        0.08 |     0.14182 |        0.08 |    0.079362 |       2.5195 |        0.291 |
|   12 | Accept |        0.09 |     0.12232 |        0.08 |     0.08402 |       14.338 |      0.44386 |
|   13 | Accept |         0.1 |     0.13074 |        0.08 |     0.08508 |    0.0022577 |      0.23803 |
|   14 | Accept |        0.11 |     0.15858 |        0.08 |    0.087378 |       0.2115 |      0.32109 |
|   15 | Best   |        0.07 |     0.15241 |        0.07 |    0.081507 |        910.2 |      0.25218 |
|   16 | Best   |       0.065 |     0.15047 |       0.065 |    0.072457 |       953.22 |      0.26253 |
|   17 | Accept |       0.075 |     0.15512 |       0.065 |    0.072554 |       998.74 |      0.23087 |
|   18 | Accept |       0.295 |     0.14554 |       0.065 |    0.072647 |       996.18 |       44.626 |
|   19 | Accept |        0.07 |     0.22102 |       0.065 |     0.06946 |       985.37 |      0.27389 |
|   20 | Accept |       0.165 |     0.14178 |       0.065 |    0.071622 |     0.065103 |      0.13679 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |       0.345 |     0.12463 |       0.065 |    0.071764 |        971.7 |       999.01 |
|   22 | Accept |        0.61 |     0.17579 |       0.065 |    0.071967 |    0.0010168 |    0.0010005 |
|   23 | Accept |       0.345 |      0.1675 |       0.065 |    0.071959 |    0.0010674 |       999.18 |
|   24 | Accept |        0.35 |     0.13478 |       0.065 |    0.071863 |    0.0010003 |       40.628 |
|   25 | Accept |        0.24 |     0.23822 |       0.065 |    0.072124 |       996.55 |       10.423 |
|   26 | Accept |        0.61 |     0.14478 |       0.065 |    0.072068 |       958.64 |    0.0010026 |
|   27 | Accept |        0.47 |     0.13262 |       0.065 |     0.07218 |       993.69 |     0.029723 |
|   28 | Accept |         0.3 |     0.15652 |       0.065 |    0.072291 |       993.15 |       170.01 |
|   29 | Accept |        0.16 |     0.31079 |       0.065 |    0.072104 |       992.81 |       3.8594 |
|   30 | Accept |       0.365 |      0.1375 |       0.065 |    0.072112 |    0.0010017 |     0.044287 |

Figure contains an axes. The axes with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

Figure contains an axes. The axes with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 48.9578 seconds
Total objective function evaluation time: 4.7979

Best observed feasible point:
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

Observed objective function value = 0.065
Estimated objective function value = 0.073726
Function evaluation time = 0.15047

Best estimated feasible point (according to models):
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

Estimated objective function value = 0.072112
Estimated function evaluation time = 0.16248
svmmod = 
  ClassificationSVM
                         ResponseName: 'Y'
                CategoricalPredictors: []
                           ClassNames: [-1 1]
                       ScoreTransform: 'none'
                      NumObservations: 200
    HyperparameterOptimizationResults: [1x1 BayesianOptimization]
                                Alpha: [77x1 double]
                                 Bias: -0.2352
                     KernelParameters: [1x1 struct]
                       BoxConstraints: [200x1 double]
                      ConvergenceInfo: [1x1 struct]
                      IsSupportVector: [200x1 logical]
                               Solver: 'SMO'


  Properties, Methods

Найдите потерю оптимизированной модели.

lossnew = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf',...
    'BoxConstraint',svmmod.HyperparameterOptimizationResults.XAtMinObjective.BoxConstraint,...
    'KernelScale',svmmod.HyperparameterOptimizationResults.XAtMinObjective.KernelScale))
lossnew = 0.0650

Эта потеря аналогична потере, сообщенной в выходе оптимизации под «Наблюдаемое значение целевой функции».

Визуализируйте оптимизированный классификатор.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(svmmod,xGrid);
figure;
h = nan(3,1); % Preallocation
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(svmmod.IsSupportVector,1),...
    cdata(svmmod.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

См. также

|

Похожие темы