В этом примере показано, как обучить наивную модель классификации мультиклассов Бейеса пошаговому обучению только, когда производительность модели является неудовлетворительной.
Гибкий рабочий процесс пошагового обучения позволяет вам обучить инкрементную модель на входящем пакете данных только, когда это необходимо (см. то, Что Пошаговое обучение?). Например, если показатели производительности модели являются удовлетворительными, то, чтобы увеличить КПД, можно пропустить обучение на входящих пакетах, пока метрики не становятся неудовлетворительными.
Загрузите набор данных деятельности человека. Случайным образом переставьте данные.
load humanactivity n = numel(actid); rng(1) % For reproducibility idx = randsample(n,n); X = feat(idx,:); Y = actid(idx);
Для получения дополнительной информации на наборе данных, введите Description
в командной строке.
Сконфигурируйте наивную модель классификации Бейеса для пошагового обучения путем установки всего следующего:
Максимальное количество ожидаемых классов к 5
Отслеженный показатель производительности к misclassification коэффициенту ошибок, который также включает минимальную стоимость
Метрический размер окна к 1 000
Метрический период прогрева к 50
initobs = 50; Mdl = incrementalClassificationNaiveBayes('MaxNumClasses',5,'MetricsWindowSize',1000,... 'Metrics','classiferror','MetricsWarmupPeriod',initobs);
Подбирайте сконфигурированную модель к первым 50 наблюдениям.
Mdl = fit(Mdl,X(1:initobs,:),Y(1:initobs))
Mdl = incrementalClassificationNaiveBayes IsWarm: 1 Metrics: [2x2 table] ClassNames: [1 2 3 4 5] ScoreTransform: 'none' DistributionNames: {1x60 cell} DistributionParameters: {5x60 cell} Properties, Methods
haveTrainedAllClasses = numel(unique(Y(1:initobs))) == 5
haveTrainedAllClasses = logical
1
Mdl
incrementalClassificationNaiveBayes
объект модели. Модель является теплой (IsWarm
1
) потому что все следующие условия применяются:
Данные о начальной подготовке содержат все ожидаемые классы (haveTrainedAllClasses
true
).
Mdl
было подходящим к Mdl.MetricsWarmupPeriod
наблюдения.
Поэтому модель готова сгенерировать предсказания, и функции пошагового обучения измеряют показатели производительности в модели.
Предположим, что вы хотите обучить модель только, когда новые 1 000 наблюдений имеют misclassification ошибку, больше, чем 5%.
Выполните пошаговое обучение, с условным обучением, путем выполнения этой процедуры для каждой итерации:
Симулируйте поток данных путем обработки фрагмента 100 наблюдений за один раз.
Обновите производительность модели путем передачи и текущего фрагмента модели данных к updateMetrics
. Перезапишите входную модель с выходной моделью.
Подбирайте модель к фрагменту данных только, когда misclassification коэффициент ошибок будет больше 0.05. Перезапишите входную модель с выходной моделью, когда обучение произойдет.
Сохраните misclassification коэффициент ошибок и среднее значение первого предиктора во втором классе чтобы видеть, как они развиваются во время обучения.
Отследите когда fit
обучает модель.
% Preallocation numObsPerChunk = 100; nchunk = floor((n - initobs)/numObsPerChunk); mu21 = zeros(nchunk,1); ce = array2table(nan(nchunk,2),'VariableNames',["Cumulative" "Window"]); trained = false(nchunk,1); % Incremental fitting for j = 1:nchunk ibegin = min(n,numObsPerChunk*(j-1) + 1 + initobs); iend = min(n,numObsPerChunk*j + initobs); idx = ibegin:iend; Mdl = updateMetrics(Mdl,X(idx,:),Y(idx)); ce{j,:} = Mdl.Metrics{"ClassificationError",:}; if ce{j,"Window"} > 0.05 Mdl = fit(Mdl,X(idx,:),Y(idx)); trained(j) = true; end mu21(j) = Mdl.DistributionParameters{2,1}(1); end
Mdl
incrementalClassificationNaiveBayes
объект модели, обученный на всех данных в потоке.
Чтобы видеть, как производительность модели и развитый во время обучения, постройте их на отдельных подграфиках. Идентифицируйте периоды, в которые была обучена модель.
subplot(2,1,1) plot(mu21) hold on plot(find(trained),mu21(trained),'r.') ylabel('\mu_{21}') legend('\mu_{21}','Training occurs','Location','best') hold off subplot(2,1,2) plot(ce.Variables) ylabel('Misclassification Error Rate') xlabel('Iteration') legend(ce.Properties.VariableNames,'Location','best')
График трассировки показывает периоды постоянных значений, во время которых производительность модели в предыдущих 1 000 окон наблюдения самое большее 0.05.
predict
| fit
| updateMetrics