В этом примере показано, как сделать прогнозы с помощью функции модели путем разделения данных в мини-пакеты.
Для больших наборов данных, или при предсказании на оборудовании с ограниченной памятью, делают прогнозы путем разделения данных в мини-пакеты. При создании прогнозов с SeriesNetwork
или DAGNetwork
объекты, predict
функционируйте автоматически разделяет входные данные в мини-пакеты. Для функций модели необходимо разделить данные в мини-пакеты вручную.
Загрузите параметры модели из файла MAT digitsMIMO.mat
. Файл MAT содержит параметры модели в struct под названием parameters
, состояние модели в struct под названием state
, и имена классов в classNames
.
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;
Функциональный model
модели, перечисленный в конце примера, задает модель, учитывая параметры модели и состояние.
Загрузите данные о цифрах для прогноза.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames');
Цикл по мини-пакетам тестовых данных и делает прогнозы с помощью пользовательского цикла прогноза.
Для каждого мини-пакета:
Преобразуйте данные в dlarray
объекты с базовым одним типом и указывают, что размерность маркирует 'SSCB'
(пространственный, пространственный, канал, пакет).
Сделайте прогнозы на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.
Сделайте прогнозы путем вызывания функции модели с doTraining
набор опции к false
.
Определите метки класса путем нахождения максимальных баллов.
miniBatchSize = 128; executionEnvironment = "auto"; doTraining = false; imds.ReadSize = miniBatchSize; numObservations = numel(imds.Files); Y1Pred = strings(1,numObservations); Y2Pred = zeros(1,numObservations); i = 1; % Loop over mini-batches. while hasdata(imds) % Read mini-batch of data. data = read(imds); X = cat(4,data{:}); % Normalize the images. X = single(X)/255; % Convert mini-batch of data to dlarray. dlX = dlarray(X,'SSCB'); % If making predictions on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Make predictions using the predict function. [dlY1Pred,dlY2Pred] = model(dlX,parameters,doTraining,state); % Determine corresponding classes. [~,idxTop] = max(extractdata(dlY1Pred),[],1); idxMiniBatch = i:min((i+miniBatchSize-1),numObservations); Y1Pred(idxMiniBatch) = classNames(idxTop); Y2Pred(idxMiniBatch) = gather(extractdata(dlY2Pred)); i = i + miniBatchSize; end
Просмотрите некоторые изображения с их прогнозами.
idx = randperm(numel(imds.Files),9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Pred(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') hold off label = string(Y1Pred(idx(i))); title("Label: " + label) end
Модель функции берет входные данные dlX
, параметры модели parameters
, флаг doTraining
который задает, должен ли к модели возвратить выходные параметры для обучения или прогноза и сетевого state
состояния. Сетевые выходные параметры прогнозы для меток, прогнозы для углов и обновленное сетевое состояние.
function [dlY1,dlY2,state] = model(dlX,parameters,doTraining,state) % Convolution W = parameters.conv1.Weights; B = parameters.conv1.Bias; dlY = dlconv(dlX,W,B,'Padding',2); % Batch normalization, ReLU Offset = parameters.batchnorm1.Offset; Scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution, batch normalization (Skip connection) W = parameters.convSkip.Weights; B = parameters.convSkip.Bias; dlYSkip = dlconv(dlY,W,B,'Stride',2); Offset = parameters.batchnormSkip.Offset; Scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else dlYSkip = batchnorm(dlYSkip,Offset,Scale,trainedMean,trainedVariance); end % Convolution W = parameters.conv2.Weights; B = parameters.conv2.Bias; dlY = dlconv(dlY,W,B,'Padding',1,'Stride',2); % Batch normalization, ReLU Offset = parameters.batchnorm2.Offset; Scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution W = parameters.conv3.Weights; B = parameters.conv3.Bias; dlY = dlconv(dlY,W,B,'Padding',1); % Batch normalization Offset = parameters.batchnorm3.Offset; Scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,Offset,Scale,trainedMean,trainedVariance); end % Addition, ReLU dlY = dlYSkip + dlY; dlY = relu(dlY); % Fully connect (angles) W = parameters.fc1.Weights; B = parameters.fc1.Bias; dlY2 = fullyconnect(dlY,W,B); % Fully connect, softmax (labels) W = parameters.fc2.Weights; B = parameters.fc2.Bias; dlY1 = fullyconnect(dlY,W,B); dlY1 = softmax(dlY1); end
batchnorm
| dlarray
| dlconv
| dlfeval
| dlgradient
| fullyconnect
| relu
| sgdmupdate
| softmax