bayesopt
В этом примере показано, как оптимизировать классификацию SVM с помощью bayesopt
функция. Классификация работает над местоположениями точек от смешанной гауссовской модели. В Элементах Статистического Изучения, Hastie, Тибширэни и Фридмана (2009), страница 17 описывает модель. Модель начинается с генерации 10 базисных точек для "зеленого" класса, распределенного как 2D независимые нормали со средним значением (1,0) и модульное отклонение. Это также генерирует 10 базисных точек для "красного" класса, распределенного как 2D независимые нормали со средним значением (0,1) и модульное отклонение. Для каждого класса (зеленый и красный), сгенерируйте 100 случайных точек можно следующим образом:
Выберите базисную точку m соответствующего цвета однородно наугад.
Сгенерируйте независимую случайную точку с 2D нормальным распределением со средним значением m и отклонением I/5, где я - единичная матрица 2 на 2. В этом примере используйте отклонение I/50, чтобы показать преимущество оптимизации более ясно.
После генерации 100 зеленых и 100 красных точек, классифицируйте их использующий fitcsvm
. Затем используйте bayesopt
оптимизировать параметры получившейся модели SVM относительно перекрестной проверки.
Сгенерируйте эти 10 базисных точек для каждого класса.
rng default
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);
Просмотрите базисные точки.
plot(grnpop(:,1),grnpop(:,2),'go') hold on plot(redpop(:,1),redpop(:,2),'ro') hold off
Поскольку некоторые красные базисные точки близко к зеленым базисным точкам, это может затруднить, чтобы классифицировать точки данных на основе одного только местоположения.
Сгенерируйте 100 точек данных каждого класса.
redpts = zeros(100,2);grnpts = redpts; for i = 1:100 grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02); redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02); end
Просмотрите точки данных.
figure plot(grnpts(:,1),grnpts(:,2),'go') hold on plot(redpts(:,1),redpts(:,2),'ro') hold off
Поместите данные в одну матрицу и сделайте векторный grp
это помечает класс каждой точки.
cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;
Настройте раздел для перекрестной проверки. Этот шаг фиксирует обучение и наборы тестов, которые оптимизация использует на каждом шаге.
c = cvpartition(200,'KFold',10);
Настройте функцию, которая берет вход z = [rbf_sigma,boxconstraint]
и возвращает значение перекрестной проверки потерь z
. Возьмите компоненты z
как положительные, преобразованные в журнал переменные между 1e-5
и 1e5
. Выберите широкий спектр, потому что вы не знаете, какие значения, вероятно, будут хороши.
sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log'); box = optimizableVariable('box',[1e-5,1e5],'Transform','log');
Этот указатель на функцию вычисляет потерю перекрестной проверки в параметрах [sigma,box]
. Для получения дополнительной информации смотрите kfoldLoss
.
bayesopt
передает переменную z
к целевой функции как таблица с одной строкой.
minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,... 'KernelFunction','rbf','BoxConstraint',z.box,... 'KernelScale',z.sigma));
Ищите лучшие параметры [sigma,box]
использование bayesopt
. Для воспроизводимости выберите 'expected-improvement-plus'
функция приобретения. Функция приобретения по умолчанию зависит от времени выполнения, и так может дать различные результаты.
results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,... 'AcquisitionFunctionName','expected-improvement-plus')
|=====================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | sigma | box | | | result | | runtime | (observed) | (estim.) | | | |=====================================================================================================| | 1 | Best | 0.61 | 0.29802 | 0.61 | 0.61 | 0.00013375 | 13929 | | 2 | Best | 0.345 | 0.14014 | 0.345 | 0.345 | 24526 | 1.936 | | 3 | Accept | 0.61 | 0.12459 | 0.345 | 0.345 | 0.0026459 | 0.00084929 | | 4 | Accept | 0.345 | 0.20949 | 0.345 | 0.345 | 3506.3 | 6.7427e-05 | | 5 | Accept | 0.345 | 0.16716 | 0.345 | 0.345 | 9135.2 | 571.87 | | 6 | Accept | 0.345 | 0.12516 | 0.345 | 0.345 | 99701 | 10223 | | 7 | Best | 0.295 | 0.13982 | 0.295 | 0.295 | 455.88 | 9957.4 | | 8 | Best | 0.24 | 1.6266 | 0.24 | 0.24 | 31.56 | 99389 | | 9 | Accept | 0.24 | 2.3481 | 0.24 | 0.24 | 10.451 | 64429 | | 10 | Accept | 0.35 | 0.19795 | 0.24 | 0.24 | 17.331 | 1.0264e-05 | | 11 | Best | 0.23 | 1.1703 | 0.23 | 0.23 | 16.005 | 90155 | | 12 | Best | 0.1 | 0.23914 | 0.1 | 0.1 | 0.36562 | 80878 | | 13 | Accept | 0.115 | 0.16647 | 0.1 | 0.1 | 0.1793 | 68459 | | 14 | Accept | 0.105 | 0.14722 | 0.1 | 0.1 | 0.2267 | 95421 | | 15 | Best | 0.095 | 0.14553 | 0.095 | 0.095 | 0.28999 | 0.0058227 | | 16 | Best | 0.075 | 0.1345 | 0.075 | 0.075 | 0.30554 | 8.9017 | | 17 | Accept | 0.085 | 0.13435 | 0.075 | 0.075 | 0.41122 | 4.4476 | | 18 | Accept | 0.085 | 0.13044 | 0.075 | 0.075 | 0.25565 | 7.8038 | | 19 | Accept | 0.075 | 0.15401 | 0.075 | 0.075 | 0.32869 | 18.076 | | 20 | Accept | 0.085 | 0.12961 | 0.075 | 0.075 | 0.32442 | 5.2118 | |=====================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | sigma | box | | | result | | runtime | (observed) | (estim.) | | | |=====================================================================================================| | 21 | Accept | 0.3 | 0.12742 | 0.075 | 0.075 | 1.3592 | 0.0098067 | | 22 | Accept | 0.12 | 0.12987 | 0.075 | 0.075 | 0.17515 | 0.00070913 | | 23 | Accept | 0.175 | 0.15997 | 0.075 | 0.075 | 0.1252 | 0.010749 | | 24 | Accept | 0.105 | 0.12983 | 0.075 | 0.075 | 1.1664 | 31.13 | | 25 | Accept | 0.1 | 0.14454 | 0.075 | 0.075 | 0.57465 | 2013.8 | | 26 | Accept | 0.12 | 0.14993 | 0.075 | 0.075 | 0.42922 | 1.1602e-05 | | 27 | Accept | 0.12 | 0.12991 | 0.075 | 0.075 | 0.42956 | 0.00027218 | | 28 | Accept | 0.095 | 0.16825 | 0.075 | 0.075 | 0.4806 | 13.452 | | 29 | Accept | 0.105 | 0.1884 | 0.075 | 0.075 | 0.19755 | 943.87 | | 30 | Accept | 0.205 | 0.21229 | 0.075 | 0.075 | 3.5051 | 93.492 |
__________________________________________________________ Optimization completed. MaxObjectiveEvaluations of 30 reached. Total function evaluations: 30 Total elapsed time: 49.0545 seconds Total objective function evaluation time: 9.4691 Best observed feasible point: sigma box _______ ______ 0.30554 8.9017 Observed objective function value = 0.075 Estimated objective function value = 0.075 Function evaluation time = 0.1345 Best estimated feasible point (according to models): sigma box _______ ______ 0.32869 18.076 Estimated objective function value = 0.075 Estimated function evaluation time = 0.1449
results = BayesianOptimization with properties: ObjectiveFcn: [function_handle] VariableDescriptions: [1x2 optimizableVariable] Options: [1x1 struct] MinObjective: 0.0750 XAtMinObjective: [1x2 table] MinEstimatedObjective: 0.0750 XAtMinEstimatedObjective: [1x2 table] NumObjectiveEvaluations: 30 TotalElapsedTime: 49.0545 NextPoint: [1x2 table] XTrace: [30x2 table] ObjectiveTrace: [30x1 double] ConstraintsTrace: [] UserDataTrace: {30x1 cell} ObjectiveEvaluationTimeTrace: [30x1 double] IterationTimeTrace: [30x1 double] ErrorTrace: [30x1 double] FeasibilityTrace: [30x1 logical] FeasibilityProbabilityTrace: [30x1 double] IndexOfMinimumTrace: [30x1 double] ObjectiveMinimumTrace: [30x1 double] EstimatedObjectiveMinimumTrace: [30x1 double]
Используйте результаты обучить новое, оптимизировал классификатор SVM.
z(1) = results.XAtMinObjective.sigma; z(2) = results.XAtMinObjective.box; SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',... 'KernelScale',z(1),'BoxConstraint',z(2));
Постройте контуры классификации. Чтобы визуализировать классификатор вектора поддержки, предскажите баллы по сетке.
d = 0.02; [x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),... min(cdata(:,2)):d:max(cdata(:,2))); xGrid = [x1Grid(:),x2Grid(:)]; [~,scores] = predict(SVMModel,xGrid); h = nan(3,1); % Preallocation figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3) = plot(cdata(SVMModel.IsSupportVector,1),... cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h,{'-1','+1','Support Vectors'},'Location','Southeast'); axis equal hold off
Сгенерируйте и классифицируйте некоторые новые точки данных.
grnobj = gmdistribution(grnpop,.2*eye(2)); redobj = gmdistribution(redpop,.2*eye(2)); newData = random(grnobj,10); newData = [newData;random(redobj,10)]; grpData = ones(20,1); grpData(11:20) = -1; % red = -1 v = predict(SVMModel,newData); g = nan(7,1); figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); h(5) = plot(cdata(SVMModel.IsSupportVector,1),... cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',... '+1 (classified)','Support Vectors'},'Location','Southeast'); axis equal hold off
Смотрите, какие новые точки данных правильно классифицируются. Окружите правильно классифицированные точки красного цвета и неправильно классифицированные точки черного цвета цвета.
mydiff = (v == grpData); % Classified correctly figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); h(5) = plot(cdata(SVMModel.IsSupportVector,1),... cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); for ii = mydiff % Plot red squares around correct pts h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12); end for ii = not(mydiff) % Plot black squares around incorrect pts h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12); end legend(h,{'-1 (training)','+1 (training)','-1 (classified)',... '+1 (classified)','Support Vectors','Correctly Classified',... 'Misclassified'},'Location','Southeast'); hold off
bayesopt
| optimizableVariable