exponenta event banner

Интерпретировать модели машинного обучения

В этом разделе описываются функции Toolbox™ статистики и машинного обучения для интерпретации моделей и показано, как интерпретировать модель машинного обучения (классификация и регрессия).

Модель машинного обучения часто называют моделью «черного ящика», поскольку может быть трудно понять, как модель делает прогнозы. Инструменты интерпретации помогают преодолеть этот аспект алгоритмов машинного обучения и раскрыть, как предикторы вносят (или не вносят) вклад в предсказания. Кроме того, можно проверить, использует ли модель правильные доказательства для своих прогнозов, и найти отклонения модели, которые не сразу очевидны.

Элементы для интерпретации модели

Использовать lime, shapley, и plotPartialDependence объяснить вклад отдельных предикторов в прогнозы обученной классификации или регрессионной модели.

  • lime - Локальные интерпретируемые модели-агностические объяснения (LIME [1]) интерпретируют прогноз для точки запроса, подгоняя простую интерпретируемую модель для точки запроса. Простая модель выступает в качестве аппроксимации для обученной модели и объясняет прогнозы модели вокруг точки запроса. Простой моделью может быть либо линейная модель, либо модель дерева решений. Можно использовать оценочные коэффициенты линейной модели или оценочную важность предиктора модели дерева решений, чтобы объяснить вклад отдельных предикторов в прогноз для точки запроса. Дополнительные сведения см. в разделе LIME.

  • shapley - [2][3] значений Шапли предсказателя для точки запроса объясняет отклонение предсказания (ответ на регрессию или оценки класса для классификации) для точки запроса от среднего предсказания, обусловленное предсказателем. Для точки запроса сумма значений Шепли для всех признаков соответствует общему отклонению прогноза от среднего значения. Дополнительные сведения см. в разделе Значения Shapley для модели машинного обучения.

  • plotPartialDependence и partialDependence - График частичной зависимости (PDP [4]) показывает взаимосвязи между предиктором (или парой предикторов) и предсказанием (отклик на регрессию или оценки классов для классификации) в обученной модели. Частичная зависимость от выбранного предиктора определяется усредненным предсказанием, полученным путем исключения влияния других переменных. Следовательно, частичная зависимость является функцией выбранного предиктора, которая показывает средний эффект выбранного предиктора по набору данных. Можно также создать набор отдельных графиков условного ожидания (ICE [5]) для каждого наблюдения, показывающих влияние выбранного предиктора на одно наблюдение. Дополнительные сведения см. в разделе Дополнительные сведения о plotPartialDependence справочная страница.

Некоторые модели машинного обучения поддерживают выбор элементов встроенного типа, где модель изучает важность предиктора как часть процесса обучения модели. Можно использовать оценочную важность предиктора для объяснения предсказаний модели. Например:

  • Тренировать ансамбль (ClassificationBaggedEnsemble или RegressionBaggedEnsemble) фасованных деревьев принятия решений (например, случайный лес) и использовать predictorImportance и oobPermutedPredictorImportance функции.

  • Тренируйте линейную модель с регуляризацией лассо, которая сокращает коэффициенты наименее важных предикторов. Затем используйте оцененные коэффициенты в качестве показателей важности предиктора. Например, использовать fitclinear или fitrlinear и укажите 'Regularization' аргумент имя-значение как 'lasso'.

Список моделей машинного обучения, поддерживающих выбор элементов встроенного типа, см. в разделе Выбор элементов встроенного типа.

Используйте функции набора инструментов для статистики и машинного обучения для трех уровней интерпретации моделей: локального, когортного и глобального.

УровеньЦельСценарий использованияНабор инструментов для статистического и машинного обучения
Локальная интерпретацияОбъясните прогноз для одной точки запроса.
  • Определите важные предикторы для индивидуального прогноза.

  • Изучите контринтуитивный прогноз.

Использовать lime и shapley для указанной точки запроса.
Когортная интерпретацияОбъясните, как обученная модель делает прогнозы для подмножества всего набора данных.Проверка прогнозов для определенной группы образцов.
  • Использовать lime и shapley для нескольких точек запроса. После создания lime или shapley объект, можно вызвать функцию объекта fit многократно интерпретировать прогнозы для других точек запроса.

  • Передача подмножества данных при вызове lime, shapley, и plotPartialDependence. Элементы интерпретируют обучаемую модель с использованием указанного подмножества вместо всего набора обучающих данных.

Глобальная интерпретацияОбъясните, как обученная модель делает прогнозы для всего набора данных.
  • Продемонстрируйте, как работает обученная модель.

  • Сравнение различных моделей.

  • Использовать plotPartialDependence для создания PDP и графиков ICE для интересующих предикторов.

  • Найдите важные предикторы из обученной модели, которая поддерживает выбор элементов встроенного типа.

Интерпретировать классификационную модель

В этом примере обучается ансамбль пакетных деревьев принятия решений с использованием алгоритма случайного леса и интерпретируется обученная модель с использованием функций интерпретируемости. Используйте функции объекта (oobPermutedPredictorImportance и predictorImportance) обученной модели, чтобы найти важные предикторы в модели. Кроме того, используйте lime и shapley для интерпретации прогнозов для указанных точек запроса. Затем использовать plotPartialDependence чтобы создать график, который показывает взаимосвязи между важным предиктором и прогнозируемыми показателями классификации.

Модель ансамбля классификации поездов

Загрузить CreditRating_Historical набор данных. Набор данных содержит идентификаторы клиентов и их финансовые коэффициенты, отраслевые наклейки и кредитные рейтинги.

tbl = readtable('CreditRating_Historical.dat');

Отображение первых трех строк таблицы.

head(tbl,3)
ans=3×8 table
     ID      WC_TA    RE_TA    EBIT_TA    MVE_BVTD    S_TA     Industry    Rating
    _____    _____    _____    _______    ________    _____    ________    ______

    62394    0.013    0.104     0.036      0.447      0.142       3        {'BB'}
    48608    0.232    0.335     0.062      1.969      0.281       8        {'A' }
    42444    0.311    0.367     0.074      1.935      0.366       1        {'A' }

Создание таблицы переменных предиктора путем удаления столбцов, содержащих идентификаторы клиентов и оценки, из tbl.

tblX = removevars(tbl,["ID","Rating"]);

Обучение ансамбля фасованных деревьев принятия решений с помощью fitcensemble функцию и указание метода агрегации ансамбля как случайного леса ('Bag'). Для воспроизводимости алгоритма случайного леса укажите 'Reproducible' аргумент имя-значение как true для обучающихся на деревьях. Также укажите имена классов, чтобы задать порядок классов в обучаемой модели.

rng('default') % For reproducibility
t = templateTree('Reproducible',true);
blackbox = fitcensemble(tblX,tbl.Rating, ...
    'Method','Bag','Learners',t, ...
    'CategoricalPredictors','Industry', ...
    'ClassNames',{'AAA' 'AA' 'A' 'BBB' 'BB' 'B' 'CCC'});

blackbox является ClassificationBaggedEnsemble модель.

Использование специфичных для модели функций интерпретируемости

ClassificationBaggedEnsemble поддерживает две функции объекта, oobPermutedPredictorImportance и predictorImportance, которые находят важные предикторы в обученной модели.

Оцените важность предиктора из мешка, используя oobPermutedPredictorImportance функция. Функция случайным образом переставляет данные вне пакета по одному предиктору за раз и оценивает увеличение ошибки вне пакета из-за этой перестановки. Чем больше увеличение, тем важнее особенность.

Imp1 = oobPermutedPredictorImportance(blackbox);

Оценить важность предиктора с помощью predictorImportance функция. Функция оценивает важность предиктора путем суммирования изменений в риске узла из-за расщеплений на каждом предикторе и деления суммы на количество узлов ветви.

Imp2 = predictorImportance(blackbox);

Создайте таблицу, содержащую оценки важности предиктора, и используйте таблицу для создания горизонтальных гистограмм. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter значение осей для 'none'.

table_Imp = table(Imp1',Imp2', ...
    'VariableNames',{'Out-of-Bag Permuted Predictor Importance','Predictor Importance'}, ...
    'RowNames',blackbox.PredictorNames);
tiledlayout(1,2)
ax1 = nexttile;
table_Imp1 = sortrows(table_Imp,'Out-of-Bag Permuted Predictor Importance');
barh(categorical(table_Imp1.Row,table_Imp1.Row),table_Imp1.('Out-of-Bag Permuted Predictor Importance'))
xlabel('Out-of-Bag Permuted Predictor Importance')
ylabel('Predictor')
ax2 = nexttile;
table_Imp2 = sortrows(table_Imp,'Predictor Importance');
barh(categorical(table_Imp2.Row,table_Imp2.Row),table_Imp2.('Predictor Importance'))
xlabel('Predictor Importance')
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';

Обе функции объекта идентифицируют MVE_BVTD и RE_TA как два наиболее важных предиктора.

Указать точку запроса

Найдите наблюдения, чьи Rating является 'AAA' и выберите из них четыре точки запроса.

tblX_AAA = tblX(strcmp(tbl.Rating,'AAA'),:);
queryPoint = datasample(tblX_AAA,4,'Replace',false)
queryPoint=4×6 table
    WC_TA    RE_TA    EBIT_TA    MVE_BVTD    S_TA     Industry
    _____    _____    _______    ________    _____    ________

    0.331    0.531     0.077      7.116      0.522       12   
     0.26    0.515     0.065      3.394      0.515        1   
    0.121    0.413     0.057      3.647      0.466       12   
    0.617    0.766     0.126      4.442      0.483        9   

Использование LIME с линейными простыми моделями

пояснять прогнозы для точек запроса с помощью lime с линейными простыми моделями. lime генерирует синтетический набор данных и подгоняет простую модель к синтетическому набору данных.

Создать lime объект с использованием tblX_AAA чтобы lime создает синтетический набор данных, используя только наблюдения, Rating является 'AAA', не весь набор данных.

explainer_lime = lime(blackbox,tblX_AAA);

Значение по умолчанию DataLocality для lime является 'global', что означает, что по умолчанию lime создает глобальный синтетический набор данных и использует его для любых точек запроса. lime использует различные веса наблюдений, так что значения весов более сфокусированы на наблюдениях вблизи точки запроса. Поэтому можно интерпретировать каждую простую модель как аппроксимацию обученной модели для определенной точки запроса.

Подгонка простых моделей для четырех точек запроса с помощью функции объекта fit. Укажите третий вход (количество важных предикторов для использования в простой модели) как 6, чтобы использовать все шесть предикторов.

explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6);
explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6);
explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6);
explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);

Постройте график коэффициентов простых моделей с помощью функции объекта plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_lime1);
ax2 = nexttile; plot(explainer_lime2);
ax3 = nexttile; plot(explainer_lime3);
ax4 = nexttile; plot(explainer_lime4);
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Все простые модели идентифицируют EBIT_TA, RE_TA, и MVE_BVTD как три наиболее важных предиктора. Положительные коэффициенты для предикторов предполагают, что увеличение значений предиктора приводит к увеличению прогнозируемых показателей в простых моделях.

Для категориального предиктора plot функция отображает только самую важную фиктивную переменную категориального предиктора. Поэтому на каждой гистограмме отображается своя фиктивная переменная.

Вычислить значения Shapley

Значение Шапли предсказателя для точки запроса объясняет отклонение прогнозируемой оценки для точки запроса от средней оценки, обусловленной предсказателем. Создать shapley объект с использованием tblX_AAA чтобы shapley вычисляет ожидаемый вклад на основе выборок для 'AAA'.

explainer_shapley = shapley(blackbox,tblX_AAA);

Вычисление значений Shapley для точек запроса с помощью функции объекта fit.

explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:));
explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:));
explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:));
explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));

Постройте график значений Shapley с помощью функции объекта plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_shapley1)
ax2 = nexttile; plot(explainer_shapley2)
ax3 = nexttile; plot(explainer_shapley3)
ax4 = nexttile; plot(explainer_shapley4)
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

MVE_BVTD и RE_TA являются двумя из трех наиболее важных предикторов для всех четырех точек запроса.

Значения Shapley MVE_BVTD являются положительными для первой и четвертой точек запроса и отрицательными для второй и третьей точек запроса. MVE_BVTD значения равны примерно 7 и 4 для первой и четвертой точек запроса соответственно, а значение для второй и третьей точек запроса равно примерно 3,5. Согласно значениям Шепли для четырех точек запроса, большой MVE_BVTD значение приводит к увеличению прогнозируемого балла, и небольшой MVE_BVTD значение приводит к снижению прогнозируемых баллов по сравнению со средним значением. Результаты согласуются с результатами из lime.

Создание графика частичной зависимости (PDP)

График PDP показывает усредненные отношения между предиктором и прогнозируемой оценкой в обученной модели. Создание PDP для RE_TA и MVE_BVTD, которые другие инструменты интерпретации идентифицируют как важные предикторы. Проход tblx_AAA кому plotPartialDependence таким образом, что функция вычисляет ожидание прогнозируемых баллов, используя только выборки для 'AAA'.

figure
plotPartialDependence(blackbox,'RE_TA','AAA',tblX_AAA)

plotPartialDependence(blackbox,'MVE_BVTD','AAA',tblX_AAA)

Незначительные клещи в x-axis представляют уникальные значения предиктора в tbl_AAA. График для MVE_BVTD показывает, что прогнозируемый балл является большим, когда MVE_BVTD значение невелико. Значение балла уменьшается по мере MVE_BVTD значение увеличивается до тех пор, пока не достигнет приблизительно 5, а затем значение оценки остается неизменным, так как MVE_BVTD значение увеличивается. Зависимость от MVE_BVTD в подмножестве tbl_AAA идентифицировано plotPartialDependence не соответствует местным вкладам MVE_BVTD в четырех точках запроса, определенных lime и shapley.

Интерпретировать регрессионную модель

Рабочий процесс интерпретации модели для задачи регрессии аналогичен рабочему процессу для задачи классификации, как показано в примере Интерпретировать модель классификации.

В этом примере изучается модель регрессии гауссова процесса (GPR) и интерпретируется обученная модель с использованием функций интерпретируемости. Используйте параметр ядра модели GPR для оценки веса предиктора. Кроме того, используйте lime и shapley для интерпретации прогнозов для указанных точек запроса. Затем использовать plotPartialDependence для создания графика, показывающего взаимосвязи между важным предиктором и прогнозируемыми ответами.

Модель линии GPR

Загрузить carbig набор данных, содержащий замеры автомобилей, сделанные в 1970-х и начале 1980-х годов.

load carbig

Создание таблицы, содержащей переменные предиктора Acceleration, Cylindersи так далее

tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight);

Обучение модели GPR переменной ответа MPG с помощью fitrgp функция. Определить 'KernelFunction' как 'ardsquaredexponential' для использования квадратного экспоненциального ядра с отдельной шкалой длины на предиктор.

blackbox = fitrgp(tbl,MPG,'ResponseName','MPG','CategoricalPredictors',[2 5], ...
    'KernelFunction','ardsquaredexponential');

blackbox является RegressionGP модель.

Использование специфичных для модели функций интерпретируемости

Можно вычислить веса предиктора (важность предиктора) на основе усвоенных шкал длины функции ядра, используемой в модели. Шкалы длины определяют, насколько далеко может быть предиктор, чтобы значения ответа стали некоррелированными. Найдите нормализованные веса предиктора, взяв экспоненту отрицательных усвоенных шкал длины.

sigmaL = blackbox.KernelInformation.KernelParameters(1:end-1); % Learned length scales
weights = exp(-sigmaL); % Predictor weights
weights = weights/sum(weights); % Normalized predictor weights

Создайте таблицу, содержащую нормализованные веса предиктора, и используйте таблицу для создания горизонтальных гистограмм. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter значение осей для 'none'.

tbl_weight = table(weights,'VariableNames',{'Predictor Weight'}, ...
    'RowNames',blackbox.ExpandedPredictorNames);
tbl_weight = sortrows(tbl_weight,'Predictor Weight');
b = barh(categorical(tbl_weight.Row,tbl_weight.Row),tbl_weight.('Predictor Weight'));
b.Parent.TickLabelInterpreter = 'none'; 
xlabel('Predictor Weight')
ylabel('Predictor')

Веса предикторов указывают, что несколько фиктивных переменных для категориальных предикторов Model_Year и Cylinders важны.

Указать точку запроса

Найдите наблюдения, чьи MPG значения меньше 0,25 квантиля MPG. В подмножестве выберите четыре точки запроса, не содержащие отсутствующих значений.

rng('default') % For reproducibility
idx_subset = find(MPG < quantile(MPG,0.25));
tbl_subset = tbl(idx_subset,:);
queryPoint = datasample(rmmissing(tbl_subset),4,'Replace',false)
queryPoint=4×6 table
    Acceleration    Cylinders    Displacement    Horsepower    Model_Year    Weight
    ____________    _________    ____________    __________    __________    ______

        13.2            8            318            150            76         3940 
        14.9            8            302            130            77         4295 
          14            8            360            215            70         4615 
        13.7            8            318            145            77         4140 

Использование LIME с простыми моделями дерева

пояснять прогнозы для точек запроса с помощью lime с простыми моделями дерева решений. lime генерирует синтетический набор данных и подгоняет простую модель к синтетическому набору данных.

Создать lime объект с использованием tbl_subset чтобы lime создает синтетический набор данных, используя подмножество вместо всего набора данных. Определить 'SimpleModelType' как 'tree' для использования простой модели дерева решений.

explainer_lime = lime(blackbox,tbl_subset,'SimpleModelType','tree');

Значение по умолчанию DataLocality для lime является 'global', что означает, что по умолчанию lime создает глобальный синтетический набор данных и использует его для любых точек запроса. lime использует различные веса наблюдений, так что значения весов более сфокусированы на наблюдениях вблизи точки запроса. Поэтому можно интерпретировать каждую простую модель как аппроксимацию обученной модели для определенной точки запроса.

Подгонка простых моделей для четырех точек запроса с помощью функции объекта fit. Укажите третий вход (количество важных предикторов для использования в простой модели) как 6. С помощью этой настройки программное обеспечение определяет максимальное количество разделений решений (или узлов ветвления) как 6, так что подходящее дерево решений использует максимум все предикторы.

explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6);
explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6);
explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6);
explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);

Постройте график важности предиктора с помощью функции объекта plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_lime1);
ax2 = nexttile; plot(explainer_lime2);
ax3 = nexttile; plot(explainer_lime3);
ax4 = nexttile; plot(explainer_lime4);
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Все простые модели идентифицируют Displacement, Model_Year, и Weight как важные предикторы.

Вычислить значения Shapley

Значение Шапли предиктора для точки запроса объясняет отклонение предсказанного ответа для точки запроса от среднего ответа, обусловленного предиктором. Создать shapley объект для модели blackbox использование tbl_subset чтобы shapley вычисляет ожидаемый вклад на основе наблюдений в tbl_subset.

explainer_shapley = shapley(blackbox,tbl_subset);

Вычисление значений Shapley для точек запроса с помощью функции объекта fit.

explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:));
explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:));
explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:));
explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));

Постройте график значений Shapley с помощью функции объекта plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_shapley1)
ax2 = nexttile; plot(explainer_shapley2)
ax3 = nexttile; plot(explainer_shapley3)
ax4 = nexttile; plot(explainer_shapley4)
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Model_Year является наиболее важным предиктором для первой, второй и четвертой точек запроса и значений Шапли Model_Year являются положительными для трех точек запроса. Model_Year значение равно 76 или 77 для этих трех точек, а значение для третьей точки запроса равно 70. Согласно значениям Шепли для четырех точек запроса, небольшой Model_Year значение приводит к снижению прогнозируемого ответа, а большое Model_Year значение приводит к увеличению прогнозируемого ответа по сравнению со средним.

Создание графика частичной зависимости (PDP)

График PDP показывает усредненные отношения между предиктором и прогнозируемым ответом в обученной модели. Создание PDP для Model_Year, которые другие инструменты интерпретируемости идентифицируют как важный предиктор. Проход tbl_subset кому plotPartialDependence таким образом, что функция вычисляет ожидание прогнозируемых ответов, используя только выборки в tbl_subset.

figure
plotPartialDependence(blackbox,'Model_Year',tbl_subset)

График показывает ту же тенденцию, что и значения Shapley для четырех точек запроса. Прогнозируемый отклик (MPG) значение увеличивается по мере Model_Year значение увеличивается.

Ссылки

[1] Рибейро, Марко Тулио, С. Сингх и К. Гестрин. "Почему я должен доверять вам?": Объяснение предсказаний любого классификатора. " В трудах 22-й Международной конференции ACM SIGKDD по открытию знаний и анализу данных, 1135-44. Сан-Франциско, Калифорния: ACM, 2016.

[2] Лундберг, Скотт М. и С. Ли. «Единый подход к интерпретации предсказаний модели». Достижения в системах обработки нейронной информации 30 (2017): 4765-774.

[3] Аас, Кьерсти, Мартин. Джуллум и Андерс Лёланд. «Объяснение индивидуальных предсказаний, когда особенности зависимы: более точные приближения к значениям Шепли». arXiv:1903.10464 (2019).

[4] Фридман, Джером. H. «Жадное приближение функции: градиентная повышающая машина». Анналы статистики 29, № 5 (2001): 1189-1232.

[5] Голдштейн, Алекс, Адам Капельнер, Джастин Блейх и Эмиль Питкин. «Заглядывание внутрь черного ящика: визуализация статистического обучения с помощью графиков индивидуального условного ожидания». Журнал вычислительной и графической статистики 24, № 1 (2 января 2015 г.): 44-65.

См. также

| |

Связанные темы