Предскажите и обновите сетевое состояние в Simulink

В этом примере показано, как предсказать ответы для обученной рекуррентной нейронной сети в Simulink® при помощи Stateful Predict блок. Этот пример использует предварительно обученную сеть долгой краткосрочной памяти (LSTM).

Загрузите предварительно обученную сеть

Загрузите JapaneseVowelsNet, предварительно обученная сеть долгой краткосрочной памяти (LSTM), обученная на японском наборе данных Гласных как описано в [1] и [2]. Эта сеть была обучена на последовательностях, отсортированных по длине последовательности с мини-пакетным размером 27.

load JapaneseVowelsNet

Просмотрите сетевую архитектуру.

net.Layers
ans = 

  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Данные о нагрузочном тесте

Загрузите японские тестовые данные Гласных. XTest массив ячеек, содержащий 370 последовательностей размерности 12 из различной длины. YTest категориальный вектор из меток "1", "2"... "9", которые соответствуют этим девяти динамикам.

[XTest,YTest] = japaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);

Модель Simulink для предсказания ответов

Модель Simulink для предсказания ответов содержит Stateful Predict блокируйтесь, чтобы предсказать баллы и MATLAB Function блоки, чтобы загрузить последовательность входных данных по временным шагам.

open_system('StatefulPredictExample');

Сконфигурируйте модель для симуляции

Установите параметры конфигурации модели для входных блоков и Stateful Predict блок.

set_param('StatefulPredictExample/Input','Value','X');
set_param('StatefulPredictExample/Index','uplimit','numTimeSteps-1');
set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulPredictExample', 'SimulationMode', 'Normal');

Запустите симуляцию

Вычислить ответы для JapaneseVowelsNet сеть, запустите симуляцию. Баллы предсказания сохранены в рабочей области MATLAB®.

out = sim('StatefulPredictExample');

Постройте баллы предсказания. График показывает, как баллы предсказания изменяются между временными шагами.

scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps));

classNames = string(net.Layers(end).Classes);
figure
lines = plot(scores');
xlim([1 numTimeSteps])
legend("Class " + classNames,'Location','northwest')
xlabel("Time Step")
ylabel("Score")
title("Prediction Scores Over Time Steps")

Подсветите, что баллы предсказания в зависимости от времени продвигаются для правильного класса.

trueLabel = YTest(94);
lines(trueLabel).LineWidth = 3;

Отобразите итоговое предсказание временного шага в столбчатой диаграмме.

figure
bar(scores(:,end))
title("Final Prediction Scores")
xlabel("Class")
ylabel("Score")

Ссылки

[1] М. Кудо, J. Тояма, и М. Шимбо. "Многомерная Классификация Кривых Используя Прохождение через области". Буквы Распознавания образов. Издание 20, № 11-13, страницы 1103-1111.

[2] Репозиторий Машинного обучения UCI: японский Набор данных Гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Смотрите также

| | |

Похожие темы