Движение автоматизирующий выбор модели Используя байесовую оптимизацию

В этом примере показано, как создать несколько моделей классификации для данного обучающего набора данных, оптимизируйте их гиперпараметры с помощью Байесовой оптимизации и выберите модель, которая выполняет лучшее на наборе тестовых данных.

Обучение несколько моделей и настройки их гиперпараметров может часто занимать дни или недели. Создание скрипта, чтобы разработать и сравнить многоуровневые модели автоматически может быть намного быстрее. Можно также использовать Байесовую оптимизацию, чтобы ускорить процесс. Вместо обучения каждая модель с различными наборами гиперпараметров вы выбираете несколько различных моделей и настраиваете их гиперпараметры по умолчанию с помощью Байесовой оптимизации. Байесова оптимизация находит оптимальный набор гиперпараметров для данной модели путем минимизации целевой функции модели. Этот алгоритм оптимизации стратегически выбирает новые гиперпараметры в каждой итерации и обычно прибывает в оптимальный набор гиперпараметров более быстро, чем простой поиск сетки. Можно использовать скрипт в этом примере, чтобы обучить несколько моделей классификации с помощью Байесовой оптимизации в данном обучающем наборе данных и идентифицировать модель, которая выполняет лучше всего на наборе тестовых данных.

Загрузка демонстрационных данных

Этот пример использует 1 994 данных о переписи, хранимых в census1994.mat. Набор данных состоит из демографических данных Бюро переписи США, чтобы предсказать, передает ли индивидуум 50 000$ в год. Задача классификации состоит в том, чтобы подобрать модель, которая предсказывает категорию зарплаты людей, учитывая их возраст, рабочий класс, образовательный уровень, семейное положение, гонку, и так далее.

Загрузите выборочные данные census1994 и отобразите переменные в наборе данных.

load census1994
whos
  Name                 Size              Bytes  Class    Attributes

  Description         20x74               2960  char               
  adultdata        32561x15            1873655  table              
  adulttest        16281x15             945793  table              

census1994 содержит обучающий набор данных adultdata и тестовые данные устанавливают adulttest. В данном примере уменьшать время выполнения, поддемонстрационные 5 000 обучения и тестовых наблюдений каждый, из исходных таблиц adultdata и adulttest, при помощи datasample функция. (Можно пропустить этот шаг, если вы хотите использовать наборы полных данных.)

NumSamples = 5000;
s = RandStream('mlfg6331_64'); % For reproducibility
adultdata = datasample(s,adultdata,NumSamples,'Replace',false);
adulttest = datasample(s,adulttest,NumSamples,'Replace',false);

Предварительно просмотрите первые несколько строк обучающего набора данных.

head(adultdata)
ans=8×15 table
    age     workClass       fnlwgt       education      education_num      marital_status         occupation         relationship     race      sex      capital_gain    capital_loss    hours_per_week    native_country    salary
    ___    ___________    __________    ____________    _____________    __________________    _________________    ______________    _____    ______    ____________    ____________    ______________    ______________    ______

    39     Private          4.91e+05    Bachelors            13          Never-married         Exec-managerial      Other-relative    Black    Male           0               0                45          United-States     <=50K 
    25     Private        2.2022e+05    11th                  7          Never-married         Handlers-cleaners    Own-child         White    Male           0               0                45          United-States     <=50K 
    24     Private        2.2761e+05    10th                  6          Divorced              Handlers-cleaners    Unmarried         White    Female         0               0                58          United-States     <=50K 
    51     Private        1.7329e+05    HS-grad               9          Divorced              Other-service        Not-in-family     White    Female         0               0                40          United-States     <=50K 
    54     Private        2.8029e+05    Some-college         10          Married-civ-spouse    Sales                Husband           White    Male           0               0                32          United-States     <=50K 
    53     Federal-gov         39643    HS-grad               9          Widowed               Exec-managerial      Not-in-family     White    Female         0               0                58          United-States     <=50K 
    52     Private             81859    HS-grad               9          Married-civ-spouse    Machine-op-inspct    Husband           White    Male           0               0                48          United-States     >50K  
    37     Private        1.2429e+05    Some-college         10          Married-civ-spouse    Adm-clerical         Husband           White    Male           0               0                50          United-States     <=50K 

Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и размещение. Последний столбец salary показывает, есть ли у человека зарплата, меньше чем или равная 50 000$ в год или больше, чем 50 000$ в год.

Тестовые данные устанавливают adulttest содержит два ненужных пустых класса. Удалите их при помощи removecats функция.

adulttest.salary = removecats(adulttest.salary);

Изучите данные и выберите модели классификации

Statistics and Machine Learning Toolbox™ предоставляет несколько возможностей для классификации, включая деревья классификации, дискриминантный анализ, наивного Бейеса, самых близких соседей, машины опорных векторов (SVMs) и ансамбли классификации. Для полного списка алгоритмов смотрите Классификацию.

Прежде, чем выбрать алгоритмы, чтобы использовать в вашей проблеме, смотрите свой набор данных. Данные о переписи имеют несколько примечательных характеристик:

  • Данные являются табличными и содержат и числовые и категориальные переменные.

  • Данные содержат отсутствующие значения.

  • Переменная отклика (salary) имеет два класса (бинарная классификация).

Не делая предположений или с помощью предварительных знаний алгоритмов, что вы ожидаете работать хорошо над своими данными, вы просто обучаете все алгоритмы, которые поддерживают табличные данные и бинарную классификацию. Модели выходных кодов с коррекцией ошибок (ECOC) используются в данных больше чем с двумя классами. Дискриминантный анализ и самые близкие соседние алгоритмы не анализируют данные, которые содержат и числовые и категориальные переменные. Поэтому алгоритмы, подходящие для этого примера, являются SVMs, деревом решений, ансамблем деревьев решений и наивной моделью Bayes.

Создайте гиперпараметры мелодии и модели

Чтобы ускорить процесс, настройте опции гипероптимизации параметров управления. Задайте 'ShowPlots' как false и 'Verbose' как 0, чтобы отключить график и индикаторы сообщения, соответственно. Кроме того, задайте 'UseParallel' как true запускать Байесовую оптимизацию параллельно, которая требует Parallel Computing Toolbox™. Из-за невоспроизводимости синхронизации параллели, параллельная Байесова оптимизация не обязательно дает к восстанавливаемым результатам.

hypopts = struct('ShowPlots',false,'Verbose',0,'UseParallel',true);

Запустите параллельный пул.

poolobj = gcp;
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

Можно соответствовать обучающему набору данных и настройкам параметров легко путем вызывания каждой подходящей функции и установки ее 'OptimizeHyperparameters' аргумент пары "имя-значение" 'auto'. Создайте модели классификации.

% SVMs: SVM with polynomial kernel & SVM with Gaussian kernel
mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','polynomial','Standardize','on', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);
mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','gaussian','Standardize','on', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);

% Decision tree
mdls{3} = fitctree(adultdata,'salary', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);

% Ensemble of Decision trees
mdls{4} = fitcensemble(adultdata,'salary','Learners','tree', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);

% Naive Bayes
mdls{5} = fitcnb(adultdata,'salary', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);
Warning: It is recommended that you first standardize all numeric predictors when optimizing the Naive Bayes 'Width' parameter. Ignore this warning if you have done that.

Постройте минимальные объективные кривые

Извлеките Байесовы результаты оптимизации из каждой модели и постройте минимальную наблюдаемую величину целевой функции для каждой модели по каждой итерации гипероптимизации параметров управления. Значение целевой функции соответствует misclassification уровню, измеренному пятикратной перекрестной проверкой с помощью обучающего набора данных. График сравнивает производительность каждой модели.

figure
hold on
N = length(mdls);
for i = 1:N
    mdl = mdls{i};
    results = mdls{i}.HyperparameterOptimizationResults;
    plot(results.ObjectiveMinimumTrace,'Marker','o','MarkerSize',5);
end
names = {'SVM-Polynomial','SVM-Gaussian','Decision Tree','Ensemble-Trees','Naive Bayes'};
legend(names,'Location','northeast')
title('Bayesian Optimization')
xlabel('Number of Iterations')
ylabel('Minimum Objective Value')

Используя Байесовую оптимизацию, чтобы найти лучшие наборы гиперпараметра улучшает производительность моделей по нескольким итерациям. В этом случае график показывает, что у ансамбля деревьев решений есть лучшая точность прогноза для данных. Эта модель выполняет хорошо последовательно по нескольким итерациям и различным наборам Байесовых гиперпараметров оптимизации.

Проверяйте производительность с набором тестов

Проверяйте производительность классификатора с набором тестовых данных при помощи матрицы беспорядка и кривой рабочей характеристики получателя (ROC).

Найдите предсказанные метки и значения счета набора тестовых данных.

label = cell(N,1);
score = cell(N,1);
for i = 1:N
    [label{i},score{i}] = predict(mdls{i},adulttest);
end

Матрица беспорядка

Получите наиболее вероятный класс для каждого тестового наблюдения при помощи predict функция каждой модели. Затем вычислите матрицу беспорядка с предсказанными классами и известными (TRUE) классами набора тестовых данных при помощи confusionchart функция.

figure
c = cell(N,1);
for i = 1:N
    subplot(2,3,i)
    c{i} = confusionchart(adulttest.salary,label{i});
    title(names{i})
end

Диагональные элементы указывают на количество правильно классифицированных экземпляров данного класса. Недиагональные элементы являются экземплярами неправильно классифицированных наблюдений.

Кривая ROC

Смотрите производительность классификатора более тесно путем графического вывода кривой ROC для каждого классификатора. Используйте perfcurve функция, чтобы получить X и Y координаты ROC изгибаются и область под кривой (AUC) значение для вычисленного X и Y.

Построить кривые ROC для значений счета, соответствующих метке '<=50K', проверяйте порядка следования столбцов значений счета, возвращенных от predict функция. Порядок следования столбцов совпадает с порядком категории переменной отклика в обучающем наборе данных. Отобразите порядок категории.

c = categories(adultdata.salary)
c = 2×1 cell array
    {'<=50K'}
    {'>50K' }

Постройте кривые ROC.

figure
hold on
AUC = zeros(1,N);
for i = 1:N    
    [X,Y,~,AUC(i)] = perfcurve(adulttest.salary,score{i}(:,1),'<=50K');
    plot(X,Y)
end
title('ROC Curves')
xlabel('False Positive Rate')
ylabel('True Positive Rate')
legend(names,'Location','southeast')

Кривая ROC показывает истинный положительный уровень по сравнению с ложным положительным уровнем (или, чувствительность по сравнению с 1 спецификой) для различных порогов классификатора выход.

Теперь постройте значения AUC с помощью столбчатого графика. Для совершенного классификатора, истинный положительный уровень которого всегда 1 независимо от порогов, AUC = 1. Для классификатора, который случайным образом присваивает наблюдения классам, AUC = 0.5. Большие значения AUC указывают на лучшую производительность классификатора.

figure
bar(AUC)
title('Area Under the Curve')
xlabel('Model')
ylabel('AUC')
xticklabels(names)
xtickangle(30)
ylim([0,1])

На основе матрицы беспорядка и столбчатого графика AUC, ансамбль деревьев решений и моделей SVM достигает лучшей точности, чем дерево решений и наивные модели Bayes.

Возобновите оптимизацию большинства многообещающих моделей

Выполнение Байесовой оптимизации на всех моделях для дальнейших итераций может быть в вычислительном отношении дорогим. Вместо этого выберите подмножество моделей, которые выполнили хорошо до сих пор и продолжают оптимизацию для еще 30 итераций при помощи resume функция. Постройте минимальные наблюдаемые величины целевой функции для каждой итерации Байесовой оптимизации.

figure
hold on
selectedMdls = mdls([1,2,4]);
newresults = cell(1,length(selectedMdls));
for i = 1:length(selectedMdls)
    newresults{i} = resume(selectedMdls{i}.HyperparameterOptimizationResults,'MaxObjectiveEvaluations',30);
    plot(newresults{i}.ObjectiveMinimumTrace,'Marker','o','MarkerSize',5)
end
title('Bayesian Optimization with resume')
xlabel('Number of Iterations')
ylabel('Minimum Objective Value')
legend({'SVM-Polynomial','SVM-Gaussian','Ensemble-Trees'},'Location','northeast')

Первые 30 итераций соответствуют первому раунду Байесовой оптимизации. Следующие 30 итераций соответствуют результатам resume функция. Возобновление оптимизации полезно, потому что потеря продолжает уменьшать далее после первых 30 итераций.

Смотрите также

| | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте