Интерпретируйте глубокие сетевые Предсказания на табличных данных с помощью LIME

Этот пример показывает, как использовать локально интерпретируемый метод моделирования-агностических объяснений (LIME), чтобы понять предсказания глубокой нейронной сети, классифицирующей табличные данные. Можно использовать метод LIME, чтобы понять, какие предикторы наиболее важны для решения о классификации сети.

В этом примере вы интерпретируете сеть классификации данных функций с помощью LIME. Для заданного наблюдения запроса LIME генерирует синтетический набор данных, статистика которого для каждой функции соответствует реальному набору данных. Этот синтетический набор данных передается через глубокую нейронную сеть, чтобы получить классификацию, и подгоняется простая, интерпретируемая модель. Эта простая модель может использоваться, чтобы понять важность нескольких верхних функций для решения о классификации сети. При обучении этой интерпретируемой модели синтетические наблюдения взвешиваются по их расстоянию от наблюдения запроса, поэтому объяснение является «локальным» к этому наблюдению.

Этот пример использует lime (Statistics and Machine Learning Toolbox) и fit (Statistics and Machine Learning Toolbox), чтобы сгенерировать синтетический набор данных и соответствовать простой интерпретируемой модели синтетическому набору данных. Чтобы понять предсказания обученной нейронной сети классификации изображений, используйте imageLIME. Для получения дополнительной информации см. Раздел «Изучение предсказаний с использованием LIME».

Загрузка данных

Загрузите набор данных радужной оболочки глаза Фишера. Эти данные содержат 150 наблюдений с четырьмя входными функциями, представляющими параметры объекта, и один категориальный ответ, представляющий вид объекта. Каждое наблюдение классифицируется как один из трёх видов: setosa, versicolor или virginica. Каждое наблюдение имеет четыре измерения: ширина чашелистика, длина чашелистика, ширина лепестка и длина лепестка.

filename = fullfile(toolboxdir('stats'),'statsdemos','fisheriris.mat');
load(filename)

Преобразуйте числовые данные в таблицу.

features = ["Sepal length","Sepal width","Petal length","Petal width"];

predictors = array2table(meas,"VariableNames",features);
trueLabels = array2table(categorical(species),"VariableNames","Response");

Составьте таблицу обучающих данных, последний столбец которой является ответом.

data = [predictors trueLabels];

Вычислим количество наблюдений, функций и классов.

numObservations = size(predictors,1);
numFeatures = size(predictors,2);
numClasses = length(categories(data{:,5}));

Разделение данных на наборы для обучения, валидации и тестирования

Разделите набор данных на наборы для обучения, валидации и тестирования. Выделите 15% данных для валидации и 15% для проверки.

Определите количество наблюдений для каждого раздела. Установите случайный seed, чтобы сделать разделение данных и обучение центральный процессор воспроизводимыми.

rng('default');
numObservationsTrain = floor(0.7*numObservations);
numObservationsValidation = floor(0.15*numObservations);

Создайте массив случайных индексов, соответствующих наблюдениям, и разбейте его с помощью размеров разбиений.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation);
idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);

Разделите таблицу данных на разделы для обучения, валидации и проверки с помощью индексов.

dataTrain = data(idxTrain,:);
dataVal = data(idxValidation,:);
dataTest = data(idxTest,:);

Определение сетевой архитектуры

Создайте простой многослойный перцептрон с одним скрытым слоем с пятью нейронами и активацией ReLU. Входной слой функции принимает данные, содержащие числовые скаляры, представляющие функциям, такие как набор данных радужной оболочки глаза Фишера.

numHiddenUnits = 5;
layers = [
    featureInputLayer(numFeatures)
    fullyConnectedLayer(numHiddenUnits)
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Определите опции обучения и обучите сеть

Обучите сеть с помощью стохастического градиентного спуска с импульсом (SGDM). Установите максимальное количество эпох равное 30 и используйте мини-пакет размером 15, так как обучающие данные не содержат много наблюдений.

opts = trainingOptions("sgdm", ...
    "MaxEpochs",30, ...
    "MiniBatchSize",15, ...
    "Shuffle","every-epoch", ...
    "ValidationData",dataVal, ...
    "ExecutionEnvironment","cpu");

Обучите сеть.

net = trainNetwork(dataTrain,layers,opts);
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:00 |       40.00% |       31.82% |       1.3060 |       1.2897 |          0.0100 |
|       8 |          50 |       00:00:00 |       86.67% |       90.91% |       0.4223 |       0.3656 |          0.0100 |
|      15 |         100 |       00:00:00 |       93.33% |       86.36% |       0.2947 |       0.2927 |          0.0100 |
|      22 |         150 |       00:00:00 |       86.67% |       81.82% |       0.2804 |       0.3707 |          0.0100 |
|      29 |         200 |       00:00:01 |       86.67% |       90.91% |       0.2268 |       0.2129 |          0.0100 |
|      30 |         210 |       00:00:01 |       93.33% |       95.45% |       0.2782 |       0.1666 |          0.0100 |
|======================================================================================================================|

Оценка эффективности сети

Классифицируйте наблюдения из тестового набора с помощью обученной сети.

predictedLabels = net.classify(dataTest);
trueLabels = dataTest{:,end};

Визуализируйте результаты с помощью матрицы неточностей.

figure
confusionchart(trueLabels,predictedLabels)

Сеть успешно использует четыре функции объектов для предсказания видов тестовых наблюдений.

Осмыслите, как различные предикторы важны для разных классов

Используйте LIME, чтобы понять важность каждого предиктора для классификационных решений сети.

Исследуйте два наиболее важных предиктора для каждого наблюдения.

numImportantPredictors = 2;

Использование lime чтобы создать синтетический набор данных, статистика которого для каждой функции соответствует реальному набору данных. Создайте lime объект с использованием модели глубокого обучения blackbox и данные предиктора, содержащиеся в predictors. Используйте низкую 'KernelWidth' значение так lime использует веса, которые фокусируются на выборках около точки запроса.

blackbox = @(x)classify(net,x);
explainer = lime(blackbox,predictors,'Type','classification','KernelWidth',0.1);

Можно использовать LIME explainer, чтобы понять самые важные функции глубокой нейронной сети. Функция оценивает важность признака при помощи простой линейной модели, которая аппроксимирует нейронную сеть в окрестностях наблюдения запроса.

Найдите индексы первых двух наблюдений в тестовых данных, соответствующих классу setosa.

trueLabelsTest = dataTest{:,end};

label = "setosa";
idxSetosa = find(trueLabelsTest == label,2);

Используйте fit функция для подгонки простой линейной модели к первым двум наблюдениям из заданного класса.

explainerObs1 = fit(explainer,dataTest(idxSetosa(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxSetosa(2),1:4),numImportantPredictors);

Постройте график результатов.

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для класса setosa наиболее важными предикторами являются значение низкой длины лепестка и значение высокой ширины сепаля.

Выполните тот же анализ для класса versicolor.

label = "versicolor";
idxVersicolor = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVersicolor(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVersicolor(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для класса versicolor важно высокое значение длины лепестка.

Наконец, рассмотрим класс virginica.

label = "virginica";
idxVirginica = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVirginica(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVirginica(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

Для класса virginica важно высокое значение длины лепестка и низкое значение ширины сепаля.

Проверьте гипотезу LIME

Графики LIME предполагают, что высокое значение длины лепестка связано с классами versicolor и virginica, а низкое значение длины лепестка связано с классом setosa. Можно продолжить исследование результатов путем изучения данных.

Постройте график длины лепестка каждого изображения в наборе данных.

setosaIdx = ismember(data{:,end},"setosa");
versicolorIdx = ismember(data{:,end},"versicolor");
virginicaIdx = ismember(data{:,end},"virginica");

figure
hold on
plot(data{setosaIdx,"Petal length"},'.')
plot(data{versicolorIdx,"Petal length"},'.')
plot(data{virginicaIdx,"Petal length"},'.')
hold off

xlabel("Observation number")
ylabel("Petal length")
legend(["setosa","versicolor","virginica"])

Класс setosa имеет гораздо более низкие значения длины лепестка, чем другие классы, совпадающие с результатами, полученными из lime модель.

См. также

| | | | (Statistics and Machine Learning Toolbox) | (Statistics and Machine Learning Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте