exponenta event banner

Составлять прогнозы с помощью dlnetwork Объект

В этом примере показано, как делать прогнозы с помощью dlnetwork путем разделения данных на мини-пакеты.

Для больших наборов данных или при прогнозировании на оборудовании с ограниченной памятью сделайте прогнозы, разбив данные на мини-пакеты. При составлении прогнозов с помощью SeriesNetwork или DAGNetwork объекты, predict функция автоматически разбивает входные данные на мини-пакеты. Для dlnetwork объекты, необходимо разделить данные на мини-пакеты вручную.

Груз dlnetwork Объект

Загрузка обученного dlnetwork объект и соответствующие классы.

s = load("digitsCustom.mat");
dlnet = s.dlnet;
classes = s.classes;

Загрузить данные для прогнозирования

Загрузите данные цифр для прогнозирования.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true);

Делать прогнозы

Закольцовывать мини-пакеты тестовых данных и делать прогнозы с использованием пользовательского цикла прогнозирования.

Использовать minibatchqueue для обработки и управления мини-партиями изображений. Укажите размер мини-пакета 128. Задайте для свойства read size хранилища данных образа размер мини-пакета.

Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определено в конце этого примера) для объединения данных в пакет и нормализации изображений.

  • Форматирование изображений с размерами 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single.

  • Сделайте прогнозы на GPU, если они доступны. По умолчанию minibatchqueue объект преобразует выходные данные в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

miniBatchSize = 128;
imds.ReadSize = miniBatchSize;

mbq = minibatchqueue(imds,...
    "MiniBatchSize",miniBatchSize,...
    "MiniBatchFcn", @preprocessMiniBatch,...
    "MiniBatchFormat","SSCB");

Закольцовывать мини-пакеты данных и делать прогнозы с помощью predict функция. Используйте onehotdecode для определения меток класса. Сохраните прогнозируемые метки класса.

numObservations = numel(imds.Files);
YPred = strings(1,numObservations);

predictions = [];

% Loop over mini-batches.
while hasdata(mbq)
    
    % Read mini-batch of data.
    dlX = next(mbq);
       
    % Make predictions using the predict function.
    dlYPred = predict(dlnet,dlX);
   
    % Determine corresponding classes.
    predBatch = onehotdecode(dlYPred,classes,1);
    predictions = [predictions predBatch];
  
end

Визуализируйте некоторые прогнозы.

idx = randperm(numObservations,9);

figure
for i = 1:9
    subplot(3,3,i)
    I = imread(imds.Files{idx(i)});    
    label = predictions(idx(i));
    imshow(I)
    title("Label: " + string(label))
  
end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция выполняет предварительную обработку данных с помощью следующих шагов:

  1. Извлеките данные из входящего массива ячеек и объедините их в числовой массив. При конкатенации над четвертым размером к каждому изображению добавляется третий размер, который будет использоваться в качестве размера одиночного канала.

  2. Нормализация значений пикселов между 0 и 1.

function X = preprocessMiniBatch(data)    
    % Extract image data from cell and concatenate
    X = cat(4,data{:});
    
    % Normalize the images.
    X = X/255;
end

См. также

| | | |

Связанные темы