exponenta event banner

Визуализация активации сети LSTM

В этом примере показано, как исследовать и визуализировать функции, полученные сетями LSTM, извлекая активации.

Загрузить предварительно обученную сеть. JapaneseVowelsNet является предварительно обученной сетью LSTM, обученной на наборе данных Vowels на японском языке, как описано в [1] и [2]. Его обучали последовательностям, отсортированным по длине последовательности с размером мини-партии 27.

load JapaneseVowelsNet

Просмотр сетевой архитектуры.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Загрузите данные теста.

[XTest,YTest] = japaneseVowelsTestData;

Визуализация первого временного ряда на графике. Каждая строка соответствует элементу.

X = XTest{1};

figure
plot(XTest{1}')
xlabel("Time Step")
title("Test Observation 1")
numFeatures = size(XTest{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Figure contains an axes. The axes with title Test Observation 1 contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

Для каждого временного шага последовательностей получайте активизации, выводимые уровнем LSTM (уровень 2) для этого временного шага, и обновляйте состояние сети.

sequenceLength = size(X,2);
idxLayer = 2;
outputSize = net.Layers(idxLayer).NumHiddenUnits;

for i = 1:sequenceLength
    features(:,i) = activations(net,X(:,i),idxLayer);
    [net, YPred(i)] = classifyAndUpdateState(net,X(:,i));
end

Визуализируйте первые 10 скрытых единиц измерения с помощью тепловой карты.

figure
heatmap(features(1:10,:));
xlabel("Time Step")
ylabel("Hidden Unit")
title("LSTM Activations")

Figure contains an object of type heatmap. The chart of type heatmap has title LSTM Activations.

Тепловая карта показывает, насколько сильно активизируется каждая скрытая единица, и показывает, как активизируются с течением времени.

Ссылки

[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием сквозных областей». Буквы распознавания образов. т. 20, № 11-13, стр. 1103-1111.

[2] Хранилище машинного обучения UCI: набор данных гласных на японском языке. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

См. также

| | | | |

Связанные темы