Нормализуйте ошибки нескольких выходов

Наиболее распространенной функцией эффективности, используемой для обучения нейронных сетей, является средняя квадратичная невязка (mse). Однако с несколькими выходами, которые имеют различные области значений, обучение со средней квадратичной невязкой имеет тенденцию оптимизировать точность на выходном элементе с более широким диапазоном значений относительно выходного элемента с меньшей областью значений.

Например, здесь два целевых элемента имеют очень разные области значений:

x = -1:0.01:1;
t1 = 100*sin(x);
t2 = 0.01*cos(x);
t = [t1; t2];

Область области значений t1 составляет 200 (от минимум -100 до максимум 100), в то время как область значений t2 всего 0,02 (от -0,01 до 0,01). Область области значений t1 в 10 000 раз больше, чем область значений t2.

Если вы создаете и обучаете нейронную сеть на этом, чтобы минимизировать среднюю квадратичную невязку, обучение способствует относительной точности первого выходного элемента по сравнению со вторым.

net = feedforwardnet(5);
net1 = train(net,x,t);
y = net1(x);

Здесь можно увидеть, что сеть научилась очень хорошо соответствовать первому выходному элементу.

figure(1)
plot(x,y(1,:),x,t(1,:))

Figure contains an axes. The axes contains 2 objects of type line.

Однако функция второго элемента не подходит почти так же.

figure(2)
plot(x,y(2,:),x,t(2,:))

Figure contains an axes. The axes contains 2 objects of type line.

Чтобы одинаково хорошо подогнать оба выходных элемента в относительном смысле, установите normalization параметр эффективности в 'standard'. Это затем вычисляет ошибки для показателей эффективности, как если бы каждый выходной элемент имел область значений 2 (то есть, как если бы значения каждого выходного элемента варьировались от -1 до 1, вместо их различных областей значений).

net.performParam.normalization = 'standard';
net2 = train(net,x,t);
y = net2(x);

Теперь оба элемента выхода подгонки хорошо.

figure(3)
plot(x,y(1,:),x,t(1,:))

Figure contains an axes. The axes contains 2 objects of type line.

figure(4)
plot(x,y(2,:),x,t(2,:))

Figure contains an axes. The axes contains 2 objects of type line.

Для просмотра документации необходимо авторизоваться на сайте