resetState

Сбросьте состояние рекуррентной нейронной сети

Описание

пример

updatedNet = resetState(recNet) сбрасывает состояние рекуррентной нейронной сети (например, сеть LSTM) к начальному состоянию.

Примеры

свернуть все

Сбросьте сетевое состояние между прогнозами последовательности.

Загрузите JapaneseVowelsNet, предварительно обученная сеть долгой краткосрочной памяти (LSTM), обученная на японском наборе данных Гласных как описано в [1] и [2]. Эта сеть была обучена на последовательностях, отсортированных по длине последовательности с мини-пакетным размером 27.

load JapaneseVowelsNet

Просмотрите сетевую архитектуру.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Загрузите тестовые данные.

[XTest,YTest] = japaneseVowelsTestData;

Классифицируйте последовательность и обновите сетевое состояние. Для воспроизводимости, набор rng к 'shuffle'.

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,X);
label
label = categorical
     3 

Классифицируйте другую последовательность с помощью обновленной сети.

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

Сравните итоговый прогноз с истинной меткой.

trueLabel = YTest(1)
trueLabel = categorical
     1 

Обновленное состояние сети, возможно, негативно влияло на классификацию. Сбросьте сетевое состояние и предскажите на последовательности снова.

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

Входные параметры

свернуть все

Обученная рекуррентная нейронная сеть, заданная как SeriesNetwork или DAGNetwork объект. Можно получить обучивший сеть путем импорта предварительно обученной сети или по образованию собственная сеть с помощью trainNetwork функция.

recNet рекуррентная нейронная сеть. Это должно иметь по крайней мере один текущий слой (например, сеть LSTM).

Выходные аргументы

свернуть все

Сеть Updated. updatedNet тот же тип сети как входная сеть.

Ссылки

[1] М. Кудо, J. Тояма, и М. Шимбо. "Многомерная Классификация Кривых Используя Прохождение через области". Буквы Распознавания образов. Издание 20, № 11-13, страницы 1103-1111.

[2] Репозиторий Машинного обучения UCI: японский Набор данных Гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Введенный в R2017b