train

Обучите неглубокую нейронную сеть

Описание

Эта функция обучает мелкую нейронную сеть. Для глубокого обучения со сверточными или LSTM нейронными сетями смотрите trainNetwork вместо этого.

пример

trainedNet = train(net,X,T,Xi,Ai,EW) обучает сетевую net согласно net.trainFcn и net.trainParam.

[trainedNet,tr] = train(net,X,T,Xi,Ai,EW) также возвращает обучающую запись.

пример

[trainedNet,tr] = train(net,X,T,Xi,Ai,EW,Name,Value) обучает сеть с дополнительными опциями, заданными одним или несколькими аргументами пары "имя-значение".

Примеры

свернуть все

Здесь введите x и целевые t задайте простую функцию, которую можно построить на графике:

x = [0 1 2 3 4 5 6 7 8];
t = [0 0.84 0.91 0.14 -0.77 -0.96 -0.28 0.66 0.99];
plot(x,t,'o')

Вот feedforwardnet создает двухслойную сеть прямого распространения. В сети есть один скрытый слой с десятью нейронами.

net = feedforwardnet(10);
net = configure(net,x,t);
y1 = net(x)
plot(x,t,'o',x,y1,'x')

Сеть обучается и затем повторяется.

net = train(net,x,t);
y2 = net(x)
plot(x,t,'o',x,y1,'x',x,y2,'*')

Этот пример обучает разомкнутой нелинейно-авторегрессивной сети с внешним входом, чтобы смоделировать левитированную магнитную систему, заданную управляющим током x и ответ вертикального положения магнита t, затем моделирует сеть. Функция preparets подготавливает данные перед обучением и симуляцией. Он создает объединенные входы разомкнутого контура сети xo, который содержит оба внешних входных x и предыдущие значения позиционных t. Он также подготавливает состояния задержки xi.

[x,t] = maglev_dataset;
net = narxnet(10);
[xo,xi,~,to] = preparets(net,x,{},t);
net = train(net,xo,to,xi);
y = net(xo,xi)

Эта же система также может быть моделирована в форме с обратной связью.

netc = closeloop(net);
view(netc)
[xc,xi,ai,tc] = preparets(netc,x,{},t);
yc = netc(xc,xi,ai);

Parallel Computing Toolbox™ позволяет Deep Learning Toolbox™ моделировать и обучать сети быстрее и на больших наборах данных, чем это возможно на одном ПК. Параллельное обучение в настоящее время поддерживается только для обучения backpropagation, а не для самоорганизующихся карт.

Здесь обучение и симуляция происходят между параллельными работниками MATLAB.

parpool
[X,T] = vinyl_dataset;
net = feedforwardnet(10);
net = train(net,X,T,'useParallel','yes','showResources','yes');
Y = net(X);

Используйте Композитные значения, чтобы распределить данные вручную и получить результаты как Составные значения. Если данные загружаются по мере их распределения, то в то время как каждый фрагмент набора данных должен помещаться в ОЗУ, набор данных в целом ограничивается только общей ОЗУ всех работников.

[X,T] = vinyl_dataset;
Q = size(X,2);
Xc = Composite;
Tc = Composite;
numWorkers = numel(Xc);
ind = [0 ceil((1:numWorkers)*(Q/numWorkers))];
for i=1:numWorkers
    indi = (ind(i)+1):ind(i+1);
    Xc{i} = X(:,indi);
    Tc{i} = T(:,indi);
end
net = feedforwardnet;
net = configure(net,X,T);
net = train(net,Xc,Tc);
Yc = net(Xc);

Обратите внимание, что в примере выше конфигурация функции использовалась, чтобы задать размерности и настройки обработки входов сети. Обычно это происходит автоматически при вызове train, но при предоставлении составных данных этот шаг должен выполняться вручную с данными, отличными от составных.

Обучать сети можно с помощью текущего графического процессора, если оно поддерживается Parallel Computing Toolbox. Обучение графический процессор в настоящее время поддерживается только для обучения backpropagation, а не для самоорганизующихся карт.

[X,T] = vinyl_dataset;
net = feedforwardnet(10);
net = train(net,X,T,'useGPU','yes');
y = net(X); 

Чтобы поместить данные на графический процессор вручную:

[X,T] = vinyl_dataset;
Xgpu = gpuArray(X);
Tgpu = gpuArray(T);
net = configure(net,X,T);
net = train(net,Xgpu,Tgpu);
Ygpu = net(Xgpu);
Y = gather(Ygpu); 

Обратите внимание, что в примере выше конфигурация функции использовалась, чтобы задать размерности и настройки обработки входов сети. Обычно это происходит автоматически при вызове train, но при предоставлении данных gpuArray этот шаг должен выполняться вручную с данными, отличными от gpuArray.

Чтобы выполнять параллельно, каждый рабочий процесс назначается другому уникальному графическому процессору с дополнительными рабочими процессорами, работающими на центральном процессоре:

net = train(net,X,T,'useParallel','yes','useGPU','yes');
y = net(X);

Использование только рабочих процессов с уникальными графическими процессорами может привести к более высокой скорости, поскольку центральный процессор рабочие могут не успевать.

net = train(net,X,T,'useParallel','yes','useGPU','only');
Y = net(X);

Здесь сеть обучается с контрольными точками, сохраненными со скоростью не более одного раза в две минуты.

[x,t] = vinyl_dataset;
net = fitnet([60 30]);
net = train(net,x,t,'CheckpointFile','MyCheckpoint','CheckpointDelay',120);

После сбоя компьютера последнюю сеть можно восстановить и использовать, чтобы продолжить обучение с точки отказа. Файл контрольной точки включает структурную переменную checkpoint, который включает в себя сеть, обучающую запись, имя файла, время и номер.

[x,t] = vinyl_dataset;
load MyCheckpoint
net = checkpoint.net;
net = train(net,x,t,'CheckpointFile','MyCheckpoint');

Другое использование функции контрольной точки - когда вы останавливаете параллельный сеанс обучения (начинается с 'UseParallel' параметр), хотя Neural Network Training Tool недоступен во время параллельного обучения. В этом случае установите 'CheckpointFile', используйте Ctrl + C, чтобы остановить обучение в любое время, а затем загрузите файл контрольной точки, чтобы получить сетевую и обучающую запись.

Входные параметры

свернуть все

Входная сеть, заданная как network объект. Как создать network объект, использование для примера, feedforwardnet или narxnet.

Входы сети, заданные как R-by- Q матрица или Ni-by- TS массив ячеек, где

  • R - размер входа

  • Q - размер пакета;

  • Ni = net.numInputs

  • TS количество временных шагов

train аргументы могут иметь два формата: матрицы для статических задач и сетей с одинарными входами и выходами и массивы ячеек для нескольких временных интервалов и сетей с несколькими входами и выходами.

  • Матричный формат может использоваться, если необходимо моделировать только один временной шаг (TS = 1). Он удобен для сетей только с одним входом и выходом, но может использоваться с сетями, которые имеют больше. Когда сеть имеет несколько входов, размер матрицы (сумма Ri) -by- Q.

  • Формат массива ячеек является более общим и более удобным для сетей с несколькими входами и выходами, позволяя представлять последовательности входов. Каждый элемент X{i,ts} является Ri-by- Q матрица, где Ri = net.inputs{i}.size.

Если используются составные данные, то 'useParallel' автоматически устанавливается на 'yes'. Функция принимает Композитные данные и возвращает Композитные результаты.

Если используются данные gpuArray, то 'useGPU' автоматически устанавливается на 'yes'. Функция принимает данные gpuArray и возвращает результаты gpuArray

Примечание

Если столбец X содержит хотя бы один NaN, train не использует этот столбец для обучения, проверки или валидации. Если целевое значение в T является NaN, затем train игнорирует эту строку и использует другие строки для обучения, проверки или валидации.

Сетевые цели, заданные как U-by- Q матрица или No-by- TS массив ячеек, где

  • U - размер выхода

  • Q - размер пакета;

  • No = net.numOutputs

  • TS количество временных шагов

train аргументы могут иметь два формата: матрицы для статических задач и сетей с одинарными входами и выходами и массивы ячеек для нескольких временных интервалов и сетей с несколькими входами и выходами.

  • Матричный формат может использоваться, если необходимо моделировать только один временной шаг (TS = 1). Он удобен для сетей только с одним входом и выходом, но может использоваться с сетями, которые имеют больше. Когда сеть имеет несколько входов, размер матрицы (сумма Ui) -by- Q.

  • Формат массива ячеек является более общим и более удобным для сетей с несколькими входами и выходами, позволяя представлять последовательности входов. Каждый элемент T{i,ts} является Ui-by- Q матрица, где Ui = net.outputs{i}.size.

Если используются составные данные, то 'useParallel' автоматически устанавливается на 'yes'. Функция принимает Композитные данные и возвращает Композитные результаты.

Если используются данные gpuArray, то 'useGPU' автоматически устанавливается на 'yes'. Функция принимает данные gpuArray и возвращает результаты gpuArray

Обратите внимание, что T является необязательным и должен использоваться только для сетей, которые требуют целевых объектов.

Примечание

Любой NaN значения во входах X или цели T, рассматриваются как отсутствующие данные. Если столбец X или T содержит, по крайней мере, один NaN, этот столбец не используется для обучения, проверки или валидации.

Начальные условия задержки входа, заданные как Ni-by- ID массив ячеек или R-by- (ID*Q) матрица, где

  • ID = net.numInputDelays

  • Ni = net.numInputs

  • R - размер входа

  • Q - размер пакета;

Для входа массива ячеек, столбцы Xi заказываются из самого старого условия задержки в самое последнее: Xi{i,k} является ли вход i во время ts = k - ID.

Xi также является необязательным и должен использоваться только для сетей с задержками входного сигнала или слоя.

Начальные условия задержки слоя, заданные как Nl-by- LD массив ячеек или a (сумма Si) -by- (LD*Q) матрица, где

  • Nl = net.numLayers

  • LD = net.numLayerDelays

  • Si = net.layers{i}.size

  • Q - размер пакета;

Для входа массива ячеек, столбцы Ai заказываются из самого старого условия задержки в самое последнее: Ai{i,k} - выходной параметр слоя i во время ts = k - LD.

Веса ошибок, заданные как No-by- TS массив ячеек или a (сумма Ui) -by- Q матрица, где

  • No = net.numOutputs

  • TS количество временных шагов

  • Ui = net.outputs{i}.size

  • Q - размер пакета;

Для входов массива ячеек. каждый элемент EW{i,ts} является Ui-by- Q матрица, где

  • Ui = net.outputs{i}.size

  • Q - размер пакета;

Веса ошибок EW может также иметь размер 1 вместо всего или любого из No, TS, Ui или Q. В этом случае EW автоматически расширяется размер, чтобы соответствовать целям T. Это позволяет удобно взвешивать важность в любой размерности (такой как по выборке), имея равную важность для другой (такой как время, с TS=1). Если все размерности равны 1, например, если EW = {1}затем все целевые значения обрабатываются с одинаковой важностью. Это значение по умолчанию EW.

Как отмечалось выше, веса ошибок EW может иметь те же размерности, что и цели T, или иметь некоторые размерности, установленные на 1. Для образца, если EW 1-by- Q, тогда целевые выборки будут иметь разный импорт, но каждый элемент в выборке будет иметь ту же важность. Если EW is (сумма Ui) -by-1, тогда каждый выходной элемент имеет разную важность, со всеми выборками, обработанными с одинаковой важностью.

Аргументы в виде пар имя-значение

Задайте необязательные разделенные разделенными запятой парами Name,Value аргументы. Name - имя аргумента и Value - соответствующее значение. Name должны находиться внутри кавычек. Можно задать несколько аргументов в виде пар имен и значений в любом порядке Name1,Value1,...,NameN,ValueN.

Пример: 'useParallel','yes'

Опция для задания параллельных вычислений, заданная как 'yes' или 'no'.

  • 'no' - Вычисления выполняются на обычном потоке MATLAB. Это значение по умолчанию 'useParallel' настройка.

  • 'yes' - Вычисления выполняются параллельными рабочими, если открыт параллельный пул. В противном случае вычисления происходят на нормальном MATLAB® поток.

Опция для задания вычислений графический процессор, заданная как 'yes', 'no', или 'only'.

  • 'no' - Расчеты выполняются на центральный процессор. Это значение по умолчанию 'useGPU' настройка.

  • 'yes' - Вычисления происходят на текущей gpuDevice если это поддерживаемый графический процессор (См. Parallel Computing Toolbox для требований к графическому процессору). Если текущий gpuDevice не поддерживается, вычисления остаются на центральный процессор. Если 'useParallel' также 'yes' и параллельный пул открыт, затем каждый рабочий процесс с уникальным графическим процессором использует этот графический процессор, другие рабочие выполняют вычисления на своих соответствующих центральных процессорах ядрах.

  • 'only' - Если параллельный пул не открыт, этот параметр аналогичен 'yes'. Если параллельный пул открыт, используются только работники с уникальными графическими процессорами. Однако, если параллельный пул открыт, но поддерживаемые графические процессоры недоступны, вычисления возвращаются к выполнению на всех рабочих центральных процессорах.

Опция для отображения ресурсов, заданная как 'yes' или 'no'.

  • 'no' - Не отображать вычислительные ресурсы, используемые в командной строке. Это значение по умолчанию.

  • 'yes' - Показать в командной строке сводные данные фактически используемых вычислительных ресурсов. Фактические ресурсы могут отличаться от запрашиваемых ресурсов, если запрашиваются параллельные вычисления или вычисления на графическом процессоре, но параллельный пул не открыт или поддерживаемый графический процессор недоступен. Когда используются параллельные работники, описывается режим расчетов каждого работника, включая работников пула, которые не используются.

Сокращение памяти, заданное как положительное целое число.

Для большинства нейронных сетей режимом расчета центрального процессора по умолчанию является скомпилированный алгоритм MEX. Однако для больших сетей вычисления могут происходить в режиме вычисления MATLAB. Это может быть подтверждено с помощью 'showResources'. Если MATLAB используется, а память является проблемой, установка значения опции сокращения на N, больше 1, уменьшает большую часть временного хранилища, необходимого для обучения в N раз, в обмен на более длительное время обучения.

Файл контрольной точки, заданный как вектор символов.

Значение для 'CheckpointFile' можно задать имя файла для сохранения в текущей рабочей папке, путь к файлу в другой папке или пустую строку для отключения сохранений контрольных точек (значение по умолчанию).

Задержка контрольной точки, заданная как неотрицательное целое число.

Опциональный параметр 'CheckpointDelay' пределы, как часто происходит сохранение. Ограничение частоты контрольных точек может улучшить эффективность путем поддержания количества времени сохранения контрольных точек низким по сравнению со временем, потраченным в вычислениях. Оно имеет значение по умолчанию 60, что означает, что сохранение контрольных точек не происходит более одного раза в минуту. Установите значение 'CheckpointDelay' 0, если необходимо, чтобы сохранение контрольных точек происходило только один раз в каждую эпоху.

Выходные аргументы

свернуть все

Обученная сеть, возвращается как network объект.

Обучающая запись (epoch и perf), возвращается как структура, поля которой зависят от функции сетевого обучения (net.NET.trainFcn). Он может включать такие поля, как:

  • Обучение, деление данных и эффективность функции и параметры

  • Индексы деления данных для наборов для обучения, валидации и тестирования

  • Маски деления данных для валидации обучения и тестирования наборов

  • Количество эпох (num_epochs) и лучшая эпоха (best_epoch).

  • Список имен состояний обучения (states).

  • Поля для каждого имени состояния, регистрирующие его значение на протяжении всего обучения

  • Выступления лучшей сети (best_perf, best_vperf, best_tperf)

Алгоритмы

train вызывает функцию, обозначенную net.trainFcn, с использованием значений параметров, обозначенных net.trainParam.

Обычно одна эпоха обучения определяется как одно представление всех входных векторов в сеть. Затем сеть обновляется в соответствии с результатами всех этих презентаций.

Обучение происходит до тех пор, пока не произойдет максимальное количество эпох, не будет достигнута эффективность цель или любое другое условие остановки функции net.trainFcn происходит.

Некоторые функции обучения отходят от этой нормы, представляя только один входной вектор (или последовательность) каждую эпоху. Вектор входа (или последовательность) выбирается случайным образом для каждой эпохи из параллельных векторов входа (или последовательностей). competlayer возвращает сети, которые используют trainru, обучающая функция, которая делает это.

См. также

| | |

Представлено до R2006a