Интерпретируйте глубокие сетевые предсказания на табличных данных Используя LIME

В этом примере показано, как использовать метод локально поддающихся толкованию объяснений модели агностических (LIME), чтобы изучить предсказания глубокой нейронной сети, классифицирующей табличные данные. Можно использовать метод LIME, чтобы понять, какие предикторы являются самыми важными для решения классификации о сети.

В этом примере вы интерпретируете сеть классификации данных о функции использование LIME. Для заданного наблюдения запроса LIME генерирует синтетический набор данных, статистические данные которого для каждой функции совпадают с действительным набором данных. Этот синтетический набор данных передается через глубокую нейронную сеть, чтобы получить классификацию, и подбирается простая, поддающаяся толкованию модель. Эта простая модель может использоваться, чтобы изучить важность главных немногих функций к решению классификации о сети. В обучении эта поддающаяся толкованию модель синтетические наблюдения взвешиваются их расстоянием от наблюдения запроса, таким образом, объяснение "локально" для того наблюдения.

Этот пример использует lime (Statistics and Machine Learning Toolbox) и fit (Statistics and Machine Learning Toolbox), чтобы сгенерировать синтетический набор данных и подбирать простую поддающуюся толкованию модель к синтетическому набору данных. Чтобы изучить предсказания обученной нейронной сети классификации изображений, используйте imageLIME. Для получения дополнительной информации смотрите, Изучают Сетевые Предсказания Используя LIME.

Загрузка данных

Загрузите ирисовый набор данных Фишера. Эти данные содержат 150 наблюдений с четырьмя входными функциями, представляющими параметры объекта и одного категориального ответа, представляющего виды растений. Каждое наблюдение классифицируется как одна из трех разновидностей: setosa, versicolor, или virginica. Каждое наблюдение имеет четыре измерения: ширина чашелистика, длина чашелистика, лепестковая ширина и лепестковая длина.

filename = fullfile(toolboxdir('stats'),'statsdemos','fisheriris.mat');
load(filename)

Преобразуйте числовые данные в таблицу.

features = ["Sepal length","Sepal width","Petal length","Petal width"];

predictors = array2table(meas,"VariableNames",features);
trueLabels = array2table(categorical(species),"VariableNames","Response");

Составьте таблицу обучающих данных, последний столбец которых является ответом.

data = [predictors trueLabels];

Вычислите количество наблюдений, функций и классов.

numObservations = size(predictors,1);
numFeatures = size(predictors,2);
numClasses = length(categories(data{:,5}));

Разделение данных в обучение, валидацию и наборы тестов

Разделите набор данных в обучение, валидацию и наборы тестов. Отложите 15% данных для валидации и 15% для тестирования.

Определите количество наблюдений для каждого раздела. Установите случайный seed делать разделение данных и обучение центрального процессора восстанавливаемыми.

rng('default');
numObservationsTrain = floor(0.7*numObservations);
numObservationsValidation = floor(0.15*numObservations);

Создайте массив случайных индексов, соответствующих наблюдениям, и разделите его с помощью размеров раздела.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation);
idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);

Разделите таблицу данных в обучение, валидацию и разделы тестирования с помощью индексов.

dataTrain = data(idxTrain,:);
dataVal = data(idxValidation,:);
dataTest = data(idxTest,:);

Архитектура сети Define

Создайте простой многоуровневый perceptron с одним скрытым слоем с пятью нейронами и активациями ReLU. Входной слой функции принимает данные, содержащие числовые скаляры, представляющие функции, такие как ирисовый набор данных Фишера.

numHiddenUnits = 5;
layers = [
    featureInputLayer(numFeatures)
    fullyConnectedLayer(numHiddenUnits)
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Задайте опции обучения и обучите сеть

Обучите сеть с помощью стохастического градиентного спуска с импульсом (SGDM). Определите максимальный номер эпох к 30 и используйте мини-пакетный размер 15, когда обучающие данные не содержат много наблюдений.

opts = trainingOptions("sgdm", ...
    "MaxEpochs",30, ...
    "MiniBatchSize",15, ...
    "Shuffle","every-epoch", ...
    "ValidationData",dataVal, ...
    "ExecutionEnvironment","cpu");

Обучите сеть.

net = trainNetwork(dataTrain,layers,opts);
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:00 |       40.00% |       31.82% |       1.3060 |       1.2897 |          0.0100 |
|       8 |          50 |       00:00:00 |       86.67% |       90.91% |       0.4223 |       0.3656 |          0.0100 |
|      15 |         100 |       00:00:00 |       93.33% |       86.36% |       0.2947 |       0.2927 |          0.0100 |
|      22 |         150 |       00:00:00 |       86.67% |       81.82% |       0.2804 |       0.3707 |          0.0100 |
|      29 |         200 |       00:00:01 |       86.67% |       90.91% |       0.2268 |       0.2129 |          0.0100 |
|      30 |         210 |       00:00:01 |       93.33% |       95.45% |       0.2782 |       0.1666 |          0.0100 |
|======================================================================================================================|

Оцените производительность сети

Классифицируйте наблюдения от набора тестов с помощью обучившего сеть.

predictedLabels = net.classify(dataTest);
trueLabels = dataTest{:,end};

Визуализируйте результаты с помощью матрицы беспорядка.

figure
confusionchart(trueLabels,predictedLabels)

Сеть успешно использует четыре функции объекта, чтобы предсказать разновидности тестовых наблюдений.

Изучите, как различные предикторы важны для различных классов

Используйте LIME, чтобы изучить важность каждого предиктора к решениям классификации о сети.

Исследуйте два самых важных предиктора для каждого наблюдения.

numImportantPredictors = 2;

Используйте lime создать синтетический набор данных, статистические данные которого для каждой функции совпадают с действительным набором данных. Создайте lime объект с помощью модели blackbox глубокого обучения и данные о предикторе содержатся в predictors. Используйте низкий 'KernelWidth' значение так lime веса использования, которые фокусируются на выборках около точки запроса.

blackbox = @(x)classify(net,x);
explainer = lime(blackbox,predictors,'Type','classification','KernelWidth',0.1);

Можно использовать человека, который объясняет LIME, чтобы изучить самые важные функции к глубокой нейронной сети. Функция оценивает важность функции при помощи простой линейной модели, которая аппроксимирует нейронную сеть около наблюдения запроса.

Найдите индексы первых двух наблюдений в тестовых данных, соответствующих setosa классу.

trueLabelsTest = dataTest{:,end};

label = "setosa";
idxSetosa = find(trueLabelsTest == label,2);

Используйте fit функция, чтобы подбирать простую линейную модель к первым двум наблюдениям от заданного класса.

explainerObs1 = fit(explainer,dataTest(idxSetosa(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxSetosa(2),1:4),numImportantPredictors);

Постройте график результатов.

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для setosa класса самые важные предикторы являются низким лепестковым значением длины и высоким значением ширины чашелистика.

Выполните тот же анализ для класса versicolor.

label = "versicolor";
idxVersicolor = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVersicolor(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVersicolor(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для versicolor класса высокое лепестковое значение длины важно.

Наконец, рассмотрите virginica класс.

label = "virginica";
idxVirginica = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVirginica(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVirginica(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для virginica класса, высокого лепесткового значения длины и низкого значения ширины чашелистика важно.

Подтвердите гипотезу LIME

Графики LIME предполагают, что высокое лепестковое значение длины сопоставлено с versicolor и virginica классами, и низкое лепестковое значение длины сопоставлено с setosa классом. Можно исследовать результаты далее путем исследования данных.

Постройте лепестковую длину каждого изображения в наборе данных.

setosaIdx = ismember(data{:,end},"setosa");
versicolorIdx = ismember(data{:,end},"versicolor");
virginicaIdx = ismember(data{:,end},"virginica");

figure
hold on
plot(data{setosaIdx,"Petal length"},'.')
plot(data{versicolorIdx,"Petal length"},'.')
plot(data{virginicaIdx,"Petal length"},'.')
hold off

xlabel("Observation number")
ylabel("Petal length")
legend(["setosa","versicolor","virginica"])

setosa класс имеет намного более низкие лепестковые значения длины, чем другие классы, совпадая с результатами, приведенными от lime модель.

Смотрите также

| | | | (Statistics and Machine Learning Toolbox) | (Statistics and Machine Learning Toolbox)

Похожие темы