Интерпретируйте модели машинного обучения

Эта тема вводит функции Statistics and Machine Learning Toolbox™ интерпретации модели и показывает, как интерпретировать модель машинного обучения (классификация и регрессия).

Модель машинного обучения часто упоминается как модель "черного квадрата", потому что она может затруднить, чтобы изучить, как модель делает предсказания. Инструменты Interpretability помогают вам преодолеть этот аспект алгоритмов машинного обучения и показывают, как предикторы способствуют (или не способствуйте) к предсказаниям. Кроме того, можно подтвердить, использует ли модель правильное доказательство для своих предсказаний, и найдите смещения модели, которые не сразу очевидны.

Функции интерпретации модели

Используйте lime, shapley, и plotPartialDependence объяснить вклад отдельных предикторов к предсказаниям обученной классификации или модели регрессии.

  • lime — Локальные поддающиеся толкованию объяснения модели агностические (LIME [1]) интерпретируют предсказание для точки запроса, подбирая простую поддающуюся толкованию модель для точки запроса. Простая модель действует как приближение для обученной модели и объясняет предсказания модели вокруг точки запроса. Простая модель может быть или линейной моделью или моделью дерева принятия решения. Можно использовать предполагаемые коэффициенты линейной модели или предполагаемую важность предиктора модели дерева принятия решения, чтобы объяснить вклад отдельных предикторов к предсказанию для точки запроса. Для получения дополнительной информации смотрите LIME.

  • shapley — Значение Шепли [2][3] предиктора для точки запроса объясняет отклонение предсказания (ответ для регрессии или музыка класса к классификации) для точки запроса из среднего предсказания, из-за предиктора. Для точки запроса сумма значений Шепли для всех функций соответствует общему отклонению предсказания от среднего значения. Для получения дополнительной информации смотрите Значения Шепли для Модели Машинного обучения.

  • plotPartialDependence и partialDependence — Частичный график зависимости (PDP [4]) показывает отношения между предиктором (или пара предикторов) и предсказанием (ответ для регрессии или музыка класса к классификации) в обученной модели. Частичная зависимость от выбранного предиктора задана усредненным предсказанием, полученным путем маргинализации эффекта других переменных. Поэтому частичная зависимость является функцией выбранного предиктора, который показывает средний эффект выбранного предиктора по набору данных. Можно также создать набор отдельного условного ожидания (ICE [5]) графики для каждого наблюдения, показав эффект выбранного предиктора на одном наблюдении. Для получения дополнительной информации смотрите Больше О на plotPartialDependence страница с описанием.

Встроенный выбор признаков типа поддержки моделей некоторого машинного обучения, где модель изучает важность предиктора как часть процесса обучения модели. Можно использовать предполагаемую важность предиктора, чтобы объяснить предсказания модели. Например:

  • Обучите ансамбль (ClassificationBaggedEnsemble или RegressionBaggedEnsemble) из сложенных в мешок деревьев решений (например, случайный лес) и использование predictorImportance и oobPermutedPredictorImportance функции.

  • Обучите линейную модель с регуляризацией лассо, которая уменьшает коэффициенты наименее важных предикторов. Затем используйте предполагаемые коэффициенты в качестве мер для важности предиктора. Например, использовать fitclinear или fitrlinear и задайте 'Regularization' аргумент значения имени как 'lasso'.

Для списка моделей машинного обучения, которые поддерживают встроенный выбор признаков типа, смотрите Встроенный Выбор признаков Типа.

Используйте функции Statistics and Machine Learning Toolbox для трех уровней интерпретации модели: локальный, когорта и глобальная переменная.

УровеньЦельВариант использованияФункция Statistics and Machine Learning Toolbox
Локальная интерпретацияОбъясните предсказание для точки единого запроса.
  • Идентифицируйте важные предикторы для отдельного предсказания.

  • Исследуйте парадоксальное предсказание.

Используйте lime и shapley для заданной точки запроса.
Интерпретация когортыОбъясните, как обученная модель делает предсказания для подмножества целого набора данных.Подтвердите предсказания для конкретной группы выборок.
  • Используйте lime и shapley для нескольких точек запроса. После создания lime или shapley объект, можно вызвать объектную функцию fit многократно интерпретировать предсказания для других точек запроса.

  • Передайте подмножество данных, когда вы вызовете lime, shapley, и plotPartialDependence. Функции интерпретируют обученную модель с помощью заданного подмножества вместо целого обучающего набора данных.

Глобальная интерпретацияОбъясните, как обученная модель делает предсказания для целого набора данных.
  • Продемонстрируйте, как работает обученная модель.

  • Сравните различные модели.

  • Использование plotPartialDependence создать PDPs и ICE строит для предикторов интереса.

  • Найдите важные предикторы из обученной модели, которая поддерживает Встроенный Выбор признаков Типа.

Интерпретируйте модель классификации

Этот пример обучает ансамбль сложенных в мешок деревьев решений с помощью случайного лесного алгоритма и интерпретирует обученную модель с помощью interpretability функции. Используйте объектные функции (oobPermutedPredictorImportance и predictorImportance) из обученной модели, чтобы найти важные предикторы в модели. Кроме того, используйте lime и shapley интерпретировать предсказания для заданных точек запроса. Затем используйте plotPartialDependence создать график, который показывает отношения между важным предиктором и предсказанными классификационными оценками.

Обучите модель ансамбля классификации

Загрузите CreditRating_Historical набор данных. Набор данных содержит идентификаторы клиентов и их финансовые отношения, промышленные метки и кредитные рейтинги.

tbl = readtable('CreditRating_Historical.dat');

Отобразите первые три строки таблицы.

head(tbl,3)
ans=3×8 table
     ID      WC_TA    RE_TA    EBIT_TA    MVE_BVTD    S_TA     Industry    Rating
    _____    _____    _____    _______    ________    _____    ________    ______

    62394    0.013    0.104     0.036      0.447      0.142       3        {'BB'}
    48608    0.232    0.335     0.062      1.969      0.281       8        {'A' }
    42444    0.311    0.367     0.074      1.935      0.366       1        {'A' }

Составьте таблицу переменных предикторов путем удаления столбцов, содержащих идентификаторы клиентов и оценки от tbl.

tblX = removevars(tbl,["ID","Rating"]);

Обучите ансамбль сложенных в мешок деревьев решений при помощи fitcensemble функция и определение метода агрегации ансамбля как случайный лес ('Bag'). Для воспроизводимости случайного лесного алгоритма задайте 'Reproducible' аргумент значения имени как true для древовидных учеников. Кроме того, задайте имена классов, чтобы установить порядок классов в обученной модели.

rng('default') % For reproducibility
t = templateTree('Reproducible',true);
blackbox = fitcensemble(tblX,tbl.Rating, ...
    'Method','Bag','Learners',t, ...
    'CategoricalPredictors','Industry', ...
    'ClassNames',{'AAA' 'AA' 'A' 'BBB' 'BB' 'B' 'CCC'});

blackbox ClassificationBaggedEnsemble модель.

Используйте функции Interpretability модели специфичные

ClassificationBaggedEnsemble поддерживает две объектных функции, oobPermutedPredictorImportance и predictorImportance, которые находят важные предикторы в обученной модели.

Оцените важность предиктора из сумки при помощи oobPermutedPredictorImportance функция. Функция случайным образом переставляет данные из сумки через один предиктор за один раз и оценивает увеличение ошибки из сумки из-за этого сочетания. Чем больше увеличение, тем более важный функция.

Imp1 = oobPermutedPredictorImportance(blackbox);

Оцените важность предиктора при помощи predictorImportance функция. Функциональная оценочная важность предиктора путем подведения итогов изменений в узле рискует из-за разделений на каждом предикторе и делении суммы количеством узлов ветви.

Imp2 = predictorImportance(blackbox);

Составьте таблицу, содержащую оценки важности предиктора, и используйте таблицу, чтобы создать горизонтальные столбчатые графики. Чтобы отобразить существующее подчеркивание на любое имя предиктора, измените TickLabelInterpreter значение осей к 'none'.

table_Imp = table(Imp1',Imp2', ...
    'VariableNames',{'Out-of-Bag Permuted Predictor Importance','Predictor Importance'}, ...
    'RowNames',blackbox.PredictorNames);
tiledlayout(1,2)
ax1 = nexttile;
table_Imp1 = sortrows(table_Imp,'Out-of-Bag Permuted Predictor Importance');
barh(categorical(table_Imp1.Row,table_Imp1.Row),table_Imp1.('Out-of-Bag Permuted Predictor Importance'))
xlabel('Out-of-Bag Permuted Predictor Importance')
ylabel('Predictor')
ax2 = nexttile;
table_Imp2 = sortrows(table_Imp,'Predictor Importance');
barh(categorical(table_Imp2.Row,table_Imp2.Row),table_Imp2.('Predictor Importance'))
xlabel('Predictor Importance')
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';

Обе объектных функции идентифицируют MVE_BVTD и RE_TA как два самых важных предиктора.

Задайте точку запроса

Найдите наблюдения чей Rating 'AAA' и выберите четыре точки запроса среди них.

tblX_AAA = tblX(strcmp(tbl.Rating,'AAA'),:);
queryPoint = datasample(tblX_AAA,4,'Replace',false)
queryPoint=4×6 table
    WC_TA    RE_TA    EBIT_TA    MVE_BVTD    S_TA     Industry
    _____    _____    _______    ________    _____    ________

    0.331    0.531     0.077      7.116      0.522       12   
     0.26    0.515     0.065      3.394      0.515        1   
    0.121    0.413     0.057      3.647      0.466       12   
    0.617    0.766     0.126      4.442      0.483        9   

Используйте LIME с линейными простыми моделями

Объясните предсказания для точек запроса с помощью lime с линейными простыми моделями. lime генерирует синтетический набор данных и подбирает простую модель к синтетическому набору данных.

Создайте lime объект с помощью tblX_AAA так, чтобы lime генерирует синтетический набор данных с помощью только наблюдения чей Rating 'AAA', не целый набор данных.

explainer_lime = lime(blackbox,tblX_AAA);

Значение по умолчанию 'DataLocality' для lime 'global', который подразумевает что, по умолчанию, lime генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime использует различные веса наблюдения так, чтобы значения веса более фокусировались на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.

Подбирайте простые модели для четырех точек запроса при помощи объектного функционального fit. Задайте третий вход (количество важных предикторов, чтобы использовать в простой модели) как 6, чтобы использовать все шесть предикторов.

explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6);
explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6);
explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6);
explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);

Постройте коэффициенты простых моделей при помощи объектного функционального plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_lime1);
ax2 = nexttile; plot(explainer_lime2);
ax3 = nexttile; plot(explainer_lime3);
ax4 = nexttile; plot(explainer_lime4);
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Все простые модели идентифицируют EBIT_TA, RE_TA, и MVE_BVTD как три самых важных предиктора. Положительные коэффициенты для предикторов предполагают, что увеличение значений предиктора приводит к увеличению предсказанных баллов в простых моделях.

Для категориального предиктора, plot функционируйте отображает только самую важную фиктивную переменную категориального предиктора. Поэтому каждый столбчатый график отображает различную фиктивную переменную.

Вычислите значения Шепли

Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного счета к точке запроса от средней оценки, из-за предиктора. Создайте shapley объект с помощью tblX_AAA так, чтобы shapley вычисляет ожидаемый вклад на основе выборок для 'AAA'.

explainer_shapley = shapley(blackbox,tblX_AAA);

Вычислите значения Шепли для точек запроса при помощи объектного функционального fit.

explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:));
explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:));
explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:));
explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));

Постройте значения Шепли при помощи объектного функционального plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_shapley1)
ax2 = nexttile; plot(explainer_shapley2)
ax3 = nexttile; plot(explainer_shapley3)
ax4 = nexttile; plot(explainer_shapley4)
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

MVE_BVTD и RE_TA два из трех самых важных предикторов для всех четырех точек запроса.

Значения Шепли MVE_BVTD положительны для первых и четвертых точек запроса и отрицательны для вторых и третьих точек запроса. MVE_BVTD значения - приблизительно 7 и 4 для первых и четвертых точек запроса, соответственно, и значение и для вторых и для третьих точек запроса является приблизительно 3,5. Согласно значениям Шепли для четырех точек запроса, большому MVE_BVTD значение приводит к увеличению предсказанного счета и маленькому MVE_BVTD значение приводит к уменьшению в предсказанных баллах по сравнению со средним значением. Результаты сопоставимы с результатами lime.

Создайте Частичный график зависимости (PDP)

График PDP показывает усредненные отношения между предиктором и предсказанным счетом в обученной модели. Создайте PDPs для RE_TA и MVE_BVTD, который другие interpretability инструменты идентифицируют как важные предикторы. Передайте tblx_AAA к plotPartialDependence так, чтобы функция вычислила ожидание предсказанных баллов с помощью только выборки для 'AAA'.

figure
plotPartialDependence(blackbox,'RE_TA','AAA',tblX_AAA)

plotPartialDependence(blackbox,'MVE_BVTD','AAA',tblX_AAA)

Незначительные метки деления в x- ось представляет уникальные значения предиктора в tbl_AAA. График для MVE_BVTD показывает, что предсказанный счет является большим когда MVE_BVTD значение мало. Значение баллов уменьшается как MVE_BVTD повышения стоимости, пока это не достигает приблизительно 5, и затем значения баллов, остаются неизменными как MVE_BVTD повышения стоимости. Зависимость от MVE_BVTD в подмножестве tbl_AAA идентифицированный plotPartialDependence не сопоставимо с локальными вкладами MVE_BVTD в четырех точках запроса, идентифицированных lime и shapley.

Интерпретируйте модель регрессии

Рабочий процесс интерпретации модели для проблемы регрессии похож на рабочий процесс для проблемы классификации, как продемонстрировано в примере Интерпретируют Модель Классификации.

Этот пример обучает модель Gaussian process regression (GPR) и интерпретирует обученную модель с помощью interpretability функции. Используйте параметр ядра модели GPR, чтобы оценить веса предиктора. Кроме того, используйте lime и shapley интерпретировать предсказания для заданных точек запроса. Затем используйте plotPartialDependence создать график, который показывает отношения между важным предиктором и предсказанными ответами.

Обучите модель GPR

Загрузите carbig набор данных, который содержит измерения автомобилей, сделанных в 1970-х и в начале 1980-х.

load carbig

Составьте таблицу, содержащую переменные предикторы Acceleration, Cylinders, и так далее

tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight);

Обучите модель GPR переменной отклика MPG при помощи fitrgp функция. Задайте 'KernelFunction' как 'ardsquaredexponential' использовать экспоненциальное ядро в квадрате с отдельной шкалой расстояний на предиктор.

blackbox = fitrgp(tbl,MPG,'ResponseName','MPG','CategoricalPredictors',[2 5], ...
    'KernelFunction','ardsquaredexponential');

blackbox RegressionGP модель.

Используйте функции Interpretability модели специфичные

Можно вычислить веса предиктора (важность предиктора) от изученных шкал расстояний функции ядра, используемой в модели. Шкалы расстояний задают, как далеко независимо предиктор может быть для значений отклика, чтобы стать некоррелированым. Найдите нормированные веса предиктора путем взятия экспоненциала отрицательных изученных шкал расстояний.

sigmaL = blackbox.KernelInformation.KernelParameters(1:end-1); % Learned length scales
weights = exp(-sigmaL); % Predictor weights
weights = weights/sum(weights); % Normalized predictor weights

Составьте таблицу, содержащую нормированные веса предиктора, и используйте таблицу, чтобы создать горизонтальные столбчатые графики. Чтобы отобразить существующее подчеркивание на любое имя предиктора, измените TickLabelInterpreter значение осей к 'none'.

tbl_weight = table(weights,'VariableNames',{'Predictor Weight'}, ...
    'RowNames',blackbox.ExpandedPredictorNames);
tbl_weight = sortrows(tbl_weight,'Predictor Weight');
b = barh(categorical(tbl_weight.Row,tbl_weight.Row),tbl_weight.('Predictor Weight'));
b.Parent.TickLabelInterpreter = 'none'; 
xlabel('Predictor Weight')
ylabel('Predictor')

Веса предиктора указывают что несколько фиктивных переменных для категориальных предикторов Model_Year и Cylinders важны.

Задайте точку запроса

Найдите наблюдения чей MPG значения меньше, чем 0,25 квантиля MPG. От подмножества выберите четыре точки запроса, которые не включают отсутствующие значения.

rng('default') % For reproducibility
idx_subset = find(MPG < quantile(MPG,0.25));
tbl_subset = tbl(idx_subset,:);
queryPoint = datasample(rmmissing(tbl_subset),4,'Replace',false)
queryPoint=4×6 table
    Acceleration    Cylinders    Displacement    Horsepower    Model_Year    Weight
    ____________    _________    ____________    __________    __________    ______

        13.2            8            318            150            76         3940 
        14.9            8            302            130            77         4295 
          14            8            360            215            70         4615 
        13.7            8            318            145            77         4140 

Используйте LIME с древовидными простыми моделями

Объясните предсказания для точек запроса с помощью lime с деревом решений простые модели. lime генерирует синтетический набор данных и подбирает простую модель к синтетическому набору данных.

Создайте lime объект с помощью tbl_subset так, чтобы lime генерирует синтетический набор данных с помощью подмножества вместо целого набора данных. Задайте 'SimpleModelType' как 'tree' использовать дерево решений простая модель.

explainer_lime = lime(blackbox,tbl_subset,'SimpleModelType','tree');

Значение по умолчанию 'DataLocality' для lime 'global', который подразумевает что, по умолчанию, lime генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime использует различные веса наблюдения так, чтобы значения веса более фокусировались на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.

Подбирайте простые модели для четырех точек запроса при помощи объектного функционального fit. Задайте третий вход (количество важных предикторов, чтобы использовать в простой модели) как 6. С этой установкой программное обеспечение задает максимальное количество разделений решения (или узлы ветви) как 6 так, чтобы подходящее дерево решений использовало самое большее все предикторы.

explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6);
explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6);
explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6);
explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);

Постройте важность предиктора при помощи объектного функционального plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_lime1);
ax2 = nexttile; plot(explainer_lime2);
ax3 = nexttile; plot(explainer_lime3);
ax4 = nexttile; plot(explainer_lime4);
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Все простые модели идентифицируют Displacement, Model_Year, и Weight как важные предикторы.

Вычислите значения Шепли

Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного ответа для точки запроса от среднего ответа, из-за предиктора. Создайте shapley объект для модели blackbox использование tbl_subset так, чтобы shapley вычисляет ожидаемый вклад на основе наблюдений в tbl_subset.

explainer_shapley = shapley(blackbox,tbl_subset);

Вычислите значения Шепли для точек запроса при помощи объектного функционального fit.

explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:));
explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:));
explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:));
explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));

Постройте значения Шепли при помощи объектного функционального plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_shapley1)
ax2 = nexttile; plot(explainer_shapley2)
ax3 = nexttile; plot(explainer_shapley3)
ax4 = nexttile; plot(explainer_shapley4)
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Model_Year самый важный предиктор для первых, вторых, и четвертых точек запроса и значения Шепли Model_Year положительны для трех точек запроса. Model_Year значение равняется 76 или 77 для этих трех точек, и значение для третьей точки запроса равняется 70. Согласно значениям Шепли для четырех точек запроса, маленькому Model_Year значение приводит к уменьшению в предсказанном ответе и большому Model_Year значение приводит к увеличению предсказанного ответа по сравнению со средним значением.

Создайте Частичный график зависимости (PDP)

График PDP показывает усредненные отношения между предиктором и предсказанным ответом в обученной модели. Создайте PDP для Model_Year, который другие interpretability инструменты идентифицируют как важный предиктор. Передайте tbl_subset к plotPartialDependence так, чтобы функция вычислила ожидание предсказанных ответов с помощью только выборки в tbl_subset.

figure
plotPartialDependence(blackbox,'Model_Year',tbl_subset)

График показывает тот же тренд, идентифицированный значениями Шепли для четырех точек запроса. Предсказанный ответ (MPG) повышения стоимости как Model_Year повышения стоимости.

Ссылки

[1] Рибейру, Марко Тулио, С. Сингх и К. Гуестрин. "'Почему я должен Доверять Вам?': Объяснение Предсказаний Любого Классификатора". В Продолжениях 22-й Международной конференции ACM SIGKDD по вопросам Открытия Знаний и Анализа данных, 1135–44. Сан-Франциско, Калифорния: ACM, 2016.

[2] Лундберг, Скотт М. и С. Ли. "Объединенный подход к интерпретации предсказаний модели". Усовершенствования в нейронных системах обработки информации 30 (2017): 4765–774.

[3] Научный работник, Керсти, Мартин. Джаллум и Андерс Лылэнд. "Объясняя Отдельные Предсказания, Когда Функции Зависят: Более точные Приближения к Значениям Шепли". arXiv:1903.10464 (2019).

[4] Фридман, Джером. H. “Жадное Приближение функций: Машина Повышения Градиента”. Летопись Статистики 29, № 5 (2001): 1189-1232.

[5] Голдстайн, Алекс, Адам Кэпелнер, Джастин Блейч и Эмиль Питкин. “Посмотрев В Черном квадрате: Визуализация Статистического Изучения с Графиками Отдельного Условного Ожидания”. Журнал Вычислительной и Графической Статистики 24, № 1 (2 января 2015): 44–65.

Смотрите также

| |

Похожие темы