Оптимизируйте перекрестный подтвержденный классификатор Используя bayesopt

В этом примере показано, как оптимизировать классификацию SVM с помощью bayesopt функция.

В качестве альтернативы можно оптимизировать классификатор при помощи OptimizeHyperparameters аргумент значения имени. Для примера смотрите, Оптимизируют Подгонку Классификатора Используя Байесовую Оптимизацию.

Сгенерируйте данные

Классификация работает над местоположениями точек от смешанной гауссовской модели. В Элементах Статистического Изучения, Hastie, Тибширэни и Фридмана (2009), страница 17 описывает модель. Модель начинается с генерации 10 базисных точек для "зеленого" класса, распределенного как 2D независимые нормали со средним значением (1,0) и модульное отклонение. Это также генерирует 10 базисных точек для "красного" класса, распределенного как 2D независимые нормали со средним значением (0,1) и модульное отклонение. Для каждого класса (зеленый и красный), сгенерируйте 100 случайных точек можно следующим образом:

  1. Выберите базисную точку m соответствующего цвета однородно наугад.

  2. Сгенерируйте независимую случайную точку с 2D нормальным распределением со средним значением m и отклонением I/5, где я - единичная матрица 2 на 2. В этом примере используйте отклонение I/50, чтобы показать преимущество оптимизации более ясно.

Сгенерируйте эти 10 базисных точек для каждого класса.

rng('default') % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

Просмотрите базисные точки.

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line.

Поскольку некоторые красные базисные точки близко к зеленым базисным точкам, это может затруднить, чтобы классифицировать точки данных на основе одного только местоположения.

Сгенерируйте 100 точек данных каждого класса.

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

Просмотрите точки данных.

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line.

Подготовка данных для классификации

Поместите данные в одну матрицу и сделайте векторный grp это помечает класс каждой точки. 1 указывает, что зеленый класс, и-1 указывает на красный класс.

cdata = [grnpts;redpts];
grp = ones(200,1);
grp(101:200) = -1;

Подготовьте перекрестную проверку

Настройте раздел для перекрестной проверки. Этот шаг фиксирует обучение и наборы тестов, которые оптимизация использует на каждом шаге.

c = cvpartition(200,'KFold',10);

Подготовьте переменные к байесовой оптимизации

Настройте функцию, которая берет вход z = [rbf_sigma,boxconstraint] и возвращает значение перекрестной проверки потерь z. Возьмите компоненты z как положительные, преобразованные в журнал переменные между 1e-5 и 1e5. Выберите широкий спектр, потому что вы не знаете, какие значения, вероятно, будут хороши.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

Целевая функция

Этот указатель на функцию вычисляет потерю перекрестной проверки в параметрах [sigma,box]. Для получения дополнительной информации смотрите kfoldLoss.

bayesopt передает переменную z к целевой функции как таблица с одной строкой.

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

Оптимизируйте классификатор

Ищите лучшие параметры [sigma,box] использование bayesopt. Для воспроизводимости выберите 'expected-improvement-plus' функция захвата. Функция захвата по умолчанию зависит от времени выполнения, и так может дать различные результаты.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.48985 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.40009 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |       0.283 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.27349 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |     0.20833 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |     0.20963 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |     0.21767 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |      2.7117 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |      3.2475 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |     0.24964 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |      2.1158 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |      0.3453 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |     0.21464 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.28157 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |     0.19414 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |     0.18644 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |      0.1926 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |     0.19321 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |     0.18904 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |     0.20115 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |     0.20615 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |     0.23223 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |     0.21296 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |     0.20081 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |     0.19406 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |     0.20198 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |     0.21041 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |     0.19326 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |     0.22525 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |     0.19009 |       0.075 |       0.075 |       3.5051 |       93.492 |

Figure contains an axes object. The axes object with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

Figure contains an axes object. The axes object with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 52.689 seconds
Total objective function evaluation time: 14.472

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.18644

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.19289
results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: [function_handle]
              VariableDescriptions: [1x2 optimizableVariable]
                           Options: [1x1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1x2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1x2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 52.6890
                         NextPoint: [1x2 table]
                            XTrace: [30x2 table]
                    ObjectiveTrace: [30x1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30x1 cell}
      ObjectiveEvaluationTimeTrace: [30x1 double]
                IterationTimeTrace: [30x1 double]
                        ErrorTrace: [30x1 double]
                  FeasibilityTrace: [30x1 logical]
       FeasibilityProbabilityTrace: [30x1 double]
               IndexOfMinimumTrace: [30x1 double]
             ObjectiveMinimumTrace: [30x1 double]
    EstimatedObjectiveMinimumTrace: [30x1 double]

Получите лучшую предполагаемую допустимую точку из XAtMinEstimatedObjective свойство или при помощи bestPoint функция. По умолчанию, bestPoint функционируйте использует 'min-visited-upper-confidence-interval' критерий. Для получения дополнительной информации смотрите аргумент значения имени Критерия bestPoint.

results.XAtMinEstimatedObjective
ans=1×2 table
     sigma      box  
    _______    ______

    0.32869    18.076

z = bestPoint(results)
z=1×2 table
     sigma      box  
    _______    ______

    0.32869    18.076

Используйте лучшую точку, чтобы обучить новое, оптимизировал классификатор SVM.

SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf', ...
    'KernelScale',z.sigma,'BoxConstraint',z.box);

Чтобы визуализировать классификатор вектора поддержки, предскажите баллы по сетке.

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

Постройте контуры классификации.

figure
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1) ,...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');

Figure contains an axes object. The axes object contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

Оцените точность на новых данных

Сгенерируйте и классифицируйте новые точки тестовых данных.

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);  % green = 1
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

Вычислите misclassification уровни на наборе тестовых данных.

L = loss(SVMModel,newData,grpData)
L = 0.3500

Смотрите, какие новые точки данных правильно классифицируются. Окружите правильно классифицированные точки красного цвета и неправильно классифицированные точки черного цвета цвета.

h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**');

mydiff = (v == grpData); % Classified correctly

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','Support Vectors', ...
    '-1 (classified)','+1 (classified)', ...
    'Correctly Classified','Misclassified'}, ...
    'Location','Southeast');
hold off

Figure contains an axes object. The axes object contains 8 objects of type line, contour. These objects represent -1 (training), +1 (training), Support Vectors, -1 (classified), +1 (classified), Correctly Classified, Misclassified.

Смотрите также

|

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте