В этом примере показано, как использовать метод локально поддающихся толкованию объяснений модели агностических (LIME), чтобы изучить предсказания глубокой нейронной сети, классифицирующей табличные данные. Можно использовать метод LIME, чтобы понять, какие предикторы являются самыми важными для решения классификации о сети.
В этом примере вы интерпретируете сеть классификации данных о функции использование LIME. Для заданного наблюдения запроса LIME генерирует синтетический набор данных, статистические данные которого для каждой функции совпадают с действительным набором данных. Этот синтетический набор данных передается через глубокую нейронную сеть, чтобы получить классификацию, и подбирается простая, поддающаяся толкованию модель. Эта простая модель может использоваться, чтобы изучить важность главных немногих функций к решению классификации о сети. В обучении эта поддающаяся толкованию модель синтетические наблюдения взвешиваются их расстоянием от наблюдения запроса, таким образом, объяснение "локально" для того наблюдения.
Этот пример использует lime
(Statistics and Machine Learning Toolbox) и fit
(Statistics and Machine Learning Toolbox), чтобы сгенерировать синтетический набор данных и подбирать простую поддающуюся толкованию модель к синтетическому набору данных. Чтобы изучить предсказания обученной нейронной сети классификации изображений, используйте imageLIME
. Для получения дополнительной информации смотрите, Изучают Сетевые Предсказания Используя LIME.
Загрузите ирисовый набор данных Фишера. Эти данные содержат 150 наблюдений с четырьмя входными функциями, представляющими параметры объекта и одного категориального ответа, представляющего виды растений. Каждое наблюдение классифицируется как одна из трех разновидностей: setosa, versicolor, или virginica. Каждое наблюдение имеет четыре измерения: ширина чашелистика, длина чашелистика, лепестковая ширина и лепестковая длина.
filename = fullfile(toolboxdir('stats'),'statsdemos','fisheriris.mat'); load(filename)
Преобразуйте числовые данные в таблицу.
features = ["Sepal length","Sepal width","Petal length","Petal width"]; predictors = array2table(meas,"VariableNames",features); trueLabels = array2table(categorical(species),"VariableNames","Response");
Составьте таблицу обучающих данных, последний столбец которых является ответом.
data = [predictors trueLabels];
Вычислите количество наблюдений, функций и классов.
numObservations = size(predictors,1); numFeatures = size(predictors,2); numClasses = length(categories(data{:,5}));
Разделите набор данных в обучение, валидацию и наборы тестов. Отложите 15% данных для валидации и 15% для тестирования.
Определите количество наблюдений для каждого раздела. Установите случайный seed делать разделение данных и обучение центрального процессора восстанавливаемыми.
rng('default');
numObservationsTrain = floor(0.7*numObservations);
numObservationsValidation = floor(0.15*numObservations);
Создайте массив случайных индексов, соответствующих наблюдениям, и разделите его с помощью размеров раздела.
idx = randperm(numObservations); idxTrain = idx(1:numObservationsTrain); idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation); idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);
Разделите таблицу данных в обучение, валидацию и разделы тестирования с помощью индексов.
dataTrain = data(idxTrain,:); dataVal = data(idxValidation,:); dataTest = data(idxTest,:);
Создайте простой многоуровневый perceptron с одним скрытым слоем с пятью нейронами и активациями ReLU. Входной слой функции принимает данные, содержащие числовые скаляры, представляющие функции, такие как ирисовый набор данных Фишера.
numHiddenUnits = 5; layers = [ featureInputLayer(numFeatures) fullyConnectedLayer(numHiddenUnits) reluLayer fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Обучите сеть с помощью стохастического градиентного спуска с импульсом (SGDM). Определите максимальный номер эпох к 30 и используйте мини-пакетный размер 15, когда обучающие данные не содержат много наблюдений.
opts = trainingOptions("sgdm", ... "MaxEpochs",30, ... "MiniBatchSize",15, ... "Shuffle","every-epoch", ... "ValidationData",dataVal, ... "ExecutionEnvironment","cpu");
Обучите сеть.
net = trainNetwork(dataTrain,layers,opts);
|======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:00 | 40.00% | 31.82% | 1.3060 | 1.2897 | 0.0100 | | 8 | 50 | 00:00:00 | 86.67% | 90.91% | 0.4223 | 0.3656 | 0.0100 | | 15 | 100 | 00:00:00 | 93.33% | 86.36% | 0.2947 | 0.2927 | 0.0100 | | 22 | 150 | 00:00:00 | 86.67% | 81.82% | 0.2804 | 0.3707 | 0.0100 | | 29 | 200 | 00:00:01 | 86.67% | 90.91% | 0.2268 | 0.2129 | 0.0100 | | 30 | 210 | 00:00:01 | 93.33% | 95.45% | 0.2782 | 0.1666 | 0.0100 | |======================================================================================================================|
Классифицируйте наблюдения от набора тестов с помощью обучившего сеть.
predictedLabels = net.classify(dataTest); trueLabels = dataTest{:,end};
Визуализируйте результаты с помощью матрицы беспорядка.
figure confusionchart(trueLabels,predictedLabels)
Сеть успешно использует четыре функции объекта, чтобы предсказать разновидности тестовых наблюдений.
Используйте LIME, чтобы изучить важность каждого предиктора к решениям классификации о сети.
Исследуйте два самых важных предиктора для каждого наблюдения.
numImportantPredictors = 2;
Используйте lime
создать синтетический набор данных, статистические данные которого для каждой функции совпадают с действительным набором данных. Создайте lime
объект с помощью модели blackbox
глубокого обучения и данные о предикторе содержатся в
predictors
. Используйте низкий 'KernelWidth'
значение так lime
веса использования, которые фокусируются на выборках около точки запроса.
blackbox = @(x)classify(net,x); explainer = lime(blackbox,predictors,'Type','classification','KernelWidth',0.1);
Можно использовать человека, который объясняет LIME, чтобы изучить самые важные функции к глубокой нейронной сети. Функция оценивает важность функции при помощи простой линейной модели, которая аппроксимирует нейронную сеть около наблюдения запроса.
Найдите индексы первых двух наблюдений в тестовых данных, соответствующих setosa классу.
trueLabelsTest = dataTest{:,end};
label = "setosa";
idxSetosa = find(trueLabelsTest == label,2);
Используйте fit
функция, чтобы подбирать простую линейную модель к первым двум наблюдениям от заданного класса.
explainerObs1 = fit(explainer,dataTest(idxSetosa(1),1:4),numImportantPredictors); explainerObs2 = fit(explainer,dataTest(idxSetosa(2),1:4),numImportantPredictors);
Постройте график результатов.
figure subplot(2,1,1) plot(explainerObs1); subplot(2,1,2) plot(explainerObs2);
Для setosa класса самые важные предикторы являются низким лепестковым значением длины и высоким значением ширины чашелистика.
Выполните тот же анализ для класса versicolor.
label = "versicolor";
idxVersicolor = find(trueLabelsTest == label,2);
explainerObs1 = fit(explainer,dataTest(idxVersicolor(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVersicolor(2),1:4),numImportantPredictors);
figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);
Для versicolor класса высокое лепестковое значение длины важно.
Наконец, рассмотрите virginica класс.
label = "virginica";
idxVirginica = find(trueLabelsTest == label,2);
explainerObs1 = fit(explainer,dataTest(idxVirginica(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVirginica(2),1:4),numImportantPredictors);
figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);
Для virginica класса, высокого лепесткового значения длины и низкого значения ширины чашелистика важно.
Графики LIME предполагают, что высокое лепестковое значение длины сопоставлено с versicolor и virginica классами, и низкое лепестковое значение длины сопоставлено с setosa классом. Можно исследовать результаты далее путем исследования данных.
Постройте лепестковую длину каждого изображения в наборе данных.
setosaIdx = ismember(data{:,end},"setosa"); versicolorIdx = ismember(data{:,end},"versicolor"); virginicaIdx = ismember(data{:,end},"virginica"); figure hold on plot(data{setosaIdx,"Petal length"},'.') plot(data{versicolorIdx,"Petal length"},'.') plot(data{virginicaIdx,"Petal length"},'.') hold off xlabel("Observation number") ylabel("Petal length") legend(["setosa","versicolor","virginica"])
setosa класс имеет намного более низкие лепестковые значения длины, чем другие классы, совпадая с результатами, приведенными от lime
модель.
fit
(Statistics and Machine Learning Toolbox) | lime
(Statistics and Machine Learning Toolbox) | trainNetwork
| classify
| featureInputLayer
| imageLIME