Вместо того, чтобы использовать функцию модели в предсказании, можно собрать сеть в DAGNetwork
готовый к предсказанию с помощью functionToLayerGraph
и assembleNetwork
функции. Это позволяет вам использовать predict
функция.
Загрузите параметры модели из файла MAT digitsMIMO.mat
. Файл MAT содержит параметры модели в struct под названием parameters
, состояние модели в struct под названием state
, и имена классов в classNames
.
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;
Функциональный model
модели, перечисленный в конце примера, задает модель, учитывая параметры модели и состояние.
Задайте анонимную функцию с фиксированным набором параметров модели, состояния модели, и установите doTraining
опция к false
.
doTraining = false; fun = @(dlX) model(dlX,parameters,doTraining,state);
Преобразуйте функцию модели в график слоев с помощью functionToLayerGraph
функция. Создайте переменную dlX
это содержит мини-пакет данных с ожидаемым форматом.
X = rand(28,28,1,128,'single'); dlX = dlarray(X,'SSCB'); lgraph = functionToLayerGraph(fun,dlX);
График слоев выводится functionToLayerGraph
функция не включает входные и выходные слои. Добавьте входной слой, слой классификации и слой регрессии к графику слоев с помощью addLayers
и connectLayers
функции.
layers = imageInputLayer([28 28 1],'Name','input','Normalization','none'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'input','conv_1'); layers = classificationLayer('Classes',classNames,'Name','coutput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'sm_1','coutput'); layers = regressionLayer('Name','routput'); lgraph = addLayers(lgraph,layers); lgraph = connectLayers(lgraph,'fc_1','routput');
Просмотрите график сети.
figure plot(lgraph)
Соберите сеть с помощью assembleNetwork
функция.
net = assembleNetwork(lgraph)
net = DAGNetwork with properties: Layers: [18×1 nnet.cnn.layer.Layer] Connections: [18×2 table] InputNames: {'input'} OutputNames: {'coutput' 'routput'}
Загрузите тестовые данные.
[XTest,Y1Test,Y2Test] = digitTest4DArrayData;
Чтобы сделать предсказания с помощью собранной сети, используйте predict
функция. Чтобы возвратить категориальные метки для классификации выход, установите 'ReturnCategorical'
опция к true
.
[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);
Оцените точность классификации.
accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9644
Оцените точность регрессии.
angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
5.8081
Просмотрите некоторые изображения с их предсказаниями. Отобразите предсказанные углы красного цвета и правильные метки зеленого цвета.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Pred(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') thetaValidation = Y2Test(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--') hold off label = string(Y1Pred(idx(i))); title("Label: " + label) end
Модель функции берет входные данные dlX, параметры модели parameters
, флаг doTraining
который задает, должен ли к модели возвратить выходные параметры для обучения или предсказания и сетевого state
состояния. Сетевые выходные параметры предсказания для меток, предсказания для углов и обновленное сетевое состояние.
function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state) % Convolution W = parameters.conv1.Weights; B = parameters.conv1.Bias; dlY = dlconv(dlX,W,B,'Padding',2); % Batch normalization, ReLU Offset = parameters.batchnorm1.Offset; Scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution, batch normalization (Skip connection) W = parameters.convSkip.Weights; B = parameters.convSkip.Bias; dlYSkip = dlconv(dlY,W,B,'Stride',2); Offset = parameters.batchnormSkip.Offset; Scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else dlYSkip = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance); end % Convolution W = parameters.conv2.Weights; B = parameters.conv2.Bias; dlY = dlconv(dlY,W,B,'Padding',1,'Stride',2); % Batch normalization, ReLU Offset = parameters.batchnorm2.Offset; Scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution W = parameters.conv3.Weights; B = parameters.conv3.Bias; dlY = dlconv(dlY,W,B,'Padding',1); % Batch normalization Offset = parameters.batchnorm3.Offset; Scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end % Addition, ReLU dlY = dlYSkip + dlY; dlY = relu(dlY); % Fully connect (angles) W = parameters.fc1.Weights; B = parameters.fc1.Bias; dlY2 = fullyconnect(dlY,W,B); % Fully connect, softmax (labels) W = parameters.fc2.Weights; B = parameters.fc2.Bias; dlY1 = fullyconnect(dlY,W,B); dlY1 = softmax(dlY1); end
assembleNetwork
| batchnorm
| dlarray
| dlconv
| fullyconnect
| functionToLayerGraph
| predict
| relu
| softmax