Классификация PNN

Этот пример использует функции NEWPNN и SIM.

Вот три двухэлементных входных вектора X и связанных с ними классов Tc. Мы хотели бы создать y вероятностную нейронную сеть, которая классифицирует эти векторы правильно.

X = [1 2; 2 2; 1 1]';
Tc = [1 2 3];
plot(X(1,:),X(2,:),'.','markersize',30)
for i = 1:3, text(X(1,i)+0.1,X(2,i),sprintf('class %g',Tc(i))), end
axis([0 3 0 3])
title('Three vectors and their classes.')
xlabel('X(1,:)')
ylabel('X(2,:)')

Figure contains an axes. The axes with title Three vectors and their classes. contains 4 objects of type line, text.

Сначала мы преобразуем индексы целевого класса Tc в векторы T. Затем мы проектируем y вероятностную нейронную сеть с помощью NEWPNN. Мы используем y SPREAD значение 1, потому что это y типичное расстояние между входными векторами.

T = ind2vec(Tc);
spread = 1;
net = newpnn(X,T,spread);

Теперь тестируем сеть на проектируемых входных векторах. Мы делаем это, симулируя сеть и преобразуя ее векторные выходы в индексы.

Y = net(X);
Yc = vec2ind(Y);
plot(X(1,:),X(2,:),'.','markersize',30)
axis([0 3 0 3])
for i = 1:3,text(X(1,i)+0.1,X(2,i),sprintf('class %g',Yc(i))),end
title('Testing the network.')
xlabel('X(1,:)')
ylabel('X(2,:)')

Figure contains an axes. The axes with title Testing the network. contains 4 objects of type line, text.

Давайте классифицируем y нового вектора с нашей сетью.

x = [2; 1.5];
y = net(x);
ac = vec2ind(y);
hold on
plot(x(1),x(2),'.','markersize',30,'color',[1 0 0])
text(x(1)+0.1,x(2),sprintf('class %g',ac))
hold off
title('Classifying y new vector.')
xlabel('X(1,:) and x(1)')
ylabel('X(2,:) and x(2)')

Figure contains an axes. The axes with title Classifying y new vector. contains 6 objects of type line, text.

Эта схема показывает, как вероятностная нейронная сеть делит входное пространство на три класса.

x1 = 0:.05:3;
x2 = x1;
[X1,X2] = meshgrid(x1,x2);
xx = [X1(:) X2(:)]';
yy = net(xx);
yy = full(yy);
m = mesh(X1,X2,reshape(yy(1,:),length(x1),length(x2)));
m.FaceColor = [0 0.5 1];
m.LineStyle = 'none';
hold on
m = mesh(X1,X2,reshape(yy(2,:),length(x1),length(x2)));
m.FaceColor = [0 1.0 0.5];
m.LineStyle = 'none';
m = mesh(X1,X2,reshape(yy(3,:),length(x1),length(x2)));
m.FaceColor = [0.5 0 1];
m.LineStyle = 'none';
plot3(X(1,:),X(2,:),[1 1 1]+0.1,'.','markersize',30)
plot3(x(1),x(2),1.1,'.','markersize',30,'color',[1 0 0])
hold off
view(2)
title('The three classes.')
xlabel('X(1,:) and x(1)')
ylabel('X(2,:) and x(2)')

Figure contains an axes. The axes with title The three classes. contains 5 objects of type surface, line.