exponenta event banner

Прогнозирование и обновление состояния сети в Simulink

В этом примере показано, как предсказать ответы для обученной рецидивирующей нейронной сети в Simulink ® с использованием Stateful Predict блок. В этом примере используется предварительно обученная сеть долговременной памяти (LSTM).

Загрузить предварительно обученную сеть

Груз JapaneseVowelsNet, предварительно подготовленная сеть долговременной памяти (LSTM), обученная на наборе данных японских гласных, как описано в [1] и [2]. Эта сеть была обучена последовательностям, отсортированным по длине последовательности с размером мини-партии 27.

load JapaneseVowelsNet

Просмотр сетевой архитектуры.

net.Layers
ans = 

  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Данные нагрузочного теста

Загрузите данные теста японских гласных. XTest - массив ячеек, содержащий 370 последовательностей размерности 12 переменной длины. YTest - категориальный вектор меток «1», «2»,... «9», которым соответствуют девять говорящих.

[XTest,YTest] = japaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);

Модель Simulink для прогнозирования ответов

Модель Simulink для прогнозирования ответов содержит Stateful Predict блок для прогнозирования оценок и MATLAB Function блоки для загрузки последовательности входных данных в течение временных шагов.

open_system('StatefulPredictExample');

Настройка модели для моделирования

Задайте параметры конфигурации модели для входных блоков и Stateful Predict блок.

set_param('StatefulPredictExample/Input','Value','X');
set_param('StatefulPredictExample/Index','uplimit','numTimeSteps-1');
set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulPredictExample', 'SimulationMode', 'Normal');

Запуск моделирования

Вычисление ответов для JapaneseVowelsNet сеть, запустите моделирование. Оценки прогноза сохраняются в рабочей области MATLAB ®.

out = sim('StatefulPredictExample');

Постройте график оценок прогнозирования. График показывает, как изменяются оценки прогноза между временными шагами.

scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps));

classNames = string(net.Layers(end).Classes);
figure
lines = plot(scores');
xlim([1 numTimeSteps])
legend("Class " + classNames,'Location','northwest')
xlabel("Time Step")
ylabel("Score")
title("Prediction Scores Over Time Steps")

Выделите оценки прогнозирования во временных шагах для правильного класса.

trueLabel = YTest(94);
lines(trueLabel).LineWidth = 3;

Отображение последнего прогноза временного шага на гистограмме.

figure
bar(scores(:,end))
title("Final Prediction Scores")
xlabel("Class")
ylabel("Score")

Ссылки

[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием сквозных областей». Буквы распознавания образов. т. 20, № 11-13, стр. 1103-1111.

[2] Хранилище машинного обучения UCI: набор данных гласных на японском языке. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

См. также

| | |

Связанные темы