Оптимизируйте перекрестный подтвержденный классификатор SVM Используя bayesopt

В этом примере показано, как оптимизировать классификацию SVM с помощью bayesopt функция. Классификация работает над местоположениями точек от смешанной гауссовской модели. В Элементах Статистического Изучения, Hastie, Тибширэни и Фридмана (2009), страница 17 описывает модель. Модель начинается с генерации 10 базисных точек для "зеленого" класса, распределенного как 2D независимые нормали со средним значением (1,0) и модульное отклонение. Это также генерирует 10 базисных точек для "красного" класса, распределенного как 2D независимые нормали со средним значением (0,1) и модульное отклонение. Для каждого класса (зеленый и красный), сгенерируйте 100 случайных точек можно следующим образом:

  1. Выберите базисную точку m соответствующего цвета однородно наугад.

  2. Сгенерируйте независимую случайную точку с 2D нормальным распределением со средним значением m и отклонением I/5, где я - единичная матрица 2 на 2. В этом примере используйте отклонение I/50, чтобы показать преимущество оптимизации более ясно.

После генерации 100 зеленых и 100 красных точек, классифицируйте их использующий fitcsvm. Затем используйте bayesopt оптимизировать параметры получившейся модели SVM относительно перекрестной проверки.

Сгенерируйте точки и классификатор

Сгенерируйте эти 10 базисных точек для каждого класса.

rng default
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

Просмотрите базисные точки.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

Поскольку некоторые красные базисные точки близко к зеленым базисным точкам, это может затруднить, чтобы классифицировать точки данных на основе одного только местоположения.

Сгенерируйте 100 точек данных каждого класса.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

Просмотрите точки данных.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

Подготовка данных для классификации

Поместите данные в одну матрицу и сделайте векторный grp это помечает класс каждой точки.

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

Подготовьте перекрестную проверку

Настройте раздел для перекрестной проверки. Этот шаг фиксирует обучение и наборы тестов, которые оптимизация использует на каждом шаге.

c = cvpartition(200,'KFold',10);

Подготовьте переменные к байесовой оптимизации

Настройте функцию, которая берет вход z = [rbf_sigma,boxconstraint] и возвращает значение перекрестной проверки потерь z. Возьмите компоненты z как положительные, преобразованные в журнал переменные между 1e-5 и 1e5. Выберите широкий спектр, потому что вы не знаете, какие значения, вероятно, будут хороши.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

Целевая функция

Этот указатель на функцию вычисляет потерю перекрестной проверки в параметрах [sigma,box]. Для получения дополнительной информации смотрите kfoldLoss.

bayesopt передает переменную z к целевой функции как таблица с одной строкой.

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

Оптимизируйте классификатор

Ищите лучшие параметры [sigma,box] использование bayesopt. Для воспроизводимости выберите 'expected-improvement-plus' функция приобретения. Функция приобретения по умолчанию зависит от времени выполнения, и так может дать различные результаты.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.29802 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.14014 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |     0.12459 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.20949 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |     0.16716 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |     0.12516 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |     0.13982 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |      1.6266 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |      2.3481 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |     0.19795 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |      1.1703 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |     0.23914 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |     0.16647 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.14722 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |     0.14553 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |      0.1345 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |     0.13435 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |     0.13044 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |     0.15401 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |     0.12961 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |     0.12742 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |     0.12987 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |     0.15997 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |     0.12983 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |     0.14454 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |     0.14993 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |     0.12991 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |     0.16825 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |      0.1884 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |     0.21229 |       0.075 |       0.075 |       3.5051 |       93.492 |

Figure contains an axes. The axes with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

Figure contains an axes. The axes with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 49.0545 seconds
Total objective function evaluation time: 9.4691

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.1345

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.1449
results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: [function_handle]
              VariableDescriptions: [1x2 optimizableVariable]
                           Options: [1x1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1x2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1x2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 49.0545
                         NextPoint: [1x2 table]
                            XTrace: [30x2 table]
                    ObjectiveTrace: [30x1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30x1 cell}
      ObjectiveEvaluationTimeTrace: [30x1 double]
                IterationTimeTrace: [30x1 double]
                        ErrorTrace: [30x1 double]
                  FeasibilityTrace: [30x1 logical]
       FeasibilityProbabilityTrace: [30x1 double]
               IndexOfMinimumTrace: [30x1 double]
             ObjectiveMinimumTrace: [30x1 double]
    EstimatedObjectiveMinimumTrace: [30x1 double]

Используйте результаты обучить новое, оптимизировал классификатор SVM.

z(1) = results.XAtMinObjective.sigma;
z(2) = results.XAtMinObjective.box;
SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'KernelScale',z(1),'BoxConstraint',z(2));

Постройте контуры классификации. Чтобы визуализировать классификатор вектора поддержки, предскажите баллы по сетке.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

h = nan(3,1); % Preallocation
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

Оцените точность на новых данных

Сгенерируйте и классифицируйте некоторые новые точки данных.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

g = nan(7,1);
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 6 objects of type line, contour. These objects represent -1 (training), +1 (training), -1 (classified), +1 (classified), Support Vectors.

Смотрите, какие новые точки данных правильно классифицируются. Окружите правильно классифицированные точки красного цвета и неправильно классифицированные точки черного цвета цвета.

mydiff = (v == grpData); % Classified correctly
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors','Correctly Classified',...
    'Misclassified'},'Location','Southeast');
hold off

Figure contains an axes. The axes contains 8 objects of type line, contour. These objects represent -1 (training), +1 (training), -1 (classified), +1 (classified), Support Vectors, Correctly Classified, Misclassified.

Смотрите также

|

Похожие темы