Для большинства задач глубокого обучения можно использовать предварительно обученную сеть и адаптировать ее к собственным данным. Для примера, показывающего, как использовать передачу обучения, чтобы переобучить сверточную нейронную сеть, чтобы классифицировать новый набор изображений, смотрите, Обучают Нейронную сеть для глубокого обучения Классифицировать Новые Изображения. В качестве альтернативы можно создать и обучить нейронные сети с нуля с помощью layerGraph
объекты с trainNetwork
и trainingOptions
функции.
Если trainingOptions
функция не обеспечивает опции обучения, в которых вы нуждаетесь для своей задачи, затем можно создать пользовательский учебный цикл с помощью автоматического дифференцирования. Чтобы узнать больше, смотрите, Задают Нейронную сеть для глубокого обучения для Пользовательских Учебных Циклов.
Если Deep Learning Toolbox™ не обеспечивает слои, вам нужно для вашей задачи (включая выходные слои, которые задают функции потерь), то можно создать пользовательский слой. Чтобы узнать больше, смотрите, Задают Пользовательские Слои Глубокого обучения. Для функций потерь, которые не могут быть заданы с помощью выходного слоя, можно задать потерю в пользовательском учебном цикле. Чтобы узнать больше, смотрите, Задают Функции потерь. Для сетей, которые не могут быть созданы с помощью графиков слоев, можно задать пользовательские сети как функцию. Чтобы узнать больше, смотрите Сеть Define как Функцию Модели.
Для получения дополнительной информации, о который метод обучения использовать, для который задача, смотрите, Обучают Модель Глубокого обучения в MATLAB.
dlnetwork
ОбъектДля большинства задач можно управлять деталями алгоритма настройки с помощью trainingOptions
и trainNetwork
функции. Если trainingOptions
функция не предоставляет возможности, в которых вы нуждаетесь для своей задачи (например, пользовательское расписание скорости обучения), затем можно задать собственный учебный цикл с помощью dlnetwork
объект. dlnetwork
объект позволяет вам обучать сеть, заданную как график слоев с помощью автоматического дифференцирования.
Для сетей, заданных как график слоев, можно создать dlnetwork
объект от графика слоев при помощи dlnetwork
функционируйте непосредственно.
dlnet = dlnetwork(lgraph);
Для списка слоев, поддержанных dlnetwork
объекты, смотрите раздел Supported Layers dlnetwork
страница. Для примера, показывающего, как обучить сеть с пользовательским расписанием скорости обучения, смотрите, Обучат сеть Используя Пользовательский Учебный Цикл.
Для архитектур, которые не могут быть созданы с помощью графиков слоев (например, сиамская сеть, которая требует разделяемых весов), можно задать модель в зависимости от формы [dlY1,...,dlYM] = model(parameters,dlX1,...,dlXN)
, где parameters
содержит сетевые параметры, dlX1,...,dlXN
соответствует входным данным для N
входные параметры модели и dlY1,...,dlYM
соответствует M
выходные параметры модели. Чтобы обучить модель глубокого обучения, заданную как функцию, используйте пользовательский учебный цикл. Для примера смотрите, Обучат сеть Используя Функцию Модели.
Когда вы задаете модель глубокого обучения как функцию, необходимо вручную инициализировать веса слоя. Для получения дополнительной информации смотрите, Инициализируют Настраиваемые параметры для Функции Модели.
Если вы задаете пользовательскую сеть как функцию, то функция модели должна поддержать автоматическое дифференцирование. Можно использовать следующие операции глубокого обучения. Функции, перечисленные здесь, являются только подмножеством. Для полного списка функций та поддержка dlarray
введите, см. Список Функций с Поддержкой dlarray.
Функция | Описание |
---|---|
avgpool | Средняя операция объединения выполняет субдискретизацию путем деления входа на объединение областей и вычисление среднего значения каждой области. |
batchnorm | Операция нормализации партии. нормирует входные данные через все наблюдения для каждого канала независимо. Чтобы ускорить обучение сверточной нейронной сети и уменьшать чувствительность к сетевой инициализации, используйте нормализацию партии. между сверткой и нелинейными операциями такой как relu . |
crossentropy | Перекрестная энтропийная операция вычисляет потерю перекрестной энтропии между сетевыми предсказаниями и целевыми значениями для задач классификации одно меток и мультиметок. |
crosschannelnorm | Межканальная операция нормализации использует локальные ответы в различных каналах, чтобы нормировать каждую активацию. Межканальная нормализация обычно следует за a relu операция. Межканальная нормализация также известна как локальную нормализацию ответа. |
ctc | Операция CTC вычисляет потерю ассоциативной временной классификации (CTC) между невыровненными последовательностями. |
dlconv | Операция свертки применяет скользящие фильтры к входным данным. Используйте dlconv функция для свертки глубокого обучения, сгруппированной свертки и мудрой каналом отделимой свертки. |
dltranspconv | Транспонированная операция свертки сверхдискретизировала карты функции. |
embed | Встроить операция преобразует числовые индексы в числовые векторы, где индексы соответствуют дискретным данным. Используйте вложения, чтобы сопоставить дискретные данные, такие как категориальные значения или слова к числовым векторам. |
fullyconnect | Полностью операция connect умножает вход на матрицу веса и затем добавляет вектор смещения. |
groupnorm | Операция нормализации группы нормирует входные данные через сгруппированные подмножества каналов для каждого наблюдения независимо. Чтобы ускорить обучение сверточной нейронной сети и уменьшать чувствительность к сетевой инициализации, используйте нормализацию группы между сверткой и нелинейными операциями такой как relu . |
gru | Операция закрытого текущего модуля (GRU) позволяет сети изучать зависимости между временными шагами в данных о последовательности и временных рядах. |
huber | Операция Хубера вычисляет утрату Хубера между сетевыми предсказаниями и целевыми значениями для задач регрессии. Когда 'TransitionPoint' опция равняется 1, это также известно как сглаженную потерю L1. |
instancenorm | Операция нормализации экземпляра нормирует входные данные через каждый канал для каждого наблюдения независимо. Чтобы улучшить сходимость обучения сверточная нейронная сеть и уменьшать чувствительность к сетевым гиперпараметрам, используйте нормализацию экземпляра между сверткой и нелинейными операциями такой как relu . |
layernorm | Операция нормализации слоя нормирует входные данные через все каналы для каждого наблюдения независимо. Чтобы ускорить обучение текущих и многоуровневых perceptron нейронных сетей и уменьшать чувствительность к сетевой инициализации, используйте нормализацию слоя после learnable операций, таких как LSTM и полностью соедините операции. |
leakyrelu | Текучий исправленный линейный модуль (ReLU), операция активации выполняет нелинейную пороговую операцию, где любое входное значение меньше, чем нуль умножается на фиксированный масштабный коэффициент. |
lstm | Операция долгой краткосрочной памяти (LSTM) позволяет сети изучать долгосрочные зависимости между временными шагами в данных о последовательности и временных рядах. |
maxpool | Максимальная операция объединения выполняет субдискретизацию путем деления входа на объединение областей и вычисление максимального значения каждой области. |
maxunpool | Максимальная операция необъединения не объединяет выход максимальной операции объединения путем повышающей дискретизации и дополнения нулями. |
mse | Половина операции среднеквадратической ошибки вычисляет половину потери среднеквадратической ошибки между сетевыми предсказаниями и целевыми значениями для задач регрессии. |
onehotdecode | Одногорячая операция декодирования декодирует векторы вероятности, такие как выход сети классификации, в метки классификации. Вход |
relu | Исправленный линейный модуль (ReLU), операция активации выполняет нелинейную пороговую операцию, где любое входное значение меньше, чем нуль обнуляется. |
sigmoid | Сигмоидальная операция активации применяет сигмоидальную функцию к входным данным. |
softmax | softmax операция активации применяет функцию softmax к размерности канала входных данных. |
Когда вы используете пользовательский учебный цикл, необходимо вычислить потерю в функции градиентов модели. Используйте значение потерь когда вычислительные градиенты для обновления сетевых весов. Чтобы вычислить потерю, можно использовать следующие функции.
Функция | Описание |
---|---|
softmax | softmax операция активации применяет функцию softmax к размерности канала входных данных. |
sigmoid | Сигмоидальная операция активации применяет сигмоидальную функцию к входным данным. |
crossentropy | Перекрестная энтропийная операция вычисляет потерю перекрестной энтропии между сетевыми предсказаниями и целевыми значениями для задач классификации одно меток и мультиметок. |
huber | Операция Хубера вычисляет утрату Хубера между сетевыми предсказаниями и целевыми значениями для задач регрессии. Когда 'TransitionPoint' опция равняется 1, это также известно как сглаженную потерю L1. |
mse | Половина операции среднеквадратической ошибки вычисляет половину потери среднеквадратической ошибки между сетевыми предсказаниями и целевыми значениями для задач регрессии. |
ctc | Операция CTC вычисляет потерю ассоциативной временной классификации (CTC) между невыровненными последовательностями. |
В качестве альтернативы можно использовать пользовательскую функцию потерь путем создания функции формы loss = myLoss(Y,T)
, где Y
и T
соответствуйте сетевым предсказаниям и целям, соответственно, и loss
возвращенная потеря.
Для примера, показывающего, как обучить порождающую соперничающую сеть (GAN), которая генерирует изображения с помощью пользовательской функции потерь, смотрите, Обучают Порождающую соперничающую сеть (GAN).
Когда вы обучаете модель глубокого обучения с пользовательским учебным циклом, программное обеспечение минимизирует потерю относительно настраиваемых параметров. Чтобы минимизировать потерю, программное обеспечение использует градиенты потери относительно настраиваемых параметров. Чтобы вычислить эти градиенты с помощью автоматического дифференцирования, необходимо задать функцию градиентов модели.
Для модели, заданной как dlnetwork
возразите, создайте функцию формы gradients = modelGradients(dlnet,dlX,T)
, где dlnet
сеть, dlX
сетевой вход, T
содержит цели и gradients
содержит возвращенные градиенты. Опционально, можно передать дополнительные аргументы функции градиентов (например, если функция потерь запрашивает дополнительную информацию), или возвратите дополнительные аргументы (например, метрики для графического вывода процесса обучения).
Для определенной функцией модели создайте функцию формы gradients = modelGradients(parameters,dlX,T)
, где parameters
содержит настраиваемые параметры, dlX
вход модели, T
содержит цели и gradients
содержит возвращенные градиенты. Опционально, можно передать дополнительные аргументы функции градиентов (например, если функция потерь запрашивает дополнительную информацию), или возвратите дополнительные аргументы (например, метрики для графического вывода процесса обучения).
Чтобы узнать больше об определении функций градиентов модели для пользовательских учебных циклов, смотрите Функцию Градиентов Модели Define для Пользовательского Учебного Цикла.
Чтобы оценить градиенты модели с помощью автоматического дифференцирования, используйте dlfeval
функция, которая выполняет функцию с автоматическим включенным дифференцированием. Для первого входа dlfeval
, передайте определенный функцией указатель функции градиентов модели. Для следующих входных параметров передайте необходимые переменные для функции градиентов модели. Для выходных параметров dlfeval
функционируйте, задайте те же выходные параметры как функция градиентов модели.
Чтобы обновить настраиваемые параметры с помощью градиентов, можно использовать следующие функции.
Функция | Описание |
---|---|
adamupdate | Обновите параметры с помощью адаптивной оценки момента (Адам) |
rmspropupdate | Обновите параметры с помощью корневого среднеквадратического распространения (RMSProp) |
sgdmupdate | Обновите параметры с помощью стохастического градиентного спуска с импульсом (SGDM) |
dlupdate | Обновите параметры с помощью пользовательской функции |
dlarray
| dlfeval
| dlgradient
| dlnetwork