В этом примере показано, как собрать кратное выходная сеть для предсказания.
Вместо того, чтобы использовать dlnetwork
объект для предсказания, можно собрать сеть в DAGNetwork
готовый к предсказанию с помощью assembleNetwork
функция. Это позволяет вам использовать predict
функция с другими типами данных, такими как хранилища данных.
Загрузите параметры модели из файла MAT dlnetDigits.mat
. Файл MAT содержит dlnetwork
объект, который предсказывает и музыку к категориальным меткам и числовые углы вращения изображений цифр и соответствующие имена классов.
s = load("dlnetDigits.mat");
dlnet = s.dlnet;
classNames = s.classNames;
Извлеките график слоев из dlnetwork
объект с помощью layerGraph
функция.
lgraph = layerGraph(dlnet);
График слоев не включает выходные слои. Добавьте слой классификации и слой регрессии к графику слоев с помощью addLayers
и connectLayers
функции.
layers = classificationLayer('Classes',classNames,'Name','coutput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'softmax','coutput'); layers = regressionLayer('Name','routput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'fc2','routput');
Просмотрите график сети.
figure plot(lgraph)
Соберите сеть с помощью assembleNetwork
функция.
net = assembleNetwork(lgraph)
net = DAGNetwork with properties: Layers: [19x1 nnet.cnn.layer.Layer] Connections: [19x2 table] InputNames: {'in'} OutputNames: {'coutput' 'routput'}
Загрузите тестовые данные.
[XTest,Y1Test,Y2Test] = digitTest4DArrayData;
Чтобы сделать предсказания с помощью собранной сети, используйте predict
функция. Чтобы возвратить категориальные метки для классификации выход, установите 'ReturnCategorical'
опция к true
.
[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);
Оцените точность классификации.
accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9870
Оцените точность регрессии.
angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
6.0091
Просмотрите некоторые изображения с их предсказаниями. Отобразите предсказанные углы красного цвета и правильные метки зеленого цвета.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Pred(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = Y2Test(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(Y1Pred(idx(i))); title("Label: " + label) end
assembleNetwork
| batchNormalizationLayer
| convolution2dLayer
| fullyConnectedLayer
| predict
| reluLayer
| softmaxLayer