Эта тема представляет функции Statistics and Machine Learning Toolbox™ для интерпретации модели и показывает, как интерпретировать модель машинного обучения (классификация и регрессия).
Модель машинного обучения часто упоминается как модель «черного ящика», потому что может быть трудно понять, как модель делает предсказания. Инструменты интерпретации помогают вам преодолеть этот аспект алгоритмов машинного обучения и раскрыть, как предикторы способствуют (или не способствуют) предсказаниям. Кроме того, можно подтвердить, использует ли модель правильное доказательство для своих предсказаний, и найти смещения модели, которые не сразу очевидны.
Использование lime
, shapley
, и plotPartialDependence
объяснить вклад отдельных предикторов в предсказания обученной классификационной или регрессионой модели.
lime
- Локальные интерпретируемые модели-агностические объяснения (LIME [1]) интерпретируют предсказание для точки запроса, подгоняя простую интерпретируемую модель для точки запроса. Простая модель действует как приближение для обученной модели и объясняет предсказания модели вокруг точки запроса. Простая модель может быть либо линейной моделью, либо моделью дерева решений. Можно использовать оцененные коэффициенты линейной модели или предполагаемую предикторную важность модели дерева решений, чтобы объяснить вклад отдельных предикторов в предсказание для точки запроса. Для получения дополнительной информации см. LIME.
shapley
- [2][3] значения Шепли предиктора для точки запроса объясняет отклонение предсказания (ответ для регрессии или счетов классов для классификации) для точки запроса от среднего предсказания из-за предиктора. Для точки запроса сумма значений Shapley для всех функций соответствует общему отклонению предсказания от среднего значения. Для получения дополнительной информации смотрите Значения Shapley для модели машинного обучения.
plotPartialDependence
и partialDependence
- График частичной зависимости (PDP [4]) показывает отношения между предиктором (или парой предикторов) и предсказанием (реакция на регрессию или счета классов для классификации) в обученной модели. Частичная зависимость от выбранного предиктора определяется усредненным предсказанием, полученным путем маргинализации эффекта других переменных. Поэтому частичная зависимость является функцией выбранного предиктора, которая показывает средний эффект выбранного предиктора над набором данных. Можно также создать набор индивидуальных условных графиков ожиданий (ICE [5]) для каждого наблюдения, показывающих эффект выбранного предиктора на одно наблюдение. Для получения дополнительной информации смотрите Подробнее о plotPartialDependence
страница с описанием.
Некоторые модели машинного обучения поддерживают выбор признаков встроенного типа, где модель учит предикторную важность как часть процесса обучения модели. Можно использовать предполагаемую важность предиктора, чтобы объяснить предсказания модели. Для примера:
Обучите ансамбль (ClassificationBaggedEnsemble
или RegressionBaggedEnsemble
) пакетированных деревьев принятия решений (для примера, случайных лесов) и использовать predictorImportance
и oobPermutedPredictorImportance
функций.
Обучите линейную модель с регуляризацией лассо, которая сокращает коэффициенты наименее важных предикторов. Затем используйте оцененные коэффициенты в качестве мер для предикторной важности. Для примера используйте fitclinear
или fitrlinear
и задайте 'Regularization'
аргумент имя-значение как 'lasso'
.
Список моделей машинного обучения, поддерживающих выбор признаков встроенного типа, см. в разделе «Выбор признаков встроенного типа».
Используйте функции Statistics and Machine Learning Toolbox для трех уровней интерпретации модели: локальной, когорты и глобальной.
Уровень | Цель | Пример использования | Функция набора инструментов Statistics and Machine Learning Toolbox |
---|---|---|---|
Интерпретация на местах | Объясните предсказание для одной точки запроса. |
| Использование lime и shapley для заданной точки запроса. |
Когортная интерпретация | Объясните, как обученная модель делает предсказания для подмножества всего набора данных. | Проверьте предсказания для определенной группы выборок. |
|
Глобальная интерпретация | Объясните, как обученная модель делает предсказания для всего набора данных. |
|
|
Этот пример обучает ансамбль мешанных деревьев принятия решений с помощью случайного алгоритма леса и интерпретирует обученную модель с помощью функций интерпретации. Используйте функции объекта (oobPermutedPredictorImportance
и predictorImportance
) обученной модели для поиска важных предикторов в модели. Кроме того, используйте lime
и shapley
для интерпретации предсказаний для заданных точек запроса. Затем используйте plotPartialDependence
создать график, который показывает отношения между важным предиктором и предсказанными классификационными оценками.
Train модели ансамбля классификации
Загрузите CreditRating_Historical
набор данных. Набор данных содержит идентификаторы клиентов и их финансовые коэффициенты, отраслевые метки и кредитные рейтинги.
tbl = readtable('CreditRating_Historical.dat');
Отобразите первые три строки таблицы.
head(tbl,3)
ans=3×8 table
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating
_____ _____ _____ _______ ________ _____ ________ ______
62394 0.013 0.104 0.036 0.447 0.142 3 {'BB'}
48608 0.232 0.335 0.062 1.969 0.281 8 {'A' }
42444 0.311 0.367 0.074 1.935 0.366 1 {'A' }
Составьте таблицу переменных предиктора путем удаления столбцов, содержащих идентификаторы клиентов и оценки из tbl
.
tblX = removevars(tbl,["ID","Rating"]);
Обучите ансамбль мешанных деревьев принятия решений при помощи fitcensemble
функция и определение метода агрегации ансамбля как случайного леса ('Bag'
). Для воспроизводимости алгоритма случайного леса задайте 'Reproducible'
аргумент имя-значение как true
для учащихся дерева. Кроме того, укажите имена классов, чтобы задать порядок классов в обученной модели.
rng('default') % For reproducibility t = templateTree('Reproducible',true); blackbox = fitcensemble(tblX,tbl.Rating, ... 'Method','Bag','Learners',t, ... 'CategoricalPredictors','Industry', ... 'ClassNames',{'AAA' 'AA' 'A' 'BBB' 'BB' 'B' 'CCC'});
blackbox
является ClassificationBaggedEnsemble
модель.
Используйте специфичные для модели функции интерпретации
ClassificationBaggedEnsemble
поддерживает две функции объекта, oobPermutedPredictorImportance
и predictorImportance
, которые находят важные предикторы в обученной модели.
Оцените значение предиктора вне мешка при помощи oobPermutedPredictorImportance
функция. Функция случайным образом перестраивает данные из мешка через один предиктор за раз и оценивает увеличение ошибки вне мешка из-за этого сочетания. Чем больше увеличение, тем важнее функция.
Imp1 = oobPermutedPredictorImportance(blackbox);
Оцените важность предиктора при помощи predictorImportance
функция. Функция оценивает важность предиктора путем суммирования изменений в риске узла из-за расщеплений на каждом предикторе и деления суммы на количество узлов ветви.
Imp2 = predictorImportance(blackbox);
Составьте таблицу, содержащую предикторные оценки важности, и используйте таблицу для создания горизонтальных столбчатых графиков. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter
значение осей для 'none'
.
table_Imp = table(Imp1',Imp2', ... 'VariableNames',{'Out-of-Bag Permuted Predictor Importance','Predictor Importance'}, ... 'RowNames',blackbox.PredictorNames); tiledlayout(1,2) ax1 = nexttile; table_Imp1 = sortrows(table_Imp,'Out-of-Bag Permuted Predictor Importance'); barh(categorical(table_Imp1.Row,table_Imp1.Row),table_Imp1.('Out-of-Bag Permuted Predictor Importance')) xlabel('Out-of-Bag Permuted Predictor Importance') ylabel('Predictor') ax2 = nexttile; table_Imp2 = sortrows(table_Imp,'Predictor Importance'); barh(categorical(table_Imp2.Row,table_Imp2.Row),table_Imp2.('Predictor Importance')) xlabel('Predictor Importance') ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none';
Обе функции объекта идентифицируют MVE_BVTD
и RE_TA
как два наиболее важных предиктора.
Задайте точку запроса
Найдите наблюдения, чьи Rating
является 'AAA'
и выберите среди них четыре точки запроса.
tblX_AAA = tblX(strcmp(tbl.Rating,'AAA'),:); queryPoint = datasample(tblX_AAA,4,'Replace',false)
queryPoint=4×6 table
WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry
_____ _____ _______ ________ _____ ________
0.331 0.531 0.077 7.116 0.522 12
0.26 0.515 0.065 3.394 0.515 1
0.121 0.413 0.057 3.647 0.466 12
0.617 0.766 0.126 4.442 0.483 9
Использование LIME с линейными простыми моделями
Объясните предсказания для точек запроса используя lime
с линейными простыми моделями. lime
генерирует синтетический набор данных и подгоняет простую модель к синтетическому набору данных.
Создайте lime
объект, использующий tblX_AAA
так что lime
генерирует синтетический набор данных, используя только наблюдения, чьи Rating
является 'AAA'
, а не весь набор данных.
explainer_lime = lime(blackbox,tblX_AAA);
Значение по умолчанию 'DataLocality' для lime
является 'global'
, что подразумевает, что по умолчанию lime
генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime
использует различные веса наблюдений, так что значения веса больше фокусируются на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.
Подгонка простых моделей для четырех точек запроса с помощью функции объекта fit
. Укажите третий вход (количество важных предикторов для использования в простой модели) как 6, чтобы использовать все шесть предикторов.
explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6); explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6); explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6); explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);
Постройте график коэффициентов простых моделей с помощью функции объекта plot
.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_lime1); ax2 = nexttile; plot(explainer_lime2); ax3 = nexttile; plot(explainer_lime3); ax4 = nexttile; plot(explainer_lime4); ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';
Все простые модели идентифицируют EBIT_TA
, RE_TA
, и MVE_BVTD
как три наиболее важных предиктора. Положительные коэффициенты для предикторов предполагают, что увеличение значений предиктора приводит к увеличению предсказанных счетов в простых моделях.
Для категориального предиктора, plot
функция отображает только самую важную фиктивную переменную категориального предиктора. Поэтому на каждой гистограмме отображается другая переменная манекена.
Вычисление значений Shapley
Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного счета для точки запроса от среднего счета из-за предиктора. Создайте shapley
объект, использующий tblX_AAA
так что shapley
вычисляет ожидаемый вклад на основе выборок для 'AAA'
.
explainer_shapley = shapley(blackbox,tblX_AAA);
Вычислите значения Shapley для точек запроса с помощью функции объекта fit
.
explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:)); explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:)); explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:)); explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));
Постройте график значений Shapley с помощью функции объекта plot
.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_shapley1) ax2 = nexttile; plot(explainer_shapley2) ax3 = nexttile; plot(explainer_shapley3) ax4 = nexttile; plot(explainer_shapley4) ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';
MVE_BVTD
и RE_TA
являются двумя из трех наиболее важных предикторов для всех четырех точек запроса.
Значения Шепли MVE_BVTD
положительны для первой и четвертой точек запроса и отрицательны для второй и третьей точек запроса. The MVE_BVTD
значения составляют около 7 и 4 для первой и четвертой точек запроса, соответственно, и значение для второй и третьей точек запроса составляет около 3,5. Согласно значениям Шепли для четырех точек запроса, большое MVE_BVTD
значение приводит к увеличению предсказанного счета, и небольшому MVE_BVTD
значение приводит к снижению прогнозируемых счетов по сравнению со средним значением. Результаты соответствуют результатам lime
.
Создайте график частичной зависимости (PDP)
График PDP показывает усредненные отношения между предиктором и предсказанным счетом в обученной модели. Создайте PDP для RE_TA
и MVE_BVTD
, которые другие инструменты интерпретации идентифицируют как важные предикторы. Передайте tblx_AAA
на plotPartialDependence
так, что функция вычисляет ожидание предсказанных счетов, используя только выборки для 'AAA'
.
figure plotPartialDependence(blackbox,'RE_TA','AAA',tblX_AAA)
plotPartialDependence(blackbox,'MVE_BVTD','AAA',tblX_AAA)
Минорные такты в x
-ось представляет уникальные значения предиктора в tbl_AAA
. График для MVE_BVTD
показывает, что предсказанный счет велик, когда MVE_BVTD
значение невелико. Значение баллов уменьшается как MVE_BVTD
значение увеличений пока оно не достигает около 5, и затем значение баллов остается неизменным как MVE_BVTD
значение увеличивается. Зависимость от MVE_BVTD
в подмножестве tbl_AAA
идентифицируется по plotPartialDependence
не соответствует местным вкладам MVE_BVTD
в четырех точках запроса, обозначенных lime
и shapley
.
Рабочий процесс интерпретации модели для регрессионной задачи аналогичен рабочему процессу для классификационной задачи, как показано в примере Интерпретационная классификационная модель.
Этот пример обучает модель регрессии Гауссова процесса (GPR) и интерпретирует обученную модель, используя функции интерпретации. Используйте параметр ядра модели GPR, чтобы оценить веса предикторов. Кроме того, используйте lime
и shapley
для интерпретации предсказаний для заданных точек запроса. Затем используйте plotPartialDependence
создать график, который показывает отношения между важным предиктором и предсказанными откликами.
Обучите модель GPR
Загрузите carbig
набор данных, содержащий измерения автомобилей 1970-х и начала 1980-х годов.
load carbig
Создайте таблицу, содержащую переменные предиктора Acceleration
, Cylinders
, и так далее
tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight);
Обучите модель GPR переменной отклика MPG
при помощи fitrgp
функция. Задайте 'KernelFunction'
как 'ardsquaredexponential'
использовать квадратное экспоненциальное ядро с отдельной шкалой длины на предиктор.
blackbox = fitrgp(tbl,MPG,'ResponseName','MPG','CategoricalPredictors',[2 5], ... 'KernelFunction','ardsquaredexponential');
blackbox
является RegressionGP
модель.
Используйте специфичные для модели функции интерпретации
Можно вычислить веса предикторов (важность предиктора) из выученных шкал длины функции ядра, используемой в модели. Шкалы длины определяют, насколько далеко может быть предиктор, чтобы значения отклика стали некоррелированными. Найдите нормированные веса предиктора путем взятия экспоненциала отрицательной выученной длины.
sigmaL = blackbox.KernelInformation.KernelParameters(1:end-1); % Learned length scales weights = exp(-sigmaL); % Predictor weights weights = weights/sum(weights); % Normalized predictor weights
Создайте таблицу, содержащую нормированные предикторные веса, и используйте таблицу, чтобы создать горизонтальные столбчатые графики. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter
значение осей для 'none'
.
tbl_weight = table(weights,'VariableNames',{'Predictor Weight'}, ... 'RowNames',blackbox.ExpandedPredictorNames); tbl_weight = sortrows(tbl_weight,'Predictor Weight'); b = barh(categorical(tbl_weight.Row,tbl_weight.Row),tbl_weight.('Predictor Weight')); b.Parent.TickLabelInterpreter = 'none'; xlabel('Predictor Weight') ylabel('Predictor')
Веса предиктора указывают, что несколько фиктивных переменных для категориальных предикторов Model_Year
и Cylinders
важны.
Задайте точку запроса
Найдите наблюдения, чьи MPG
значения меньше, чем 0,25 квантиля MPG
. Из подмножества выберите четыре точки запроса, которые не содержат отсутствующих значений.
rng('default') % For reproducibility idx_subset = find(MPG < quantile(MPG,0.25)); tbl_subset = tbl(idx_subset,:); queryPoint = datasample(rmmissing(tbl_subset),4,'Replace',false)
queryPoint=4×6 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight
____________ _________ ____________ __________ __________ ______
13.2 8 318 150 76 3940
14.9 8 302 130 77 4295
14 8 360 215 70 4615
13.7 8 318 145 77 4140
Использование LIME с древовидными простыми моделями
Объясните предсказания для точек запроса используя lime
с простыми моделями дерева решений. lime
генерирует синтетический набор данных и подгоняет простую модель к синтетическому набору данных.
Создайте lime
объект, использующий tbl_subset
так что lime
генерирует синтетический набор данных, используя подмножество вместо всего набора данных. Задайте 'SimpleModelType'
как 'tree'
использовать простую модель дерева решений.
explainer_lime = lime(blackbox,tbl_subset,'SimpleModelType','tree');
Значение по умолчанию 'DataLocality' для lime
является 'global'
, что подразумевает, что по умолчанию lime
генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime
использует различные веса наблюдений, так что значения веса больше фокусируются на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.
Подгонка простых моделей для четырех точек запроса с помощью функции объекта fit
. Задайте третий вход (количество важных предикторов для использования в простой модели) как 6. С помощью этой настройки программное обеспечение задает максимальное количество разделений решений (или узлов ветви) как 6, так что подобранное дерево решений использует самое большее все предикторы.
explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6); explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6); explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6); explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);
Постройте график важности предиктора с помощью функции объекта plot
.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_lime1); ax2 = nexttile; plot(explainer_lime2); ax3 = nexttile; plot(explainer_lime3); ax4 = nexttile; plot(explainer_lime4); ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';
Все простые модели идентифицируют Displacement
, Model_Year
, и Weight
как важные предикторы.
Вычисление значений Shapley
Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного отклика для точки запроса от среднего отклика из-за предиктора. Создайте shapley
объект для модели blackbox
использование tbl_subset
так что shapley
вычисляет ожидаемый вклад на основе наблюдений в tbl_subset
.
explainer_shapley = shapley(blackbox,tbl_subset);
Вычислите значения Shapley для точек запроса с помощью функции объекта fit
.
explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:)); explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:)); explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:)); explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));
Постройте график значений Shapley с помощью функции объекта plot
.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_shapley1) ax2 = nexttile; plot(explainer_shapley2) ax3 = nexttile; plot(explainer_shapley3) ax4 = nexttile; plot(explainer_shapley4) ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';
Model_Year
является наиболее важным предиктором для первой, второй и четвертой точек запроса и значений Шепли Model_Year
положительны для трех точек запроса. The Model_Year
значение является 76 или 77 для этих трех точек, и значение для третьей точки запроса равняется 70. Согласно значениям Шепли для четырех точек запроса, небольшая Model_Year
значение приводит к уменьшению предсказанного отклика и большому Model_Year
значение приводит к увеличению предсказанного отклика по сравнению со средним значением.
Создайте график частичной зависимости (PDP)
График PDP показывает усредненные отношения между предиктором и предсказанной реакцией в обученной модели. Создайте PDP для Model_Year
, который другие инструменты интерпретации идентифицируют как важный предиктор. Передайте tbl_subset
на plotPartialDependence
так, что функция вычисляет ожидание предсказанных ответов, используя только выборки в tbl_subset
.
figure
plotPartialDependence(blackbox,'Model_Year',tbl_subset)
График показывает тот же тренд, идентифицированную значениями Шепли для четырех точек запроса. Предсказанный ответ (MPG
) значение увеличивается как Model_Year
значение увеличивается.
lime
| plotPartialDependence
| shapley