Эта тема вводит функции Statistics and Machine Learning Toolbox™ интерпретации модели и показывает, как интерпретировать модель машинного обучения (классификация и регрессия).
Модель машинного обучения часто упоминается как модель "черного квадрата", потому что она может затруднить, чтобы изучить, как модель делает предсказания. Инструменты Interpretability помогают вам преодолеть этот аспект алгоритмов машинного обучения и показывают, как предикторы способствуют (или не способствуйте) к предсказаниям. Кроме того, можно подтвердить, использует ли модель правильное доказательство для своих предсказаний, и найдите смещения модели, которые не сразу очевидны.
Используйте lime
, shapley
, и plotPartialDependence
объяснить вклад отдельных предикторов к предсказаниям обученной классификации или модели регрессии.
lime
— Локальные поддающиеся толкованию объяснения модели агностические (LIME [1]) интерпретируют предсказание для точки запроса, подбирая простую поддающуюся толкованию модель для точки запроса. Простая модель действует как приближение для обученной модели и объясняет предсказания модели вокруг точки запроса. Простая модель может быть или линейной моделью или моделью дерева принятия решения. Можно использовать предполагаемые коэффициенты линейной модели или предполагаемую важность предиктора модели дерева принятия решения, чтобы объяснить вклад отдельных предикторов к предсказанию для точки запроса. Для получения дополнительной информации смотрите LIME.
shapley
— Значение Шепли [2][3] предиктора для точки запроса объясняет отклонение предсказания (ответ для регрессии или музыка класса к классификации) для точки запроса из среднего предсказания, из-за предиктора. Для точки запроса сумма значений Шепли для всех функций соответствует общему отклонению предсказания от среднего значения. Для получения дополнительной информации смотрите Значения Шепли для Модели Машинного обучения.
plotPartialDependence
и partialDependence
— Частичный график зависимости (PDP [4]) показывает отношения между предиктором (или пара предикторов) и предсказанием (ответ для регрессии или музыка класса к классификации) в обученной модели. Частичная зависимость от выбранного предиктора задана усредненным предсказанием, полученным путем маргинализации эффекта других переменных. Поэтому частичная зависимость является функцией выбранного предиктора, который показывает средний эффект выбранного предиктора по набору данных. Можно также создать набор отдельного условного ожидания (ICE [5]) графики для каждого наблюдения, показав эффект выбранного предиктора на одном наблюдении. Для получения дополнительной информации смотрите Больше О на plotPartialDependence
страница с описанием.
Встроенный выбор признаков типа поддержки моделей некоторого машинного обучения, где модель изучает важность предиктора как часть процесса обучения модели. Можно использовать предполагаемую важность предиктора, чтобы объяснить предсказания модели. Например:
Обучите ансамбль (ClassificationBaggedEnsemble
или RegressionBaggedEnsemble
) из сложенных в мешок деревьев решений (например, случайный лес) и использование predictorImportance
и oobPermutedPredictorImportance
функции.
Обучите линейную модель с регуляризацией лассо, которая уменьшает коэффициенты наименее важных предикторов. Затем используйте предполагаемые коэффициенты в качестве мер для важности предиктора. Например, использовать fitclinear
или fitrlinear
и задайте 'Regularization'
аргумент значения имени как 'lasso'
.
Для списка моделей машинного обучения, которые поддерживают встроенный выбор признаков типа, смотрите Встроенный Выбор признаков Типа.
Используйте функции Statistics and Machine Learning Toolbox для трех уровней интерпретации модели: локальный, когорта и глобальная переменная.
Уровень | Цель | Вариант использования | Функция Statistics and Machine Learning Toolbox |
---|---|---|---|
Локальная интерпретация | Объясните предсказание для точки единого запроса. |
| Используйте lime и shapley для заданной точки запроса. |
Интерпретация когорты | Объясните, как обученная модель делает предсказания для подмножества целого набора данных. | Подтвердите предсказания для конкретной группы выборок. |
|
Глобальная интерпретация | Объясните, как обученная модель делает предсказания для целого набора данных. |
|
|
Этот пример обучает ансамбль сложенных в мешок деревьев решений с помощью случайного лесного алгоритма и интерпретирует обученную модель с помощью interpretability функции. Используйте объектные функции (oobPermutedPredictorImportance
и predictorImportance
) из обученной модели, чтобы найти важные предикторы в модели. Кроме того, используйте lime
и shapley
интерпретировать предсказания для заданных точек запроса. Затем используйте plotPartialDependence
создать график, который показывает отношения между важным предиктором и предсказанными классификационными оценками.
Обучите модель ансамбля классификации
Загрузите CreditRating_Historical
набор данных. Набор данных содержит идентификаторы клиентов и их финансовые отношения, промышленные метки и кредитные рейтинги.
tbl = readtable('CreditRating_Historical.dat');
Отобразите первые три строки таблицы.
head(tbl,3)
ans=3×8 table
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating
_____ _____ _____ _______ ________ _____ ________ ______
62394 0.013 0.104 0.036 0.447 0.142 3 {'BB'}
48608 0.232 0.335 0.062 1.969 0.281 8 {'A' }
42444 0.311 0.367 0.074 1.935 0.366 1 {'A' }
Составьте таблицу переменных предикторов путем удаления столбцов, содержащих идентификаторы клиентов и оценки от tbl
.
tblX = removevars(tbl,["ID","Rating"]);
Обучите ансамбль сложенных в мешок деревьев решений при помощи fitcensemble
функция и определение метода агрегации ансамбля как случайный лес ('Bag'
). Для воспроизводимости случайного лесного алгоритма задайте 'Reproducible'
аргумент значения имени как true
для древовидных учеников. Кроме того, задайте имена классов, чтобы установить порядок классов в обученной модели.
rng('default') % For reproducibility t = templateTree('Reproducible',true); blackbox = fitcensemble(tblX,tbl.Rating, ... 'Method','Bag','Learners',t, ... 'CategoricalPredictors','Industry', ... 'ClassNames',{'AAA' 'AA' 'A' 'BBB' 'BB' 'B' 'CCC'});
blackbox
ClassificationBaggedEnsemble
модель.
Используйте функции Interpretability модели специфичные
ClassificationBaggedEnsemble
поддерживает две объектных функции, oobPermutedPredictorImportance
и predictorImportance
, которые находят важные предикторы в обученной модели.
Оцените важность предиктора из сумки при помощи oobPermutedPredictorImportance
функция. Функция случайным образом переставляет данные из сумки через один предиктор за один раз и оценивает увеличение ошибки из сумки из-за этого сочетания. Чем больше увеличение, тем более важный функция.
Imp1 = oobPermutedPredictorImportance(blackbox);
Оцените важность предиктора при помощи predictorImportance
функция. Функциональная оценочная важность предиктора путем подведения итогов изменений в узле рискует из-за разделений на каждом предикторе и делении суммы количеством узлов ветви.
Imp2 = predictorImportance(blackbox);
Составьте таблицу, содержащую оценки важности предиктора, и используйте таблицу, чтобы создать горизонтальные столбчатые графики. Чтобы отобразить существующее подчеркивание на любое имя предиктора, измените TickLabelInterpreter
значение осей к 'none'
.
table_Imp = table(Imp1',Imp2', ... 'VariableNames',{'Out-of-Bag Permuted Predictor Importance','Predictor Importance'}, ... 'RowNames',blackbox.PredictorNames); tiledlayout(1,2) ax1 = nexttile; table_Imp1 = sortrows(table_Imp,'Out-of-Bag Permuted Predictor Importance'); barh(categorical(table_Imp1.Row,table_Imp1.Row),table_Imp1.('Out-of-Bag Permuted Predictor Importance')) xlabel('Out-of-Bag Permuted Predictor Importance') ylabel('Predictor') ax2 = nexttile; table_Imp2 = sortrows(table_Imp,'Predictor Importance'); barh(categorical(table_Imp2.Row,table_Imp2.Row),table_Imp2.('Predictor Importance')) xlabel('Predictor Importance') ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none';
Обе объектных функции идентифицируют MVE_BVTD
и RE_TA
как два самых важных предиктора.
Задайте точку запроса
Найдите наблюдения чей Rating
'AAA'
и выберите четыре точки запроса среди них.
tblX_AAA = tblX(strcmp(tbl.Rating,'AAA'),:); queryPoint = datasample(tblX_AAA,4,'Replace',false)
queryPoint=4×6 table
WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry
_____ _____ _______ ________ _____ ________
0.331 0.531 0.077 7.116 0.522 12
0.26 0.515 0.065 3.394 0.515 1
0.121 0.413 0.057 3.647 0.466 12
0.617 0.766 0.126 4.442 0.483 9
Используйте LIME с линейными простыми моделями
Объясните предсказания для точек запроса с помощью lime
с линейными простыми моделями. lime
генерирует синтетический набор данных и подбирает простую модель к синтетическому набору данных.
Создайте lime
объект с помощью tblX_AAA
так, чтобы lime
генерирует синтетический набор данных с помощью только наблюдения чей Rating
'AAA'
, не целый набор данных.
explainer_lime = lime(blackbox,tblX_AAA);
Значение по умолчанию DataLocality для lime
'global'
, который подразумевает что, по умолчанию, lime
генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime
использует различные веса наблюдения так, чтобы значения веса более фокусировались на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.
Подбирайте простые модели для четырех точек запроса при помощи объектного функционального fit
. Задайте третий вход (количество важных предикторов, чтобы использовать в простой модели) как 6, чтобы использовать все шесть предикторов.
explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6); explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6); explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6); explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);
Постройте коэффициенты простых моделей при помощи объектного функционального plot
.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_lime1); ax2 = nexttile; plot(explainer_lime2); ax3 = nexttile; plot(explainer_lime3); ax4 = nexttile; plot(explainer_lime4); ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';
Все простые модели идентифицируют EBIT_TA
, RE_TA
, и MVE_BVTD
как три самых важных предиктора. Положительные коэффициенты для предикторов предполагают, что увеличение значений предиктора приводит к увеличению предсказанных баллов в простых моделях.
Для категориального предиктора, plot
функционируйте отображает только самую важную фиктивную переменную категориального предиктора. Поэтому каждый столбчатый график отображает различную фиктивную переменную.
Вычислите значения Шепли
Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного счета к точке запроса от средней оценки, из-за предиктора. Создайте shapley
объект с помощью tblX_AAA
так, чтобы shapley
вычисляет ожидаемый вклад на основе выборок для 'AAA'
.
explainer_shapley = shapley(blackbox,tblX_AAA);
Вычислите значения Шепли для точек запроса при помощи объектного функционального fit
.
explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:)); explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:)); explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:)); explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));
Постройте значения Шепли при помощи объектного функционального plot
.
tiledlayout(2,2) nexttile plot(explainer_shapley1) nexttile plot(explainer_shapley2) nexttile plot(explainer_shapley3) nexttile plot(explainer_shapley4)
MVE_BVTD
и RE_TA
два из трех самых важных предикторов для всех четырех точек запроса.
Значения Шепли MVE_BVTD
положительны для первых и четвертых точек запроса и отрицательны для вторых и третьих точек запроса. MVE_BVTD
значения - приблизительно 7 и 4 для первых и четвертых точек запроса, соответственно, и значение и для вторых и для третьих точек запроса является приблизительно 3,5. Согласно значениям Шепли для четырех точек запроса, большому MVE_BVTD
значение приводит к увеличению предсказанного счета и маленькому MVE_BVTD
значение приводит к уменьшению в предсказанных баллах по сравнению со средним значением. Результаты сопоставимы с результатами lime
.
Создайте Частичный график зависимости (PDP)
График PDP показывает усредненные отношения между предиктором и предсказанным счетом в обученной модели. Создайте PDPs для RE_TA
и MVE_BVTD
, который другие interpretability инструменты идентифицируют как важные предикторы. Передайте tblx_AAA
к plotPartialDependence
так, чтобы функция вычислила ожидание предсказанных баллов с помощью только выборки для 'AAA'
.
figure plotPartialDependence(blackbox,'RE_TA','AAA',tblX_AAA)
plotPartialDependence(blackbox,'MVE_BVTD','AAA',tblX_AAA)
Незначительные метки деления в x
- ось представляет уникальные значения предиктора в tbl_AAA
. График для MVE_BVTD
показывает, что предсказанный счет является большим когда MVE_BVTD
значение мало. Значение баллов уменьшается как MVE_BVTD
повышения стоимости, пока это не достигает приблизительно 5, и затем значения баллов, остаются неизменными как MVE_BVTD
повышения стоимости. Зависимость от MVE_BVTD
в подмножестве tbl_AAA
идентифицированный plotPartialDependence
не сопоставимо с локальными вкладами MVE_BVTD
в четырех точках запроса, идентифицированных lime
и shapley
.
Рабочий процесс интерпретации модели для проблемы регрессии похож на рабочий процесс для проблемы классификации, как продемонстрировано в примере Интерпретируют Модель Классификации.
Этот пример обучает модель Gaussian process regression (GPR) и интерпретирует обученную модель с помощью interpretability функции. Используйте параметр ядра модели GPR, чтобы оценить веса предиктора. Кроме того, используйте lime
и shapley
интерпретировать предсказания для заданных точек запроса. Затем используйте plotPartialDependence
создать график, который показывает отношения между важным предиктором и предсказанными ответами.
Обучите модель GPR
Загрузите carbig
набор данных, который содержит измерения автомобилей, сделанных в 1970-х и в начале 1980-х.
load carbig
Составьте таблицу, содержащую переменные предикторы Acceleration
, Cylinders
, и так далее
tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight);
Обучите модель GPR переменной отклика MPG
при помощи fitrgp
функция. Задайте 'KernelFunction'
как 'ardsquaredexponential'
использовать экспоненциальное ядро в квадрате с отдельной шкалой расстояний на предиктор.
blackbox = fitrgp(tbl,MPG,'ResponseName','MPG','CategoricalPredictors',[2 5], ... 'KernelFunction','ardsquaredexponential');
blackbox
RegressionGP
модель.
Используйте функции Interpretability модели специфичные
Можно вычислить веса предиктора (важность предиктора) от изученных шкал расстояний функции ядра, используемой в модели. Шкалы расстояний задают, как далеко независимо предиктор может быть для значений отклика, чтобы стать некоррелированым. Найдите нормированные веса предиктора путем взятия экспоненциала отрицательных изученных шкал расстояний.
sigmaL = blackbox.KernelInformation.KernelParameters(1:end-1); % Learned length scales weights = exp(-sigmaL); % Predictor weights weights = weights/sum(weights); % Normalized predictor weights
Составьте таблицу, содержащую нормированные веса предиктора, и используйте таблицу, чтобы создать горизонтальные столбчатые графики. Чтобы отобразить существующее подчеркивание на любое имя предиктора, измените TickLabelInterpreter
значение осей к 'none'
.
tbl_weight = table(weights,'VariableNames',{'Predictor Weight'}, ... 'RowNames',blackbox.ExpandedPredictorNames); tbl_weight = sortrows(tbl_weight,'Predictor Weight'); b = barh(categorical(tbl_weight.Row,tbl_weight.Row),tbl_weight.('Predictor Weight')); b.Parent.TickLabelInterpreter = 'none'; xlabel('Predictor Weight') ylabel('Predictor')
Веса предиктора указывают что несколько фиктивных переменных для категориальных предикторов Model_Year
и Cylinders
важны.
Задайте точку запроса
Найдите наблюдения чей MPG
значения меньше, чем 0,25 квантиля MPG
. От подмножества выберите четыре точки запроса, которые не включают отсутствующие значения.
rng('default') % For reproducibility idx_subset = find(MPG < quantile(MPG,0.25)); tbl_subset = tbl(idx_subset,:); queryPoint = datasample(rmmissing(tbl_subset),4,'Replace',false)
queryPoint=4×6 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight
____________ _________ ____________ __________ __________ ______
13.2 8 318 150 76 3940
14.9 8 302 130 77 4295
14 8 360 215 70 4615
13.7 8 318 145 77 4140
Используйте LIME с древовидными простыми моделями
Объясните предсказания для точек запроса с помощью lime
с деревом решений простые модели. lime
генерирует синтетический набор данных и подбирает простую модель к синтетическому набору данных.
Создайте lime
объект с помощью tbl_subset
так, чтобы lime
генерирует синтетический набор данных с помощью подмножества вместо целого набора данных. Задайте 'SimpleModelType'
как 'tree'
использовать дерево решений простая модель.
explainer_lime = lime(blackbox,tbl_subset,'SimpleModelType','tree');
Значение по умолчанию DataLocality для lime
'global'
, который подразумевает что, по умолчанию, lime
генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime
использует различные веса наблюдения так, чтобы значения веса более фокусировались на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.
Подбирайте простые модели для четырех точек запроса при помощи объектного функционального fit
. Задайте третий вход (количество важных предикторов, чтобы использовать в простой модели) как 6. С этой установкой программное обеспечение задает максимальное количество разделений решения (или узлы ветви) как 6 так, чтобы подходящее дерево решений использовало самое большее все предикторы.
explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6); explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6); explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6); explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);
Постройте важность предиктора при помощи объектного функционального plot
.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_lime1); ax2 = nexttile; plot(explainer_lime2); ax3 = nexttile; plot(explainer_lime3); ax4 = nexttile; plot(explainer_lime4); ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';
Все простые модели идентифицируют Displacement
, Model_Year
, и Weight
как важные предикторы.
Вычислите значения Шепли
Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного ответа для точки запроса от среднего ответа, из-за предиктора. Создайте shapley
объект для модели blackbox
использование tbl_subset
так, чтобы shapley
вычисляет ожидаемый вклад на основе наблюдений в tbl_subset
.
explainer_shapley = shapley(blackbox,tbl_subset);
Вычислите значения Шепли для точек запроса при помощи объектного функционального fit
.
explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:)); explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:)); explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:)); explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));
Постройте значения Шепли при помощи объектного функционального plot
.
tiledlayout(2,2) nexttile plot(explainer_shapley1) nexttile plot(explainer_shapley2) nexttile plot(explainer_shapley3) nexttile plot(explainer_shapley4)
Model_Year
самый важный предиктор для первых, вторых, и четвертых точек запроса и значения Шепли Model_Year
положительны для трех точек запроса. Model_Year
значение равняется 76 или 77 для этих трех точек, и значение для третьей точки запроса равняется 70. Согласно значениям Шепли для четырех точек запроса, маленькому Model_Year
значение приводит к уменьшению в предсказанном ответе и большому Model_Year
значение приводит к увеличению предсказанного ответа по сравнению со средним значением.
Создайте Частичный график зависимости (PDP)
График PDP показывает усредненные отношения между предиктором и предсказанным ответом в обученной модели. Создайте PDP для Model_Year
, который другие interpretability инструменты идентифицируют как важный предиктор. Передайте tbl_subset
к plotPartialDependence
так, чтобы функция вычислила ожидание предсказанных ответов с помощью только выборки в tbl_subset
.
figure
plotPartialDependence(blackbox,'Model_Year',tbl_subset)
График показывает тот же тренд, идентифицированный значениями Шепли для четырех точек запроса. Предсказанный ответ (MPG
) повышения стоимости как Model_Year
повышения стоимости.
lime
| shapley
| plotPartialDependence