resetState

Сбросьте состояние рекуррентной нейронной сети

Описание

пример

updatedNet = resetState(recNet) сбрасывает состояние рекуррентной нейронной сети (например, сеть LSTM) к начальному состоянию.

Примеры

свернуть все

Сбросьте сетевое состояние между предсказаниями последовательности.

Загрузите JapaneseVowelsNet, предварительно обученная сеть долгой краткосрочной памяти (LSTM), обученная на японском наборе данных Гласных как описано в [1] и [2]. Эта сеть была обучена на последовательностях, отсортированных по длине последовательности с мини-пакетным размером 27.

load JapaneseVowelsNet

Просмотрите сетевую архитектуру.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Загрузите тестовые данные.

[XTest,YTest] = japaneseVowelsTestData;

Классифицируйте последовательность и обновите сетевое состояние. Для воспроизводимости установите rng к 'shuffle'.

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,X);
label
label = categorical
     3 

Классифицируйте другую последовательность с помощью обновленной сети.

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

Сравните итоговое предсказание с истинной меткой.

trueLabel = YTest(1)
trueLabel = categorical
     1 

Обновленное состояние сети, возможно, негативно влияло на классификацию. Сбросьте сетевое состояние и предскажите на последовательности снова.

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

Входные параметры

свернуть все

Обученная рекуррентная нейронная сеть в виде SeriesNetwork или DAGNetwork объект. Можно получить обучивший сеть путем импорта предварительно обученной сети или по образованию собственная сеть с помощью trainNetwork функция.

recNet рекуррентная нейронная сеть. Это должно иметь по крайней мере один текущий слой (например, сеть LSTM). Если входная сеть не является текущей сетью, то функция не оказывает влияния и возвращает входную сеть.

Выходные аргументы

свернуть все

Сеть Updated. updatedNet тот же тип сети как входная сеть.

Если входная сеть не является текущей сетью, то функция не оказывает влияния и возвращает входную сеть.

Ссылки

[1] М. Кудо, J. Тояма, и М. Шимбо. "Многомерная Классификация Кривых Используя Прохождение через области". Буквы Распознавания образов. Издание 20, № 11-13, страницы 1103-1111.

[2] Репозиторий Машинного обучения UCI: японский Набор данных Гласных. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Расширенные возможности

Введенный в R2017b