exponenta event banner

Переход к автоматизации выбора модели с использованием байесовской оптимизации

В этом примере показано, как построить несколько моделей классификации для данного набора обучающих данных, оптимизировать их гиперпараметры с помощью байесовской оптимизации и выбрать модель, которая лучше всего работает с набором тестовых данных.

Обучение нескольких моделей и настройка их гиперпараметров часто могут занять дни или недели. Создание сценария для автоматической разработки и сравнения нескольких моделей может быть намного быстрее. Для ускорения процесса можно также использовать байесовскую оптимизацию. Вместо обучения каждой модели различным наборам гиперпараметров следует выбрать несколько различных моделей и настроить их гиперпараметры по умолчанию с помощью байесовской оптимизации. Байесовская оптимизация находит оптимальный набор гиперпараметров для данной модели путём минимизации целевой функции модели. Этот алгоритм оптимизации стратегически выбирает новые гиперпараметры в каждой итерации и обычно достигает оптимального набора гиперпараметров быстрее, чем простой поиск по сетке. Сценарий в этом примере можно использовать для обучения нескольких моделей классификации с использованием байесовской оптимизации для данного набора обучающих данных и определения модели, которая лучше всего работает с набором тестовых данных.

Кроме того, чтобы выбрать классификационную модель автоматически для выбора типов классификаторов и значений гиперпараметров, используйте fitcauto. Пример см. в разделе Автоматический выбор классификатора с байесовской оптимизацией.

Загрузить данные образца

В этом примере используются данные переписи 1994 года, хранящиеся в census1994.mat. Набор данных состоит из демографических данных Бюро переписи населения США для прогнозирования того, составляет ли человек более 50 000 долларов в год. Задача классификации состоит в том, чтобы соответствовать модели, которая предсказывает категорию зарплаты людей с учетом их возраста, рабочего класса, уровня образования, семейного положения, расы и так далее.

Загрузка данных образца census1994 и отображение переменных в наборе данных.

load census1994
whos
  Name                 Size              Bytes  Class    Attributes

  Description         20x74               2960  char               
  adultdata        32561x15            1872567  table              
  adulttest        16281x15             944467  table              

census1994 содержит набор данных обучения adultdata и набор тестовых данных adulttest. Для этого примера, чтобы сократить время работы, выполните выборку 5000 учебных и тестовых наблюдений из исходных таблиц. adultdata и adulttest, с помощью datasample функция. (Этот шаг можно пропустить, если требуется использовать полные наборы данных.)

NumSamples = 5000;
s = RandStream('mlfg6331_64'); % For reproducibility
adultdata = datasample(s,adultdata,NumSamples,'Replace',false);
adulttest = datasample(s,adulttest,NumSamples,'Replace',false);

Просмотрите первые несколько строк набора данных обучения.

head(adultdata)
ans=8×15 table
    age     workClass       fnlwgt       education      education_num      marital_status         occupation         relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
    ___    ___________    __________    ____________    _____________    __________________    _________________    ______________    _____    ______    ____________    ____________    ______________    ______________    ______

    39     Private          4.91e+05    Bachelors            13          Never-married         Exec-managerial      Other-relative    Black    Male           0               0                45          United-States     <=50K 
    25     Private        2.2022e+05    11th                  7          Never-married         Handlers-cleaners    Own-child         White    Male           0               0                45          United-States     <=50K 
    24     Private        2.2761e+05    10th                  6          Divorced              Handlers-cleaners    Unmarried         White    Female         0               0                58          United-States     <=50K 
    51     Private        1.7329e+05    HS-grad               9          Divorced              Other-service        Not-in-family     White    Female         0               0                40          United-States     <=50K 
    54     Private        2.8029e+05    Some-college         10          Married-civ-spouse    Sales                Husband           White    Male           0               0                32          United-States     <=50K 
    53     Federal-gov         39643    HS-grad               9          Widowed               Exec-managerial      Not-in-family     White    Female         0               0                58          United-States     <=50K 
    52     Private             81859    HS-grad               9          Married-civ-spouse    Machine-op-inspct    Husband           White    Male           0               0                48          United-States     >50K  
    37     Private        1.2429e+05    Some-college         10          Married-civ-spouse    Adm-clerical         Husband           White    Male           0               0                50          United-States     <=50K 

Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и род занятий. Последний столбец salary показывает, имеет ли человек зарплату меньше или равную 50 000 долл. США в год или больше 50 000 долл. США в год.

Понимание данных и выбор классификационных моделей

Toolbox™ статистики и машинного обучения предоставляет несколько вариантов классификации, включая деревья классификации, дискриминантный анализ, наивный Байес, ближайшие соседи, вспомогательные векторные машины (SVM) и классификационные ансамбли. Полный список алгоритмов см. в разделе Классификация.

Прежде чем выбирать алгоритмы для решения проблемы, проверьте набор данных. Данные переписи имеют несколько заслуживающих внимания характеристик:

  • Данные являются табличными и содержат как числовые, так и категориальные переменные.

  • Данные содержат отсутствующие значения.

  • Переменная ответа (salary) имеет два класса (бинарная классификация).

Без каких-либо предположений или использования предварительного знания алгоритмов, которые вы ожидаете хорошо работать с вашими данными, вы просто тренируете все алгоритмы, которые поддерживают табличные данные и двоичную классификацию. Модели выходных кодов с исправлением ошибок (ECOC) используются для данных с более чем двумя классами. Дискриминантный анализ и алгоритмы ближайших соседей не анализируют данные, содержащие как числовые, так и категориальные переменные. Поэтому подходящими для этого примера алгоритмами являются SVM, дерево решений, ансамбль деревьев решений и наивная модель Байеса.

Построение моделей и настройка гиперпараметров

Чтобы ускорить процесс, настройте параметры оптимизации гиперпараметров. Определить 'ShowPlots' как false и 'Verbose' как 0, чтобы отключить отображение графика и сообщений соответственно. Также укажите 'UseParallel' как true для параллельного выполнения байесовской оптимизации, что требует Toolbox™ параллельных вычислений. Из-за непродуктивности параллельной синхронизации параллельная байесовская оптимизация не обязательно дает воспроизводимые результаты.

hypopts = struct('ShowPlots',false,'Verbose',0,'UseParallel',true);

Запустите параллельный пул.

poolobj = gcp;

Можно легко подобрать набор учебных данных и настроить параметры, вызвав каждую функцию подбора и установив ее 'OptimizeHyperparameters' аргумент пары имя-значение для 'auto'. Создайте классификационные модели.

% SVMs: SVM with polynomial kernel & SVM with Gaussian kernel
mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','polynomial','Standardize','on', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);
mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','gaussian','Standardize','on', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);

% Decision tree
mdls{3} = fitctree(adultdata,'salary', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);

% Ensemble of Decision trees
mdls{4} = fitcensemble(adultdata,'salary','Learners','tree', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);

% Naive Bayes
mdls{5} = fitcnb(adultdata,'salary', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);
Warning: It is recommended that you first standardize all numeric predictors when optimizing the Naive Bayes 'Width' parameter. Ignore this warning if you have done that.

График минимальных целевых кривых

Извлеките результаты байесовской оптимизации из каждой модели и постройте график минимального наблюдаемого значения целевой функции для каждой модели в каждой итерации оптимизации гиперпараметров. Значение целевой функции соответствует частоте неправильной классификации, измеренной пятикратной перекрестной проверкой с использованием набора обучающих данных. График сравнивает производительность каждой модели.

figure
hold on
N = length(mdls);
for i = 1:N
    mdl = mdls{i};
    results = mdls{i}.HyperparameterOptimizationResults;
    plot(results.ObjectiveMinimumTrace,'Marker','o','MarkerSize',5);
end
names = {'SVM-Polynomial','SVM-Gaussian','Decision Tree','Ensemble-Trees','Naive Bayes'};
legend(names,'Location','northeast')
title('Bayesian Optimization')
xlabel('Number of Iterations')
ylabel('Minimum Objective Value')

Использование байесовской оптимизации для поиска лучших наборов гиперпараметров повышает производительность моделей в нескольких итерациях. В этом случае график указывает, что ансамбль деревьев решений имеет наилучшую точность прогнозирования для данных. Эта модель успешно работает в нескольких итерациях и различных наборах гиперпараметров байесовской оптимизации.

Проверка производительности с помощью тестового набора

Проверьте производительность классификатора с помощью набора тестовых данных с помощью матрицы путаницы и кривой рабочих характеристик приемника (ROC).

Найдите прогнозируемые метки и значения баллов набора тестовых данных.

label = cell(N,1);
score = cell(N,1);
for i = 1:N
    [label{i},score{i}] = predict(mdls{i},adulttest);
end

Матрица путаницы

Получение наиболее вероятного класса для каждого тестового наблюдения с помощью predict функция каждой модели. Затем вычислите матрицу путаницы с предсказанными классами и известными (истинными) классами набора тестовых данных, используя confusionchart функция.

figure
c = cell(N,1);
for i = 1:N
    subplot(2,3,i)
    c{i} = confusionchart(adulttest.salary,label{i});
    title(names{i})
end

Диагональные элементы указывают количество правильно классифицированных экземпляров данного класса. Внедиагональные элементы являются экземплярами неправильно классифицированных наблюдений.

Кривая ROC

Более тщательно проверьте производительность классификатора, построив кривую ROC для каждого классификатора. Используйте perfcurve для получения функции X и Y координаты кривой ROC и площадь под значением кривой (AUC) для вычисленного X и Y.

Построение графиков ROC для значений баллов, соответствующих метке '<=50K', проверьте порядок столбцов значений баллов, возвращенных из predict функция. Порядок столбцов совпадает с порядком категорий переменной ответа в наборе данных обучения. Просмотрите порядок категорий.

c = categories(adultdata.salary)
c = 2×1 cell
    {'<=50K'}
    {'>50K' }

Постройте график кривых ROC.

figure
hold on
AUC = zeros(1,N);
for i = 1:N    
    [X,Y,~,AUC(i)] = perfcurve(adulttest.salary,score{i}(:,1),'<=50K');
    plot(X,Y)
end
title('ROC Curves')
xlabel('False Positive Rate')
ylabel('True Positive Rate')
legend(names,'Location','southeast')

Кривая ROC показывает истинную положительную скорость по сравнению с ложноположительной скоростью (или чувствительность по сравнению с 1-специфичностью) для различных пороговых значений выхода классификатора.

Теперь постройте график значений AUC с помощью гистограммы. Для идеального классификатора, чья истинная положительная скорость всегда равна 1 независимо от пороговых значений, AUC = 1. Для классификатора, который случайным образом присваивает наблюдения классам, AUC = 0,5. Большие значения AUC указывают на лучшую производительность классификатора.

figure
bar(AUC)
title('Area Under the Curve')
xlabel('Model')
ylabel('AUC')
xticklabels(names)
xtickangle(30)
ylim([0.85,0.925])

Основываясь на матрице путаницы и гистограмме AUC, ансамбль деревьев решений и моделей SVM достигает лучшей точности, чем дерево решений и наивные модели Байеса.

Возобновление оптимизации наиболее перспективных моделей

Выполнение байесовской оптимизации на всех моделях для дальнейших итераций может быть дорогостоящим с точки зрения вычислений. Вместо этого выберите подмножество моделей, которые до сих пор работали хорошо, и продолжите оптимизацию еще для 30 итераций с помощью resume функция. Постройте график минимальных наблюдаемых значений целевой функции для каждой итерации байесовской оптимизации.

figure
hold on
selectedMdls = mdls([1,2,4]);
newresults = cell(1,length(selectedMdls));
for i = 1:length(selectedMdls)
    newresults{i} = resume(selectedMdls{i}.HyperparameterOptimizationResults,'MaxObjectiveEvaluations',30);
    plot(newresults{i}.ObjectiveMinimumTrace,'Marker','o','MarkerSize',5)
end
title('Bayesian Optimization with resume')
xlabel('Number of Iterations')
ylabel('Minimum Objective Value')
legend({'SVM-Polynomial','SVM-Gaussian','Ensemble-Trees'},'Location','northeast')

Первые 30 итераций соответствуют первому раунду байесовской оптимизации. Следующие 30 итераций соответствуют результатам resume функция. Возобновление оптимизации полезно, поскольку потери продолжают уменьшаться после первых 30 итераций.

См. также

| | |

Связанные темы