В этом примере показано, как сделать предсказания с помощью функции модели путем разделения данных в мини-пакеты.
Для больших наборов данных, или при предсказании на оборудовании с ограниченной памятью, делают предсказания путем разделения данных в мини-пакеты. При создании предсказаний с SeriesNetwork
или DAGNetwork
объекты, predict
функционируйте автоматически разделяет входные данные в мини-пакеты. Для функций модели необходимо разделить данные в мини-пакеты вручную.
Загрузите параметры модели из файла MAT digitsMIMO.mat
. Файл MAT содержит параметры модели в struct под названием parameters
, состояние модели в struct под названием state
, и имена классов в classNames
.
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;
Функциональный model
модели, перечисленный в конце примера, задает модель, учитывая параметры модели и состояние.
Загрузите данные о цифрах для предсказания.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); numObservations = numel(imds.Files);
Цикл по мини-пакетам тестовых данных и делает предсказания с помощью пользовательского цикла предсказания.
Используйте minibatchqueue
обработать и управлять мини-пакетами изображений. Задайте мини-пакетный размер 128. Установите свойство размера чтения datastore изображений к мини-пакетному размеру.
Для каждого мини-пакета:
Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch
(заданный в конце этого примера), чтобы конкатенировать данные в пакет и нормировать изображения.
Отформатируйте изображения с размерностями 'SSCB'
(пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
.
Сделайте предсказания на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue
объект преобразует выход в gpuArray
если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... "MiniBatchSize",miniBatchSize,... "MiniBatchFcn", @preprocessMiniBatch,... "MiniBatchFormat","SSCB");
Цикл по мини-пакетам данных и делает предсказания с помощью predict
функция. Используйте onehotdecode
функция, чтобы определить метки класса. Сохраните предсказанные метки класса.
doTraining = false; Y1Predictions = []; Y2Predictions = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. dlX = next(mbq); % Make predictions using the predict function. [dlY1Pred,dlY2Pred] = model(parameters,dlX,doTraining,state); % Determine corresponding classes. Y1PredBatch = onehotdecode(dlY1Pred,classNames,1); Y1Predictions = [Y1Predictions Y1PredBatch]; Y2PredBatch = extractdata(dlY2Pred); Y2Predictions = [Y2Predictions Y2PredBatch]; end
Просмотрите некоторые изображения с их предсказаниями.
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Predictions(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') hold off label = string(Y1Predictions(idx(i))); title("Label: " + label) end
Функциональный model
берет параметры модели parameters
, входные данные dlX
, флаг doTraining
который задает, должен ли к модели возвратить выходные параметры для обучения или предсказания и сетевого state
состояния. Сетевые выходные параметры предсказания для меток, предсказания для углов и обновленное сетевое состояние.
function [dlY1,dlY2,state] = model(parameters,dlX,doTraining,state) % Convolution weights = parameters.conv1.Weights; bias = parameters.conv1.Bias; dlY = dlconv(dlX,weights,bias,'Padding','same'); % Batch normalization, ReLU offset = parameters.batchnorm1.Offset; scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution, batch normalization (Skip connection) weights = parameters.convSkip.Weights; bias = parameters.convSkip.Bias; dlYSkip = dlconv(dlY,weights,bias,'Stride',2); offset = parameters.batchnormSkip.Offset; scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance); end % Convolution weights = parameters.conv2.Weights; bias = parameters.conv2.Bias; dlY = dlconv(dlY,weights,bias,'Padding','same','Stride',2); % Batch normalization, ReLU offset = parameters.batchnorm2.Offset; scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution weights = parameters.conv3.Weights; bias = parameters.conv3.Bias; dlY = dlconv(dlY,weights,bias,'Padding','same'); % Batch normalization offset = parameters.batchnorm3.Offset; scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end % Addition, ReLU dlY = dlYSkip + dlY; dlY = relu(dlY); % Fully connect, softmax (labels) weights = parameters.fc1.Weights; bias = parameters.fc1.Bias; dlY1 = fullyconnect(dlY,weights,bias); dlY1 = softmax(dlY1); % Fully connect (angles) weights = parameters.fc2.Weights; bias = parameters.fc2.Bias; dlY2 = fullyconnect(dlY,weights,bias); end
preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные из массива входящей ячейки и конкатенируйте в числовой массив. Конкатенация по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использоваться в качестве одноэлементной размерности канала.
Нормируйте пиксельные значения между 0
и 1
.
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
batchnorm
| dlarray
| dlconv
| dlfeval
| dlgradient
| fullyconnect
| minibatchqueue
| onehotdecode
| relu
| sgdmupdate
| softmax