Обучите классификационную сеть для классификации объекта в 3-D облаке точек

Этот пример демонстрирует подход, описанный в [1], в котором данные облака точек предварительно обрабатываются в вокселированную кодировку, а затем используются непосредственно с простой 3-D сверточной архитектурой нейронной сети для выполнения классификации объектов. В более поздних подходах, таких как [2], кодирования данных облака точек могут быть более сложными и могут быть изучены кодировки, которые обучаются сквозь пальцы наряду с сетью, выполняющей задачу классификации/обнаружения объектов/сегментации. Однако общий шаблон перехода от нерегулярных неупорядоченных точек к сетчатой структуре, которая может подаваться в конвнеты, остается одинаковым во всех этих оценках.

Импорт и анализ данных

В этом примере мы работаем с Sydney Urban Objects Dataset. В этом примере мы используем складки 1-3 из данных в качестве набора обучающих данных и складку 4 в качестве набора валидации.

dataPath = downloadSydneyUrbanObjects(tempdir);
dsTrain = loadSydneyUrbanObjectsData(dataPath,[1 2 3]);
dsVal = loadSydneyUrbanObjectsData(dataPath,4);

Проанализируйте набор обучающих данных, чтобы понять метки, существующие в данных и общем распределении меток.

dsLabels = transform(dsTrain,@(data) data{2});
labels = readall(dsLabels);
figure
histogram(labels)

Из гистограммы очевидно, что существует проблема дисбаланса классов в обучающих данных, в которых некоторые классы объектов любят Car и Pedestrian встречаются гораздо чаще, чем менее частые классы, такие как Ute.

Трубопровод увеличения данных

Чтобы избежать избыточной подгонки и добавить робастность классификатору, некоторое количество рандомизированного увеличения данных, как правило, является хорошей идеей при обучении сети. Функции randomAffine2d и pctransform облегчают определение рандомизированных аффинных преобразований на данных облака точек. Кроме того, мы добавляем некоторое рандомизированное дрожание по точкам к каждой точке в каждом облаке точек. Функция augmentPointCloudData включено в раздел вспомогательных функций ниже.

dsTrain = transform(dsTrain,@augmentPointCloudData);

Проверьте, что увеличение количества данных облака точек выглядит разумно.

dataOut = preview(dsTrain);
figure
pcshow(dataOut{1});
title(dataOut{2});

Далее мы добавляем простое преобразование вокселизации к каждому входу облака точек как обсуждалось в предыдущем примере, чтобы преобразовать наш вход облако точек в псевдоизображение, которое может использоваться со сверточной нейронной сетью. Используйте простую сетку заполнения.

dsTrain = transform(dsTrain,@formOccupancyGrid);
dsVal = transform(dsVal,@formOccupancyGrid);

Исследуйте выборку последнего вокселизированного тома, который мы будем отправлять в сеть, чтобы убедиться, что вокселиксация работает правильно.

data = preview(dsTrain);
figure
p = patch(isosurface(data{1},0.5));
p.FaceColor = 'red';
p.EdgeColor = 'none';
daspect([1 1 1])
view(45,45)
camlight; 
lighting phong
title(data{2});

Определите сетевую архитектуру

В этом примере мы используем простую архитектуру классификации 3-D, как описано в [1].

layers = [image3dInputLayer([32 32 32],'Name','inputLayer','Normalization','none'),...
    convolution3dLayer(5,32,'Stride',2,'Name','Conv1'),...
    leakyReluLayer(0.1,'Name','leakyRelu1'),...
    convolution3dLayer(3,32,'Stride',1,'Name','Conv2'),...
    leakyReluLayer(0.1,'Name','leakyRulu2'),...
    maxPooling3dLayer(2,'Stride',2,'Name','maxPool'),...
    fullyConnectedLayer(128,'Name','fc1'),...
    reluLayer('Name','relu'),...
    dropoutLayer(0.5,'Name','dropout1'),...
    fullyConnectedLayer(14,'Name','fc2'),...
    softmaxLayer('Name','softmax'),...
    classificationLayer('Name','crossEntropyLoss')];

voxnet = layerGraph(layers);
figure
plot(voxnet);

Setup опций обучения

Используйте стохастический градиентный спуск с импульсом с кусочно-линейной корректировкой графика скорости обучения. Этот пример выполнялся на графическом процессоре TitanX, для графических процессоров с меньшим объемом памяти может потребоваться уменьшить размер пакета. Хотя 3D конвейеры имеют преимущество концептуальной простоты, они имеют недостаток больших объемов использования памяти во время обучения.

miniBatchSize = 32;
dsLength = length(dsTrain.UnderlyingDatastore.Files);
iterationsPerEpoch = floor(dsLength/miniBatchSize);
dropPeriod = floor(8000/iterationsPerEpoch);

options = trainingOptions('sgdm','InitialLearnRate',0.01,'MiniBatchSize',miniBatchSize,...
    'LearnRateSchedule','Piecewise',...
    'LearnRateDropPeriod',dropPeriod,...
    'ValidationData',dsVal,'MaxEpochs',60,...
    'DispatchInBackground',false,...
    'Shuffle','never');

Обучите сеть

voxnet = trainNetwork(dsTrain,voxnet,options);
Training on single GPU.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |        9.38% |       20.65% |       2.6408 |       2.6300 |          0.0100 |
|       4 |          50 |       00:00:25 |       31.25% |       29.03% |       2.2892 |       2.2954 |          0.0100 |
|       8 |         100 |       00:00:45 |       37.50% |       37.42% |       1.9256 |       2.0372 |          0.0100 |
|      12 |         150 |       00:01:05 |       53.12% |       47.10% |       1.6398 |       1.7396 |          0.0100 |
|      16 |         200 |       00:01:24 |       43.75% |       55.48% |       1.9551 |       1.5172 |          0.0100 |
|      20 |         250 |       00:01:44 |       40.62% |       61.29% |       1.7413 |       1.3598 |          0.0100 |
|      24 |         300 |       00:02:04 |       50.00% |       60.00% |       1.4652 |       1.2962 |          0.0100 |
|      27 |         350 |       00:02:23 |       43.75% |       64.52% |       1.5017 |       1.1762 |          0.0100 |
|      31 |         400 |       00:02:42 |       53.12% |       69.03% |       1.2488 |       1.1132 |          0.0100 |
|      35 |         450 |       00:03:02 |       50.00% |       69.03% |       1.3160 |       1.0272 |          0.0100 |
|      39 |         500 |       00:03:23 |       59.38% |       69.03% |       1.1753 |       1.1366 |          0.0100 |
|      43 |         550 |       00:03:44 |       56.25% |       65.81% |       1.1546 |       1.1086 |          0.0100 |
|      47 |         600 |       00:04:03 |       68.75% |       65.81% |       0.9808 |       1.0251 |          0.0100 |
|      50 |         650 |       00:04:22 |       65.62% |       69.68% |       1.1245 |       1.0136 |          0.0100 |
|      54 |         700 |       00:04:42 |       62.50% |       65.16% |       1.2860 |       1.0934 |          0.0100 |
|      58 |         750 |       00:05:01 |       59.38% |       68.39% |       1.2466 |       1.0271 |          0.0100 |
|      60 |         780 |       00:05:13 |       56.25% |       64.52% |       1.1676 |       1.0798 |          0.0100 |
|======================================================================================================================|

Оценка сети

Следуя структуре [1], этот пример формирует только набор обучения и валидации из объектов Sydney Urban. Оцените эффективность обученной сети с помощью валидации, поскольку она не использовалась для обучения сети.

valLabelSet = transform(dsVal,@(data) data{2});
valLabels = readall(valLabelSet);
outputLabels = classify(voxnet,dsVal);
accuracy = nnz(outputLabels == valLabels) / numel(outputLabels);
disp(accuracy)
    0.6452

Просмотрите матрицу неточностей, чтобы изучить точность по различным категориям меток

figure
plotconfusion(valLabels,outputLabels)

Дисбаланс меток, отмеченный в наборе обучающих данных, является проблемой точности классификации. График неточностей иллюстрирует более высокую точность и отзыв для пешехода, наиболее распространенного класса, чем для менее распространенных классов, таких как фургон. Поскольку цель этого примера состоит в том, чтобы демонстрировать подход к обучению базовой классификационной сети с данными облака точек, возможные следующие шаги, которые могут быть предприняты для улучшения классификации эффективности таких как повторная дискретизация набора обучающих данных или достижение лучшего баланса меток или использование функции потерь, более устойчивой к метке дисбаланса (например, взвешенная перекрестная энтропия), не будут исследованы.

Ссылки

1) Voxnet: 3D сверточная нейронная сеть для алгоритма распознавания в реальном времени, Daniel Maturana, Sebastian Scherer, 2015 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)

2) PointPillars: быстрые энкодеры для обнаружения объектов из облаков точек, Alex H. Lang, Sourabh Vora, et al, CVPR 2019

3) Сидней Urban Objects Dataset, Alastair Quadros, Джеймс Андервуд, Бертран Дуиллард, Сидней Urban Objects

Вспомогательные функции

function datasetPath = downloadSydneyUrbanObjects(dataLoc)

if nargin == 0
    dataLoc = pwd();
end

dataLoc = string(dataLoc);

url = "http://www.acfr.usyd.edu.au/papers/data/";
name = "sydney-urban-objects-dataset.tar.gz";

if ~exist(fullfile(dataLoc,'sydney-urban-objects-dataset'),'dir')
    disp('Downloading Sydney Urban Objects Dataset...');
    untar(fullfile(url,name),dataLoc);
end

datasetPath = dataLoc.append('sydney-urban-objects-dataset');

end

function ds = loadSydneyUrbanObjectsData(datapath,folds)
% loadSydneyUrbanObjectsData Datastore with point clouds and
% associated categorical labels for Sydney Urban Objects dataset.
%
% ds = loadSydneyUrbanObjectsData(datapath) constructs a datastore that
% represents point clouds and associated categories for the Sydney Urban
% Objects dataset. The input, datapath, is a string or char array which
% represents the path to the root directory of the Sydney Urban Objects
% Dataset.
%
% ds = loadSydneyUrbanObjectsData(___,folds) optionally allows
% specification of desired folds that you wish to be included in the
% output ds. For example, [1 2 4] specifies that you want the first,
% second, and fourth folds of the Dataset. Default: [1 2 3 4].

if nargin < 2
    folds = 1:4;
end

datapath = string(datapath);
path = fullfile(datapath,'objects',filesep);

% For now, include all folds in Datastore
foldNames{1} = importdata(fullfile(datapath,'folds','fold0.txt'));
foldNames{2} = importdata(fullfile(datapath,'folds','fold1.txt'));
foldNames{3} = importdata(fullfile(datapath,'folds','fold2.txt'));
foldNames{4} = importdata(fullfile(datapath,'folds','fold3.txt'));
names = foldNames(folds);
names = vertcat(names{:});

fullFilenames = append(path,names);
ds = fileDatastore(fullFilenames,'ReadFcn',@extractTrainingData,'FileExtensions','.bin');

% Shuffle
ds.Files = ds.Files(randperm(length(ds.Files)));

end

function dataOut = extractTrainingData(fname)

[pointData,intensity] = readbin(fname);

[~,name] = fileparts(fname);
name = string(name);
name = extractBefore(name,'.');
name = replace(name,'_',' ');

labelNames = ["4wd","building","bus","car","pedestrian","pillar",...
    "pole","traffic lights","traffic sign","tree","truck","trunk","ute","van"];

label = categorical(name,labelNames);

dataOut = {pointCloud(pointData,'Intensity',intensity),label};

end

function [pointData,intensity] = readbin(fname)
% readbin Read point and intensity data from Sydney Urban Object binary
% files.

% names = ['t','intensity','id',...
%          'x','y','z',...
%          'azimuth','range','pid']
% 
% formats = ['int64', 'uint8', 'uint8',...
%            'float32', 'float32', 'float32',...
%            'float32', 'float32', 'int32']

fid = fopen(fname, 'r');
c = onCleanup(@() fclose(fid));
    
fseek(fid,10,-1); % Move to the first X point location 10 bytes from beginning
X = fread(fid,inf,'single',30);
fseek(fid,14,-1);
Y = fread(fid,inf,'single',30);
fseek(fid,18,-1);
Z = fread(fid,inf,'single',30);

fseek(fid,8,-1);
intensity = fread(fid,inf,'uint8',33);

pointData = [X,Y,Z];

end

function dataOut = formOccupancyGrid(data)

grid = pcbin(data{1},[32 32 32]);
occupancyGrid = zeros(size(grid),'single');
for ii = 1:numel(grid)
    occupancyGrid(ii) = ~isempty(grid{ii});
end
label = data{2};
dataOut = {occupancyGrid,label};

end

function dataOut = augmentPointCloudData(data)

ptCloud = data{1};
label = data{2};

% Apply randomized rotation about Z axis.
tform = randomAffine3d('Rotation',@() deal([0 0 1],360*rand),'Scale',[0.98,1.02],'XReflection',true,'YReflection',true); % Randomized rotation about z axis
ptCloud = pctransform(ptCloud,tform);

% Apply jitter to each point in point cloud
amountOfJitter = 0.01;
numPoints = size(ptCloud.Location,1);
D = zeros(size(ptCloud.Location),'like',ptCloud.Location);
D(:,1) = diff(ptCloud.XLimits)*rand(numPoints,1);
D(:,2) = diff(ptCloud.YLimits)*rand(numPoints,1);
D(:,3) = diff(ptCloud.ZLimits)*rand(numPoints,1);
D = amountOfJitter.*D;
ptCloud = pctransform(ptCloud,D);

dataOut = {ptCloud,label};

end