В этом примере показано, как классифицировать данные для обученной рекуррентной нейронной сети в Simulink ® с помощью Stateful Classify
блок. Этот пример использует предварительно обученную сеть долгой краткосрочной памяти (LSTM).
Загрузка JapaneseVowelsNet
предварительно обученная сеть долгой краткосрочной памяти (LSTM), обученная на наборе данных японских гласных, как описано в [1] и [2]. Эта сеть была обучена на последовательностях, отсортированных по длине последовательности с размером мини-пакета 27.
load JapaneseVowelsNet
Просмотрите сетевую архитектуру.
net.Layers
ans = 5x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax 5 'classoutput' Classification Output crossentropyex with '1' and 8 other classes
Загрузите тестовые данные японских гласных. XTest
- массив ячеек, содержащий 370 последовательностей размерности 12 различной длины. YTest
является категориальным вектором меток «1», «2»... «9», которые соответствуют этим девяти дикторам.
[XTest,YTest] = japaneseVowelsTestData; X = XTest{94}; numTimeSteps = size(X,2);
Модель Simulink для классификации данных содержит Stateful Classify
блок для предсказания меток и MATLAB Function
блоки для загрузки последовательности входных данных в течение временных шагов.
open_system('StatefulClassifyExample');
Установите параметры конфигурации модели для входных блоков и Stateful Classify
блок.
set_param('StatefulClassifyExample/Input','Value','X'); set_param('StatefulClassifyExample/Index','uplimit','numTimeSteps-1'); set_param('StatefulClassifyExample/Stateful Classify','NetworkFilePath','JapaneseVowelsNet.mat'); set_param('StatefulClassifyExample','SimulationMode','Normal');
Вычисление откликов для JapaneseVowelsNet
network, запустите симуляцию. Метки прогноза сохраняются в рабочей области MATLAB ®.
out = sim('StatefulClassifyExample');
Отобразите предсказанные метки на графике лестницы. График показывает, как предсказания изменяются между временными шагами.
labels = squeeze(out.YPred.Data(1:numTimeSteps,1)); figure stairs(labels, '-o') xlim([1 numTimeSteps]) xlabel("Time Step") ylabel("Predicted Class") title("Classification Over Time Steps")
Сравните предсказания с истинной меткой. Постройте график горизонтальной линии, показывающий истинную метку наблюдения.
trueLabel = double(YTest(94)); hold on line([1 numTimeSteps],[trueLabel trueLabel], ... 'Color','red', ... 'LineStyle','--') legend(["Prediction" "True Label"]) axis([1 numTimeSteps+1 0 9]);
[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием областей». Распознавание Букв. Том 20, № 11-13, стр. 1103-1111.
[2] UCI Machine Learning Repository: Японский набор данных гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
Image Classifier | Predict | Stateful Classify | Stateful Predict