В этом примере показано, как создать несколько моделей классификации для данного обучающего набора данных, оптимизируйте их гиперпараметры с помощью Байесовой оптимизации и выберите модель, которая выполняет лучшее на наборе тестовых данных.
Обучение несколько моделей и настройки их гиперпараметров может часто занимать дни или недели. Создание скрипта, чтобы разработать и сравнить многоуровневые модели автоматически может быть намного быстрее. Можно также использовать Байесовую оптимизацию, чтобы ускорить процесс. Вместо обучения каждая модель с различными наборами гиперпараметров вы выбираете несколько различных моделей и настраиваете их гиперпараметры по умолчанию с помощью Байесовой оптимизации. Байесова оптимизация находит оптимальный набор гиперпараметров для данной модели путем минимизации целевой функции модели. Этот алгоритм оптимизации стратегически выбирает новые гиперпараметры в каждой итерации и обычно прибывает в оптимальный набор гиперпараметров более быстро, чем простой поиск сетки. Можно использовать скрипт в этом примере, чтобы обучить несколько моделей классификации с помощью Байесовой оптимизации в данном обучающем наборе данных и идентифицировать модель, которая выполняет лучше всего на наборе тестовых данных.
Этот пример использует 1 994 данных о переписи, хранимых в census1994.mat
. Набор данных состоит из демографических данных Бюро переписи США, чтобы предсказать, передает ли индивидуум 50 000$ в год. Задача классификации состоит в том, чтобы подобрать модель, которая предсказывает категорию зарплаты людей, учитывая их возраст, рабочий класс, образовательный уровень, семейное положение, гонку, и так далее.
Загрузите выборочные данные census1994
и отобразите переменные в наборе данных.
load census1994
whos
Name Size Bytes Class Attributes Description 20x74 2960 char adultdata 32561x15 1873655 table adulttest 16281x15 945793 table
census1994
содержит обучающий набор данных adultdata
и тестовые данные устанавливают adulttest
. В данном примере уменьшать время выполнения, поддемонстрационные 5 000 обучения и тестовых наблюдений каждый, из исходных таблиц adultdata
и adulttest
, при помощи datasample
функция. (Можно пропустить этот шаг, если вы хотите использовать наборы полных данных.)
NumSamples = 5000; s = RandStream('mlfg6331_64'); % For reproducibility adultdata = datasample(s,adultdata,NumSamples,'Replace',false); adulttest = datasample(s,adulttest,NumSamples,'Replace',false);
Предварительно просмотрите первые несколько строк обучающего набора данных.
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ___________ __________ ____________ _____________ __________________ _________________ ______________ _____ ______ ____________ ____________ ______________ ______________ ______
39 Private 4.91e+05 Bachelors 13 Never-married Exec-managerial Other-relative Black Male 0 0 45 United-States <=50K
25 Private 2.2022e+05 11th 7 Never-married Handlers-cleaners Own-child White Male 0 0 45 United-States <=50K
24 Private 2.2761e+05 10th 6 Divorced Handlers-cleaners Unmarried White Female 0 0 58 United-States <=50K
51 Private 1.7329e+05 HS-grad 9 Divorced Other-service Not-in-family White Female 0 0 40 United-States <=50K
54 Private 2.8029e+05 Some-college 10 Married-civ-spouse Sales Husband White Male 0 0 32 United-States <=50K
53 Federal-gov 39643 HS-grad 9 Widowed Exec-managerial Not-in-family White Female 0 0 58 United-States <=50K
52 Private 81859 HS-grad 9 Married-civ-spouse Machine-op-inspct Husband White Male 0 0 48 United-States >50K
37 Private 1.2429e+05 Some-college 10 Married-civ-spouse Adm-clerical Husband White Male 0 0 50 United-States <=50K
Каждая строка представляет атрибуты одного взрослого, такие как возраст, образование и размещение. Последний столбец salary
показывает, есть ли у человека зарплата, меньше чем или равная 50 000$ в год или больше, чем 50 000$ в год.
Тестовые данные устанавливают adulttest
содержит два ненужных пустых класса. Удалите их при помощи removecats
функция.
adulttest.salary = removecats(adulttest.salary);
Statistics and Machine Learning Toolbox™ предоставляет несколько возможностей для классификации, включая деревья классификации, дискриминантный анализ, наивного Бейеса, самых близких соседей, машины опорных векторов (SVMs) и ансамбли классификации. Для полного списка алгоритмов смотрите Классификацию.
Прежде, чем выбрать алгоритмы, чтобы использовать в вашей проблеме, смотрите свой набор данных. Данные о переписи имеют несколько примечательных характеристик:
Данные являются табличными и содержат и числовые и категориальные переменные.
Данные содержат отсутствующие значения.
Переменная отклика (salary
) имеет два класса (бинарная классификация).
Не делая предположений или с помощью предварительных знаний алгоритмов, что вы ожидаете работать хорошо над своими данными, вы просто обучаете все алгоритмы, которые поддерживают табличные данные и бинарную классификацию. Модели выходных кодов с коррекцией ошибок (ECOC) используются в данных больше чем с двумя классами. Дискриминантный анализ и самые близкие соседние алгоритмы не анализируют данные, которые содержат и числовые и категориальные переменные. Поэтому алгоритмы, подходящие для этого примера, являются SVMs, деревом решений, ансамблем деревьев решений и наивной моделью Bayes.
Чтобы ускорить процесс, настройте опции гипероптимизации параметров управления. Задайте 'ShowPlots'
как false
и 'Verbose'
как 0, чтобы отключить график и индикаторы сообщения, соответственно. Кроме того, задайте 'UseParallel'
как true
запускать Байесовую оптимизацию параллельно, которая требует Parallel Computing Toolbox™. Из-за невоспроизводимости синхронизации параллели, параллельная Байесова оптимизация не обязательно дает к восстанавливаемым результатам.
hypopts = struct('ShowPlots',false,'Verbose',0,'UseParallel',true);
Запустите параллельный пул.
poolobj = gcp;
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6).
Можно соответствовать обучающему набору данных и настройкам параметров легко путем вызывания каждой подходящей функции и установки ее 'OptimizeHyperparameters'
аргумент пары "имя-значение" 'auto'
. Создайте модели классификации.
% SVMs: SVM with polynomial kernel & SVM with Gaussian kernel mdls{1} = fitcsvm(adultdata,'salary','KernelFunction','polynomial','Standardize','on', ... 'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts); mdls{2} = fitcsvm(adultdata,'salary','KernelFunction','gaussian','Standardize','on', ... 'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts); % Decision tree mdls{3} = fitctree(adultdata,'salary', ... 'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts); % Ensemble of Decision trees mdls{4} = fitcensemble(adultdata,'salary','Learners','tree', ... 'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts); % Naive Bayes mdls{5} = fitcnb(adultdata,'salary', ... 'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions', hypopts);
Warning: It is recommended that you first standardize all numeric predictors when optimizing the Naive Bayes 'Width' parameter. Ignore this warning if you have done that.
Извлеките Байесовы результаты оптимизации из каждой модели и постройте минимальную наблюдаемую величину целевой функции для каждой модели по каждой итерации гипероптимизации параметров управления. Значение целевой функции соответствует misclassification уровню, измеренному пятикратной перекрестной проверкой с помощью обучающего набора данных. График сравнивает производительность каждой модели.
figure hold on N = length(mdls); for i = 1:N mdl = mdls{i}; results = mdls{i}.HyperparameterOptimizationResults; plot(results.ObjectiveMinimumTrace,'Marker','o','MarkerSize',5); end names = {'SVM-Polynomial','SVM-Gaussian','Decision Tree','Ensemble-Trees','Naive Bayes'}; legend(names,'Location','northeast') title('Bayesian Optimization') xlabel('Number of Iterations') ylabel('Minimum Objective Value')
Используя Байесовую оптимизацию, чтобы найти лучшие наборы гиперпараметра улучшает производительность моделей по нескольким итерациям. В этом случае график показывает, что у ансамбля деревьев решений есть лучшая точность прогноза для данных. Эта модель выполняет хорошо последовательно по нескольким итерациям и различным наборам Байесовых гиперпараметров оптимизации.
Проверяйте производительность классификатора с набором тестовых данных при помощи матрицы беспорядка и кривой рабочей характеристики получателя (ROC).
Найдите предсказанные метки и значения счета набора тестовых данных.
label = cell(N,1); score = cell(N,1); for i = 1:N [label{i},score{i}] = predict(mdls{i},adulttest); end
Получите наиболее вероятный класс для каждого тестового наблюдения при помощи predict
функция каждой модели. Затем вычислите матрицу беспорядка с предсказанными классами и известными (TRUE) классами набора тестовых данных при помощи confusionchart
функция.
figure c = cell(N,1); for i = 1:N subplot(2,3,i) c{i} = confusionchart(adulttest.salary,label{i}); title(names{i}) end
Диагональные элементы указывают на количество правильно классифицированных экземпляров данного класса. Недиагональные элементы являются экземплярами неправильно классифицированных наблюдений.
Смотрите производительность классификатора более тесно путем графического вывода кривой ROC для каждого классификатора. Используйте perfcurve
функция, чтобы получить X
и Y
координаты ROC изгибаются и область под кривой (AUC) значение для вычисленного X
и Y
.
Построить кривые ROC для значений счета, соответствующих метке '<=50K'
, проверяйте порядка следования столбцов значений счета, возвращенных от predict
функция. Порядок следования столбцов совпадает с порядком категории переменной отклика в обучающем наборе данных. Отобразите порядок категории.
c = categories(adultdata.salary)
c = 2×1 cell array
{'<=50K'}
{'>50K' }
Постройте кривые ROC.
figure hold on AUC = zeros(1,N); for i = 1:N [X,Y,~,AUC(i)] = perfcurve(adulttest.salary,score{i}(:,1),'<=50K'); plot(X,Y) end title('ROC Curves') xlabel('False Positive Rate') ylabel('True Positive Rate') legend(names,'Location','southeast')
Кривая ROC показывает истинный положительный уровень по сравнению с ложным положительным уровнем (или, чувствительность по сравнению с 1 спецификой) для различных порогов классификатора выход.
Теперь постройте значения AUC с помощью столбчатого графика. Для совершенного классификатора, истинный положительный уровень которого всегда 1 независимо от порогов, AUC = 1. Для классификатора, который случайным образом присваивает наблюдения классам, AUC = 0.5. Большие значения AUC указывают на лучшую производительность классификатора.
figure bar(AUC) title('Area Under the Curve') xlabel('Model') ylabel('AUC') xticklabels(names) xtickangle(30) ylim([0,1])
На основе матрицы беспорядка и столбчатого графика AUC, ансамбль деревьев решений и моделей SVM достигает лучшей точности, чем дерево решений и наивные модели Bayes.
Выполнение Байесовой оптимизации на всех моделях для дальнейших итераций может быть в вычислительном отношении дорогим. Вместо этого выберите подмножество моделей, которые выполнили хорошо до сих пор и продолжают оптимизацию для еще 30 итераций при помощи resume
функция. Постройте минимальные наблюдаемые величины целевой функции для каждой итерации Байесовой оптимизации.
figure hold on selectedMdls = mdls([1,2,4]); newresults = cell(1,length(selectedMdls)); for i = 1:length(selectedMdls) newresults{i} = resume(selectedMdls{i}.HyperparameterOptimizationResults,'MaxObjectiveEvaluations',30); plot(newresults{i}.ObjectiveMinimumTrace,'Marker','o','MarkerSize',5) end title('Bayesian Optimization with resume') xlabel('Number of Iterations') ylabel('Minimum Objective Value') legend({'SVM-Polynomial','SVM-Gaussian','Ensemble-Trees'},'Location','northeast')
Первые 30 итераций соответствуют первому раунду Байесовой оптимизации. Следующие 30 итераций соответствуют результатам resume
функция. Возобновление оптимизации полезно, потому что потеря продолжает уменьшать далее после первых 30 итераций.
BayesianOptimization
| confusionchart
| perfcurve
| resume