В этом примере показано, как делать предсказания с помощью функции модели путем разделения данных на мини-пакеты.
Для больших наборов данных или при прогнозировании на оборудовании с ограниченной памятью делайте предсказания, разделяя данные на мини-пакеты. При выполнении предсказаний с SeriesNetwork
или DAGNetwork
объекты, predict
функция автоматически разделяет входные данные на мини-пакеты. Для функций модели необходимо разделить данные на мини-пакеты вручную.
Загрузите параметры модели из файла MAT digitsMIMO.mat
. Файл MAT содержит параметры модели в struct с именем parameters
, состояние модели в struct с именем state
и имена классов в classNames
.
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;
Функция модели model
, перечисленный в конце примера, определяет модель, задавая параметры модели и состояние.
Загрузите данные цифр для предсказания.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); numObservations = numel(imds.Files);
Закольцовывайте мини-пакеты тестовых данных и делайте прогнозы с помощью пользовательского цикла предсказания.
Использование minibatchqueue
для обработки и управления мини-пакетами изображений. Задайте мини-пакет размером 128. Установите для свойства read size изображения datastore значение mini-batch.
Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch
(определено в конце этого примера), чтобы объединить данные в пакет и нормализовать изображения.
Отформатируйте изображения с помощью размерностей 'SSCB'
(пространственный, пространственный, канальный, пакетный). По умолчанию в minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
.
Делайте предсказания на графическом процессоре, если он доступен. По умолчанию, minibatchqueue
объект преобразует выход в gpuArray
при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... "MiniBatchSize",miniBatchSize,... "MiniBatchFcn", @preprocessMiniBatch,... "MiniBatchFormat","SSCB");
Закольцовывайте минибатчи данных и делайте предсказания, используя predict
функция. Используйте onehotdecode
функция для определения меток классов. Сохраните предсказанные метки классов.
doTraining = false; Y1Predictions = []; Y2Predictions = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. dlX = next(mbq); % Make predictions using the predict function. [dlY1Pred,dlY2Pred] = model(parameters,dlX,doTraining,state); % Determine corresponding classes. Y1PredBatch = onehotdecode(dlY1Pred,classNames,1); Y1Predictions = [Y1Predictions Y1PredBatch]; Y2PredBatch = extractdata(dlY2Pred); Y2Predictions = [Y2Predictions Y2PredBatch]; end
Просмотрите некоторые изображения с их предсказаниями.
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Predictions(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') hold off label = string(Y1Predictions(idx(i))); title("Label: " + label) end
Функция model
принимает параметры модели parameters
, входные данные dlX
, флаг doTraining
который определяет, должна ли модель возвращать выходы для обучения или предсказания, и состояние сети state
. Сеть выводит предсказания для меток, предсказания для углов и обновленное состояние сети.
function [dlY1,dlY2,state] = model(parameters,dlX,doTraining,state) % Convolution weights = parameters.conv1.Weights; bias = parameters.conv1.Bias; dlY = dlconv(dlX,weights,bias,'Padding','same'); % Batch normalization, ReLU offset = parameters.batchnorm1.Offset; scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution, batch normalization (Skip connection) weights = parameters.convSkip.Weights; bias = parameters.convSkip.Bias; dlYSkip = dlconv(dlY,weights,bias,'Stride',2); offset = parameters.batchnormSkip.Offset; scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance); end % Convolution weights = parameters.conv2.Weights; bias = parameters.conv2.Bias; dlY = dlconv(dlY,weights,bias,'Padding','same','Stride',2); % Batch normalization, ReLU offset = parameters.batchnorm2.Offset; scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution weights = parameters.conv3.Weights; bias = parameters.conv3.Bias; dlY = dlconv(dlY,weights,bias,'Padding','same'); % Batch normalization offset = parameters.batchnorm3.Offset; scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end % Addition, ReLU dlY = dlYSkip + dlY; dlY = relu(dlY); % Fully connect, softmax (labels) weights = parameters.fc1.Weights; bias = parameters.fc1.Bias; dlY1 = fullyconnect(dlY,weights,bias); dlY1 = softmax(dlY1); % Fully connect (angles) weights = parameters.fc2.Weights; bias = parameters.fc2.Bias; dlY2 = fullyconnect(dlY,weights,bias); end
The preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные из входящего массива ячеек и соедините в числовой массив. Конкатенация по четвертой размерности добавляет третью размерность к каждому изображению, чтобы использоваться в качестве размерности синглтонного канала.
Нормализуйте значения пикселей между 0
и 1
.
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
batchnorm
| dlarray
| dlconv
| dlfeval
| dlgradient
| fullyconnect
| minibatchqueue
| onehotdecode
| relu
| sgdmupdate
| softmax