В этом примере показано, как исследовать и визуализировать функции, изученные сетями LSTM путем извлечения активаций.
Предварительно обученная сеть Load. JapaneseVowelsNet
предварительно обученная сеть LSTM, обученная на японском наборе данных Vowels как описано в [1] и [2]. Это было обучено на последовательностях, отсортированных по длине последовательности с мини-пакетным размером 27.
load JapaneseVowelsNet
Просмотрите сетевую архитектуру.
net.Layers
ans = 5x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax 5 'classoutput' Classification Output crossentropyex with '1' and 8 other classes
Загрузите тестовые данные.
[XTest,YTest] = japaneseVowelsTestData;
Визуализируйте первые временные ряды в графике. Каждая линия соответствует функции.
X = XTest{1}; figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature " + string(1:numFeatures),'Location','northeastoutside')
Для каждого временного шага последовательностей выведите активации слоем LSTM (слой 2) для того временного шага и обновите сетевое состояние.
sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits; for i = 1:sequenceLength features(:,i) = activations(net,X(:,i),idxLayer); [net, YPred(i)] = classifyAndUpdateState(net,X(:,i)); end
Визуализируйте первые 10 скрытых модулей с помощью тепловой карты.
figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")
Тепловая карта показывает, как строго каждый скрытый модуль активирует и подсвечивает, как активации изменяются в зависимости от времени.
[1] М. Кудо, J. Тояма, и М. Шимбо. "Многомерная Классификация Кривых Используя Прохождение через области". Буквы Распознавания образов. Издание 20, № 11-13, страницы 1103-1111.
[2] Репозиторий Машинного обучения UCI: японский Набор данных Гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
trainNetwork
| trainingOptions
| lstmLayer
| bilstmLayer
| sequenceInputLayer
| activations