Оптимизируйте перекрестный подтвержденный классификатор SVM Используя bayesopt

Этот пример показывает, как оптимизировать классификацию SVM с помощью функции bayesopt. Классификация работает над местоположениями точек из Гауссовой модели смеси. В Элементах Статистического Изучения, Hastie, Тибширэни и Фридмана (2009), страница 17 описывает модель. Модель начинается с генерации 10 базисных точек для "зеленого" класса, распределенного как 2D независимые нормали со средним значением (1,0) и модульное отклонение. Это также генерирует 10 базисных точек для "красного" класса, распределенного как 2D независимые нормали со средним значением (0,1) и модульное отклонение. Для каждого класса (зеленый и красный), сгенерируйте 100 случайных точек можно следующим образом:

  1. Выберите базисную точку m соответствующего цвета однородно наугад.

  2. Сгенерируйте независимую случайную точку с 2D нормальным распределением со средним значением m и отклонением I/5, где я - единичная матрица 2 на 2. В этом примере используйте отклонение I/50, чтобы показать преимущество оптимизации более ясно.

После генерации 100 зеленых и 100 красных точек, классифицируйте их использующий fitcsvm. Затем используйте bayesopt, чтобы оптимизировать параметры получившейся модели SVM относительно перекрестной проверки.

Сгенерируйте точки и классификатор

Сгенерируйте эти 10 базисных точек для каждого класса.

rng default
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

Просмотрите базисные точки.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Поскольку некоторые красные базисные точки близко к зеленым базисным точкам, может быть трудно классифицировать точки данных на основе одного только местоположения.

Сгенерируйте 100 точек данных каждого класса.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

Просмотрите точки данных.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Подготовка данных для классификации

Поместите данные в одну матрицу и сделайте векторный grp, который маркирует класс каждой точки.

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

Подготовьте перекрестную проверку

Настройте раздел для перекрестной проверки. Этот шаг фиксирует составы и наборы тестов, которые оптимизация использует на каждом шаге.

c = cvpartition(200,'KFold',10);

Подготовьте переменные к байесовой оптимизации

Настройте функцию, которая берет вход z = [rbf_sigma,boxconstraint] и возвращает значение перекрестной проверки потерь z. Возьмите компоненты z как положительные, преобразованные в журнал переменные между 1e-5 и 1e5. Выберите широкий спектр, потому что вы не знаете, какие значения, вероятно, будут хороши.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

Целевая функция

Этот указатель на функцию вычисляет потерю перекрестной проверки в параметрах [sigma,box]. Для получения дополнительной информации смотрите kfoldLoss.

bayesopt передает переменную z целевой функции как таблица с одной строкой.

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

Оптимизируйте классификатор

Ищите лучшие параметры [sigma,box] с помощью bayesopt. Для воспроизводимости выберите функцию приобретения 'expected-improvement-plus'. Функция приобретения по умолчанию зависит от времени выполнения, и так может дать переменные результаты.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')

|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.22192 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.23607 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |     0.21265 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.17898 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |     0.17441 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |     0.17172 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |     0.14105 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |      2.3452 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |      2.7846 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |     0.15571 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |      1.7504 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |     0.28088 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |     0.16181 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.17606 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |     0.14088 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |     0.14291 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |     0.13851 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |     0.14844 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |      0.1417 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |     0.13378 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |     0.12395 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |     0.14245 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |     0.12691 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |     0.13728 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |     0.15283 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |     0.13978 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |     0.12669 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |      0.1319 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |      0.1478 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |     0.16981 |       0.075 |       0.075 |       3.5051 |       93.492 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 40.3981 seconds.
Total objective function evaluation time: 11.2371

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.14291

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.14286
results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: [function_handle]
              VariableDescriptions: [1x2 optimizableVariable]
                           Options: [1x1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1x2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1x2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 40.3981
                         NextPoint: [1x2 table]
                            XTrace: [30x2 table]
                    ObjectiveTrace: [30x1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30x1 cell}
      ObjectiveEvaluationTimeTrace: [30x1 double]
                IterationTimeTrace: [30x1 double]
                        ErrorTrace: [30x1 double]
                  FeasibilityTrace: [30x1 logical]
       FeasibilityProbabilityTrace: [30x1 double]
               IndexOfMinimumTrace: [30x1 double]
             ObjectiveMinimumTrace: [30x1 double]
    EstimatedObjectiveMinimumTrace: [30x1 double]

Используйте результаты обучить новое, оптимизировал классификатор SVM.

z(1) = results.XAtMinObjective.sigma;
z(2) = results.XAtMinObjective.box;
SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'KernelScale',z(1),'BoxConstraint',z(2));

Постройте контуры классификации. Чтобы визуализировать классификатор вектора поддержки, предскажите очки по сетке.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

h = nan(3,1); % Preallocation
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Оцените точность на новых данных

Сгенерируйте и классифицируйте некоторые новые точки данных.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

g = nan(7,1);
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors'},'Location','Southeast');
axis equal
hold off

Смотрите, какие новые точки данных правильно классифицируются. Окружите правильно классифицированные точки красного цвета и неправильно классифицированные точки черного цвета цвета.

mydiff = (v == grpData); % Classified correctly
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors','Correctly Classified',...
    'Misclassified'},'Location','Southeast');
hold off

Смотрите также

|

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте