Интерпретируйте модели машинного обучения

Эта тема представляет функции Statistics and Machine Learning Toolbox™ для интерпретации модели и показывает, как интерпретировать модель машинного обучения (классификация и регрессия).

Модель машинного обучения часто упоминается как модель «черного ящика», потому что может быть трудно понять, как модель делает предсказания. Инструменты интерпретации помогают вам преодолеть этот аспект алгоритмов машинного обучения и раскрыть, как предикторы способствуют (или не способствуют) предсказаниям. Кроме того, можно подтвердить, использует ли модель правильное доказательство для своих предсказаний, и найти смещения модели, которые не сразу очевидны.

Функции для интерпретации модели

Использование lime, shapley, и plotPartialDependence объяснить вклад отдельных предикторов в предсказания обученной классификационной или регрессионой модели.

  • lime - Локальные интерпретируемые модели-агностические объяснения (LIME [1]) интерпретируют предсказание для точки запроса, подгоняя простую интерпретируемую модель для точки запроса. Простая модель действует как приближение для обученной модели и объясняет предсказания модели вокруг точки запроса. Простая модель может быть либо линейной моделью, либо моделью дерева решений. Можно использовать оцененные коэффициенты линейной модели или предполагаемую предикторную важность модели дерева решений, чтобы объяснить вклад отдельных предикторов в предсказание для точки запроса. Для получения дополнительной информации см. LIME.

  • shapley - [2][3] значения Шепли предиктора для точки запроса объясняет отклонение предсказания (ответ для регрессии или счетов классов для классификации) для точки запроса от среднего предсказания из-за предиктора. Для точки запроса сумма значений Shapley для всех функций соответствует общему отклонению предсказания от среднего значения. Для получения дополнительной информации смотрите Значения Shapley для модели машинного обучения.

  • plotPartialDependence и partialDependence - График частичной зависимости (PDP [4]) показывает отношения между предиктором (или парой предикторов) и предсказанием (реакция на регрессию или счета классов для классификации) в обученной модели. Частичная зависимость от выбранного предиктора определяется усредненным предсказанием, полученным путем маргинализации эффекта других переменных. Поэтому частичная зависимость является функцией выбранного предиктора, которая показывает средний эффект выбранного предиктора над набором данных. Можно также создать набор индивидуальных условных графиков ожиданий (ICE [5]) для каждого наблюдения, показывающих эффект выбранного предиктора на одно наблюдение. Для получения дополнительной информации смотрите Подробнее о plotPartialDependence страница с описанием.

Некоторые модели машинного обучения поддерживают выбор признаков встроенного типа, где модель учит предикторную важность как часть процесса обучения модели. Можно использовать предполагаемую важность предиктора, чтобы объяснить предсказания модели. Для примера:

  • Обучите ансамбль (ClassificationBaggedEnsemble или RegressionBaggedEnsemble) пакетированных деревьев принятия решений (для примера, случайных лесов) и использовать predictorImportance и oobPermutedPredictorImportance функций.

  • Обучите линейную модель с регуляризацией лассо, которая сокращает коэффициенты наименее важных предикторов. Затем используйте оцененные коэффициенты в качестве мер для предикторной важности. Для примера используйте fitclinear или fitrlinear и задайте 'Regularization' аргумент имя-значение как 'lasso'.

Список моделей машинного обучения, поддерживающих выбор признаков встроенного типа, см. в разделе «Выбор признаков встроенного типа».

Используйте функции Statistics and Machine Learning Toolbox для трех уровней интерпретации модели: локальной, когорты и глобальной.

УровеньЦельПример использованияФункция набора инструментов Statistics and Machine Learning Toolbox
Интерпретация на местахОбъясните предсказание для одной точки запроса.
  • Идентифицируйте важные предикторы для индивидуального предсказания.

  • Рассмотрим контринтуитивное предсказание.

Использование lime и shapley для заданной точки запроса.
Когортная интерпретацияОбъясните, как обученная модель делает предсказания для подмножества всего набора данных.Проверьте предсказания для определенной группы выборок.
  • Использование lime и shapley для нескольких точек запроса. После создания lime или shapley объект, можно вызвать функцию объекта fit несколько раз, чтобы интерпретировать предсказания для других точек запроса.

  • Передайте подмножество данных при вызове lime, shapley, и plotPartialDependence. Функции интерпретируют обученную модель с помощью заданного подмножества вместо всего набора обучающих данных.

Глобальная интерпретацияОбъясните, как обученная модель делает предсказания для всего набора данных.
  • Продемонстрировать, как работает обученная модель.

  • Сравните различные модели.

  • Использовать plotPartialDependence создать PDP и графики ICE для предикторов интереса.

  • Найдите важные предикторы из обученной модели, которая поддерживает выбор признаков встроенного типа.

Интерпретируйте классификационную модель

Этот пример обучает ансамбль мешанных деревьев принятия решений с помощью случайного алгоритма леса и интерпретирует обученную модель с помощью функций интерпретации. Используйте функции объекта (oobPermutedPredictorImportance и predictorImportance) обученной модели для поиска важных предикторов в модели. Кроме того, используйте lime и shapley для интерпретации предсказаний для заданных точек запроса. Затем используйте plotPartialDependence создать график, который показывает отношения между важным предиктором и предсказанными классификационными оценками.

Train модели ансамбля классификации

Загрузите CreditRating_Historical набор данных. Набор данных содержит идентификаторы клиентов и их финансовые коэффициенты, отраслевые метки и кредитные рейтинги.

tbl = readtable('CreditRating_Historical.dat');

Отобразите первые три строки таблицы.

head(tbl,3)
ans=3×8 table
     ID      WC_TA    RE_TA    EBIT_TA    MVE_BVTD    S_TA     Industry    Rating
    _____    _____    _____    _______    ________    _____    ________    ______

    62394    0.013    0.104     0.036      0.447      0.142       3        {'BB'}
    48608    0.232    0.335     0.062      1.969      0.281       8        {'A' }
    42444    0.311    0.367     0.074      1.935      0.366       1        {'A' }

Составьте таблицу переменных предиктора путем удаления столбцов, содержащих идентификаторы клиентов и оценки из tbl.

tblX = removevars(tbl,["ID","Rating"]);

Обучите ансамбль мешанных деревьев принятия решений при помощи fitcensemble функция и определение метода агрегации ансамбля как случайного леса ('Bag'). Для воспроизводимости алгоритма случайного леса задайте 'Reproducible' аргумент имя-значение как true для учащихся дерева. Кроме того, укажите имена классов, чтобы задать порядок классов в обученной модели.

rng('default') % For reproducibility
t = templateTree('Reproducible',true);
blackbox = fitcensemble(tblX,tbl.Rating, ...
    'Method','Bag','Learners',t, ...
    'CategoricalPredictors','Industry', ...
    'ClassNames',{'AAA' 'AA' 'A' 'BBB' 'BB' 'B' 'CCC'});

blackbox является ClassificationBaggedEnsemble модель.

Используйте специфичные для модели функции интерпретации

ClassificationBaggedEnsemble поддерживает две функции объекта, oobPermutedPredictorImportance и predictorImportance, которые находят важные предикторы в обученной модели.

Оцените значение предиктора вне мешка при помощи oobPermutedPredictorImportance функция. Функция случайным образом перестраивает данные из мешка через один предиктор за раз и оценивает увеличение ошибки вне мешка из-за этого сочетания. Чем больше увеличение, тем важнее функция.

Imp1 = oobPermutedPredictorImportance(blackbox);

Оцените важность предиктора при помощи predictorImportance функция. Функция оценивает важность предиктора путем суммирования изменений в риске узла из-за расщеплений на каждом предикторе и деления суммы на количество узлов ветви.

Imp2 = predictorImportance(blackbox);

Составьте таблицу, содержащую предикторные оценки важности, и используйте таблицу для создания горизонтальных столбчатых графиков. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter значение осей для 'none'.

table_Imp = table(Imp1',Imp2', ...
    'VariableNames',{'Out-of-Bag Permuted Predictor Importance','Predictor Importance'}, ...
    'RowNames',blackbox.PredictorNames);
tiledlayout(1,2)
ax1 = nexttile;
table_Imp1 = sortrows(table_Imp,'Out-of-Bag Permuted Predictor Importance');
barh(categorical(table_Imp1.Row,table_Imp1.Row),table_Imp1.('Out-of-Bag Permuted Predictor Importance'))
xlabel('Out-of-Bag Permuted Predictor Importance')
ylabel('Predictor')
ax2 = nexttile;
table_Imp2 = sortrows(table_Imp,'Predictor Importance');
barh(categorical(table_Imp2.Row,table_Imp2.Row),table_Imp2.('Predictor Importance'))
xlabel('Predictor Importance')
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';

Обе функции объекта идентифицируют MVE_BVTD и RE_TA как два наиболее важных предиктора.

Задайте точку запроса

Найдите наблюдения, чьи Rating является 'AAA' и выберите среди них четыре точки запроса.

tblX_AAA = tblX(strcmp(tbl.Rating,'AAA'),:);
queryPoint = datasample(tblX_AAA,4,'Replace',false)
queryPoint=4×6 table
    WC_TA    RE_TA    EBIT_TA    MVE_BVTD    S_TA     Industry
    _____    _____    _______    ________    _____    ________

    0.331    0.531     0.077      7.116      0.522       12   
     0.26    0.515     0.065      3.394      0.515        1   
    0.121    0.413     0.057      3.647      0.466       12   
    0.617    0.766     0.126      4.442      0.483        9   

Использование LIME с линейными простыми моделями

Объясните предсказания для точек запроса используя lime с линейными простыми моделями. lime генерирует синтетический набор данных и подгоняет простую модель к синтетическому набору данных.

Создайте lime объект, использующий tblX_AAA так что lime генерирует синтетический набор данных, используя только наблюдения, чьи Rating является 'AAA', а не весь набор данных.

explainer_lime = lime(blackbox,tblX_AAA);

Значение по умолчанию 'DataLocality' для lime является 'global', что подразумевает, что по умолчанию lime генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime использует различные веса наблюдений, так что значения веса больше фокусируются на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.

Подгонка простых моделей для четырех точек запроса с помощью функции объекта fit. Укажите третий вход (количество важных предикторов для использования в простой модели) как 6, чтобы использовать все шесть предикторов.

explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6);
explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6);
explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6);
explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);

Постройте график коэффициентов простых моделей с помощью функции объекта plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_lime1);
ax2 = nexttile; plot(explainer_lime2);
ax3 = nexttile; plot(explainer_lime3);
ax4 = nexttile; plot(explainer_lime4);
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Все простые модели идентифицируют EBIT_TA, RE_TA, и MVE_BVTD как три наиболее важных предиктора. Положительные коэффициенты для предикторов предполагают, что увеличение значений предиктора приводит к увеличению предсказанных счетов в простых моделях.

Для категориального предиктора, plot функция отображает только самую важную фиктивную переменную категориального предиктора. Поэтому на каждой гистограмме отображается другая переменная манекена.

Вычисление значений Shapley

Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного счета для точки запроса от среднего счета из-за предиктора. Создайте shapley объект, использующий tblX_AAA так что shapley вычисляет ожидаемый вклад на основе выборок для 'AAA'.

explainer_shapley = shapley(blackbox,tblX_AAA);

Вычислите значения Shapley для точек запроса с помощью функции объекта fit.

explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:));
explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:));
explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:));
explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));

Постройте график значений Shapley с помощью функции объекта plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_shapley1)
ax2 = nexttile; plot(explainer_shapley2)
ax3 = nexttile; plot(explainer_shapley3)
ax4 = nexttile; plot(explainer_shapley4)
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

MVE_BVTD и RE_TA являются двумя из трех наиболее важных предикторов для всех четырех точек запроса.

Значения Шепли MVE_BVTD положительны для первой и четвертой точек запроса и отрицательны для второй и третьей точек запроса. The MVE_BVTD значения составляют около 7 и 4 для первой и четвертой точек запроса, соответственно, и значение для второй и третьей точек запроса составляет около 3,5. Согласно значениям Шепли для четырех точек запроса, большое MVE_BVTD значение приводит к увеличению предсказанного счета, и небольшому MVE_BVTD значение приводит к снижению прогнозируемых счетов по сравнению со средним значением. Результаты соответствуют результатам lime.

Создайте график частичной зависимости (PDP)

График PDP показывает усредненные отношения между предиктором и предсказанным счетом в обученной модели. Создайте PDP для RE_TA и MVE_BVTD, которые другие инструменты интерпретации идентифицируют как важные предикторы. Передайте tblx_AAA на plotPartialDependence так, что функция вычисляет ожидание предсказанных счетов, используя только выборки для 'AAA'.

figure
plotPartialDependence(blackbox,'RE_TA','AAA',tblX_AAA)

plotPartialDependence(blackbox,'MVE_BVTD','AAA',tblX_AAA)

Минорные такты в x-ось представляет уникальные значения предиктора в tbl_AAA. График для MVE_BVTD показывает, что предсказанный счет велик, когда MVE_BVTD значение невелико. Значение баллов уменьшается как MVE_BVTD значение увеличений пока оно не достигает около 5, и затем значение баллов остается неизменным как MVE_BVTD значение увеличивается. Зависимость от MVE_BVTD в подмножестве tbl_AAA идентифицируется по plotPartialDependence не соответствует местным вкладам MVE_BVTD в четырех точках запроса, обозначенных lime и shapley.

Интерпретируйте регрессионую модель

Рабочий процесс интерпретации модели для регрессионной задачи аналогичен рабочему процессу для классификационной задачи, как показано в примере Интерпретационная классификационная модель.

Этот пример обучает модель регрессии Гауссова процесса (GPR) и интерпретирует обученную модель, используя функции интерпретации. Используйте параметр ядра модели GPR, чтобы оценить веса предикторов. Кроме того, используйте lime и shapley для интерпретации предсказаний для заданных точек запроса. Затем используйте plotPartialDependence создать график, который показывает отношения между важным предиктором и предсказанными откликами.

Обучите модель GPR

Загрузите carbig набор данных, содержащий измерения автомобилей 1970-х и начала 1980-х годов.

load carbig

Создайте таблицу, содержащую переменные предиктора Acceleration, Cylinders, и так далее

tbl = table(Acceleration,Cylinders,Displacement,Horsepower,Model_Year,Weight);

Обучите модель GPR переменной отклика MPG при помощи fitrgp функция. Задайте 'KernelFunction' как 'ardsquaredexponential' использовать квадратное экспоненциальное ядро с отдельной шкалой длины на предиктор.

blackbox = fitrgp(tbl,MPG,'ResponseName','MPG','CategoricalPredictors',[2 5], ...
    'KernelFunction','ardsquaredexponential');

blackbox является RegressionGP модель.

Используйте специфичные для модели функции интерпретации

Можно вычислить веса предикторов (важность предиктора) из выученных шкал длины функции ядра, используемой в модели. Шкалы длины определяют, насколько далеко может быть предиктор, чтобы значения отклика стали некоррелированными. Найдите нормированные веса предиктора путем взятия экспоненциала отрицательной выученной длины.

sigmaL = blackbox.KernelInformation.KernelParameters(1:end-1); % Learned length scales
weights = exp(-sigmaL); % Predictor weights
weights = weights/sum(weights); % Normalized predictor weights

Создайте таблицу, содержащую нормированные предикторные веса, и используйте таблицу, чтобы создать горизонтальные столбчатые графики. Чтобы отобразить существующее подчеркивание в любом имени предиктора, измените TickLabelInterpreter значение осей для 'none'.

tbl_weight = table(weights,'VariableNames',{'Predictor Weight'}, ...
    'RowNames',blackbox.ExpandedPredictorNames);
tbl_weight = sortrows(tbl_weight,'Predictor Weight');
b = barh(categorical(tbl_weight.Row,tbl_weight.Row),tbl_weight.('Predictor Weight'));
b.Parent.TickLabelInterpreter = 'none'; 
xlabel('Predictor Weight')
ylabel('Predictor')

Веса предиктора указывают, что несколько фиктивных переменных для категориальных предикторов Model_Year и Cylinders важны.

Задайте точку запроса

Найдите наблюдения, чьи MPG значения меньше, чем 0,25 квантиля MPG. Из подмножества выберите четыре точки запроса, которые не содержат отсутствующих значений.

rng('default') % For reproducibility
idx_subset = find(MPG < quantile(MPG,0.25));
tbl_subset = tbl(idx_subset,:);
queryPoint = datasample(rmmissing(tbl_subset),4,'Replace',false)
queryPoint=4×6 table
    Acceleration    Cylinders    Displacement    Horsepower    Model_Year    Weight
    ____________    _________    ____________    __________    __________    ______

        13.2            8            318            150            76         3940 
        14.9            8            302            130            77         4295 
          14            8            360            215            70         4615 
        13.7            8            318            145            77         4140 

Использование LIME с древовидными простыми моделями

Объясните предсказания для точек запроса используя lime с простыми моделями дерева решений. lime генерирует синтетический набор данных и подгоняет простую модель к синтетическому набору данных.

Создайте lime объект, использующий tbl_subset так что lime генерирует синтетический набор данных, используя подмножество вместо всего набора данных. Задайте 'SimpleModelType' как 'tree' использовать простую модель дерева решений.

explainer_lime = lime(blackbox,tbl_subset,'SimpleModelType','tree');

Значение по умолчанию 'DataLocality' для lime является 'global', что подразумевает, что по умолчанию lime генерирует глобальный синтетический набор данных и использует его для любых точек запроса. lime использует различные веса наблюдений, так что значения веса больше фокусируются на наблюдениях около точки запроса. Поэтому можно интерпретировать каждую простую модель как приближение обученной модели для определенной точки запроса.

Подгонка простых моделей для четырех точек запроса с помощью функции объекта fit. Задайте третий вход (количество важных предикторов для использования в простой модели) как 6. С помощью этой настройки программное обеспечение задает максимальное количество разделений решений (или узлов ветви) как 6, так что подобранное дерево решений использует самое большее все предикторы.

explainer_lime1 = fit(explainer_lime,queryPoint(1,:),6);
explainer_lime2 = fit(explainer_lime,queryPoint(2,:),6);
explainer_lime3 = fit(explainer_lime,queryPoint(3,:),6);
explainer_lime4 = fit(explainer_lime,queryPoint(4,:),6);

Постройте график важности предиктора с помощью функции объекта plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_lime1);
ax2 = nexttile; plot(explainer_lime2);
ax3 = nexttile; plot(explainer_lime3);
ax4 = nexttile; plot(explainer_lime4);
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Все простые модели идентифицируют Displacement, Model_Year, и Weight как важные предикторы.

Вычисление значений Shapley

Значение Шепли предиктора для точки запроса объясняет отклонение предсказанного отклика для точки запроса от среднего отклика из-за предиктора. Создайте shapley объект для модели blackbox использование tbl_subset так что shapley вычисляет ожидаемый вклад на основе наблюдений в tbl_subset.

explainer_shapley = shapley(blackbox,tbl_subset);

Вычислите значения Shapley для точек запроса с помощью функции объекта fit.

explainer_shapley1 = fit(explainer_shapley,queryPoint(1,:));
explainer_shapley2 = fit(explainer_shapley,queryPoint(2,:));
explainer_shapley3 = fit(explainer_shapley,queryPoint(3,:));
explainer_shapley4 = fit(explainer_shapley,queryPoint(4,:));

Постройте график значений Shapley с помощью функции объекта plot.

tiledlayout(2,2)
ax1 = nexttile; plot(explainer_shapley1)
ax2 = nexttile; plot(explainer_shapley2)
ax3 = nexttile; plot(explainer_shapley3)
ax4 = nexttile; plot(explainer_shapley4)
ax1.TickLabelInterpreter = 'none';
ax2.TickLabelInterpreter = 'none';
ax3.TickLabelInterpreter = 'none';
ax4.TickLabelInterpreter = 'none';

Model_Year является наиболее важным предиктором для первой, второй и четвертой точек запроса и значений Шепли Model_Year положительны для трех точек запроса. The Model_Year значение является 76 или 77 для этих трех точек, и значение для третьей точки запроса равняется 70. Согласно значениям Шепли для четырех точек запроса, небольшая Model_Year значение приводит к уменьшению предсказанного отклика и большому Model_Year значение приводит к увеличению предсказанного отклика по сравнению со средним значением.

Создайте график частичной зависимости (PDP)

График PDP показывает усредненные отношения между предиктором и предсказанной реакцией в обученной модели. Создайте PDP для Model_Year, который другие инструменты интерпретации идентифицируют как важный предиктор. Передайте tbl_subset на plotPartialDependence так, что функция вычисляет ожидание предсказанных ответов, используя только выборки в tbl_subset.

figure
plotPartialDependence(blackbox,'Model_Year',tbl_subset)

График показывает тот же тренд, идентифицированную значениями Шепли для четырех точек запроса. Предсказанный ответ (MPG) значение увеличивается как Model_Year значение увеличивается.

Ссылки

[1] Рибейро, Марко Тулио, С. Сингх и К. Гестрин. "Почему я должен доверять вам?": Объяснение предсказаний любого классификатора ". В Трудах 22-й Международной конференции ACM SIGKDD по открытию знаний и майнингу данных, 1135-44. Сан-Франциско, Калифорния: ACM, 2016.

[2] Лундберг, Скотт М. и С. Ли. «Единый подход к интерпретации модельных Предсказаний». Усовершенствования в нейронных системах обработки информации 30 (2017): 4765-774.

[3] Аас, Кьерсти, Мартин. Джуллум и Андерс Лёланд. «Объяснение индивидуальных предсказаний, когда функции зависимы: более точные приближения к значениям Шепли». arXiv:1903.10464 (2019).

[4] Фридман, Джером. H. «Жадное приближение функций: градиентная бустинговая машина». Анналы статистики 29, № 5 (2001): 1189-1232.

[5] Голдштейн, Алекс, Адам Капельнер, Джастин Блейх и Эмиль Питкин. Peeking Inside the Black Box: Visualizing Statistical Learning with Plots of Individual Conditional Development (неопр.) (недоступные графики). Журнал вычислительно-графической статистики 24, № 1 (2 января 2015): 44-65.

См. также

| |

Похожие темы