exponenta event banner

Определение пользовательских циклов обучения, функций потери и сетей

Для выполнения большинства задач глубокого обучения можно использовать предварительно подготовленную сеть и адаптировать ее к собственным данным. Пример, показывающий, как использовать transfer learning для переподготовки сверточной нейронной сети для классификации нового набора изображений, см. в разделе Train Deep Learning Network to Classify New Images. Кроме того, можно создавать и обучать сети с нуля с помощью layerGraph объекты с trainNetwork и trainingOptions функции.

Если trainingOptions функция не предоставляет возможности обучения, необходимые для выполнения задачи, после чего можно создать индивидуальный цикл обучения с помощью автоматического дифференцирования. Дополнительные сведения см. в разделе Определение сети глубокого обучения для пользовательских циклов обучения.

Если Deep Learning Toolbox™ не предоставляет слои, необходимые для выполнения задачи (включая выходные слои, определяющие функции потери), можно создать пользовательский слой. Дополнительные сведения см. в разделе Определение пользовательских слоев глубокого обучения. Для функций потери, которые не могут быть заданы с помощью выходного уровня, можно указать потерю в пользовательском учебном цикле. Дополнительные сведения см. в разделе Определение функций потерь. Для сетей, которые не могут быть созданы с помощью графиков слоев, можно определить пользовательские сети как функцию. Дополнительные сведения см. в разделе Определение сети как функции модели.

Дополнительные сведения о том, какой метод обучения использовать для какой задачи, см. в разделе Обучение модели глубокого обучения в MATLAB.

Определение сети глубокого обучения для пользовательских циклов обучения

Определить сеть как dlnetwork Объект

Для большинства задач можно управлять подробностями алгоритма обучения с помощью trainingOptions и trainNetwork функции. Если trainingOptions функция не предоставляет опции, необходимые для выполнения задачи (например, индивидуальный график обучения), после чего можно определить собственный индивидуальный цикл обучения с помощью dlnetwork объект. A dlnetwork объект позволяет обучить сеть, заданную как график слоев, с помощью автоматического дифференцирования.

Для сетей, указанных как график слоев, можно создать dlnetwork объект из графа слоев с помощью dlnetwork непосредственно функция.

dlnet = dlnetwork(lgraph);

Список слоев, поддерживаемых dlnetwork см. раздел «Поддерживаемые слои» dlnetwork страница. Пример обучения сети с использованием пользовательского графика обучения см. в разделе Обучение сети с использованием пользовательского цикла обучения.

Определение сети как функции модели

Для архитектур, которые не могут быть созданы с использованием графиков слоев (например, сиамская сеть, требующая общих весов), можно определить модель как функцию формы [dlY1,...,dlYM] = model(parameters,dlX1,...,dlXN), где parameters содержит параметры сети, dlX1,...,dlXN соответствует входным данным для N входные данные модели и dlY1,...,dlYM соответствует M выходные данные модели. Для обучения модели глубокого обучения, определенной как функция, используйте индивидуальный цикл обучения. Пример см. в разделе Сеть поездов с использованием функции модели.

При определении модели глубокого обучения как функции необходимо вручную инициализировать веса слоев. Дополнительные сведения см. в разделе Инициализация обучаемых параметров для функции модели.

Если пользовательская сеть определена как функция, функция модели должна поддерживать автоматическое дифференцирование. Можно использовать следующие операции глубокого обучения. Перечисленные здесь функции являются только подмножеством. Полный список функций, поддерживающих dlarray ввод, см. Список функций с поддержкой dlarray.

ФункцияОписание
avgpoolОперация среднего объединения выполняет понижающую дискретизацию путем разделения входных данных на области объединения и вычисления среднего значения каждой области.
batchnormОперация пакетной нормализации нормализует входные данные по всем наблюдениям для каждого канала независимо. Для ускорения обучения сверточной нейронной сети и снижения чувствительности к инициализации сети используйте пакетную нормализацию между сверткой и нелинейными операциями, такими как relu.
crossentropyОперация перекрестной энтропии вычисляет потери перекрестной энтропии между предсказаниями сети и целевыми значениями для задач классификации с одной и несколькими метками.
crosschannelnormОперация нормализации кросс-канала использует локальные ответы в различных каналах для нормализации каждой активации. Кросс-канальная нормализация обычно следует за relu операция. Кросс-канальная нормализация также известна как локальная нормализация ответа.
ctcОперация CTC вычисляет потери временной классификации соединения (CTC) между неориентированными последовательностями.
dlconvОперация свертки применяет скользящие фильтры к входным данным. Используйте dlconv функция для свертки глубокого обучения, свертки с группировкой и свертки с разделением по каналам.
dltranspconvТранспонированная операция свертки увеличивает число карт элементов.
embedОперация встраивания преобразует числовые индексы в числовые векторы, где индексы соответствуют дискретным данным. Встраивание используется для отображения дискретных данных, таких как категориальные значения или слова, в числовые векторы.
fullyconnectОперация полного соединения умножает входной сигнал на весовую матрицу и затем добавляет вектор смещения.
groupnormОперация групповой нормализации нормализует входные данные по сгруппированным подмножествам каналов для каждого наблюдения независимо. Для ускорения обучения сверточной нейронной сети и снижения чувствительности к инициализации сети используйте групповую нормализацию между сверткой и нелинейными операциями, такими как relu.
gruРабота стробируемого повторяющегося блока (ГРУ) позволяет сети узнать зависимости между временными шагами во временных рядах и данными последовательности.
huberОперация Huber вычисляет потери Huber между предсказаниями сети и целевыми значениями для задач регрессии. Когда 'TransitionPoint' опция 1, это также известно как плавная потеря L1.
instancenormОперация нормализации экземпляра нормализует входные данные по каждому каналу для каждого наблюдения независимо. Чтобы улучшить сходимость тренировки сверточной нейронной сети и снизить чувствительность к гиперпараметрам сети, используйте нормализацию экземпляра между сверткой и нелинейными операциями, такими как relu.
layernormОперация нормализации уровня нормализует входные данные по всем каналам для каждого наблюдения независимо. Чтобы ускорить обучение повторяющихся и многослойных нейронных сетей перцептрона и снизить чувствительность к инициализации сети, используйте нормализацию уровня после обучаемых операций, таких как LSTM, и полностью соединяйте операции.
leakyreluОперация активации выпрямленного линейного блока с утечкой (ReLU) выполняет нелинейную пороговую операцию, где любое входное значение меньше нуля умножается на фиксированный масштабный коэффициент.
lstmДлительная кратковременная память (LSTM) позволяет сети узнать долгосрочные зависимости между временными шагами во временных рядах и данными последовательности.
maxpoolОперация максимального объединения в пул выполняет понижающую дискретизацию путем разделения входных данных на области объединения и вычисления максимального значения каждой области.
maxunpoolОперация максимальной распаковки распаковывает выходные данные операции максимальной распаковки путем увеличения дискретизации и заполнения нулями.
mseОперация вычисления среднеквадратичной ошибки вычисляет среднеквадратичную потерю ошибки между предсказаниями сети и целевыми значениями для задач регрессии.
onehotdecode

Операция однократного декодирования декодирует векторы вероятности, такие как выходные данные сети классификации, в метки классификации.

Вход A может быть dlarray. Если A форматируется, функция игнорирует формат данных.

reluОперация активации выпрямленного линейного блока (ReLU) выполняет нелинейную пороговую операцию, где любое входное значение меньше нуля устанавливается равным нулю.
sigmoidОперация активации сигмоида применяет сигмоидальную функцию к входным данным.
softmaxОперация активации softmax применяет функцию softmax к размерности канала входных данных.

Определение функций потерь

При использовании пользовательского цикла обучения необходимо вычислить потери в функции градиентов модели. Используйте значение потерь при вычислении градиентов для обновления весов сети. Для вычисления потерь можно использовать следующие функции.

ФункцияОписание
softmaxОперация активации softmax применяет функцию softmax к размерности канала входных данных.
sigmoidОперация активации сигмоида применяет сигмоидальную функцию к входным данным.
crossentropyОперация перекрестной энтропии вычисляет потери перекрестной энтропии между предсказаниями сети и целевыми значениями для задач классификации с одной и несколькими метками.
huberОперация Huber вычисляет потери Huber между предсказаниями сети и целевыми значениями для задач регрессии. Когда 'TransitionPoint' опция 1, это также известно как плавная потеря L1.
mseОперация вычисления среднеквадратичной ошибки вычисляет среднеквадратичную потерю ошибки между предсказаниями сети и целевыми значениями для задач регрессии.
ctcОперация CTC вычисляет потери временной классификации соединения (CTC) между неориентированными последовательностями.

Кроме того, можно использовать пользовательскую функцию потерь, создав функцию формы loss = myLoss(Y,T), где Y и T соответствуют предсказаниям сети и целям, соответственно, и loss является возвращенной потерей.

Пример обучения генеративной состязательной сети (GAN), которая генерирует образы с помощью пользовательской функции потерь, см. в разделе Подготовка генеративной состязательной сети (GAN).

Обновление обучаемых параметров с помощью автоматической дифференциации

При обучении модели глубокого обучения с помощью пользовательского цикла обучения программное обеспечение минимизирует потери по отношению к обучаемым параметрам. Для минимизации потерь программное обеспечение использует градиенты потерь относительно обучаемых параметров. Чтобы рассчитать эти градиенты с помощью автоматического дифференцирования, необходимо определить функцию градиентов модели.

Определение функции градиентов модели

Для модели, указанной как dlnetwork объект, создайте функцию формы gradients = modelGradients(dlnet,dlX,T), где dlnet - сеть, dlX - сетевой вход, T содержит целевые объекты, и gradients содержит возвращенные градиенты. Дополнительно можно передать дополнительные аргументы функции градиентов (например, если функция потерь требует дополнительной информации) или вернуть дополнительные аргументы (например, метрики для построения графика хода обучения).

Для модели, указанной как функция, создайте функцию формы gradients = modelGradients(parameters,dlX,T), где parameters содержит обучаемые параметры, dlX - входной сигнал модели, T содержит целевые объекты, и gradients содержит возвращенные градиенты. Дополнительно можно передать дополнительные аргументы функции градиентов (например, если функция потерь требует дополнительной информации) или вернуть дополнительные аргументы (например, метрики для построения графика хода обучения).

Дополнительные сведения об определении функций градиентов модели для пользовательских циклов обучения см. в разделе Определение функции градиентов модели для пользовательского цикла обучения.

Обновить обучаемые параметры

Чтобы оценить градиенты модели с помощью автоматического дифференцирования, используйте dlfeval функция, которая оценивает функцию с включенным автоматическим дифференцированием. Для первого ввода dlfevalпередайте функцию градиентов модели, заданную как дескриптор функции. Для следующих входных данных передайте требуемые переменные для функции градиентов модели. Для выходных данных dlfeval укажите те же выходные данные, что и функция градиентов модели.

Для обновления обучаемых параметров с помощью градиентов можно использовать следующие функции.

ФункцияОписание
adamupdateОбновление параметров с использованием адаптивной оценки момента (Адам)
rmspropupdateОбновление параметров с использованием среднеквадратичного распространения корня (RMSProp)
sgdmupdateОбновление параметров с помощью стохастического градиентного спуска с импульсом (SGDM)
dlupdateОбновление параметров с помощью пользовательской функции

См. также

| | |

Связанные темы