Этот пример показов, как исследовать и визуализировать функции, выученные сетями LSTM, путем извлечения активаций.
Загрузка предварительно обученной сети. JapaneseVowelsNet
- предварительно обученная сеть LSTM, обученная на наборе данных японских гласных, как описано в [1] и [2]. Он был обучен на последовательностях, отсортированных по длине последовательности с размером мини-партии 27.
load JapaneseVowelsNet
Просмотрите сетевую архитектуру.
net.Layers
ans = 5x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax 5 'classoutput' Classification Output crossentropyex with '1' and 8 other classes
Загрузите тестовые данные.
[XTest,YTest] = japaneseVowelsTestData;
Визуализируйте первые временные ряды на графике. Каждая линия соответствует функции.
X = XTest{1}; figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature " + string(1:numFeatures),'Location','northeastoutside')
Для каждого временного шага последовательностей получайте активации, выводимые слоем LSTM (слой 2) для этого временного шага и обновляйте состояние сети.
sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits; for i = 1:sequenceLength features(:,i) = activations(net,X(:,i),idxLayer); [net, YPred(i)] = classifyAndUpdateState(net,X(:,i)); end
Визуализируйте первые 10 скрытые модули с помощью тепловой карты.
figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")
Тепловая карта показывает, как сильно активируется каждый скрытый модуль, и подсвечивает, как изменяются активации с течением времени.
[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием областей». Распознавание Букв. Том 20, № 11-13, стр. 1103-1111.
[2] UCI Machine Learning Repository: Японский набор данных гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
activations
| bilstmLayer
| lstmLayer
| sequenceInputLayer
| trainingOptions
| trainNetwork