Сборка сети с несколькими выходами для предсказания

Этот пример показывает, как собрать несколько выходных сетей для предсказания.

Вместо использования dlnetwork объект для предсказания, можно собрать сеть в DAGNetwork готов к предсказанию с помощью assembleNetwork функция. Это позволяет вам использовать predict функция с другими типами данных, такими как datastores.

Загрузка функции модели и параметров

Загрузите параметры модели из файла MAT dlnetDigits.mat. Файл MAT содержит dlnetwork объект, который предсказывает как счета для категориальных меток, так и числовые углы поворота изображений цифр, и соответствующие имена классов.

s = load("dlnetDigits.mat");
dlnet = s.dlnet;
classNames = s.classNames;

Сборка сети для предсказания

Извлеките график слоев из dlnetwork объект с использованием layerGraph функция.

lgraph = layerGraph(dlnet);

График слоев не включает выходные слои. Добавьте слой классификации и слой регрессии к графику слоев с помощью addLayers и connectLayers функций.

layers = classificationLayer('Classes',classNames,'Name','coutput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'softmax','coutput');

layers = regressionLayer('Name','routput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'fc2','routput');

Просмотр графика сети.

figure
plot(lgraph)

Figure contains an axes. The axes contains an object of type graphplot.

Соберите сеть с помощью assembleNetwork функция.

net = assembleNetwork(lgraph)
net = 
  DAGNetwork with properties:

         Layers: [19x1 nnet.cnn.layer.Layer]
    Connections: [19x2 table]
     InputNames: {'in'}
    OutputNames: {'coutput'  'routput'}

Делайте предсказания на новых данных

Загрузите тестовые данные.

[XTest,Y1Test,Y2Test] = digitTest4DArrayData;

Чтобы делать предсказания с помощью собранной сети, используйте predict функция. Чтобы вернуть категориальные метки для выхода классификации, установите 'ReturnCategorical' опция для true.

[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);

Оцените точность классификации.

accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9870

Оцените точность регрессии.

angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
    6.0091

Просмотрите некоторые изображения с их предсказаниями. Отображение прогнозируемых углов красного цвета и правильных меток зеленого цвета.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = Y2Pred(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    
    thetaValidation = Y2Test(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--')
    
    hold off
    label = string(Y1Pred(idx(i)));
    title("Label: " + label)
end

Figure contains 9 axes. Axes 1 with title Label: 8 contains 3 objects of type image, line. Axes 2 with title Label: 9 contains 3 objects of type image, line. Axes 3 with title Label: 1 contains 3 objects of type image, line. Axes 4 with title Label: 9 contains 3 objects of type image, line. Axes 5 with title Label: 6 contains 3 objects of type image, line. Axes 6 with title Label: 0 contains 3 objects of type image, line. Axes 7 with title Label: 2 contains 3 objects of type image, line. Axes 8 with title Label: 5 contains 3 objects of type image, line. Axes 9 with title Label: 9 contains 3 objects of type image, line.

См. также

| | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте