Анализируйте эффективность неглубокой нейронной сети после обучения

Эта тема представляет часть типичного рабочего процесса неглубокой нейронной сети. Для получения дополнительной информации и других шагов см. Multilayer Shallow Neural Networks and Backpropagation Training. Чтобы узнать о том, как контролировать процесс обучения, смотрите Monitor Deep Learning Training Progress.

Когда обучение в Train and Apply Multilayer Shallow Neural Networks завершено, можно проверить эффективность сети и определить, нужно ли вносить какие-либо изменения в процесс обучения, сетевую архитектуру или наборы данных. Сначала проверьте обучающую запись, tr, который был вторым аргументом, возвращенным из функции обучения.

tr
tr = struct with fields:
        trainFcn: 'trainlm'
      trainParam: [1x1 struct]
      performFcn: 'mse'
    performParam: [1x1 struct]
        derivFcn: 'defaultderiv'
       divideFcn: 'dividerand'
      divideMode: 'sample'
     divideParam: [1x1 struct]
        trainInd: [1x176 double]
          valInd: [1x38 double]
         testInd: [1x38 double]
            stop: 'Validation stop.'
      num_epochs: 9
       trainMask: {[1x252 double]}
         valMask: {[1x252 double]}
        testMask: {[1x252 double]}
      best_epoch: 3
            goal: 0
          states: {1x8 cell}
           epoch: [0 1 2 3 4 5 6 7 8 9]
            time: [1x10 double]
            perf: [1x10 double]
           vperf: [1x10 double]
           tperf: [1x10 double]
              mu: [1x10 double]
        gradient: [1x10 double]
        val_fail: [0 0 0 0 1 2 3 4 5 6]
       best_perf: 12.3078
      best_vperf: 16.6857
      best_tperf: 24.1796

Эта структура содержит всю информацию, касающуюся обучения сети. Для примера, tr.trainInd, tr.valInd и tr.testInd содержат индексы точек данных, которые использовались в наборах обучения, валидации и тестах, соответственно. Если вы хотите переобучить сеть с помощью того же деления данных, можно задать net.divideFcn на 'divideInd', net.divideParam.trainInd на tr.trainInd, net.divideParam.valInd на tr.valInd, net.divideParam.testInd на tr.testInd.

The tr структура также отслеживает несколько переменных в ходе обучения, таких как значение функции эффективности, величина градиента и т.д. Можно использовать обучающую запись для построения графика прогресса эффективности с помощью plotperf команда:

plotperf(tr)

Figure Training Record contains an axes. The axes with title Performance is 6.3064 contains 4 objects of type line. These objects represent Test, Validation, Train.

Свойство tr.best_epoch указывает итерацию, при которой эффективность валидации достигла минимума. Обучение продолжалось ещё 6 итераций, прежде чем обучение прекратилось.

Этот рисунок не указывает на какие-либо серьезные проблемы с обучением. Кривые валидации и тестирования очень похожи. Если бы тестовая кривая значительно увеличилась до увеличения валидационной кривой, то возможно, что могло произойти некоторое сверхподбор кривой.

Следующим шагом в проверке сети является создание регрессионного графика, который показывает связь между выходами сети и целями. Если бы обучение было идеальным, выходы сети и целевые показатели были бы точно равными, но отношения редко совершенны на практике. Для примера жира тела мы можем создать регрессионный график с помощью следующих команд. Первая команда вычисляет обученную сетевую характеристику на все входы в наборе данных. Следующие шесть команд извлекают выходы и цели, которые относятся к подмножествам обучения, валидации и тестирования. Итоговая команда создает три регрессионных графика для обучения, проверки и валидации.

bodyfatOutputs = net(bodyfatInputs);
trOut = bodyfatOutputs(tr.trainInd);
vOut = bodyfatOutputs(tr.valInd);
tsOut = bodyfatOutputs(tr.testInd);
trTarg = bodyfatTargets(tr.trainInd);
vTarg = bodyfatTargets(tr.valInd);
tsTarg = bodyfatTargets(tr.testInd);
plotregression(trTarg, trOut, 'Train', vTarg, vOut, 'Validation', tsTarg, tsOut, 'Testing')

Figure Regression (plotregression) contains 3 axes. Axes 1 with title Train: R=0.91107 contains 3 objects of type line. These objects represent Y = T, Fit, Data. Axes 2 with title Validation: R=0.8456 contains 3 objects of type line. These objects represent Y = T, Fit, Data. Axes 3 with title Testing: R=0.87068 contains 3 objects of type line. These objects represent Y = T, Fit, Data.

Три графика представляют данные обучения, валидации и проверки. Штриховая линия на каждом графике представляет идеальный результат - выходы = цели. Сплошная линия представляет оптимальную линейную регрессионую линию между выходами и целями. Значение R является показателем связи между выходами и целями. Если R = 1, это указывает, что существует точная линейная связь между выходами и целями. Если R близок к нулю, то нет линейной зависимости между выходами и целями.

В данном примере обучающие данные указывают на хорошую подгонку. Результаты валидации и тестирования также показывают большие значения R. Этот график поля точек помогает показать, что некоторые точки данных имеют плохую подгонку. Для примера в тестовом наборе есть точка данных, выход сети которой близок к 35, в то время как соответствующее целевое значение составляет около 12. Следующим шагом будет исследование этой точки данных, чтобы определить, представляет ли она экстраполяцию (т.е. находится ли она вне обучающего набора обучающих данных). Если это так, то он должен быть включен в набор обучающих данных, и должны быть собраны дополнительные данные для использования в тестовом наборе.

Улучшение результатов

Если сеть недостаточно точна, можно попробовать снова инициализировать сеть и обучение. Каждый раз, когда вы инициализируете сеть прямого распространения, параметры сети различны и могут привести к различным решениям.

net = init(net);
net = train(net, bodyfatInputs, bodyfatTargets);

В качестве второго подхода можно увеличить количество скрытых нейронов выше 20. Большее количество нейронов в скрытом слое дает сети большую гибкость, потому что сеть имеет больше параметров, которые она может оптимизировать. (Постепенно увеличивайте размер слоя. Если вы сделаете скрытый слой слишком большим, вы можете вызвать недостаточную характеристику проблемы, и сеть должна оптимизировать больше параметров, чем есть векторы данных, чтобы ограничить эти параметры.)

Третья опция - попробовать другую функцию обучения. Байесовское обучение регуляризации с trainbrдля примера иногда может привести к лучшим возможностям обобщения, чем использование ранней остановки.

Наконец, попробуйте использовать дополнительные обучающие данные. Предоставление дополнительных данных для сети с большей вероятностью создаст сеть, которая хорошо обобщает новые данные.