Вместо того, чтобы использовать функцию модели в предсказании, можно собрать сеть в DAGNetwork готовый к предсказанию с помощью functionToLayerGraph и assembleNetwork функции. Это позволяет вам использовать predict функция.
Загрузите параметры модели из файла MAT digitsMIMO.mat. Файл MAT содержит параметры модели в struct под названием parameters, состояние модели в struct под названием state, и имена классов в classNames.
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;Функциональный model модели, перечисленный в конце примера, задает модель, учитывая параметры модели и состояние.
Задайте анонимную функцию с фиксированным набором параметров модели, состояния модели, и установите doTraining опция к false.
doTraining = false; fun = @(dlX) model(dlX,parameters,doTraining,state);
Преобразуйте функцию модели в график слоев с помощью functionToLayerGraph функция. Создайте переменную dlX это содержит мини-пакет данных с ожидаемым форматом.
X = rand(28,28,1,128,'single'); dlX = dlarray(X,'SSCB'); lgraph = functionToLayerGraph(fun,dlX);
График слоев выводится functionToLayerGraph функция не включает входные и выходные слои. Добавьте входной слой, слой классификации и слой регрессии к графику слоев с помощью addLayers и connectLayers функции.
layers = imageInputLayer([28 28 1],'Name','input','Normalization','none'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'input','conv_1'); layers = classificationLayer('Classes',classNames,'Name','coutput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'sm_1','coutput'); layers = regressionLayer('Name','routput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'fc_1','routput');
Просмотрите график сети.
figure plot(lgraph)

Соберите сеть с помощью assembleNetwork функция.
net = assembleNetwork(lgraph)
net =
DAGNetwork with properties:
Layers: [18×1 nnet.cnn.layer.Layer]
Connections: [18×2 table]
InputNames: {'input'}
OutputNames: {'coutput' 'routput'}
Загрузите тестовые данные.
[XTest,Y1Test,Y2Test] = digitTest4DArrayData;
Чтобы сделать предсказания с помощью собранной сети, используйте predict функция. Чтобы возвратить категориальные метки для классификации выход, установите 'ReturnCategorical' опция к true.
[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);Оцените точность классификации.
accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9644
Оцените точность регрессии.
angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
5.8081
Просмотрите некоторые изображения с их предсказаниями. Отобразите предсказанные углы красного цвета и правильные метки зеленого цвета.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Pred(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = Y2Test(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(Y1Pred(idx(i))); title("Label: " + label) end

Модель функции берет входные данные dlX, параметры модели parameters, флаг doTraining который задает, должен ли к модели возвратить выходные параметры для обучения или предсказания и сетевого state состояния. Сетевые выходные параметры предсказания для меток, предсказания для углов и обновленное сетевое состояние.
function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state) % Convolution W = parameters.conv1.Weights; B = parameters.conv1.Bias; dlY = dlconv(dlX,W,B,'Padding',2); % Batch normalization, ReLU Offset = parameters.batchnorm1.Offset; Scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution, batch normalization (Skip connection) W = parameters.convSkip.Weights; B = parameters.convSkip.Bias; dlYSkip = dlconv(dlY,W,B,'Stride',2); Offset = parameters.batchnormSkip.Offset; Scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else dlYSkip = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance); end % Convolution W = parameters.conv2.Weights; B = parameters.conv2.Bias; dlY = dlconv(dlY,W,B,'Padding',1,'Stride',2); % Batch normalization, ReLU Offset = parameters.batchnorm2.Offset; Scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution W = parameters.conv3.Weights; B = parameters.conv3.Bias; dlY = dlconv(dlY,W,B,'Padding',1); % Batch normalization Offset = parameters.batchnorm3.Offset; Scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end % Addition, ReLU dlY = dlYSkip + dlY; dlY = relu(dlY); % Fully connect (angles) W = parameters.fc1.Weights; B = parameters.fc1.Bias; dlY2 = fullyconnect(dlY,W,B); % Fully connect, softmax (labels) W = parameters.fc2.Weights; B = parameters.fc2.Bias; dlY1 = fullyconnect(dlY,W,B); dlY1 = softmax(dlY1); end
assembleNetwork | batchnorm | dlarray | dlconv | fullyconnect | functionToLayerGraph | predict | relu | softmax