Визуализация активации сети LSTM

Этот пример показов, как исследовать и визуализировать функции, выученные сетями LSTM, путем извлечения активаций.

Загрузка предварительно обученной сети. JapaneseVowelsNet - предварительно обученная сеть LSTM, обученная на наборе данных японских гласных, как описано в [1] и [2]. Он был обучен на последовательностях, отсортированных по длине последовательности с размером мини-партии 27.

load JapaneseVowelsNet

Просмотрите сетевую архитектуру.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Загрузите тестовые данные.

[XTest,YTest] = japaneseVowelsTestData;

Визуализируйте первые временные ряды на графике. Каждая линия соответствует функции.

X = XTest{1};

figure
plot(XTest{1}')
xlabel("Time Step")
title("Test Observation 1")
numFeatures = size(XTest{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Figure contains an axes. The axes with title Test Observation 1 contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

Для каждого временного шага последовательностей получайте активации, выводимые слоем LSTM (слой 2) для этого временного шага и обновляйте состояние сети.

sequenceLength = size(X,2);
idxLayer = 2;
outputSize = net.Layers(idxLayer).NumHiddenUnits;

for i = 1:sequenceLength
    features(:,i) = activations(net,X(:,i),idxLayer);
    [net, YPred(i)] = classifyAndUpdateState(net,X(:,i));
end

Визуализируйте первые 10 скрытые модули с помощью тепловой карты.

figure
heatmap(features(1:10,:));
xlabel("Time Step")
ylabel("Hidden Unit")
title("LSTM Activations")

Figure contains an object of type heatmap. The chart of type heatmap has title LSTM Activations.

Тепловая карта показывает, как сильно активируется каждый скрытый модуль, и подсвечивает, как изменяются активации с течением времени.

Ссылки

[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием областей». Распознавание Букв. Том 20, № 11-13, стр. 1103-1111.

[2] UCI Machine Learning Repository: Японский набор данных гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

См. также

| | | | |

Похожие темы