Оптимизируйте перекрестный подтвержденный классификатор SVM Используя bayesopt

В этом примере показано, как оптимизировать классификацию SVM с помощью bayesopt функция. Классификация работает над местоположениями точек от смешанной гауссовской модели. В Элементах Статистического Изучения, Hastie, Тибширэни и Фридмана (2009), страница 17 описывает модель. Модель начинается с генерации 10 базисных точек для "зеленого" класса, распределенного как 2D независимые нормали со средним значением (1,0) и модульное отклонение. Это также генерирует 10 базисных точек для "красного" класса, распределенного как 2D независимые нормали со средним значением (0,1) и модульное отклонение. Для каждого класса (зеленый и красный), сгенерируйте 100 случайных точек можно следующим образом:

  1. Выберите базисную точку m соответствующего цвета однородно наугад.

  2. Сгенерируйте независимую случайную точку с 2D нормальным распределением со средним значением m и отклонением I/5, где я - единичная матрица 2 на 2. В этом примере используйте отклонение I/50, чтобы показать преимущество оптимизации более ясно.

После генерации 100 зеленых и 100 красных точек, классифицируйте их использующий fitcsvm. Затем используйте bayesopt оптимизировать параметры получившейся модели SVM относительно перекрестной проверки.

Сгенерируйте точки и классификатор

Сгенерируйте эти 10 базисных точек для каждого класса.

rng default
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

Просмотрите базисные точки.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Поскольку некоторые красные базисные точки близко к зеленым базисным точкам, это может затруднить, чтобы классифицировать точки данных на основе одного только местоположения.

Сгенерируйте 100 точек данных каждого класса.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

Просмотрите точки данных.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Подготовка данных для классификации

Поместите данные в одну матрицу и сделайте векторный grp это помечает класс каждой точки.

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

Подготовьте перекрестную проверку

Настройте раздел для перекрестной проверки. Этот шаг фиксирует обучение и наборы тестов, которые оптимизация использует на каждом шаге.

c = cvpartition(200,'KFold',10);

Подготовьте переменные к байесовой оптимизации

Настройте функцию, которая берет вход z = [rbf_sigma,boxconstraint] и возвращает значение перекрестной проверки потерь z. Возьмите компоненты z как положительные, преобразованные в журнал переменные между 1e-5 и 1e5. Выберите широкий спектр, потому что вы не знаете, какие значения, вероятно, будут хороши.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

Целевая функция

Этот указатель на функцию вычисляет потерю перекрестной проверки в параметрах [sigma,box]. Для получения дополнительной информации смотрите kfoldLoss.

bayesopt передает переменную z к целевой функции как таблица с одной строкой.

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

Оптимизируйте классификатор

Ищите лучшие параметры [sigma,box] использование bayesopt. Для воспроизводимости выберите 'expected-improvement-plus' функция приобретения. Функция приобретения по умолчанию зависит от времени выполнения, и так может дать различные результаты.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.18209 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.12941 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |      0.1123 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.11921 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |     0.10369 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |    0.099777 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |    0.099025 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |      1.6446 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |      1.9935 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |     0.11579 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |      1.2661 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |     0.19038 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |      0.1127 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.11322 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |     0.10048 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |    0.095898 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |    0.093368 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |    0.094842 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |    0.096679 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |    0.095039 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |     0.10033 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |     0.11638 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |     0.10126 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |     0.09854 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |      0.1091 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |    0.094982 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |    0.095466 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |    0.092353 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |     0.10654 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |      0.1158 |       0.075 |       0.075 |       3.5051 |       93.492 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 26.9418 seconds.
Total objective function evaluation time: 7.8888

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.095898

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.095303
results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf','BoxConstraint',z.box,'KernelScale',z.sigma))
              VariableDescriptions: [1×2 optimizableVariable]
                           Options: [1×1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1×2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1×2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 26.9418
                         NextPoint: [1×2 table]
                            XTrace: [30×2 table]
                    ObjectiveTrace: [30×1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30×1 cell}
      ObjectiveEvaluationTimeTrace: [30×1 double]
                IterationTimeTrace: [30×1 double]
                        ErrorTrace: [30×1 double]
                  FeasibilityTrace: [30×1 logical]
       FeasibilityProbabilityTrace: [30×1 double]
               IndexOfMinimumTrace: [30×1 double]
             ObjectiveMinimumTrace: [30×1 double]
    EstimatedObjectiveMinimumTrace: [30×1 double]

Используйте результаты обучить новое, оптимизировал классификатор SVM.

z(1) = results.XAtMinObjective.sigma;
z(2) = results.XAtMinObjective.box;
SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'KernelScale',z(1),'BoxConstraint',z(2));

Постройте контуры классификации. Чтобы визуализировать классификатор вектора поддержки, предскажите баллы по сетке.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

h = nan(3,1); % Preallocation
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Оцените точность на новых данных

Сгенерируйте и классифицируйте некоторые новые точки данных.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

g = nan(7,1);
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors'},'Location','Southeast');
axis equal
hold off

Смотрите, какие новые точки данных правильно классифицируются. Окружите правильно классифицированные точки красного цвета и неправильно классифицированные точки черного цвета цвета.

mydiff = (v == grpData); % Classified correctly
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors','Correctly Classified',...
    'Misclassified'},'Location','Southeast');
hold off

Смотрите также

|

Похожие темы