Задайте пользовательские циклы обучения, функции потерь и сети

Для большинства задач глубокого обучения можно использовать предварительно обученную сеть и адаптировать ее к собственным данным. Для примера, показывающего, как использовать передачу обучения для переобучения сверточной нейронной сети для классификации нового набора изображений, смотрите Train Нейронной сети для глубокого обучения для классификации новых изображений. Также можно создавать и обучать сети с нуля, используя layerGraph объекты с trainNetwork и trainingOptions функций.

Если trainingOptions функция не предоставляет опций обучения, которые вам нужны для вашей задачи, тогда можно создать пользовательский цикл обучения с помощью автоматической дифференциации. Дополнительные сведения см. в разделе «Определение нейронной сети для глубокого обучения для пользовательских циклов обучения».

Если Deep Learning Toolbox™ не предоставляет слои, необходимые для вашей задачи (включая выходные слои, которые задают функции потерь), то можно создать пользовательский слой. Дополнительные сведения см. в разделе Определение пользовательских слоев глубокого обучения. Для функций потерь, которые нельзя задать с помощью слоя выхода, можно задать потери в пользовательском цикле обучения. Дополнительные сведения см. в разделе «Задание функций потерь». Для сетей, которые не могут быть созданы с использованием графиков слоев, можно задать пользовательские сети как функцию. Дополнительные сведения см. в разделе Определение сети как функции модели.

Для получения дополнительной информации о том, какой метод обучения использовать для какой задачи, смотрите Обучите Модель Глубокого Обучения в MATLAB.

Задайте Нейронную сеть для глубокого обучения для пользовательских циклов обучения

Определите сеть как dlnetwork Объект

Для большинства задач можно управлять деталями алгоритма настройки, используя trainingOptions и trainNetwork функций. Если trainingOptions функция не предоставляет опции, необходимые для вашей задачи (для примера, пользовательское расписание скорости обучения), тогда можно задать свой собственный пользовательский цикл обучения с помощью dlnetwork объект. A dlnetwork объект позволяет вам обучать сеть, заданную как график слоев, используя автоматическую дифференциацию.

Для сетей, заданных в качестве графика слоев, можно создать dlnetwork объект из графика слоев при помощи dlnetwork функция непосредственно.

dlnet = dlnetwork(lgraph);

Список слоев, поддерживаемых dlnetwork объекты см. раздел «Поддерживаемые слои» dlnetwork страница. Пример, показывающий, как обучить сеть с пользовательским расписанием скорости обучения, см. в Train сети с использованием пользовательского цикла обучения.

Определите сеть как функцию модели

Для архитектур, которые не могут быть созданы с использованием графиков слоев (для примера, сиамской сети, которая требует общих весов), можно задать модель как функцию от формы [dlY1,...,dlYM] = model(parameters,dlX1,...,dlXN), где parameters содержит параметры сети, dlX1,...,dlXN соответствует входным данным для N входы модели и dlY1,...,dlYM соответствует M выходы модели. Чтобы обучить модель глубокого обучения, заданную как функция, используйте пользовательский цикл обучения. Для получения примера смотрите Train Network Using Модели Function.

Когда вы задаете модель глубокого обучения как функцию, необходимо вручную инициализировать веса слоев. Для получения дополнительной информации см. «Инициализация настраиваемых параметров для функции модели».

Если вы задаете пользовательскую сеть как функцию, то функция модели должна поддерживать автоматическую дифференциацию. Можно использовать следующие операции глубокого обучения. Перечисленные здесь функции являются только подмножеством. Полный список функций, поддерживающих dlarray вход, см. Список функций с поддержкой dlarray.

ФункцияОписание
avgpoolОперация среднего объединения выполняет понижающую дискретизацию, разделяя вход на области объединения и вычисляя среднее значение каждой области.
batchnormОперация нормализации партии . нормализует входные данные по всем наблюдениям для каждого канала независимо. Чтобы ускорить обучение сверточной нейронной сети и уменьшить чувствительность к инициализации сети, используйте нормализацию партии . между сверткой и нелинейными операциями, такими как relu.
crossentropyОперация перекрестной энтропии вычисляет потери перекрестной энтропии между предсказаниями и целевыми значениями для задач классификации с одной меткой и с мультиметками.
crosschannelnormОперация межканальной нормализации использует локальные отклики в разных каналах, чтобы нормализовать каждую активацию. Межканальная нормализация обычно выполняется relu операция. Межканальная нормализация также известна как локальная нормализация отклика.
ctcОперация CTC вычисляет потери коннекционистской временной классификации (CTC) между выровненными последовательностями.
dlconvОперация свертки применяет скользящие фильтры к входным данным. Используйте dlconv функция для свертки глубокого обучения, сгруппированной свертки и разделяемой по каналам свертки.
dltranspconvОперация транспонированной свертки повышает качество карты функций.
embedОперация embed преобразует числовые индексы в числовые векторы, где индексы соответствуют дискретным данным. Используйте вложения для сопоставления дискретных данных, таких как категориальные значения или слова, с числовыми векторами.
fullyconnectОперация полного соединения умножает вход на весовую матрицу и затем добавляет вектор смещения.
groupnormОперация нормализации группы нормализует входные данные между сгруппированными подмножествами каналов для каждого наблюдения независимо. Чтобы ускорить обучение сверточной нейронной сети и уменьшить чувствительность к инициализации сети, используйте нормализацию группы между свертками и нелинейными операциями, такими как relu.
gruОперация стробируемого рекуррентного модуля (GRU) позволяет сети изучать зависимости между временными шагами во временных рядах и данными последовательности.
huberОперация Huber вычисляет потери Huber между сетевыми предсказаниями и целевыми значениями для задач регрессии. Когда 'TransitionPoint' опция равна 1, это также известно как гладкие потери L1.
instancenormОперация образца нормализации нормализует входные данные по каждому каналу для каждого наблюдения независимо. Чтобы улучшить сходимость настройки сверточной нейронной сети и уменьшить чувствительность к сетевым гиперпараметрам, используйте нормализацию образца между сверткой и нелинейными операциями, такими как relu.
layernormОперация нормализации слоя нормализует входные данные по всем каналам для каждого наблюдения независимо. Чтобы ускорить обучение рецидивирующих и многослойных нейронных сетей перцептрона и уменьшить чувствительность к инициализации сети, используйте нормализацию слоя после выученных операций, таких как LSTM и полностью соедините операции.
leakyreluОперация активации утечки выпрямленного линейного модуля (ReLU) выполняет нелинейную пороговую операцию, где любое входное значение, меньше нуля, умножается на фиксированный масштабный коэффициент.
lstmОперация долгой краткосрочной памяти (LSTM) позволяет сети изучать долгосрочные зависимости между временными шагами во временных рядах и данными последовательности.
maxpoolОперация максимального объединения выполняет понижающую дискретизацию, разделяя вход на области объединения и вычисляя максимальное значение каждой области.
maxunpoolМаксимальная операция отмены охлаждения отключает выход операции максимального объединения путем увеличения дискретизации и заполнения нулями.
mseОперация квадратичной невязки половинного среднего вычисляет половину средней потери квадратичной невязки между сетевыми предсказаниями и целевыми значениями для задач регрессии.
onehotdecode

Операция одноразового декодирования декодирует векторы вероятностей, такие как выход классификационной сети, в метки классификации.

Область входа A может быть dlarray. Если A форматирован, функция игнорирует формат данных.

reluОперация активации выпрямленного линейного модуля (ReLU) выполняет нелинейную пороговую операцию, где любое входное значение, меньше нуля, устанавливается в нуль.
sigmoidОперация активации сигмоида применяет функцию сигмоида к входным данным.
softmaxОперация активации softmax применяет функцию softmax к размерности канала входных данных.

Задайте функции потерь

Когда вы используете пользовательский цикл обучения, необходимо вычислить потери в функции градиентов модели. Используйте значение потерь при вычислении градиентов для обновления весов сети. Для вычисления потерь можно использовать следующие функции.

ФункцияОписание
softmaxОперация активации softmax применяет функцию softmax к размерности канала входных данных.
sigmoidОперация активации сигмоида применяет функцию сигмоида к входным данным.
crossentropyОперация перекрестной энтропии вычисляет потери перекрестной энтропии между предсказаниями и целевыми значениями для задач классификации с одной меткой и с мультиметками.
huberОперация Huber вычисляет потери Huber между сетевыми предсказаниями и целевыми значениями для задач регрессии. Когда 'TransitionPoint' опция равна 1, это также известно как гладкие потери L1.
mseОперация квадратичной невязки половинного среднего вычисляет половину средней потери квадратичной невязки между сетевыми предсказаниями и целевыми значениями для задач регрессии.
ctcОперация CTC вычисляет потери коннекционистской временной классификации (CTC) между выровненными последовательностями.

Также можно использовать пользовательскую функцию потерь путем создания функции формы loss = myLoss(Y,T), где Y и T соответствуют предсказаниям и целям, соответственно, и loss - возвращенная потеря.

Для примера, показывающего, как обучить генеративную состязательную сеть (GAN), которая генерирует изображения с помощью пользовательской функции потерь, смотрите Train Генеративной состязательной сети (GAN).

Обновляйте настраиваемые параметры с помощью автоматической дифференциации

Когда вы обучаете модель глубокого обучения с помощью пользовательского цикла обучения, программное обеспечение минимизирует потери относительно настраиваемых параметров. Чтобы минимизировать потери, программное обеспечение использует градиенты потерь относительно настраиваемых параметров. Чтобы вычислить эти градиенты с помощью автоматической дифференциации, необходимо задать функцию градиентов модели.

Задайте функцию градиентов модели

Для модели, заданной как dlnetwork создайте функцию формы gradients = modelGradients(dlnet,dlX,T), где dlnet является сетью, dlX - входной вход сети, T содержит цели и gradients содержит возвращенные градиенты. Опционально можно передать дополнительные аргументы в функцию gradients (для примера, если функция loss требует дополнительной информации), или вернуть дополнительные аргументы (для примера, метрики для графического изображения процесса обучения).

Для модели, заданной как функция, создайте функцию вида gradients = modelGradients(parameters,dlX,T), где parameters содержит настраиваемые параметры, dlX является входом модели, T содержит цели и gradients содержит возвращенные градиенты. Опционально можно передать дополнительные аргументы в функцию gradients (для примера, если функция loss требует дополнительной информации), или вернуть дополнительные аргументы (для примера, метрики для графического изображения процесса обучения).

Чтобы узнать больше об определении функций градиентов модели для пользовательских циклов обучения, смотрите Задать функцию градиентов модели для Пользовательского цикла обучения.

Обновление настраиваемых параметров

Чтобы вычислить градиенты модели с помощью автоматической дифференциации, используйте dlfeval функция, которая оценивает функцию с включенной автоматической дифференциацией. Для первого входа dlfeval, передайте функцию градиентов модели, заданную как указатель на функцию. Для следующих входов передайте необходимые переменные для функции градиентов модели. Для выходов dlfeval function, задайте те же выходы, что и функция градиентов модели.

Чтобы обновить настраиваемые параметры с помощью градиентов, можно использовать следующие функции.

ФункцияОписание
adamupdateОбновите параметры с помощью адаптивной оценки момента (Адам)
rmspropupdateОбновляйте параметры с помощью корневого среднего квадратного распространения (RMSProp)
sgdmupdateОбновляйте параметры с помощью стохастического градиентного спуска с импульсом (SGDM)
dlupdateОбновляйте параметры с помощью пользовательской функции

См. также

| | |

Похожие темы