exponenta event banner

resetState

Сброс состояния рекуррентной нейронной сети

Описание

пример

updatedNet = resetState(recNet) сбрасывает состояние рекуррентной нейронной сети (например, сети LSTM) в исходное состояние.

Примеры

свернуть все

Сброс состояния сети между прогнозами последовательности.

Груз JapaneseVowelsNet, предварительно подготовленная сеть долговременной памяти (LSTM), обученная на наборе данных японских гласных, как описано в [1] и [2]. Эта сеть была обучена последовательностям, отсортированным по длине последовательности с размером мини-партии 27.

load JapaneseVowelsNet

Просмотр сетевой архитектуры.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Загрузите данные теста.

[XTest,YTest] = japaneseVowelsTestData;

Классифицируйте последовательность и обновите состояние сети. Для воспроизводимости, установка rng кому 'shuffle'.

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,X);
label
label = categorical
     3 

Классифицируйте другую последовательность с помощью обновленной сети.

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

Сравните окончательный прогноз с истинной меткой.

trueLabel = YTest(1)
trueLabel = categorical
     1 

Обновленное состояние сети, возможно, негативно повлияло на классификацию. Сбросьте состояние сети и повторите прогнозирование последовательности.

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

Входные аргументы

свернуть все

Обученная рецидивирующая нейронная сеть, указанная как SeriesNetwork или DAGNetwork объект. Обученную сеть можно получить, импортируя предварительно обученную сеть или обучая собственную сеть с помощью trainNetwork функция.

recNet является рецидивирующей нейронной сетью. Он должен иметь по крайней мере один повторяющийся уровень (например, сеть LSTM). Если входная сеть не является повторяющейся сетью, то функция не оказывает никакого влияния и возвращает входную сеть.

Выходные аргументы

свернуть все

Обновленная сеть. updatedNet является сетью того же типа, что и входная сеть.

Если входная сеть не является повторяющейся сетью, то функция не оказывает никакого влияния и возвращает входную сеть.

Ссылки

[1] М. Кудо, Дж. Тояма и М. Симбо. «Многомерная классификация кривых с использованием сквозных областей». Буквы распознавания образов. т. 20, № 11-13, стр. 1103-1111.

[2] Хранилище машинного обучения UCI: набор данных гласных на японском языке. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Расширенные возможности

..
Представлен в R2017b