Предсказание и обновление состояния сети в Simulink

Этот пример показывает, как предсказать ответы для обученной рекуррентной нейронной сети в Simulink ® с помощью Stateful Predict блок. Этот пример использует предварительно обученную сеть долгой краткосрочной памяти (LSTM).

Загрузка предварительно обученной сети

Загрузка JapaneseVowelsNetпредварительно обученная сеть долгой краткосрочной памяти (LSTM), обученная на наборе данных японских гласных, как описано в [1] и [2]. Эта сеть была обучена на последовательностях, отсортированных по длине последовательности с размером мини-пакета 27.

load JapaneseVowelsNet

Просмотрите сетевую архитектуру.

net.Layers
ans = 

  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Загрузка тестовых данных

Загрузите тестовые данные японских гласных. XTest - массив ячеек, содержащий 370 последовательностей размерности 12 различной длины. YTest является категориальным вектором меток «1», «2»... «9», которые соответствуют этим девяти дикторам.

[XTest,YTest] = japaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);

Simulink Модели для предсказания ответов

Модель Simulink для предсказания откликов содержит Stateful Predict блок для предсказания счетов и MATLAB Function блоки для загрузки последовательности входных данных в течение временных шагов.

open_system('StatefulPredictExample');

Сконфигурируйте модель для симуляции

Установите параметры конфигурации модели для входных блоков и Stateful Predict блок.

set_param('StatefulPredictExample/Input','Value','X');
set_param('StatefulPredictExample/Index','uplimit','numTimeSteps-1');
set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulPredictExample', 'SimulationMode', 'Normal');

Запуск симуляции

Вычисление откликов для JapaneseVowelsNet network, запустите симуляцию. Счета предсказания сохраняются в рабочей области MATLAB ®.

out = sim('StatefulPredictExample');

Постройте график счетов предсказания. График показывает, как счета предсказания изменяются между временными шагами.

scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps));

classNames = string(net.Layers(end).Classes);
figure
lines = plot(scores');
xlim([1 numTimeSteps])
legend("Class " + classNames,'Location','northwest')
xlabel("Time Step")
ylabel("Score")
title("Prediction Scores Over Time Steps")

Выделите предсказание счетов с течением времени шаги для правильного класса.

trueLabel = YTest(94);
lines(trueLabel).LineWidth = 3;

Отображение последнего временного шага, предсказания в столбчатую диаграмму.

figure
bar(scores(:,end))
title("Final Prediction Scores")
xlabel("Class")
ylabel("Score")

Ссылки

[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием областей». Распознавание Букв. Том 20, № 11-13, стр. 1103-1111.

[2] UCI Machine Learning Repository: Японский набор данных гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

См. также

| | |

Похожие темы