Увеличьте данные об облаке точек для глубокого обучения

Этот пример демонстрирует, как установить основной рандомизированный трубопровод увеличения данных при работе с данными об облаке точек в основанных на глубоком обучении рабочих процессах. Увеличение данных почти всегда желательно при работе с глубоким обучением, потому что это помогает уменьшать сверхподбор кривой во время обучения и может добавить робастность в типы преобразований данных, которые не могут быть хорошо представлены в исходных обучающих данных.

Импортируйте данные об облаке точек

dataPath = downloadSydneyUrbanObjects(tempdir);
dsTrain = loadSydneyUrbanObjectsData(dataPath);
dataOut = preview(dsTrain)
dataOut=1×2 cell array
    {1×1 pointCloud}    {[4wd]}

Datastore dsTrain дает к pointCloud возразите и связанная скалярная категориальная метка для каждого наблюдения.

figure
pcshow(dataOut{1});
title(dataOut{2});

Задайте трубопровод увеличения

Функция преобразования datastore является удобным инструментом для определения трубопроводов увеличения.

dsAugmented = transform(dsTrain,@augmentPointCloud);

augmentPointCloud функция, показанная ниже, применяет рандомизированное вращение, однородную шкалу, рандомизированное отражение через x-и оси Y, и рандомизировала на дрожание точки к каждому наблюдению с помощью randomAffine3d функционируйте, чтобы создать рандомизированные аффинные преобразования и pctransform функция, чтобы применить эти преобразования к каждому облаку точки ввода.

dataOut = preview(dsAugmented)
dataOut=1×2 cell array
    {1×1 pointCloud}    {[4wd]}

Это всегда - хорошая идея визуально смотреть данные, которые выходят из любого увеличения, которое сделано на обучающих данных, чтобы убедиться, что данные смотрят как ожидалось. Облако точек ниже совпадает с оригиналом, показанным ранее, но с рандомизированным аффинным деформированием с на добавленное дрожание точки.

figure
pcshow(dataOut{1});
title(dataOut{2});

Получившийся TransformedDatastore и dsAugmented может быть передан функциям глубокого обучения включая trainNetwork, predict, и classify для использования в обучении и выводе.

Вспомогательные Функции

function datasetPath = downloadSydneyUrbanObjects(dataLoc)

if nargin == 0
    dataLoc = pwd();
end

dataLoc = string(dataLoc);

url = "http://www.acfr.usyd.edu.au/papers/data/";
name = "sydney-urban-objects-dataset.tar.gz";

if ~exist(fullfile(dataLoc,'sydney-urban-objects-dataset'),'dir')
    disp('Downloading Sydney Urban Objects Dataset...');
    untar(url+name,dataLoc);
end

datasetPath = dataLoc.append('sydney-urban-objects-dataset');

end

function ds = loadSydneyUrbanObjectsData(datapath,folds)
% loadSydneyUrbanObjectsData Datastore with point clouds and
% associated categorical labels for Sydney Urban Objects dataset.
%
% ds = loadSydneyUrbanObjectsData(datapath) constructs a datastore that
% represents point clouds and associated categories for the Sydney Urban
% Objects dataset. The input, datapath, is a string or char array which
% represents the path to the root directory of the Sydney Urban Objects
% Dataset.
%
% ds = loadSydneyUrbanObjectsData(___,folds) optionally allows
% specification of desired folds that you wish to be included in the
% output ds. For example, [1 2 4] specifies that you want the first,
% second, and fourth folds of the Dataset. Default: [1 2 3 4].

if nargin < 2
    folds = 1:4;
end

datapath = string(datapath);
path = fullfile(datapath,'objects',filesep);

% For now, include all folds in Datastore
foldNames{1} = importdata(fullfile(datapath,'folds','fold0.txt'));
foldNames{2} = importdata(fullfile(datapath,'folds','fold1.txt'));
foldNames{3} = importdata(fullfile(datapath,'folds','fold2.txt'));
foldNames{4} = importdata(fullfile(datapath,'folds','fold3.txt'));
names = foldNames(folds);
names = vertcat(names{:});

fullFilenames = append(path,names);
ds = fileDatastore(fullFilenames,'ReadFcn',@extractTrainingData,'FileExtensions','.bin');

end

function dataOut = extractTrainingData(fname)

[pointData,intensity] = readbin(fname);

[~,name] = fileparts(fname);
name = string(name);
name = extractBefore(name,'.');

labelNames = ["4wd","bench","bicycle","biker",...
    "building","bus","car","cyclist","excavator","pedestrian","pillar",...
    "pole","post","scooter","ticket_machine","traffic_lights","traffic_sign",...
    "trailer","trash","tree","truck","trunk","umbrella","ute","van","vegetation"];

label = categorical(name,labelNames);

dataOut = {pointCloud(pointData,'Intensity',intensity),label};

end

function [pointData,intensity] = readbin(fname)
% readbin Read point and intensity data from Sydney Urban Object binary
% files.

% names = ['t','intensity','id',...
%          'x','y','z',...
%          'azimuth','range','pid']
% 
% formats = ['int64', 'uint8', 'uint8',...
%            'float32', 'float32', 'float32',...
%            'float32', 'float32', 'int32']

fid = fopen(fname, 'r');
c = onCleanup(@() fclose(fid));
    
fseek(fid,10,-1); % Move to the first X point location 10 bytes from beginning
X = fread(fid,inf,'single',30);
fseek(fid,14,-1);
Y = fread(fid,inf,'single',30);
fseek(fid,18,-1);
Z = fread(fid,inf,'single',30);

fseek(fid,8,-1);
intensity = fread(fid,inf,'uint8',33);

pointData = [X,Y,Z];

end

function dataOut = augmentPointCloud(data)

ptCloud = data{1};
label = data{2};

% Apply randomized rotation about Z axis.
tform = randomAffine3d('Rotation',@() deal([0 0 1],360*rand),'Scale',[0.98,1.02],'XReflection',true,'YReflection',true); % Randomized rotation about z axis
ptCloud = pctransform(ptCloud,tform);

% Apply jitter to each point in point cloud
amountOfJitter = 0.01;
numPoints = size(ptCloud.Location,1);
D = zeros(size(ptCloud.Location),'like',ptCloud.Location);
D(:,1) = diff(ptCloud.XLimits)*rand(numPoints,1);
D(:,2) = diff(ptCloud.YLimits)*rand(numPoints,1);
D(:,3) = diff(ptCloud.ZLimits)*rand(numPoints,1);
D = amountOfJitter.*D;
ptCloud = pctransform(ptCloud,D);

dataOut = {ptCloud,label};

end