exponenta event banner

Оптимизация кросс-проверенного классификатора SVM с помощью bayesopt

В этом примере показано, как оптимизировать классификацию SVM с помощью bayesopt функция. Классификация работает на местах точек из гауссовой модели смеси. В «Элементах статистического обучения» Hastie, Tibshirani и Friedman (2009) на странице 17 описывается модель. Модель начинается с генерации 10 базовых точек для «зеленого» класса, распределенного как 2-D независимые нормали со средним значением (1,0) и единичной дисперсией. Он также генерирует 10 базовых точек для «красного» класса, распределенного как 2-D независимые нормали со средним значением (0,1) и единичной дисперсией. Для каждого класса (зеленого и красного) создайте 100 случайных точек следующим образом:

  1. Выберите базовую точку m соответствующего цвета равномерно случайным образом.

  2. Создайте независимую случайную точку с 2-D нормальным распределением со средним m и дисперсией I/5, где I - единичная матрица 2 на 2. В этом примере используйте I/50 дисперсии, чтобы более четко показать преимущество оптимизации.

После генерации 100 зеленых и 100 красных точек классифицируйте их с помощью fitcsvm. Затем использовать bayesopt для оптимизации параметров результирующей модели SVM в отношении перекрестной проверки.

Создание точек и классификатора

Создайте 10 базовых точек для каждого класса.

rng default
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

Просмотрите базовые точки.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

Поскольку некоторые красные базовые точки близки к зеленым базовым точкам, может быть трудно классифицировать точки данных только на основе местоположения.

Создайте 100 точек данных каждого класса.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

Просмотр точек данных.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes. The axes contains 2 objects of type line.

Подготовка данных для классификации

Поместите данные в одну матрицу и создайте вектор grp маркирует класс каждой точки.

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

Подготовка перекрестной проверки

Настройте раздел для перекрестной проверки. На этом шаге фиксируются наборы каналов и тестов, используемые при оптимизации на каждом шаге.

c = cvpartition(200,'KFold',10);

Подготовка переменных для байесовской оптимизации

Настройка функции, которая принимает входные данные z = [rbf_sigma,boxconstraint] и возвращает значение потери перекрестной проверки z. Возьмите компоненты z как положительные, преобразованные в журнал переменные между 1e-5 и 1e5. Выберите широкий диапазон, потому что вы не знаете, какие значения могут быть хорошими.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

Целевая функция

Этот дескриптор функции вычисляет потери перекрестной проверки в параметрах [sigma,box]. Для получения более подробной информации см. kfoldLoss.

bayesopt передает переменную z целевой функции в виде однострочной таблицы.

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

Оптимизация классификатора

Поиск наилучших параметров [sigma,box] использование bayesopt. Для воспроизводимости выберите 'expected-improvement-plus' функция приобретения. Функция сбора данных по умолчанию зависит от времени выполнения и может давать различные результаты.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.29802 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.14014 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |     0.12459 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.20949 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |     0.16716 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |     0.12516 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |     0.13982 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |      1.6266 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |      2.3481 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |     0.19795 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |      1.1703 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |     0.23914 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |     0.16647 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.14722 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |     0.14553 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |      0.1345 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |     0.13435 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |     0.13044 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |     0.15401 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |     0.12961 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |     0.12742 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |     0.12987 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |     0.15997 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |     0.12983 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |     0.14454 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |     0.14993 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |     0.12991 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |     0.16825 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |      0.1884 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |     0.21229 |       0.075 |       0.075 |       3.5051 |       93.492 |

Figure contains an axes. The axes with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

Figure contains an axes. The axes with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 49.0545 seconds
Total objective function evaluation time: 9.4691

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.1345

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.1449
results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: [function_handle]
              VariableDescriptions: [1x2 optimizableVariable]
                           Options: [1x1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1x2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1x2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 49.0545
                         NextPoint: [1x2 table]
                            XTrace: [30x2 table]
                    ObjectiveTrace: [30x1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30x1 cell}
      ObjectiveEvaluationTimeTrace: [30x1 double]
                IterationTimeTrace: [30x1 double]
                        ErrorTrace: [30x1 double]
                  FeasibilityTrace: [30x1 logical]
       FeasibilityProbabilityTrace: [30x1 double]
               IndexOfMinimumTrace: [30x1 double]
             ObjectiveMinimumTrace: [30x1 double]
    EstimatedObjectiveMinimumTrace: [30x1 double]

Используйте результаты для обучения нового оптимизированного классификатора SVM.

z(1) = results.XAtMinObjective.sigma;
z(2) = results.XAtMinObjective.box;
SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'KernelScale',z(1),'BoxConstraint',z(2));

Постройте график границ классификации. Чтобы визуализировать классификатор вектора поддержки, спрогнозируйте баллы по сетке.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

h = nan(3,1); % Preallocation
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

Оценка точности новых данных

Создайте и классифицируйте некоторые новые точки данных.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

g = nan(7,1);
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors'},'Location','Southeast');
axis equal
hold off

Figure contains an axes. The axes contains 6 objects of type line, contour. These objects represent -1 (training), +1 (training), -1 (classified), +1 (classified), Support Vectors.

Посмотрите, какие новые точки данных правильно классифицированы. Обведите правильно классифицированные точки красным цветом, а неправильно классифицированные точки черным цветом.

mydiff = (v == grpData); % Classified correctly
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors','Correctly Classified',...
    'Misclassified'},'Location','Southeast');
hold off

Figure contains an axes. The axes contains 8 objects of type line, contour. These objects represent -1 (training), +1 (training), -1 (classified), +1 (classified), Support Vectors, Correctly Classified, Misclassified.

См. также

|

Связанные темы