Оптимизация классификатора SVM с перекрестной проверкой с помощью bayesopt

В этом примере показано, как оптимизировать классификацию SVM с помощью bayesopt функция. Классификация работает с местоположениями точек из смешанной гауссовской модели. В The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009), страница 17 описывает модель. Модель начинается с генерации 10 базовых точек для «зеленого» класса, распределенных как 2-D независимых нормалей со средним значением (1,0) и единичным отклонением. Это также генерирует 10 базовых точек для «красного» класса, распределенных как 2-D независимых нормалей со средней (0,1) и единичным отклонением. Для каждого класса (зеленого и красного) сгенерируйте 100 случайных точек следующим образом:

  1. Выберите базовую точку m соответствующего цвета случайным образом.

  2. Сгенерируйте независимую случайную точку с 2-D нормальным распределением со средним m и I/5 отклонения, где I является единичной матрицей 2 на 2. В этом примере используйте I/50 отклонений, чтобы более четко показать преимущество оптимизации.

После генерации 100 зеленых и 100 красных точек классифицируйте их с помощью fitcsvm. Затем используйте bayesopt оптимизировать параметры полученной модели SVM относительно перекрестной валидации.

Сгенерируйте точки и классификатор

Сгенерируйте 10 базовых точек для каждого класса.

rng default
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

Просмотрите базовые точки.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

Поскольку некоторые красные базовые точки близки к зеленым базовым точкам, может быть трудно классифицировать точки данных на основе одного местоположения.

Сгенерируйте 100 точек данных каждого класса.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

Просмотрите точки данных.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

Подготовка данных к классификации

Поместите данные в одну матрицу и сделайте вектор grp который помечает класс каждой точки.

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

Подготовка перекрестной проверки

Настройте раздел для перекрестной проверки. Этот шаг исправляет train и тестовые наборы, которые оптимизация использует на каждом шаге.

c = cvpartition(200,'KFold',10);

Подготовьте переменные для байесовской оптимизации

Настройте функцию, которая принимает вход z = [rbf_sigma,boxconstraint] и возвращает значение потерь перекрестной валидации z. Примите компоненты z как положительные, логарифмически преобразованные переменные между 1e-5 и 1e5. Выберите широкую область значений, потому что вы не знаете, какие значения могут быть хорошими.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

Целевая функция

Этот указатель на функцию вычисляет потери перекрестной валидации в параметрах [sigma,box]. Для получения дополнительной информации смотрите kfoldLoss.

bayesopt передает переменную z в целевую функцию как таблицу с одной строкой.

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

Оптимизируйте классификатор

Поиск оптимальных параметров [sigma,box] использование bayesopt. Для повторяемости выберите 'expected-improvement-plus' функция сбора. Функция сбора по умолчанию зависит от времени выполнения, и поэтому может давать различные результаты.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.29802 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.14014 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |     0.12459 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.20949 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |     0.16716 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |     0.12516 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |     0.13982 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |      1.6266 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |      2.3481 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |     0.19795 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |      1.1703 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |     0.23914 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |     0.16647 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.14722 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |     0.14553 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |      0.1345 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |     0.13435 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |     0.13044 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |     0.15401 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |     0.12961 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |     0.12742 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |     0.12987 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |     0.15997 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |     0.12983 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |     0.14454 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |     0.14993 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |     0.12991 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |     0.16825 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |      0.1884 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |     0.21229 |       0.075 |       0.075 |       3.5051 |       93.492 |

Figure contains an axes. The axes with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

Figure contains an axes. The axes with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 49.0545 seconds
Total objective function evaluation time: 9.4691

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.1345

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.1449
results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: [function_handle]
              VariableDescriptions: [1x2 optimizableVariable]
                           Options: [1x1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1x2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1x2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 49.0545
                         NextPoint: [1x2 table]
                            XTrace: [30x2 table]
                    ObjectiveTrace: [30x1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30x1 cell}
      ObjectiveEvaluationTimeTrace: [30x1 double]
                IterationTimeTrace: [30x1 double]
                        ErrorTrace: [30x1 double]
                  FeasibilityTrace: [30x1 logical]
       FeasibilityProbabilityTrace: [30x1 double]
               IndexOfMinimumTrace: [30x1 double]
             ObjectiveMinimumTrace: [30x1 double]
    EstimatedObjectiveMinimumTrace: [30x1 double]

Используйте результаты для обучения нового оптимизированного классификатора SVM.

z(1) = results.XAtMinObjective.sigma;
z(2) = results.XAtMinObjective.box;
SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'KernelScale',z(1),'BoxConstraint',z(2));

Постройте график контуров классификации. Чтобы визуализировать классификатор векторов поддержки, спрогнозируйте счета по сетке.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

h = nan(3,1); % Preallocation
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

Оцените точность новых данных

Сгенерируйте и классифицируйте некоторые новые точки данных.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

g = nan(7,1);
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 6 objects of type line, contour. These objects represent -1 (training), +1 (training), -1 (classified), +1 (classified), Support Vectors.

Проверьте, какие новые точки данных правильно классифицированы. Округлить правильно классифицированные точки красным цветом, и неправильно классифицированные точки черным цветом.

mydiff = (v == grpData); % Classified correctly
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors','Correctly Classified',...
    'Misclassified'},'Location','Southeast');
hold off

Figure contains an axes. The axes contains 8 objects of type line, contour. These objects represent -1 (training), +1 (training), -1 (classified), +1 (classified), Support Vectors, Correctly Classified, Misclassified.

См. также

|

Похожие темы