exponenta event banner

Анализ неглубокой производительности нейронной сети после обучения

В этом разделе представлена часть типичного неглубокого рабочего процесса нейронной сети. Дополнительные сведения и другие шаги см. в разделе Многоуровневые неглубокие нейронные сети и обучение обратному распространению. Сведения о том, как контролировать ход обучения в глубоких условиях обучения, см. в разделе Мониторинг хода обучения в глубоких условиях обучения.

После завершения обучения по программе Train and Apply Multilayer Shallow Neural Networks можно проверить производительность сети и определить необходимость внесения каких-либо изменений в процесс обучения, архитектуру сети или наборы данных. Сначала проверьте запись обучения, tr, который был вторым аргументом, возвращенным из обучающей функции.

tr
tr = struct with fields:
        trainFcn: 'trainlm'
      trainParam: [1x1 struct]
      performFcn: 'mse'
    performParam: [1x1 struct]
        derivFcn: 'defaultderiv'
       divideFcn: 'dividerand'
      divideMode: 'sample'
     divideParam: [1x1 struct]
        trainInd: [1x176 double]
          valInd: [1x38 double]
         testInd: [1x38 double]
            stop: 'Validation stop.'
      num_epochs: 9
       trainMask: {[1x252 double]}
         valMask: {[1x252 double]}
        testMask: {[1x252 double]}
      best_epoch: 3
            goal: 0
          states: {1x8 cell}
           epoch: [0 1 2 3 4 5 6 7 8 9]
            time: [1x10 double]
            perf: [1x10 double]
           vperf: [1x10 double]
           tperf: [1x10 double]
              mu: [1x10 double]
        gradient: [1x10 double]
        val_fail: [0 0 0 0 1 2 3 4 5 6]
       best_perf: 12.3078
      best_vperf: 16.6857
      best_tperf: 24.1796

Эта структура содержит всю информацию, касающуюся обучения сети. Например, tr.trainInd, tr.valInd и tr.testInd содержат индексы точек данных, которые использовались в обучающих, валидационных и тестовых наборах соответственно. При необходимости переподготовки сети с использованием того же разделения данных можно установить net.divideFcn кому 'divideInd', net.divideParam.trainInd кому tr.trainInd, net.divideParam.valInd кому tr.valInd, net.divideParam.testInd кому tr.testInd.

tr структура также отслеживает несколько переменных во время обучения, таких как значение функции производительности, величина градиента и т.д. Запись обучения можно использовать для отображения хода выполнения с помощью plotperf команда:

plotperf(tr)

Figure Training Record contains an axes. The axes with title Performance is 6.3064 contains 4 objects of type line. These objects represent Test, Validation, Train.

Собственность tr.best_epoch указывает итерацию, при которой производительность проверки достигла минимума. Обучение продолжалось ещё 6 итераций, прежде чем обучение прекратилось.

Эта цифра не указывает на какие-либо серьезные проблемы с обучением. Кривые проверки и тестирования очень похожи. Если тестовая кривая значительно увеличилась до увеличения валидационной кривой, то возможно, что может произойти некоторое переоборудование.

Следующим шагом при проверке сети является создание регрессионного графика, который показывает взаимосвязь между выходами сети и целевыми объектами. Если бы обучение было идеальным, сетевые выходы и цели были бы точно равными, но на практике отношения редко бывают идеальными. Для примера телесного жира можно создать график регрессии с помощью следующих команд. Первая команда вычисляет ответ обученной сети на все входы в наборе данных. Следующие шесть команд извлекают выходные данные и целевые значения, относящиеся к подмножествам обучения, проверки и тестирования. Окончательная команда создает три графика регрессии для обучения, тестирования и проверки.

bodyfatOutputs = net(bodyfatInputs);
trOut = bodyfatOutputs(tr.trainInd);
vOut = bodyfatOutputs(tr.valInd);
tsOut = bodyfatOutputs(tr.testInd);
trTarg = bodyfatTargets(tr.trainInd);
vTarg = bodyfatTargets(tr.valInd);
tsTarg = bodyfatTargets(tr.testInd);
plotregression(trTarg, trOut, 'Train', vTarg, vOut, 'Validation', tsTarg, tsOut, 'Testing')

Figure Regression (plotregression) contains 3 axes. Axes 1 with title Train: R=0.91107 contains 3 objects of type line. These objects represent Y = T, Fit, Data. Axes 2 with title Validation: R=0.8456 contains 3 objects of type line. These objects represent Y = T, Fit, Data. Axes 3 with title Testing: R=0.87068 contains 3 objects of type line. These objects represent Y = T, Fit, Data.

Три графика представляют данные обучения, проверки и тестирования. Пунктирная линия на каждом графике представляет идеальный результат - выходы = цели. Сплошная линия представляет линию линейной регрессии наилучшего соответствия между выходами и целями. Значение R является показателем взаимосвязи между выходами и целями. Если R = 1, это указывает на наличие точной линейной зависимости между выходами и целями. Если R близок к нулю, то нет линейной зависимости между выходами и целями.

В этом примере данные обучения указывают на хорошую подгонку. Результаты проверки и тестирования также показывают большие значения R. График рассеяния помогает показать, что некоторые точки данных имеют плохие соответствия. Например, в тестовом наборе имеется точка данных, сетевой выход которой близок к 35, в то время как соответствующее целевое значение составляет около 12. Следующим шагом будет исследование этой точки данных, чтобы определить, представляет ли она экстраполяцию (т.е. находится ли она вне набора обучающих данных). Если это так, то он должен быть включен в обучающий набор, и должны быть собраны дополнительные данные для использования в тестовом наборе.

Улучшение результатов

Если сеть недостаточно точна, можно повторить попытку инициализации сети и обучения. Каждый раз, когда вы инициализируете сеть прямой связи, параметры сети различны и могут создавать различные решения.

net = init(net);
net = train(net, bodyfatInputs, bodyfatTargets);

В качестве второго подхода можно увеличить количество скрытых нейронов выше 20. Большее количество нейронов в скрытом слое дает сети больше гибкости, потому что сеть имеет больше параметров, которые она может оптимизировать. (Постепенно увеличивайте размер слоя. Если сделать скрытый слой слишком большим, проблема может быть недостаточно охарактеризована, и сеть должна оптимизировать больше параметров, чем существуют векторы данных для ограничения этих параметров.)

Третий вариант - попробовать другую функцию обучения. Обучение байесовской регуляризации с trainbrнапример, иногда может обеспечить лучшую способность к обобщению, чем использование ранней остановки.

Наконец, попробуйте использовать дополнительные данные обучения. Предоставление дополнительных данных для сети, скорее всего, приведет к созданию сети, которая хорошо обобщает новые данные.