В этом примере показано, как собрать несколько выходных сетей для прогнозирования.
Вместо использования dlnetwork объект для прогнозирования, можно собрать сеть в DAGNetwork готов к прогнозированию с использованием assembleNetwork функция. Это позволяет использовать predict с другими типами данных, такими как хранилища данных.
Загрузка параметров модели из файла MAT dlnetDigits.mat. Файл MAT содержит dlnetwork объект, который предсказывает как оценки для категориальных меток и числовых углов поворота изображений цифр, так и соответствующие имена классов.
s = load("dlnetDigits.mat");
dlnet = s.dlnet;
classNames = s.classNames;Извлеките график слоев из dlnetwork с использованием layerGraph функция.
lgraph = layerGraph(dlnet);
График слоев не включает выходные слои. Добавление классификационного слоя и регрессионного слоя к графу слоев с помощью addLayers и connectLayers функции.
layers = classificationLayer('Classes',classNames,'Name','coutput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'softmax','coutput'); layers = regressionLayer('Name','routput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'fc2','routput');
Просмотр графика сети.
figure plot(lgraph)

Соберите сеть с помощью assembleNetwork функция.
net = assembleNetwork(lgraph)
net =
DAGNetwork with properties:
Layers: [19x1 nnet.cnn.layer.Layer]
Connections: [19x2 table]
InputNames: {'in'}
OutputNames: {'coutput' 'routput'}
Загрузите данные теста.
[XTest,Y1Test,Y2Test] = digitTest4DArrayData;
Чтобы сделать прогнозы с использованием собранной сети, используйте predict функция. Чтобы вернуть категориальные метки для выходных данных классификации, установите значение 'ReturnCategorical' опция для true.
[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);Оцените точность классификации.
accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9870
Оцените точность регрессии.
angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
6.0091
Просмотрите некоторые изображения с их прогнозами. Отобразите прогнозируемые углы красным цветом, а правильные метки - зеленым.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Pred(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = Y2Test(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(Y1Pred(idx(i))); title("Label: " + label) end

assembleNetwork | batchNormalizationLayer | convolution2dLayer | fullyConnectedLayer | predict | reluLayer | softmaxLayer