resetState

Сбросьте состояние рекуррентной нейронной сети

Синтаксис

updatedNet = resetState(recNet)

Описание

пример

updatedNet = resetState(recNet) сбрасывает состояние рекуррентной нейронной сети (например, сеть LSTM) к начальному состоянию.

Примеры

свернуть все

Сбросьте сетевое состояние между прогнозами последовательности.

Чтобы воспроизвести результаты в этом примере, установите rng на 'default'.

rng('default')

Загрузите JapaneseVowelsNet, предварительно обученная сеть долгой краткосрочной памяти (LSTM), обученная на японском наборе данных Гласных, как описано в [1] и [2]. Эта сеть была обучена на последовательностях, отсортированных по длине последовательности с мини-пакетным размером 27.

load JapaneseVowelsNet

Просмотрите сетевую архитектуру.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Загрузите тестовые данные.

load JapaneseVowelsTest

Классифицируйте последовательность и обновите сетевое состояние. Для воспроизводимости, набор rng к 'shuffle'.

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,XTest{94});

Классифицируйте другую последовательность с помощью обновленной сети.

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

Сравните итоговый прогноз с истинной меткой.

trueLabel = YTest(1)
trueLabel = categorical
     1 

Обновленное состояние сети, возможно, негативно влияло на классификацию. Сбросьте сетевое состояние и предскажите на последовательности снова.

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

Входные параметры

свернуть все

Обученная рекуррентная нейронная сеть, заданная как объект SeriesNetwork. Можно получить обучивший сеть путем импорта предварительно обученной сети или по образованию собственная сеть с помощью функции trainNetwork.

recNet является рекуррентной нейронной сетью. Это должно иметь по крайней мере один текущий слой (например, сеть LSTM).

Выходные аргументы

свернуть все

Обновленная сеть, возвращенная как объект SeriesNetwork.

Ссылки

[1] М. Кудо, J. Тояма, и М. Шимбо. "Многомерная Классификация Кривых Используя Прохождение через области". Буквы Распознавания образов. Издание 20, № 11-13, страницы 1103-1111.

[2] Репозиторий машинного обучения UCI: японский набор данных гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Введенный в R2017b

Для просмотра документации необходимо авторизоваться на сайте