Сделайте предсказания Используя функцию модели

В этом примере показано, как сделать предсказания с помощью функции модели путем разделения данных в мини-пакеты.

Для больших наборов данных, или при предсказании на оборудовании с ограниченной памятью, делают предсказания путем разделения данных в мини-пакеты. При создании предсказаний с SeriesNetwork или DAGNetwork объекты, predict функционируйте автоматически разделяет входные данные в мини-пакеты. Для функций модели необходимо разделить данные в мини-пакеты вручную.

Создайте функцию модели и загрузите параметры

Загрузите параметры модели из файла MAT digitsMIMO.mat. Файл MAT содержит параметры модели в struct под названием parameters, состояние модели в struct под названием state, и имена классов в classNames.

s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;

Функциональный model модели, перечисленный в конце примера, задает модель, учитывая параметры модели и состояние.

Загрузите данные для предсказания

Загрузите данные о цифрах для предсказания.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

numObservations = numel(imds.Files);

Сделайте предсказания

Цикл по мини-пакетам тестовых данных и делает предсказания с помощью пользовательского цикла предсказания.

Используйте minibatchqueue обработать и управлять мини-пакетами изображений. Задайте мини-пакетный размер 128. Установите свойство размера чтения datastore изображений к мини-пакетному размеру.

Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера), чтобы конкатенировать данные в пакет и нормировать изображения.

  • Отформатируйте изображения с размерностями 'SSCB' (пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single.

  • Сделайте предсказания на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue объект преобразует выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

miniBatchSize = 128;
imds.ReadSize = miniBatchSize;

mbq = minibatchqueue(imds,...
    "MiniBatchSize",miniBatchSize,...
    "MiniBatchFcn", @preprocessMiniBatch,...
    "MiniBatchFormat","SSCB");

Цикл по мини-пакетам данных и делает предсказания с помощью predict функция. Используйте onehotdecode функция, чтобы определить метки класса. Сохраните предсказанные метки класса.

doTraining = false;

Y1Predictions = [];
Y2Predictions = [];

% Loop over mini-batches.
while hasdata(mbq)
    
    % Read mini-batch of data.
    dlX = next(mbq);
    
    % Make predictions using the predict function.
    [dlY1Pred,dlY2Pred] = model(parameters,dlX,doTraining,state);
    
    % Determine corresponding classes.
    Y1PredBatch = onehotdecode(dlY1Pred,classNames,1);
    Y1Predictions = [Y1Predictions Y1PredBatch];
    
    Y2PredBatch = extractdata(dlY2Pred);
    Y2Predictions = [Y2Predictions Y2PredBatch];

end

Просмотрите некоторые изображения с их предсказаниями.

idx = randperm(numObservations,9);
figure
for i = 1:9
    subplot(3,3,i)
    I = imread(imds.Files{idx(i)});
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = Y2Predictions(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    
    hold off
    label = string(Y1Predictions(idx(i)));
    title("Label: " + label)
end

Функция модели

Функциональный model берет параметры модели parameters, входные данные dlX, флаг doTraining который задает, должен ли к модели возвратить выходные параметры для обучения или предсказания и сетевого state состояния. Сетевые выходные параметры предсказания для меток, предсказания для углов и обновленное сетевое состояние.

function [dlY1,dlY2,state] = model(parameters,dlX,doTraining,state)

% Convolution
weights = parameters.conv1.Weights;
bias = parameters.conv1.Bias;
dlY = dlconv(dlX,weights,bias,'Padding','same');

% Batch normalization, ReLU
offset = parameters.batchnorm1.Offset;
scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = relu(dlY);

% Convolution, batch normalization (Skip connection)
weights = parameters.convSkip.Weights;
bias = parameters.convSkip.Bias;
dlYSkip = dlconv(dlY,weights,bias,'Stride',2);

offset = parameters.batchnormSkip.Offset;
scale = parameters.batchnormSkip.Scale;
trainedMean = state.batchnormSkip.TrainedMean;
trainedVariance = state.batchnormSkip.TrainedVariance;

if doTraining
    [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnormSkip.TrainedMean = trainedMean;
    state.batchnormSkip.TrainedVariance = trainedVariance;
else
    dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);
end

% Convolution
weights = parameters.conv2.Weights;
bias = parameters.conv2.Bias;
dlY = dlconv(dlY,weights,bias,'Padding','same','Stride',2);

% Batch normalization, ReLU
offset = parameters.batchnorm2.Offset;
scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm2.TrainedMean = trainedMean;
    state.batchnorm2.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = relu(dlY);

% Convolution
weights = parameters.conv3.Weights;
bias = parameters.conv3.Bias;
dlY = dlconv(dlY,weights,bias,'Padding','same');

% Batch normalization
offset = parameters.batchnorm3.Offset;
scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
    
    % Update state
    state.batchnorm3.TrainedMean = trainedMean;
    state.batchnorm3.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

% Addition, ReLU
dlY = dlYSkip + dlY;
dlY = relu(dlY);

% Fully connect, softmax (labels)
weights = parameters.fc1.Weights;
bias = parameters.fc1.Bias;
dlY1 = fullyconnect(dlY,weights,bias);
dlY1 = softmax(dlY1);

% Fully connect (angles)
weights = parameters.fc2.Weights;
bias = parameters.fc2.Bias;
dlY2 = fullyconnect(dlY,weights,bias);

end

Функция предварительной обработки мини-пакета

preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные из массива входящей ячейки и конкатенируйте в числовой массив. Конкатенация по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использоваться в качестве одноэлементной размерности канала.

  2. Нормируйте пиксельные значения между 0 и 1.

function X = preprocessMiniBatch(data)    
    % Extract image data from cell and concatenate
    X = cat(4,data{:});
    
    % Normalize the images.
    X = X/255;
end

Смотрите также

| | | | | | | | | |

Похожие темы