Этот пример показывает, как использовать локально интерпретируемый метод моделирования-агностических объяснений (LIME), чтобы понять предсказания глубокой нейронной сети, классифицирующей табличные данные. Можно использовать метод LIME, чтобы понять, какие предикторы наиболее важны для решения о классификации сети.
В этом примере вы интерпретируете сеть классификации данных функций с помощью LIME. Для заданного наблюдения запроса LIME генерирует синтетический набор данных, статистика которого для каждой функции соответствует реальному набору данных. Этот синтетический набор данных передается через глубокую нейронную сеть, чтобы получить классификацию, и подгоняется простая, интерпретируемая модель. Эта простая модель может использоваться, чтобы понять важность нескольких верхних функций для решения о классификации сети. При обучении этой интерпретируемой модели синтетические наблюдения взвешиваются по их расстоянию от наблюдения запроса, поэтому объяснение является «локальным» к этому наблюдению.
Этот пример использует lime
(Statistics and Machine Learning Toolbox) и fit
(Statistics and Machine Learning Toolbox), чтобы сгенерировать синтетический набор данных и соответствовать простой интерпретируемой модели синтетическому набору данных. Чтобы понять предсказания обученной нейронной сети классификации изображений, используйте imageLIME
. Для получения дополнительной информации см. Раздел «Изучение предсказаний с использованием LIME».
Загрузите набор данных радужной оболочки глаза Фишера. Эти данные содержат 150 наблюдений с четырьмя входными функциями, представляющими параметры объекта, и один категориальный ответ, представляющий вид объекта. Каждое наблюдение классифицируется как один из трёх видов: setosa, versicolor или virginica. Каждое наблюдение имеет четыре измерения: ширина чашелистика, длина чашелистика, ширина лепестка и длина лепестка.
filename = fullfile(toolboxdir('stats'),'statsdemos','fisheriris.mat'); load(filename)
Преобразуйте числовые данные в таблицу.
features = ["Sepal length","Sepal width","Petal length","Petal width"]; predictors = array2table(meas,"VariableNames",features); trueLabels = array2table(categorical(species),"VariableNames","Response");
Составьте таблицу обучающих данных, последний столбец которой является ответом.
data = [predictors trueLabels];
Вычислим количество наблюдений, функций и классов.
numObservations = size(predictors,1); numFeatures = size(predictors,2); numClasses = length(categories(data{:,5}));
Разделите набор данных на наборы для обучения, валидации и тестирования. Выделите 15% данных для валидации и 15% для проверки.
Определите количество наблюдений для каждого раздела. Установите случайный seed, чтобы сделать разделение данных и обучение центральный процессор воспроизводимыми.
rng('default');
numObservationsTrain = floor(0.7*numObservations);
numObservationsValidation = floor(0.15*numObservations);
Создайте массив случайных индексов, соответствующих наблюдениям, и разбейте его с помощью размеров разбиений.
idx = randperm(numObservations); idxTrain = idx(1:numObservationsTrain); idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation); idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);
Разделите таблицу данных на разделы для обучения, валидации и проверки с помощью индексов.
dataTrain = data(idxTrain,:); dataVal = data(idxValidation,:); dataTest = data(idxTest,:);
Создайте простой многослойный перцептрон с одним скрытым слоем с пятью нейронами и активацией ReLU. Входной слой функции принимает данные, содержащие числовые скаляры, представляющие функциям, такие как набор данных радужной оболочки глаза Фишера.
numHiddenUnits = 5; layers = [ featureInputLayer(numFeatures) fullyConnectedLayer(numHiddenUnits) reluLayer fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Обучите сеть с помощью стохастического градиентного спуска с импульсом (SGDM). Установите максимальное количество эпох равное 30 и используйте мини-пакет размером 15, так как обучающие данные не содержат много наблюдений.
opts = trainingOptions("sgdm", ... "MaxEpochs",30, ... "MiniBatchSize",15, ... "Shuffle","every-epoch", ... "ValidationData",dataVal, ... "ExecutionEnvironment","cpu");
Обучите сеть.
net = trainNetwork(dataTrain,layers,opts);
|======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:00 | 40.00% | 31.82% | 1.3060 | 1.2897 | 0.0100 | | 8 | 50 | 00:00:00 | 86.67% | 90.91% | 0.4223 | 0.3656 | 0.0100 | | 15 | 100 | 00:00:00 | 93.33% | 86.36% | 0.2947 | 0.2927 | 0.0100 | | 22 | 150 | 00:00:00 | 86.67% | 81.82% | 0.2804 | 0.3707 | 0.0100 | | 29 | 200 | 00:00:01 | 86.67% | 90.91% | 0.2268 | 0.2129 | 0.0100 | | 30 | 210 | 00:00:01 | 93.33% | 95.45% | 0.2782 | 0.1666 | 0.0100 | |======================================================================================================================|
Классифицируйте наблюдения из тестового набора с помощью обученной сети.
predictedLabels = net.classify(dataTest); trueLabels = dataTest{:,end};
Визуализируйте результаты с помощью матрицы неточностей.
figure confusionchart(trueLabels,predictedLabels)
Сеть успешно использует четыре функции объектов для предсказания видов тестовых наблюдений.
Используйте LIME, чтобы понять важность каждого предиктора для классификационных решений сети.
Исследуйте два наиболее важных предиктора для каждого наблюдения.
numImportantPredictors = 2;
Использование lime
чтобы создать синтетический набор данных, статистика которого для каждой функции соответствует реальному набору данных. Создайте lime
объект с использованием модели глубокого обучения blackbox
и данные предиктора, содержащиеся в predictors
. Используйте низкую 'KernelWidth'
значение так lime
использует веса, которые фокусируются на выборках около точки запроса.
blackbox = @(x)classify(net,x); explainer = lime(blackbox,predictors,'Type','classification','KernelWidth',0.1);
Можно использовать LIME explainer, чтобы понять самые важные функции глубокой нейронной сети. Функция оценивает важность признака при помощи простой линейной модели, которая аппроксимирует нейронную сеть в окрестностях наблюдения запроса.
Найдите индексы первых двух наблюдений в тестовых данных, соответствующих классу setosa.
trueLabelsTest = dataTest{:,end};
label = "setosa";
idxSetosa = find(trueLabelsTest == label,2);
Используйте fit
функция для подгонки простой линейной модели к первым двум наблюдениям из заданного класса.
explainerObs1 = fit(explainer,dataTest(idxSetosa(1),1:4),numImportantPredictors); explainerObs2 = fit(explainer,dataTest(idxSetosa(2),1:4),numImportantPredictors);
Постройте график результатов.
figure subplot(2,1,1) plot(explainerObs1); subplot(2,1,2) plot(explainerObs2);
Для класса setosa наиболее важными предикторами являются значение низкой длины лепестка и значение высокой ширины сепаля.
Выполните тот же анализ для класса versicolor.
label = "versicolor";
idxVersicolor = find(trueLabelsTest == label,2);
explainerObs1 = fit(explainer,dataTest(idxVersicolor(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVersicolor(2),1:4),numImportantPredictors);
figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);
Для класса versicolor важно высокое значение длины лепестка.
Наконец, рассмотрим класс virginica.
label = "virginica";
idxVirginica = find(trueLabelsTest == label,2);
explainerObs1 = fit(explainer,dataTest(idxVirginica(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVirginica(2),1:4),numImportantPredictors);
figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);
Для класса virginica важно высокое значение длины лепестка и низкое значение ширины сепаля.
Графики LIME предполагают, что высокое значение длины лепестка связано с классами versicolor и virginica, а низкое значение длины лепестка связано с классом setosa. Можно продолжить исследование результатов путем изучения данных.
Постройте график длины лепестка каждого изображения в наборе данных.
setosaIdx = ismember(data{:,end},"setosa"); versicolorIdx = ismember(data{:,end},"versicolor"); virginicaIdx = ismember(data{:,end},"virginica"); figure hold on plot(data{setosaIdx,"Petal length"},'.') plot(data{versicolorIdx,"Petal length"},'.') plot(data{virginicaIdx,"Petal length"},'.') hold off xlabel("Observation number") ylabel("Petal length") legend(["setosa","versicolor","virginica"])
Класс setosa имеет гораздо более низкие значения длины лепестка, чем другие классы, совпадающие с результатами, полученными из lime
модель.
classify
| featureInputLayer
| imageLIME
| trainNetwork
| fit
(Statistics and Machine Learning Toolbox) | lime
(Statistics and Machine Learning Toolbox)