В этом примере показано, как классифицировать данные для обученной повторяющейся нейронной сети в Simulink ® с помощью Stateful Classify блок. В этом примере используется предварительно обученная сеть долговременной памяти (LSTM).
Груз JapaneseVowelsNet, предварительно подготовленная сеть долговременной памяти (LSTM), обученная на наборе данных японских гласных, как описано в [1] и [2]. Эта сеть была обучена последовательностям, отсортированным по длине последовательности с размером мини-партии 27.
load JapaneseVowelsNet
Просмотр сетевой архитектуры.
net.Layers
ans =
5x1 Layer array with layers:
1 'sequenceinput' Sequence Input Sequence input with 12 dimensions
2 'lstm' LSTM LSTM with 100 hidden units
3 'fc' Fully Connected 9 fully connected layer
4 'softmax' Softmax softmax
5 'classoutput' Classification Output crossentropyex with '1' and 8 other classes
Загрузите данные теста японских гласных. XTest - массив ячеек, содержащий 370 последовательностей размерности 12 переменной длины. YTest - категориальный вектор меток «1», «2»,... «9», которым соответствуют девять говорящих.
[XTest,YTest] = japaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);
Модель Simulink для классификации данных содержит Stateful Classify блок для прогнозирования меток и MATLAB Function блоки для загрузки последовательности входных данных в течение временных шагов.
open_system('StatefulClassifyExample');

Задайте параметры конфигурации модели для входных блоков и Stateful Classify блок.
set_param('StatefulClassifyExample/Input','Value','X'); set_param('StatefulClassifyExample/Index','uplimit','numTimeSteps-1'); set_param('StatefulClassifyExample/Stateful Classify','NetworkFilePath','JapaneseVowelsNet.mat'); set_param('StatefulClassifyExample','SimulationMode','Normal');
Вычисление ответов для JapaneseVowelsNet сеть, запустите моделирование. Метки прогнозирования сохраняются в рабочей области MATLAB ®.
out = sim('StatefulClassifyExample');
Постройте график прогнозируемых меток на графике лестницы. На графике показано, как изменяются прогнозы между временными шагами.
labels = squeeze(out.YPred.Data(1:numTimeSteps,1)); figure stairs(labels, '-o') xlim([1 numTimeSteps]) xlabel("Time Step") ylabel("Predicted Class") title("Classification Over Time Steps")

Сравните прогнозы с истинной меткой. Постройте график горизонтальной линии, показывающей истинную метку наблюдения.
trueLabel = double(YTest(94)); hold on line([1 numTimeSteps],[trueLabel trueLabel], ... 'Color','red', ... 'LineStyle','--') legend(["Prediction" "True Label"]) axis([1 numTimeSteps+1 0 9]);

[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием сквозных областей». Буквы распознавания образов. т. 20, № 11-13, стр. 1103-1111.
[2] Хранилище машинного обучения UCI: набор данных гласных на японском языке. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
Классификатор изображений | Предсказать | Классификация с учетом состояния | Прогнозирование с учетом состояния