В этом примере показано, как классифицировать данные для обученной рекуррентной нейронной сети в Simulink® при помощи Stateful Classify
блок. Этот пример использует предварительно обученную сеть долгой краткосрочной памяти (LSTM).
Загрузите JapaneseVowelsNet
, предварительно обученная сеть долгой краткосрочной памяти (LSTM), обученная на японском наборе данных Гласных как описано в [1] и [2]. Эта сеть была обучена на последовательностях, отсортированных по длине последовательности с мини-пакетным размером 27.
load JapaneseVowelsNet
Просмотрите сетевую архитектуру.
net.Layers
ans = 5x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax 5 'classoutput' Classification Output crossentropyex with '1' and 8 other classes
Загрузите японские тестовые данные Гласных. XTest
массив ячеек, содержащий 370 последовательностей размерности 12 из различной длины. YTest
категориальный вектор из меток "1", "2"... "9", которые соответствуют этим девяти динамикам.
[XTest,YTest] = japaneseVowelsTestData; X = XTest{94}; numTimeSteps = size(X,2);
Модель Simulink для классификации данных содержит Stateful Classify
блокируйтесь, чтобы предсказать метки и MATLAB Function
блоки, чтобы загрузить последовательность входных данных по временным шагам.
open_system('StatefulClassifyExample');
Установите параметры конфигурации модели для входных блоков и Stateful Classify
блок.
set_param('StatefulClassifyExample/Input','Value','X'); set_param('StatefulClassifyExample/Index','uplimit','numTimeSteps-1'); set_param('StatefulClassifyExample/Stateful Classify','NetworkFilePath','JapaneseVowelsNet.mat'); set_param('StatefulClassifyExample','SimulationMode','Normal');
Вычислить ответы для JapaneseVowelsNet
сеть, запустите симуляцию. Метки предсказания сохранены в рабочей области MATLAB®.
out = sim('StatefulClassifyExample');
Постройте предсказанные метки в графике ступеньки. График показывает, как предсказания изменяются между временными шагами.
labels = squeeze(out.YPred.Data(1:numTimeSteps,1)); figure stairs(labels, '-o') xlim([1 numTimeSteps]) xlabel("Time Step") ylabel("Predicted Class") title("Classification Over Time Steps")
Сравните предсказания с истинной меткой. Постройте горизонтальный график, показывающий истинную метку наблюдения.
trueLabel = double(YTest(94)); hold on line([1 numTimeSteps],[trueLabel trueLabel], ... 'Color','red', ... 'LineStyle','--') legend(["Prediction" "True Label"]) axis([1 numTimeSteps+1 0 9]);
[1] М. Кудо, J. Тояма, и М. Шимбо. "Многомерная Классификация Кривых Используя Прохождение через области". Буквы Распознавания образов. Издание 20, № 11-13, страницы 1103-1111.
[2] Репозиторий Машинного обучения UCI: японский Набор данных Гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
Image Classifier | Predict | Stateful Classify | Stateful Predict