Эта тема представляет функции Statistics and Machine Learning Toolbox™ для интерпретации модели и показывает, как интерпретировать модель машинного обучения (классификация и регрессия).
Модель машинного обучения часто упоминается как модель «черного ящика», потому что может быть трудно понять, как модель делает предсказания. Инструменты интерпретации помогают вам преодолеть этот аспект алгоритмов машинного обучения и раскрыть, как предикторы способствуют (или не способствуют) предсказаниям. Кроме того, можно подтвердить, использует ли модель правильное доказательство для своих предсказаний, и найти смещения модели, которые не сразу очевидны.
Использование lime, shapley, и plotPartialDependence объяснить вклад отдельных предикторов в предсказания обученной классификационной или регрессионой модели.
lime - Локальные интерпретируемые модели-агностические объяснения (LIME [1]) интерпретируют предсказание для точки запроса, подгоняя простую интерпретируемую модель для точки запроса. Простая модель действует как приближение для обученной модели и объясняет предсказания модели вокруг точки запроса. Простая модель может быть либо линейной моделью, либо моделью дерева решений. Можно использовать оцененные коэффициенты линейной модели или предполагаемую предикторную важность модели дерева решений, чтобы объяснить вклад отдельных предикторов в предсказание для точки запроса. Для получения дополнительной информации см. LIME.
shapley - [2][3] значения Шепли предиктора для точки запроса объясняет отклонение предсказания (ответ для регрессии или счетов классов для классификации) для точки запроса от среднего предсказания из-за предиктора. Для точки запроса сумма значений Shapley для всех функций соответствует общему отклонению предсказания от среднего значения. Для получения дополнительной информации смотрите Значения Shapley для модели машинного обучения.
plotPartialDependence и partialDependence - График частичной зависимости (PDP [4]) показывает отношения между предиктором (или парой предикторов) и предсказанием (реакция на регрессию или счета классов для классификации) в обученной модели. Частичная зависимость от выбранного предиктора определяется усредненным предсказанием, полученным путем маргинализации эффекта других переменных. Поэтому частичная зависимость является функцией выбранного предиктора, которая показывает средний эффект выбранного предиктора над набором данных. Можно также создать набор индивидуальных условных графиков ожиданий (ICE [5]) для каждого наблюдения, показывающих эффект выбранного предиктора на одно наблюдение. Для получения дополнительной информации смотрите Подробнее о plotPartialDependence страница с описанием.
Некоторые модели машинного обучения поддерживают выбор признаков встроенного типа, где модель учит предикторную важность как часть процесса обучения модели. Можно использовать предполагаемую важность предиктора, чтобы объяснить предсказания модели. Для примера:
Обучите ансамбль (ClassificationBaggedEnsemble или RegressionBaggedEnsemble) пакетированных деревьев принятия решений (для примера, случайных лесов) и использовать predictorImportance и oobPermutedPredictorImportance функций.
Обучите линейную модель с регуляризацией лассо, которая сокращает коэффициенты наименее важных предикторов. Затем используйте оцененные коэффициенты в качестве мер для предикторной важности. Для примера используйте fitclinear или fitrlinear и задайте 'Regularization' аргумент имя-значение как 'lasso'.
Список моделей машинного обучения, поддерживающих выбор признаков встроенного типа, см. в разделе «Выбор признаков встроенного типа».
Используйте функции Statistics and Machine Learning Toolbox для трех уровней интерпретации модели: локальной, когорты и глобальной.
| Уровень | Цель | Пример использования | Функция набора инструментов Statistics and Machine Learning Toolbox |
|---|---|---|---|
| Интерпретация на местах | Объясните предсказание для одной точки запроса. |
| Использование lime и shapley для заданной точки запроса. |
| Когортная интерпретация | Объясните, как обученная модель делает предсказания для подмножества всего набора данных. | Проверьте предсказания для определенной группы выборок. |
|
| Глобальная интерпретация | Объясните, как обученная модель делает предсказания для всего набора данных. |
|
|
Этот пример обучает ансамбль мешанных деревьев принятия решений с помощью случайного алгоритма леса и интерпретирует обученную модель с помощью функций интерпретации. Используйте функции объекта (oobPermutedPredictorImportance и predictorImportance) обученной модели для поиска важных предикторов в модели. Кроме того, используйте lime и shapley для интерпретации предсказаний для заданных точек запроса. Затем используйте plotPartialDependence создать график, который показывает отношения между важным предиктором и предсказанными классификационными оценками.
Train модели ансамбля классификации
Загрузите CreditRating_Historical набор данных. Набор данных содержит идентификаторы клиентов и их финансовые коэффициенты, отраслевые метки и кредитные рейтинги.
tbl = readtable('CreditRating_Historical.dat');Отобразите первые три строки таблицы.
head(tbl,3)
ans=3×8 table
ID WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry Rating
_____ _____ _____ _______ ________ _____ ________ ______
62394 0.013 0.104 0.036 0.447 0.142 3 {'BB'}
48608 0.232 0.335 0.062 1.969 0.281 8 {'A' }
42444 0.311 0.367 0.074 1.935 0.366 1 {'A' }
Составьте таблицу переменных предиктора путем удаления столбцов, содержащих идентификаторы клиентов и оценки из tbl.
tblX = removevars(tbl,["ID","Rating"]);
Обучите ансамбль мешанных деревьев принятия решений при помощи fitcensemble функция и определение метода агрегации ансамбля как случайного леса ('Bag'). Для воспроизводимости алгоритма случайного леса задайте 'Reproducible' аргумент имя-значение как true для учащихся дерева. Кроме того, укажите имена классов, чтобы задать порядок классов в обученной модели.
rng('default') % For reproducibility t = templateTree('Reproducible',true); blackbox = fitcensemble(tblX,tbl.Rating, ... 'Method','Bag','Learners',t, ... 'CategoricalPredictors','Industry', ... 'ClassNames',{'AAA' 'AA' 'A' 'BBB' 'BB' 'B' 'CCC'});
blackbox является ClassificationBaggedEnsemble модель.
Используйте специфичные для модели функции интерпретации
ClassificationBaggedEnsemble поддерживает две функции объекта, oobPermutedPredictorImportance и predictorImportance, которые находят важные предикторы в обученной модели.
Оцените значение предиктора вне мешка при помощи oobPermutedPredictorImportance функция. Функция случайным образом перестраивает данные из мешка через один предиктор за раз и оценивает увеличение ошибки вне мешка из-за этого сочетания. Чем больше увеличение, тем важнее функция.
Imp1 = oobPermutedPredictorImportance(blackbox);
Оцените важность предиктора при помощи predictorImportance функция. Функция оценивает важность предиктора путем суммирования изменений в риске узла из-за расщеплений на каждом предикторе и деления суммы на количество узлов ветви.
Imp2 = predictorImportance(blackbox);
Составьте таблицу, содержащую предикторные оценки важности, и используйте таблицу для создания горизонтальных столбчатых графиков. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter значение осей для 'none'.
table_Imp = table(Imp1',Imp2', ... 'VariableNames',{'Out-of-Bag Permuted Predictor Importance','Predictor Importance'}, ... 'RowNames',blackbox.PredictorNames); tiledlayout(1,2) ax1 = nexttile; table_Imp1 = sortrows(table_Imp,'Out-of-Bag Permuted Predictor Importance'); barh(categorical(table_Imp1.Row,table_Imp1.Row),table_Imp1.('Out-of-Bag Permuted Predictor Importance')) xlabel('Out-of-Bag Permuted Predictor Importance') ylabel('Predictor') ax2 = nexttile; table_Imp2 = sortrows(table_Imp,'Predictor Importance'); barh(categorical(table_Imp2.Row,table_Imp2.Row),table_Imp2.('Predictor Importance')) xlabel('Predictor Importance') ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none';

Обе функции объекта идентифицируют MVE_BVTD и RE_TA как два наиболее важных предиктора.
Задайте точку запроса
Найдите наблюдения, чьи Rating является 'AAA' и выберите среди них четыре точки запроса.
tblX_AAA = tblX(strcmp(tbl.Rating,'AAA'),:); queryPoint = datasample(tblX_AAA,4,'Replace',false)
queryPoint=4×6 table
WC_TA RE_TA EBIT_TA MVE_BVTD S_TA Industry
_____ _____ _______ ________ _____ ________
0.331 0.531 0.077 7.116 0.522 12
0.26 0.515 0.065 3.394 0.515 1
0.121 0.413 0.057 3.647 0.466 12
0.617 0.766 0.126 4.442 0.483 9
Использование LIME с линейными простыми моделями
Объясните предсказания для точек запроса используя lime с линейными простыми моделями. lime генерирует синтетический набор данных и подгоняет простую модель к синтетическому набору данных.
Создайте lime объект, использующий tblX_AAA так что lime генерирует синтетический набор данных, используя только наблюдения, чьи Rating является 'AAA', а не весь набор данных.
explainer_lime = lime(blackbox,tblX_AAA);
Значение по умолчанию 'DataLocality' для lime является 'global', что подразумевает, что по умолчанию lime генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime использует различные веса наблюдений, так что значения веса больше фокусируются на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.
Подгонка простых моделей для четырех точек запроса с помощью функции объекта fit. Укажите третий вход (количество важных предикторов для использования в простой модели) как 6, чтобы использовать все шесть предикторов.
explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6); explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6); explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6); explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);
Постройте график коэффициентов простых моделей с помощью функции объекта plot.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_lime1); ax2 = nexttile; plot(explainer_lime2); ax3 = nexttile; plot(explainer_lime3); ax4 = nexttile; plot(explainer_lime4); ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';

Все простые модели идентифицируют EBIT_TA, RE_TA, и MVE_BVTD как три наиболее важных предиктора. Положительные коэффициенты для предикторов предполагают, что увеличение значений предиктора приводит к увеличению предсказанных счетов в простых моделях.
Для категориального предиктора, plot функция отображает только самую важную фиктивную переменную категориального предиктора. Поэтому на каждой гистограмме отображается другая переменная манекена.
Вычисление значений Shapley
Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного счета для точки запроса от среднего счета из-за предиктора. Создайте shapley объект, использующий tblX_AAA так что shapley вычисляет ожидаемый вклад на основе выборок для 'AAA'.
explainer_shapley = shapley(blackbox,tblX_AAA);
Вычислите значения Shapley для точек запроса с помощью функции объекта fit.
explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:)); explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:)); explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:)); explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));
Постройте график значений Shapley с помощью функции объекта plot.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_shapley1) ax2 = nexttile; plot(explainer_shapley2) ax3 = nexttile; plot(explainer_shapley3) ax4 = nexttile; plot(explainer_shapley4) ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';

MVE_BVTD и RE_TA являются двумя из трех наиболее важных предикторов для всех четырех точек запроса.
Значения Шепли MVE_BVTD положительны для первой и четвертой точек запроса и отрицательны для второй и третьей точек запроса. The MVE_BVTD значения составляют около 7 и 4 для первой и четвертой точек запроса, соответственно, и значение для второй и третьей точек запроса составляет около 3,5. Согласно значениям Шепли для четырех точек запроса, большое MVE_BVTD значение приводит к увеличению предсказанного счета, и небольшому MVE_BVTD значение приводит к снижению прогнозируемых счетов по сравнению со средним значением. Результаты соответствуют результатам lime.
Создайте график частичной зависимости (PDP)
График PDP показывает усредненные отношения между предиктором и предсказанным счетом в обученной модели. Создайте PDP для RE_TA и MVE_BVTD, которые другие инструменты интерпретации идентифицируют как важные предикторы. Передайте tblx_AAA на plotPartialDependence так, что функция вычисляет ожидание предсказанных счетов, используя только выборки для 'AAA'.
figure plotPartialDependence(blackbox,'RE_TA','AAA',tblX_AAA)

plotPartialDependence(blackbox,'MVE_BVTD','AAA',tblX_AAA)

Минорные такты в x-ось представляет уникальные значения предиктора в tbl_AAA. График для MVE_BVTD показывает, что предсказанный счет велик, когда MVE_BVTD значение невелико. Значение баллов уменьшается как MVE_BVTD значение увеличений пока оно не достигает около 5, и затем значение баллов остается неизменным как MVE_BVTD значение увеличивается. Зависимость от MVE_BVTD в подмножестве tbl_AAA идентифицируется по plotPartialDependence не соответствует местным вкладам MVE_BVTD в четырех точках запроса, обозначенных lime и shapley.
Рабочий процесс интерпретации модели для регрессионной задачи аналогичен рабочему процессу для классификационной задачи, как показано в примере Интерпретационная классификационная модель.
Этот пример обучает модель регрессии Гауссова процесса (GPR) и интерпретирует обученную модель, используя функции интерпретации. Используйте параметр ядра модели GPR, чтобы оценить веса предикторов. Кроме того, используйте lime и shapley для интерпретации предсказаний для заданных точек запроса. Затем используйте plotPartialDependence создать график, который показывает отношения между важным предиктором и предсказанными откликами.
Обучите модель GPR
Загрузите carbig набор данных, содержащий измерения автомобилей 1970-х и начала 1980-х годов.
load carbigСоздайте таблицу, содержащую переменные предиктора Acceleration, Cylinders, и так далее
tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight);
Обучите модель GPR переменной отклика MPG при помощи fitrgp функция. Задайте 'KernelFunction' как 'ardsquaredexponential' использовать квадратное экспоненциальное ядро с отдельной шкалой длины на предиктор.
blackbox = fitrgp(tbl,MPG,'ResponseName','MPG','CategoricalPredictors',[2 5], ... 'KernelFunction','ardsquaredexponential');
blackbox является RegressionGP модель.
Используйте специфичные для модели функции интерпретации
Можно вычислить веса предикторов (важность предиктора) из выученных шкал длины функции ядра, используемой в модели. Шкалы длины определяют, насколько далеко может быть предиктор, чтобы значения отклика стали некоррелированными. Найдите нормированные веса предиктора путем взятия экспоненциала отрицательной выученной длины.
sigmaL = blackbox.KernelInformation.KernelParameters(1:end-1); % Learned length scales weights = exp(-sigmaL); % Predictor weights weights = weights/sum(weights); % Normalized predictor weights
Создайте таблицу, содержащую нормированные предикторные веса, и используйте таблицу, чтобы создать горизонтальные столбчатые графики. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter значение осей для 'none'.
tbl_weight = table(weights,'VariableNames',{'Predictor Weight'}, ... 'RowNames',blackbox.ExpandedPredictorNames); tbl_weight = sortrows(tbl_weight,'Predictor Weight'); b = barh(categorical(tbl_weight.Row,tbl_weight.Row),tbl_weight.('Predictor Weight')); b.Parent.TickLabelInterpreter = 'none'; xlabel('Predictor Weight') ylabel('Predictor')

Веса предиктора указывают, что несколько фиктивных переменных для категориальных предикторов Model_Year и Cylinders важны.
Задайте точку запроса
Найдите наблюдения, чьи MPG значения меньше, чем 0,25 квантиля MPG. Из подмножества выберите четыре точки запроса, которые не содержат отсутствующих значений.
rng('default') % For reproducibility idx_subset = find(MPG < quantile(MPG,0.25)); tbl_subset = tbl(idx_subset,:); queryPoint = datasample(rmmissing(tbl_subset),4,'Replace',false)
queryPoint=4×6 table
Acceleration Cylinders Displacement Horsepower Model_Year Weight
____________ _________ ____________ __________ __________ ______
13.2 8 318 150 76 3940
14.9 8 302 130 77 4295
14 8 360 215 70 4615
13.7 8 318 145 77 4140
Использование LIME с древовидными простыми моделями
Объясните предсказания для точек запроса используя lime с простыми моделями дерева решений. lime генерирует синтетический набор данных и подгоняет простую модель к синтетическому набору данных.
Создайте lime объект, использующий tbl_subset так что lime генерирует синтетический набор данных, используя подмножество вместо всего набора данных. Задайте 'SimpleModelType' как 'tree' использовать простую модель дерева решений.
explainer_lime = lime(blackbox,tbl_subset,'SimpleModelType','tree');
Значение по умолчанию 'DataLocality' для lime является 'global', что подразумевает, что по умолчанию lime генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime использует различные веса наблюдений, так что значения веса больше фокусируются на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.
Подгонка простых моделей для четырех точек запроса с помощью функции объекта fit. Задайте третий вход (количество важных предикторов для использования в простой модели) как 6. С помощью этой настройки программное обеспечение задает максимальное количество разделений решений (или узлов ветви) как 6, так что подобранное дерево решений использует самое большее все предикторы.
explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6); explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6); explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6); explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);
Постройте график важности предиктора с помощью функции объекта plot.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_lime1); ax2 = nexttile; plot(explainer_lime2); ax3 = nexttile; plot(explainer_lime3); ax4 = nexttile; plot(explainer_lime4); ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';

Все простые модели идентифицируют Displacement, Model_Year, и Weight как важные предикторы.
Вычисление значений Shapley
Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного отклика для точки запроса от среднего отклика из-за предиктора. Создайте shapley объект для модели blackbox использование tbl_subset так что shapley вычисляет ожидаемый вклад на основе наблюдений в tbl_subset.
explainer_shapley = shapley(blackbox,tbl_subset);
Вычислите значения Shapley для точек запроса с помощью функции объекта fit.
explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:)); explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:)); explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:)); explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));
Постройте график значений Shapley с помощью функции объекта plot.
tiledlayout(2,2) ax1 = nexttile; plot(explainer_shapley1) ax2 = nexttile; plot(explainer_shapley2) ax3 = nexttile; plot(explainer_shapley3) ax4 = nexttile; plot(explainer_shapley4) ax1.TickLabelInterpreter = 'none'; ax2.TickLabelInterpreter = 'none'; ax3.TickLabelInterpreter = 'none'; ax4.TickLabelInterpreter = 'none';

Model_Year является наиболее важным предиктором для первой, второй и четвертой точек запроса и значений Шепли Model_Year положительны для трех точек запроса. The Model_Year значение является 76 или 77 для этих трех точек, и значение для третьей точки запроса равняется 70. Согласно значениям Шепли для четырех точек запроса, небольшая Model_Year значение приводит к уменьшению предсказанного отклика и большому Model_Year значение приводит к увеличению предсказанного отклика по сравнению со средним значением.
Создайте график частичной зависимости (PDP)
График PDP показывает усредненные отношения между предиктором и предсказанной реакцией в обученной модели. Создайте PDP для Model_Year, который другие инструменты интерпретации идентифицируют как важный предиктор. Передайте tbl_subset на plotPartialDependence так, что функция вычисляет ожидание предсказанных ответов, используя только выборки в tbl_subset.
figure
plotPartialDependence(blackbox,'Model_Year',tbl_subset)
График показывает тот же тренд, идентифицированную значениями Шепли для четырех точек запроса. Предсказанный ответ (MPG) значение увеличивается как Model_Year значение увеличивается.
lime | plotPartialDependence | shapley