Соберите несколько - Выходная сеть для предсказания

В этом примере показано, как собрать кратное выходная сеть для предсказания.

Вместо того, чтобы использовать dlnetwork объект для предсказания, можно собрать сеть в DAGNetwork готовый к предсказанию с помощью assembleNetwork функция. Это позволяет вам использовать predict функция с другими типами данных, такими как хранилища данных.

Загрузите функцию модели и параметры

Загрузите параметры модели из файла MAT dlnetDigits.mat. Файл MAT содержит dlnetwork объект, который предсказывает и музыку к категориальным меткам и числовые углы вращения изображений цифр и соответствующие имена классов.

s = load("dlnetDigits.mat");
dlnet = s.dlnet;
classNames = s.classNames;

Соберите сеть для предсказания

Извлеките график слоев из dlnetwork объект с помощью layerGraph функция.

lgraph = layerGraph(dlnet);

График слоев не включает выходные слои. Добавьте слой классификации и слой регрессии к графику слоев с помощью addLayers и connectLayers функции.

layers = classificationLayer('Classes',classNames,'Name','coutput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'softmax','coutput');

layers = regressionLayer('Name','routput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'fc2','routput');

Просмотрите график сети.

figure
plot(lgraph)

Figure contains an axes. The axes contains an object of type graphplot.

Соберите сеть с помощью assembleNetwork функция.

net = assembleNetwork(lgraph)
net = 
  DAGNetwork with properties:

         Layers: [19x1 nnet.cnn.layer.Layer]
    Connections: [19x2 table]
     InputNames: {'in'}
    OutputNames: {'coutput'  'routput'}

Сделайте предсказания на новых данных

Загрузите тестовые данные.

[XTest,Y1Test,Y2Test] = digitTest4DArrayData;

Чтобы сделать предсказания с помощью собранной сети, используйте predict функция. Чтобы возвратить категориальные метки для классификации выход, установите 'ReturnCategorical' опция к true.

[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);

Оцените точность классификации.

accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9870

Оцените точность регрессии.

angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
    6.0091

Просмотрите некоторые изображения с их предсказаниями. Отобразите предсказанные углы красного цвета и правильные метки зеленого цвета.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = Y2Pred(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    
    thetaValidation = Y2Test(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--')
    
    hold off
    label = string(Y1Pred(idx(i)));
    title("Label: " + label)
end

Figure contains 9 axes. Axes 1 with title Label: 8 contains 3 objects of type image, line. Axes 2 with title Label: 9 contains 3 objects of type image, line. Axes 3 with title Label: 1 contains 3 objects of type image, line. Axes 4 with title Label: 9 contains 3 objects of type image, line. Axes 5 with title Label: 6 contains 3 objects of type image, line. Axes 6 with title Label: 0 contains 3 objects of type image, line. Axes 7 with title Label: 2 contains 3 objects of type image, line. Axes 8 with title Label: 5 contains 3 objects of type image, line. Axes 9 with title Label: 9 contains 3 objects of type image, line.

Смотрите также

| | | | | |

Похожие темы