exponenta event banner

Интерпретация глубоких сетевых прогнозов по табличным данным с помощью LIME

Этот пример показывает, как использовать метод локально интерпретируемых моделей-агностических объяснений (LIME), чтобы понять прогнозы глубокой нейронной сети, классифицирующей табличные данные. Можно использовать метод LIME, чтобы понять, какие предикторы наиболее важны для решения о классификации сети.

В этом примере с помощью LIME интерпретируется сеть классификации данных элементов. Для указанного наблюдения запроса LIME генерирует синтетический набор данных, статистика по каждому элементу которого соответствует реальному набору данных. Этот синтетический набор данных пропускают через глубокую нейронную сеть для получения классификации и подгоняют простую, интерпретируемую модель. Эта простая модель может использоваться для понимания важности нескольких основных функций для решения о классификации сети. При обучении этой интерпретируемой модели синтетические наблюдения взвешиваются по их расстоянию от наблюдения запроса, поэтому объяснение является «локальным» к этому наблюдению.

В этом примере используется lime (Статистика и инструментарий машинного обучения) и fit (Statistics and Machine Learning Toolbox) для создания набора синтетических данных и подгонки простой интерпретируемой модели к набору синтетических данных. Чтобы понять прогнозы обученной нейронной сети классификации изображений, используйте imageLIME. Дополнительные сведения см. в разделе Понимание сетевых прогнозов с помощью LIME.

Загрузить данные

Загрузите набор данных радужки Фишера. Эти данные содержат 150 наблюдений с четырьмя входными признаками, представляющими параметры растения, и одним категориальным ответом, представляющим вид растения. Каждое наблюдение классифицируется как один из трёх видов: сетоза, версиколор или virginica. Каждое наблюдение имеет четыре измерения: ширина чашелистика, длина чашелистика, ширина лепестка и длина лепестка.

filename = fullfile(toolboxdir('stats'),'statsdemos','fisheriris.mat');
load(filename)

Преобразование числовых данных в таблицу.

features = ["Sepal length","Sepal width","Petal length","Petal width"];

predictors = array2table(meas,"VariableNames",features);
trueLabels = array2table(categorical(species),"VariableNames","Response");

Создайте таблицу данных обучения, последний столбец которой является ответом.

data = [predictors trueLabels];

Вычислите количество наблюдений, элементов и классов.

numObservations = size(predictors,1);
numFeatures = size(predictors,2);
numClasses = length(categories(data{:,5}));

Разделение данных на обучающие, проверочные и тестовые наборы

Разбейте набор данных на обучающие, проверочные и тестовые наборы. Отложите 15% данных для проверки и 15% для тестирования.

Определите количество наблюдений для каждого раздела. Задайте случайное начальное значение для обеспечения воспроизводимости разделения данных и обучения ЦП.

rng('default');
numObservationsTrain = floor(0.7*numObservations);
numObservationsValidation = floor(0.15*numObservations);

Создайте массив случайных индексов, соответствующих наблюдениям, и разделите его с использованием размеров разделов.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation);
idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);

Разбиение таблицы данных на разделы обучения, проверки и тестирования с использованием индексов.

dataTrain = data(idxTrain,:);
dataVal = data(idxValidation,:);
dataTest = data(idxTest,:);

Определение сетевой архитектуры

Создайте простой многослойный перцептрон с одним скрытым слоем с пятью нейронами и активациями ReLU. Входной уровень элемента принимает данные, содержащие числовые скаляры, представляющие элементы, такие как набор данных радужки Фишера.

numHiddenUnits = 5;
layers = [
    featureInputLayer(numFeatures)
    fullyConnectedLayer(numHiddenUnits)
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Определение вариантов обучения и сети обучения

Обучение сети с помощью стохастического градиентного спуска с импульсом (SGDM). Установите максимальное количество эпох равным 30 и используйте размер мини-партии 15, так как обучающие данные не содержат большого количества наблюдений.

opts = trainingOptions("sgdm", ...
    "MaxEpochs",30, ...
    "MiniBatchSize",15, ...
    "Shuffle","every-epoch", ...
    "ValidationData",dataVal, ...
    "ExecutionEnvironment","cpu");

Обучение сети.

net = trainNetwork(dataTrain,layers,opts);
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:00 |       40.00% |       31.82% |       1.3060 |       1.2897 |          0.0100 |
|       8 |          50 |       00:00:00 |       86.67% |       90.91% |       0.4223 |       0.3656 |          0.0100 |
|      15 |         100 |       00:00:00 |       93.33% |       86.36% |       0.2947 |       0.2927 |          0.0100 |
|      22 |         150 |       00:00:00 |       86.67% |       81.82% |       0.2804 |       0.3707 |          0.0100 |
|      29 |         200 |       00:00:01 |       86.67% |       90.91% |       0.2268 |       0.2129 |          0.0100 |
|      30 |         210 |       00:00:01 |       93.33% |       95.45% |       0.2782 |       0.1666 |          0.0100 |
|======================================================================================================================|

Оценка производительности сети

Классификация наблюдений из тестового набора с использованием обученной сети.

predictedLabels = net.classify(dataTest);
trueLabels = dataTest{:,end};

Визуализация результатов с помощью матрицы путаницы.

figure
confusionchart(trueLabels,predictedLabels)

Сеть успешно использует четыре функции растений для прогнозирования видов тестовых наблюдений.

Понять, как разные предикторы важны для разных классов

Используйте LIME, чтобы понять важность каждого предиктора для классификационных решений сети.

Исследуйте два наиболее важных предиктора для каждого наблюдения.

numImportantPredictors = 2;

Использовать lime для создания синтетического набора данных, статистика по каждому элементу которого соответствует реальному набору данных. Создать lime объект с использованием модели глубокого обучения blackbox и данные предиктора, содержащиеся в predictors. Использовать низкий 'KernelWidth' значение так lime использует веса, сфокусированные на выборках вблизи точки запроса.

blackbox = @(x)classify(net,x);
explainer = lime(blackbox,predictors,'Type','classification','KernelWidth',0.1);

Вы можете использовать объяснение LIME, чтобы понять наиболее важные особенности глубокой нейронной сети. Функция оценивает важность признака, используя простую линейную модель, которая аппроксимирует нейронную сеть в окрестности наблюдения запроса.

Найдите индексы первых двух наблюдений в тестовых данных, соответствующих классу setosa.

trueLabelsTest = dataTest{:,end};

label = "setosa";
idxSetosa = find(trueLabelsTest == label,2);

Используйте fit для подгонки простой линейной модели к первым двум наблюдениям из указанного класса.

explainerObs1 = fit(explainer,dataTest(idxSetosa(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxSetosa(2),1:4),numImportantPredictors);

Постройте график результатов.

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для класса setosa наиболее важными предикторами являются низкое значение длины лепестка и высокое значение ширины чашелистика.

Выполните такой же анализ для класса versicolor.

label = "versicolor";
idxVersicolor = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVersicolor(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVersicolor(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для класса versicolor важно высокое значение длины лепестка.

Наконец, рассмотрим класс virginica.

label = "virginica";
idxVirginica = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVirginica(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVirginica(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для класса virginica важно высокое значение длины лепестка и низкое значение ширины чашелистика.

Проверка гипотезы LIME

Графики LIME предполагают, что высокое значение длины лепестка связано с классами versicolor и virginica, а низкое значение длины лепестка связано с классом setosa. Результаты можно исследовать дополнительно, изучив данные.

Постройте график длины лепестка каждого изображения в наборе данных.

setosaIdx = ismember(data{:,end},"setosa");
versicolorIdx = ismember(data{:,end},"versicolor");
virginicaIdx = ismember(data{:,end},"virginica");

figure
hold on
plot(data{setosaIdx,"Petal length"},'.')
plot(data{versicolorIdx,"Petal length"},'.')
plot(data{virginicaIdx,"Petal length"},'.')
hold off

xlabel("Observation number")
ylabel("Petal length")
legend(["setosa","versicolor","virginica"])

Класс setosa имеет гораздо более низкие значения длины лепестков, чем другие классы, соответствующие результатам, полученным из lime модель.

См. также

| | | | (инструментарий статистики и машинного обучения) | (Набор инструментов для статистики и машинного обучения)

Связанные темы