resetState

Сброс состояния рекуррентной нейронной сети

Описание

пример

updatedNet = resetState(recNet) устанавливает состояние рекуррентной нейронной сети (для примера - сети LSTM) в начальное состояние.

Примеры

свернуть все

Сбросьте сетевое состояние между предсказаниями последовательности.

Загрузка JapaneseVowelsNetпредварительно обученная сеть долгой краткосрочной памяти (LSTM), обученная на наборе данных японских гласных, как описано в [1] и [2]. Эта сеть была обучена на последовательностях, отсортированных по длине последовательности с размером мини-пакета 27.

load JapaneseVowelsNet

Просмотрите сетевую архитектуру.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Загрузите тестовые данные.

[XTest,YTest] = japaneseVowelsTestData;

Классифицируйте последовательность и обновляйте состояние сети. Для воспроизводимости задайте rng на 'shuffle'.

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,X);
label
label = categorical
     3 

Классификация другой последовательности с помощью обновленной сети.

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

Сравните окончательное предсказание с истинной меткой.

trueLabel = YTest(1)
trueLabel = categorical
     1 

Обновленное состояние сети, возможно, негативно повлияло на классификацию. Сбросьте сетевое состояние и снова предсказайте последовательность.

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

Входные параметры

свернуть все

Обученная рекуррентная нейронная сеть, заданная как SeriesNetwork или DAGNetwork объект. Вы можете получить обученную сеть, импортировав предварительно обученную сеть или обучив свою собственную сеть с помощью trainNetwork функция.

recNet является рекуррентной нейронной сетью. Он должен иметь по крайней мере один рекуррентный слой (для примера - сеть LSTM). Если вход сеть не является рекуррентной, то функция не имеет эффекта и возвращает вход сеть.

Выходные аргументы

свернуть все

Обновленная сеть. updatedNet - тот же тип сети, что и входная сеть.

Если вход сеть не является рекуррентной, то функция не имеет эффекта и возвращает вход сеть.

Ссылки

[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием областей». Распознавание Букв. Том 20, № 11-13, стр. 1103-1111.

[2] UCI Machine Learning Repository: Японский набор данных гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Расширенные возможности

..
Введенный в R2017b