Этот пример показывает, как собрать несколько выходных сетей для предсказания.
Вместо использования dlnetwork
объект для предсказания, можно собрать сеть в DAGNetwork
готов к предсказанию с помощью assembleNetwork
функция. Это позволяет вам использовать predict
функция с другими типами данных, такими как datastores.
Загрузите параметры модели из файла MAT dlnetDigits.mat
. Файл MAT содержит dlnetwork
объект, который предсказывает как счета для категориальных меток, так и числовые углы поворота изображений цифр, и соответствующие имена классов.
s = load("dlnetDigits.mat");
dlnet = s.dlnet;
classNames = s.classNames;
Извлеките график слоев из dlnetwork
объект с использованием layerGraph
функция.
lgraph = layerGraph(dlnet);
График слоев не включает выходные слои. Добавьте слой классификации и слой регрессии к графику слоев с помощью addLayers
и connectLayers
функций.
layers = classificationLayer('Classes',classNames,'Name','coutput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'softmax','coutput'); layers = regressionLayer('Name','routput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'fc2','routput');
Просмотр графика сети.
figure plot(lgraph)
Соберите сеть с помощью assembleNetwork
функция.
net = assembleNetwork(lgraph)
net = DAGNetwork with properties: Layers: [19x1 nnet.cnn.layer.Layer] Connections: [19x2 table] InputNames: {'in'} OutputNames: {'coutput' 'routput'}
Загрузите тестовые данные.
[XTest,Y1Test,Y2Test] = digitTest4DArrayData;
Чтобы делать предсказания с помощью собранной сети, используйте predict
функция. Чтобы вернуть категориальные метки для выхода классификации, установите 'ReturnCategorical'
опция для true
.
[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);
Оцените точность классификации.
accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9870
Оцените точность регрессии.
angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
6.0091
Просмотрите некоторые изображения с их предсказаниями. Отображение прогнозируемых углов красного цвета и правильных меток зеленого цвета.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Pred(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = Y2Test(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(Y1Pred(idx(i))); title("Label: " + label) end
assembleNetwork
| batchNormalizationLayer
| convolution2dLayer
| fullyConnectedLayer
| predict
| reluLayer
| softmaxLayer