Задайте пользовательские учебные циклы, функции потерь и сети

Для большинства задач глубокого обучения можно использовать предварительно обученную сеть и адаптировать ее к собственным данным. Для примера, показывающего, как использовать передачу обучения, чтобы переобучить сверточную нейронную сеть, чтобы классифицировать новый набор изображений, смотрите, Обучают Нейронную сеть для глубокого обучения Классифицировать Новые Изображения. В качестве альтернативы можно создать и обучить нейронные сети с нуля с помощью layerGraph объекты с trainNetwork и trainingOptions функции.

Если trainingOptions функция не обеспечивает опции обучения, в которых вы нуждаетесь для своей задачи, затем можно создать пользовательский учебный цикл с помощью автоматического дифференцирования. Чтобы узнать больше, смотрите, Задают Нейронную сеть для глубокого обучения для Пользовательских Учебных Циклов.

Если Deep Learning Toolbox™ не обеспечивает слои, вам нужно для вашей задачи (включая выходные слои, которые задают функции потерь), то можно создать пользовательский слой. Чтобы узнать больше, смотрите, Задают Пользовательские Слои Глубокого обучения. Для функций потерь, которые не могут быть заданы с помощью выходного слоя, можно задать потерю в пользовательском учебном цикле. Чтобы узнать больше, смотрите, Задают Функции потерь. Для сетей, которые не могут быть созданы с помощью графиков слоев, можно задать пользовательские сети как функцию. Чтобы узнать больше, смотрите Сеть Define как Функцию Модели.

Для получения дополнительной информации, о который метод обучения использовать, для который задача, смотрите, Обучают Модель Глубокого обучения в MATLAB.

Задайте нейронную сеть для глубокого обучения для пользовательских учебных циклов

Сеть Define как dlnetwork Объект

Для большинства задач можно управлять деталями алгоритма настройки с помощью trainingOptions и trainNetwork функции. Если trainingOptions функция не предоставляет возможности, в которых вы нуждаетесь для своей задачи (например, пользовательское расписание скорости обучения), затем можно задать собственный учебный цикл с помощью dlnetwork объект. dlnetwork объект позволяет вам обучать сеть, заданную как график слоев с помощью автоматического дифференцирования.

Для сетей, заданных как график слоев, можно создать dlnetwork объект от графика слоев при помощи dlnetwork функционируйте непосредственно.

dlnet = dlnetwork(lgraph);

Для списка слоев, поддержанных dlnetwork объекты, смотрите раздел Supported Layers dlnetwork страница. Для примера, показывающего, как обучить сеть с пользовательским расписанием скорости обучения, смотрите, Обучат сеть Используя Пользовательский Учебный Цикл.

Сеть Define как функция модели

Для архитектур, которые не могут быть созданы с помощью графиков слоев (например, сиамская сеть, которая требует разделяемых весов), можно задать модель в зависимости от формы [dlY1,...,dlYM] = model(parameters,dlX1,...,dlXN), где parameters содержит сетевые параметры, dlX1,...,dlXN соответствует входным данным для N входные параметры модели и dlY1,...,dlYM соответствует M выходные параметры модели. Чтобы обучить модель глубокого обучения, заданную как функцию, используйте пользовательский учебный цикл. Для примера смотрите, Обучат сеть Используя Функцию Модели.

Когда вы задаете модель глубокого обучения как функцию, необходимо вручную инициализировать веса слоя. Для получения дополнительной информации смотрите, Инициализируют Настраиваемые параметры для Функции Модели.

Если вы задаете пользовательскую сеть как функцию, то функция модели должна поддержать автоматическое дифференцирование. Можно использовать следующие операции глубокого обучения. Функции, перечисленные здесь, являются только подмножеством. Для полного списка функций та поддержка dlarray введите, см. Список Функций с Поддержкой dlarray.

ФункцияОписание
avgpoolСредняя операция объединения выполняет субдискретизацию путем деления входа на объединение областей и вычисление среднего значения каждой области.
batchnormОперация нормализации партии. нормирует входные данные через все наблюдения для каждого канала независимо. Чтобы ускорить обучение сверточной нейронной сети и уменьшать чувствительность к сетевой инициализации, используйте нормализацию партии. между сверткой и нелинейными операциями такой как relu.
crossentropyПерекрестная энтропийная операция вычисляет потерю перекрестной энтропии между сетевыми предсказаниями и целевыми значениями для задач классификации одно меток и мультиметок.
crosschannelnormМежканальная операция нормализации использует локальные ответы в различных каналах, чтобы нормировать каждую активацию. Межканальная нормализация обычно следует за a relu операция. Межканальная нормализация также известна как локальную нормализацию ответа.
ctcОперация CTC вычисляет потерю ассоциативной временной классификации (CTC) между невыровненными последовательностями.
dlconvОперация свертки применяет скользящие фильтры к входным данным. Используйте dlconv функция для свертки глубокого обучения, сгруппированной свертки и мудрой каналом отделимой свертки.
dltranspconvТранспонированная операция свертки сверхдискретизировала карты функции.
embedВстроить операция преобразует числовые индексы в числовые векторы, где индексы соответствуют дискретным данным. Используйте вложения, чтобы сопоставить дискретные данные, такие как категориальные значения или слова к числовым векторам.
fullyconnectПолностью операция connect умножает вход на матрицу веса и затем добавляет вектор смещения.
groupnormОперация нормализации группы нормирует входные данные через сгруппированные подмножества каналов для каждого наблюдения независимо. Чтобы ускорить обучение сверточной нейронной сети и уменьшать чувствительность к сетевой инициализации, используйте нормализацию группы между сверткой и нелинейными операциями такой как relu.
gruОперация закрытого текущего модуля (GRU) позволяет сети изучать зависимости между временными шагами в данных о последовательности и временных рядах.
huberОперация Хубера вычисляет утрату Хубера между сетевыми предсказаниями и целевыми значениями для задач регрессии. Когда 'TransitionPoint' опция равняется 1, это также известно как сглаженную потерю L1.
instancenormОперация нормализации экземпляра нормирует входные данные через каждый канал для каждого наблюдения независимо. Чтобы улучшить сходимость обучения сверточная нейронная сеть и уменьшать чувствительность к сетевым гиперпараметрам, используйте нормализацию экземпляра между сверткой и нелинейными операциями такой как relu.
layernormОперация нормализации слоя нормирует входные данные через все каналы для каждого наблюдения независимо. Чтобы ускорить обучение текущих и многоуровневых perceptron нейронных сетей и уменьшать чувствительность к сетевой инициализации, используйте нормализацию слоя после learnable операций, таких как LSTM и полностью соедините операции.
leakyreluТекучий исправленный линейный модуль (ReLU), операция активации выполняет нелинейную пороговую операцию, где любое входное значение меньше, чем нуль умножается на фиксированный масштабный коэффициент.
lstmОперация долгой краткосрочной памяти (LSTM) позволяет сети изучать долгосрочные зависимости между временными шагами в данных о последовательности и временных рядах.
maxpoolМаксимальная операция объединения выполняет субдискретизацию путем деления входа на объединение областей и вычисление максимального значения каждой области.
maxunpoolМаксимальная операция необъединения не объединяет выход максимальной операции объединения путем повышающей дискретизации и дополнения нулями.
mseПоловина операции среднеквадратической ошибки вычисляет половину потери среднеквадратической ошибки между сетевыми предсказаниями и целевыми значениями для задач регрессии.
onehotdecode

Одногорячая операция декодирования декодирует векторы вероятности, такие как выход сети классификации, в метки классификации.

Вход A может быть dlarray. Если A отформатирован, функция игнорирует формат данных.

reluИсправленный линейный модуль (ReLU), операция активации выполняет нелинейную пороговую операцию, где любое входное значение меньше, чем нуль обнуляется.
sigmoidСигмоидальная операция активации применяет сигмоидальную функцию к входным данным.
softmaxsoftmax операция активации применяет функцию softmax к размерности канала входных данных.

Задайте функции потерь

Когда вы используете пользовательский учебный цикл, необходимо вычислить потерю в функции градиентов модели. Используйте значение потерь когда вычислительные градиенты для обновления сетевых весов. Чтобы вычислить потерю, можно использовать следующие функции.

ФункцияОписание
softmaxsoftmax операция активации применяет функцию softmax к размерности канала входных данных.
sigmoidСигмоидальная операция активации применяет сигмоидальную функцию к входным данным.
crossentropyПерекрестная энтропийная операция вычисляет потерю перекрестной энтропии между сетевыми предсказаниями и целевыми значениями для задач классификации одно меток и мультиметок.
huberОперация Хубера вычисляет утрату Хубера между сетевыми предсказаниями и целевыми значениями для задач регрессии. Когда 'TransitionPoint' опция равняется 1, это также известно как сглаженную потерю L1.
mseПоловина операции среднеквадратической ошибки вычисляет половину потери среднеквадратической ошибки между сетевыми предсказаниями и целевыми значениями для задач регрессии.
ctcОперация CTC вычисляет потерю ассоциативной временной классификации (CTC) между невыровненными последовательностями.

В качестве альтернативы можно использовать пользовательскую функцию потерь путем создания функции формы loss = myLoss(Y,T), где Y и T соответствуйте сетевым предсказаниям и целям, соответственно, и loss возвращенная потеря.

Для примера, показывающего, как обучить порождающую соперничающую сеть (GAN), которая генерирует изображения с помощью пользовательской функции потерь, смотрите, Обучают Порождающую соперничающую сеть (GAN).

Обновите настраиваемые параметры Используя автоматическое дифференцирование

Когда вы обучаете модель глубокого обучения с пользовательским учебным циклом, программное обеспечение минимизирует потерю относительно настраиваемых параметров. Чтобы минимизировать потерю, программное обеспечение использует градиенты потери относительно настраиваемых параметров. Чтобы вычислить эти градиенты с помощью автоматического дифференцирования, необходимо задать функцию градиентов модели.

Функция градиентов модели Define

Для модели, заданной как dlnetwork возразите, создайте функцию формы gradients = modelGradients(dlnet,dlX,T), где dlnet сеть, dlX сетевой вход, T содержит цели и gradients содержит возвращенные градиенты. Опционально, можно передать дополнительные аргументы функции градиентов (например, если функция потерь запрашивает дополнительную информацию), или возвратите дополнительные аргументы (например, метрики для графического вывода процесса обучения).

Для определенной функцией модели создайте функцию формы gradients = modelGradients(parameters,dlX,T), где parameters содержит настраиваемые параметры, dlX вход модели, T содержит цели и gradients содержит возвращенные градиенты. Опционально, можно передать дополнительные аргументы функции градиентов (например, если функция потерь запрашивает дополнительную информацию), или возвратите дополнительные аргументы (например, метрики для графического вывода процесса обучения).

Чтобы узнать больше об определении функций градиентов модели для пользовательских учебных циклов, смотрите Функцию Градиентов Модели Define для Пользовательского Учебного Цикла.

Обновите настраиваемые параметры

Чтобы оценить градиенты модели с помощью автоматического дифференцирования, используйте dlfeval функция, которая выполняет функцию с автоматическим включенным дифференцированием. Для первого входа dlfeval, передайте определенный функцией указатель функции градиентов модели. Для следующих входных параметров передайте необходимые переменные для функции градиентов модели. Для выходных параметров dlfeval функционируйте, задайте те же выходные параметры как функция градиентов модели.

Чтобы обновить настраиваемые параметры с помощью градиентов, можно использовать следующие функции.

ФункцияОписание
adamupdateОбновите параметры с помощью адаптивной оценки момента (Адам)
rmspropupdateОбновите параметры с помощью корневого среднеквадратического распространения (RMSProp)
sgdmupdateОбновите параметры с помощью стохастического градиентного спуска с импульсом (SGDM)
dlupdateОбновите параметры с помощью пользовательской функции

Смотрите также

| | |

Похожие темы