Соберите несколько - Выходная сеть для прогноза

Вместо того, чтобы использовать функцию модели в прогнозе, можно собрать сеть в DAGNetwork готовый к прогнозу с помощью functionToLayerGraph и assembleNetwork функции. Это позволяет вам использовать predict функция.

Загрузите функцию модели и параметры

Загрузите параметры модели из файла MAT digitsMIMO.mat. Файл MAT содержит параметры модели в struct под названием parameters, состояние модели в struct под названием state, и имена классов в classNames.

s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;

Функциональный model модели, перечисленный в конце примера, задает модель, учитывая параметры модели и состояние.

Соберите сеть для прогноза

Задайте анонимную функцию с фиксированным набором параметров модели, состояния модели, и установите doTraining опция к false.

doTraining = false;
fun = @(dlX) model(dlX,parameters,doTraining,state);

Преобразуйте функцию модели в график слоя с помощью functionToLayerGraph функция. Создайте переменную dlX это содержит мини-пакет данных с ожидаемым форматом.

X = rand(28,28,1,128,'single');
dlX = dlarray(X,'SSCB');
lgraph = functionToLayerGraph(fun,dlX);

График слоя выводится functionToLayerGraph функция не включает слои ввода и вывода. Добавьте входной слой, слой классификации и слой регрессии к графику слоя с помощью addLayers и connectLayers функции.

layers = imageInputLayer([28 28 1],'Name','input','Normalization','none');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'input','conv_1');
layers = classificationLayer('Classes',classNames,'Name','coutput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'sm_1','coutput');
layers = regressionLayer('Name','routput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'fc_1','routput');

Просмотрите график сети.

figure
plot(lgraph)

Соберите сеть с помощью assembleNetwork функция.

net = assembleNetwork(lgraph)
net = 
  DAGNetwork with properties:

         Layers: [17×1 nnet.cnn.layer.Layer]
    Connections: [17×2 table]
     InputNames: {'input'}
    OutputNames: {'coutput'  'routput'}

Сделайте прогнозы на новых данных

Загрузите тестовые данные.

[XTest,Y1Test,Y2Test] = digitTest4DArrayData;

Чтобы сделать прогнозы с помощью собранной сети, используйте predict функция. Чтобы возвратить категориальные метки для классификации выход, установите 'ReturnCategorical' опция к true.

[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);

Оцените точность классификации.

accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9954

Оцените точность регрессии.

angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
    5.6085

Просмотрите некоторые изображения с их прогнозами. Отобразите предсказанные углы красного цвета и правильные метки зеленого цвета.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = Y2Pred(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    
    thetaValidation = Y2Test(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--')
    
    hold off
    label = string(Y1Pred(idx(i)));
    title("Label: " + label)
end

Функция модели

Модель функции берет входные данные dlX, параметры модели parameters, флаг doTraining который задает, должен ли к модели возвратить выходные параметры для обучения или прогноза и сетевого state состояния. Сетевые выходные параметры прогнозы для меток, прогнозы для углов и обновленное сетевое состояние.

function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state)

% Convolution
W = parameters.conv1.Weights;
B = parameters.conv1.Bias;
dlY = dlconv(dlX,W,B,'Padding',2);

% Batch normalization, ReLU
Offset = parameters.batchnorm1.Offset;
Scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);

% Convolution (Skip connection)
W = parameters.convs.Weights;
B = parameters.convs.Bias;
YSkip = dlconv(dlY,W,B,'Stride',2);

% Convolution
W = parameters.conv2.Weights;
B = parameters.conv2.Bias;
dlY = dlconv(dlY,W,B,'Padding',1,'Stride',2);

% Batch normalization, ReLU
Offset = parameters.batchnorm2.Offset;
Scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm2.TrainedMean = trainedMean;
    state.batchnorm2.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);

% Convolution
W = parameters.conv3.Weights;
B = parameters.conv3.Bias;
dlY = dlconv(dlY,W,B,'Padding',1);

% Batch normalization, ReLU
Offset = parameters.batchnorm3.Offset;
Scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm3.TrainedMean = trainedMean;
    state.batchnorm3.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);

% Addition
dlY = YSkip + dlY;

% Fully connect (angles)
W = parameters.fc1.Weights;
B = parameters.fc1.Bias;
dlY2 = fullyconnect(dlY,W,B);

% Fully connect, softmax (labels)
W = parameters.fc2.Weights;
B = parameters.fc2.Bias;
dlY1 = fullyconnect(dlY,W,B);
dlY1 = softmax(dlY1);

end

Смотрите также

| | | | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте