Сделайте прогнозы Используя функцию модели

В этом примере показано, как сделать прогнозы с помощью функции модели путем разделения данных в мини-пакеты.

Для больших наборов данных, или при предсказании на оборудовании с ограниченной памятью, делают прогнозы путем разделения данных в мини-пакеты. При создании прогнозов с SeriesNetwork или DAGNetwork объекты, predict функционируйте автоматически разделяет входные данные в мини-пакеты. Для функций модели необходимо разделить данные в мини-пакеты вручную.

Создайте функцию модели и загрузите параметры

Загрузите параметры модели из файла MAT digitsMIMO.mat. Файл MAT содержит параметры модели в struct под названием parameters, состояние модели в struct под названием state, и имена классов в classNames.

s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;

Функциональный model модели, перечисленный в конце примера, задает модель, учитывая параметры модели и состояние.

Загрузите данные для прогноза

Загрузите данные о цифрах для прогноза.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

Сделайте прогнозы

Цикл по мини-пакетам тестовых данных и делает прогнозы с помощью пользовательского цикла прогноза.

Для каждого мини-пакета:

  • Преобразуйте данные в dlarray объекты с базовым одним типом и указывают, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Сделайте прогнозы на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.

  • Сделайте прогнозы путем вызывания функции модели с doTraining набор опции к false.

  • Определите метки класса путем нахождения максимальных баллов.

miniBatchSize = 128;
executionEnvironment = "auto";
doTraining = false;

imds.ReadSize = miniBatchSize;
numObservations = numel(imds.Files);
Y1Pred = strings(1,numObservations);
Y2Pred = zeros(1,numObservations);
i = 1;

% Loop over mini-batches.
while hasdata(imds)
    
    % Read mini-batch of data.
    data = read(imds);
    X = cat(4,data{:});
    
    % Normalize the images.
    X = single(X)/255;
    
    % Convert mini-batch of data to dlarray.
    dlX = dlarray(X,'SSCB');
    
    % If making predictions on a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        dlX = gpuArray(dlX);
    end
    
    % Make predictions using the predict function.
    [dlY1Pred,dlY2Pred] = model(dlX,parameters,doTraining,state);
    
    % Determine corresponding classes.
    [~,idxTop] = max(extractdata(dlY1Pred),[],1);
    idxMiniBatch = i:min((i+miniBatchSize-1),numObservations);
    Y1Pred(idxMiniBatch) = classNames(idxTop);
    
    Y2Pred(idxMiniBatch) = gather(extractdata(dlY2Pred));
    
    i = i + miniBatchSize;
end

Просмотрите некоторые изображения с их прогнозами.

idx = randperm(numel(imds.Files),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = imread(imds.Files{idx(i)});
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = Y2Pred(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
        
    hold off
    label = string(Y1Pred(idx(i)));
    title("Label: " + label)
end

Функция модели

Модель функции берет входные данные dlX, параметры модели parameters, флаг doTraining который задает, должен ли к модели возвратить выходные параметры для обучения или прогноза и сетевого state состояния. Сетевые выходные параметры прогнозы для меток, прогнозы для углов и обновленное сетевое состояние.

function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state)

% Convolution
W = parameters.conv1.Weights;
B = parameters.conv1.Bias;
dlY = dlconv(dlX,W,B,'Padding',2);

% Batch normalization, ReLU
Offset = parameters.batchnorm1.Offset;
Scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);

% Convolution, batch normalization (Skip connection)
W = parameters.convSkip.Weights;
B = parameters.convSkip.Bias;
dlYSkip = dlconv(dlY,W,B,'Stride',2);

Offset = parameters.batchnormSkip.Offset;
Scale = parameters.batchnormSkip.Scale;
trainedMean = state.batchnormSkip.TrainedMean;
trainedVariance = state.batchnormSkip.TrainedVariance;

if doTraining
    [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnormSkip.TrainedMean = trainedMean;
    state.batchnormSkip.TrainedVariance = trainedVariance;
else
    dlYSkip = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance);
end

% Convolution
W = parameters.conv2.Weights;
B = parameters.conv2.Bias;
dlY = dlconv(dlY,W,B,'Padding',1,'Stride',2);

% Batch normalization, ReLU
Offset = parameters.batchnorm2.Offset;
Scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm2.TrainedMean = trainedMean;
    state.batchnorm2.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end
dlY = relu(dlY);

% Convolution
W = parameters.conv3.Weights;
B = parameters.conv3.Bias;
dlY = dlconv(dlY,W,B,'Padding',1);

% Batch normalization
Offset = parameters.batchnorm3.Offset;
Scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm3.TrainedMean = trainedMean;
    state.batchnorm3.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance);
end

% Addition, ReLU
dlY = dlYSkip + dlY;
dlY = relu(dlY);

% Fully connect (angles)
W = parameters.fc1.Weights;
B = parameters.fc1.Bias;
dlY2 = fullyconnect(dlY,W,B);

% Fully connect, softmax (labels)
W = parameters.fc2.Weights;
B = parameters.fc2.Bias;
dlY1 = fullyconnect(dlY,W,B);
dlY1 = softmax(dlY1);

end

Смотрите также

| | | | | | | |

Похожие темы