Лоцируйте 3-D обнаружение объектов Используя глубокое обучение PointPillars

В этом примере показано, как обучить сеть PointPillars для обнаружения объектов в облаках точек.

Данные об облаке точек получены множеством датчиков, таких как датчики лидара, радарные датчики и камеры глубины. Эти датчики получают 3-D информацию о положении об объектах в сцене, которая полезна для многих приложений в автономном управлении автомобилем и дополненной реальности. Однако учебные устойчивые детекторы с данными об облаке точек сложны из-за разреженности данных на объект, объектные поглощения газов и шум датчика. Методы глубокого обучения, как показывали, обратились ко многим из этих проблем путем изучения устойчивых представлений функции непосредственно от данных об облаке точек. Один метод глубокого обучения для 3-D обнаружения объектов является PointPillars [1]. Используя подобную архитектуру к PointNet, извлечения сети PointPillars плотные, устойчивые функции от разреженных облаков точек вызвали столбы, затем используют 2D нейронную сеть для глубокого обучения с модифицированной сетью обнаружения объектов SSD, чтобы получить объединенную ограничительную рамку и предсказания класса.

Этот пример использует набор данных магистральных сцен, собранных с помощью Изгнания датчик лидара OS1.

Загрузите предварительно обученную сеть

Загрузите предварительно обученную сеть с данного URL с помощью функции помощника downloadPretrainedPointPillarsNet, заданный в конце этого примера. Предварительно обученная модель позволяет вам запускать целый пример, не имея необходимость ожидать обучения завершиться. Если вы хотите обучить сеть, установите doTraining переменная к true.

outputFolder = fullfile(tempdir,'WPI');
pretrainedNetURL = 'https://ssd.mathworks.com/supportfiles/lidar/data/trainedPointPillars.zip';

doTraining = false;
if ~doTraining
    net = downloadPretrainedPointPillarsNet(outputFolder, pretrainedNetURL);
end

Загрузка данных

Загрузите магистральный набор данных сцен с данного URL с помощью функции помощника downloadWPIData, заданный в конце этого примера. Набор данных содержит организованные сканы облака точек лидара 1 617 полных полных представлений магистральных сцен и соответствующих меток для объектов автомобиля и грузовика.

lidarURL = 'https://www.mathworks.com/supportfiles/lidar/data/WPI_LidarData.tar.gz';
lidarData = downloadWPIData(outputFolder, lidarURL);

Примечание: В зависимости от вашего Интернет-соединения, процесс загрузки может занять время. Код приостанавливает выполнение MATLAB®, пока процесс загрузки не завершен. В качестве альтернативы можно загрузить набор данных на локальный диск с помощью веб-браузера и извлечь файл. Если вы делаете так, изменяете outputFolder переменная в коде к местоположению загруженного файла.

Загрузите 3-D метки ограничительной рамки.

load('WPI_LidarGroundTruth.mat','bboxGroundTruth');
Labels = timetable2table(bboxGroundTruth);
Labels = Labels(:,2:end);

Отобразите облако полной точки наблюдения.

figure
ax = pcshow(lidarData{1,1}.Location);
set(ax,'XLim',[-50 50],'YLim',[-40 40]);
zoom(ax,2.5);
axis off;

Предварительная Обработка Данных

Эта полная точка наблюдения обрезок в качестве примера облака к облакам точек вида спереди с помощью стандартных параметров [1]. Эти параметры помогают выбрать, размер входа передал сети. Выбирание меньшей области значений облаков точек вдоль x, y, и оси z помогает обнаружить объекты, которые более близки началу координат, и также уменьшает полное учебное время сети.

xMin = 0.0;     % Minimum value along X-axis.
yMin = -39.68;  % Minimum value along Y-axis.
zMin = -5.0;    % Minimum value along Z-axis.
xMax = 69.12;   % Maximum value along X-axis.
yMax = 39.68;   % Maximum value along Y-axis.
zMax = 5.0;     % Maximum value along Z-axis.
xStep = 0.16;   % Resolution along X-axis.
yStep = 0.16;   % Resolution along Y-axis.
dsFactor = 2.0; % Downsampling factor.

% Calculate the dimensions for pseudo-image.
Xn = round(((xMax - xMin) / xStep));
Yn = round(((yMax - yMin) / yStep));

% Define pillar extraction parameters.
gridParams = {{xMin,yMin,zMin},{xMax,yMax,zMax},{xStep,yStep,dsFactor},{Xn,Yn}};

Обрежьте вид спереди облаков точек с помощью createFrontViewFromLidarData функция помощника, присоединенная к этому примеру как вспомогательный файл и значения калибровки фотоаппарата лидара. Пропустите этот шаг, если ваш набор данных содержит облака точек вида спереди.

% Load the calibration parameters.
fview = load('calibrationValues.mat');
[inputPointCloud, boxLabels] = createFrontViewFromLidarData(lidarData, Labels, gridParams, fview); 
Processing data 100% complete

Отобразите обрезанное облако точек.

figure
ax1 = pcshow(inputPointCloud{1,1}.Location);
gtLabels = boxLabels.car(1,:);
showShape('cuboid', gtLabels{1,1}, 'Parent', ax1, 'Opacity', 0.1, 'Color', 'green','LineWidth',0.5);
zoom(ax1,2);

Создайте объекты FileDatastore и BoxLabelDatastore для обучения

Разделите набор данных в набор обучающих данных для того, чтобы обучить сеть и набор тестов для оценки сети. Выберите 70% данных для обучения. Используйте остальных для оценки.

rng(1);
shuffledIndices = randperm(size(inputPointCloud,1));
idx = floor(0.7 * length(shuffledIndices));

trainData = inputPointCloud(shuffledIndices(1:idx),:);
testData = inputPointCloud(shuffledIndices(idx+1:end),:);

trainLabels = boxLabels(shuffledIndices(1:idx),:);
testLabels = boxLabels(shuffledIndices(idx+1:end),:);

Для быстрого доступа с хранилищами данных сохраните обучающие данные как файлы PCD. Пропустите этот шаг, если ваш формат данных облака точек поддерживается pcread (Computer Vision Toolbox) функция.

dataLocation = fullfile(outputFolder,'InputData');
saveptCldToPCD(trainData,dataLocation);
Processing data 100% complete

Создайте fileDatastore загружать файлы PCD с помощью pcread (Computer Vision Toolbox) функция.

lds = fileDatastore(dataLocation,'ReadFcn',@(x) pcread(x));

Создайте boxLabelDatastore (Computer Vision Toolbox) для загрузки 3-D меток ограничительной рамки.

bds = boxLabelDatastore(trainLabels);

Используйте combine функционируйте, чтобы объединить облака точек и 3-D метки ограничительной рамки в один datastore для обучения.

cds = combine(lds,bds);

Увеличение данных

Этот пример использует увеличение достоверных данных и несколько других методов увеличения глобальных данных, чтобы добавить больше разнообразия в обучающие данные и соответствующие поля.

Считайте и отобразите облако точек перед увеличением.

augData = read(cds);
augptCld = augData{1,1};
augLabels = augData{1,2};
figure;
ax2 = pcshow(augptCld.Location);
showShape('cuboid', augLabels, 'Parent', ax2, 'Opacity', 0.1, 'Color', 'green','LineWidth',0.5);
zoom(ax2,2);

reset(cds);

Используйте generateGTDataForAugmentation функция помощника, присоединенная к этому примеру как вспомогательный файл, чтобы извлечь все ограничительные рамки основной истины из обучающих данных.

gtData = generateGTDataForAugmentation(trainData,trainLabels);

Используйте groundTruthDataAugmentation функция помощника, присоединенная к этому примеру как вспомогательный файл, чтобы случайным образом добавить постоянное число автомобильных объектов класса к каждому облаку точек.

Используйте transform функция, чтобы применить основную истину и пользовательские увеличения данных к обучающим данным.

cdsAugmented = transform(cds,@(x) groundTruthDataAugmenation(x,gtData));

Кроме того, примените следующие увеличения данных к каждому облаку точек.

  • Случайное зеркальное отражение вдоль оси X

  • Случайное масштабирование на 5 процентов

  • Случайное вращение вдоль оси z от [-pi/4, пи/4]

  • Случайный перевод [0.2, 0.2, 0.1] метры вдоль оси XYZ соответственно

cdsAugmented = transform(cdsAugmented,@(x) augmentData(x));

Отобразитесь увеличенное облако точек наряду с основной истиной увеличило поля.

augData = read(cdsAugmented);
augptCld = augData{1,1};
augLabels = augData{1,2};
figure;
ax3 = pcshow(augptCld(:,1:3));
showShape('cuboid', augLabels, 'Parent', ax3, 'Opacity', 0.1, 'Color', 'green','LineWidth',0.5);
zoom(ax3,2);

reset(cdsAugmented);

Извлеките информацию о столбе из облаков точек

Преобразуйте 3-D облако точек в 2D представление, чтобы применить 2D архитектуру свертки к облакам точек для более быстрой обработки. Используйте transform функция с createPillars функция помощника, присоединенная к этому примеру как вспомогательный файл, чтобы создать функции столба и индексы столба от облаков точек. Функция помощника выполняет следующие операции:

  • Дискретизируйте 3-D облака точек в равномерно расположенные с интервалами сетки в x-y плоскости, чтобы создать набор, вертикальные столбцы вызвали столбы.

  • Выберите видные столбы (P) на основе числа точек на столб (N).

  • Вычислите расстояние до среднего арифметического всех точек в столбе.

  • Вычислите смещение из центра столба.

  • Используйте x, y, z местоположение, интенсивность, расстояние и возместите значения, чтобы создать девять размерных (9-D) векторов для каждой точки в столбе.

% Define number of prominent pillars.
P = 12000; 

% Define number of points per pillar.
N = 100;   
cdsTransformed = transform(cdsAugmented,@(x) createPillars(x,gridParams,P,N));

Задайте сеть PointPillars

Сеть PointPillars использует упрощенную версию сети PointNet, которая берет функции столба в качестве входа. Для каждой функции столба линейный слой применяется сопровождаемый слоями ReLU и нормализацией партии. Наконец, макс. объединяющая операция по каналам применяется, чтобы получить закодированные функции высокого уровня. Эти закодированные функции рассеиваются назад к исходным местоположениям столба, чтобы создать псевдоизображение с помощью пользовательского слоя helperscatterLayer, присоединенный к этому примеру как вспомогательный файл. Псевдоизображение затем обрабатывается с 2D сверточной магистралью, сопровождаемой различными головами обнаружения SSD, чтобы предсказать 3-D ограничительные рамки наряду с его классами.

Задайте размерности поля привязки на основе классов, чтобы обнаружить. Как правило, эти размерности являются средними значениями всех значений ограничительной рамки в наборе обучающих данных [1]. Поля привязки заданы в формате {длина, ширина, высота, z-центр, угол отклонения от курса}.

anchorBoxes = {{3.9, 1.6, 1.56, -1.78, 0}, {3.9, 1.6, 1.56, -1.78, pi/2}};
numAnchors = size(anchorBoxes,2);
classNames = trainLabels.Properties.VariableNames;

Затем создайте сеть PointPillars с помощью функции помощника pointpillarNetwork, присоединенный к этому примеру как вспомогательный файл.

lgraph = pointpillarNetwork(numAnchors,gridParams,P,N);

Задайте опции обучения

Задайте следующие опции обучения.

  • Определите номер эпох к 160.

  • Установите мини-пакетный размер как 2. Это должно быть установлено в зависимости от доступной памяти.

  • Установите скорость обучения на 0,0002.

  • Установите learnRateDropPeriod к 15. Этот параметр обозначает номер эпох, чтобы пропустить скорость обучения после на основе формулы learningRate×(iteration%learnRateDropPeriod)×learnRateDropFactor.

  • Установите learnRateDropFactor к 0,8. Этот параметр обозначает уровень, которым можно пропустить скорость обучения после каждого learnRateDropPeriod.

  • Установитесь коэффициент затухания градиента на 0,9.

  • Установитесь коэффициент затухания градиента в квадрате на 0,999.

  • Инициализируйте среднее значение градиентов к []. Это используется оптимизатором Адама.

  • Инициализируйте среднее значение градиентов в квадрате к []. Это используется оптимизатором Адама.

numEpochs = 160;
miniBatchSize = 2;
learningRate = 0.0002;
learnRateDropPeriod = 15;
learnRateDropFactor = 0.8;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;
trailingAvg = [];
trailingAvgSq = [];

Обучите модель

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше. Чтобы автоматически обнаружить, если вы имеете графический процессор в наличии, установите executionEnvironment к "auto". Если вы не имеете графического процессора или не хотите использовать один для обучения, устанавливать executionEnvironment к "cpu". Чтобы гарантировать использование графического процессора для обучения, установите executionEnvironment к "gpu".

Затем создайте minibatchqueue загружать данные в пакетах miniBatchSize во время обучения.

executionEnvironment = "auto";
if canUseParallelPool
    dispatchInBackground = true;
else
    dispatchInBackground = false;
end

mbq = minibatchqueue(cdsTransformed,3,...
                     "MiniBatchSize",miniBatchSize,...
                     "OutputEnvironment",executionEnvironment,...
                     "MiniBatchFcn",@(features,indices,boxes,labels) createBatchData(features,indices,boxes,labels,classNames),...
                     "MiniBatchFormat",["SSCB","SSCB",""],...
                     "DispatchInBackground",dispatchInBackground);

Чтобы обучить сеть с пользовательским учебным циклом и включить автоматическое дифференцирование, преобразуйте график слоев в dlnetwork объект. Затем создайте плоттер процесса обучения с помощью функции помощника configureTrainingProgressPlotter, заданный в конце этого примера.

Наконец, задайте пользовательский учебный цикл. Для каждой итерации:

  • Считайте облака точек и основные блоки истинности от minibatchqueue использование next функция.

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция. modelGradients функция помощника, заданная в конце примера, возвращает градиенты потери относительно настраиваемых параметров в net, соответствующая мини-пакетная потеря и состояние текущего пакета.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • Обновите параметры состояния net.

  • Обновите график процесса обучения.

if doTraining
    % Convert layer graph to dlnetwork.
    net = dlnetwork(lgraph);
    
    % Initialize plot.
    fig = figure;
    lossPlotter = configureTrainingProgressPlotter(fig);    
    iteration = 0;
    
    % Custom training loop.
    for epoch = 1:numEpochs
        
        % Reset datastore.
        reset(mbq);
        
        while(hasdata(mbq))
            iteration = iteration + 1;
            
            % Read batch of data.
            [pillarFeatures, pillarIndices, boxLabels] = next(mbq);
                        
            % Evaluate the model gradients and loss using dlfeval and the modelGradients function.
            [gradients, loss, state] = dlfeval(@modelGradients, net, pillarFeatures, pillarIndices, boxLabels,...
                                                gridParams, anchorBoxes, executionEnvironment);
            
            % Do not update the network learnable parameters if NaN values
            % are present in gradients or loss values.
            if checkForNaN(gradients,loss)
                continue;
            end
                    
            % Update the state parameters of dlnetwork.
            net.State = state;
            
            % Update the network learnable parameters using the Adam
            % optimizer.
            [net.Learnables, trailingAvg, trailingAvgSq] = adamupdate(net.Learnables, gradients, ...
                                                               trailingAvg, trailingAvgSq, iteration,...
                                                               learningRate,gradientDecayFactor, squaredGradientDecayFactor);
            
            % Update training plot with new points.         
            addpoints(lossPlotter, iteration,double(gather(extractdata(loss))));
            title("Training Epoch " + epoch +" of " + numEpochs);
            drawnow;
        end
                
        % Update the learning rate after every learnRateDropPeriod.
        if mod(epoch,learnRateDropPeriod) == 0
            learningRate = learningRate * learnRateDropFactor;
        end
    end
end

Оцените модель

Computer Vision Toolbox™ обеспечивает функции оценки детектора объектов, чтобы измерить общие метрики, такие как средняя точность (evaluateDetectionAOS (Computer Vision Toolbox)). В данном примере используйте среднюю метрику точности. Средняя точность обеспечивает один номер, который включает способность детектора сделать правильные классификации (точность) и способность детектора найти все соответствующие объекты (отзыв).

Оцените обученный dlnetwork объект net на тестовых данных путем выполнения этих шагов.

  • Задайте порог доверия, чтобы использовать только обнаружения с оценками достоверности выше этого значения.

  • Задайте порог перекрытия, чтобы удалить перекрывающиеся обнаружения.

  • Используйте generatePointPillarDetections функция помощника, присоединенная к этому примеру как вспомогательный файл, чтобы получить ограничительные рамки, возражает оценкам достоверности и меткам класса.

  • Вызовите evaluateDetectionAOS (Computer Vision Toolbox) с detectionResults и groundTruthData в качестве аргументов.

numInputs = numel(testData);

% Generate rotated rectangles from the cuboid labels.
bds = boxLabelDatastore(testLabels);
groundTruthData = transform(bds,@(x) createRotRect(x));

% Set the threshold values.
nmsPositiveIoUThreshold = 0.5;
confidenceThreshold = 0.25;
overlapThreshold = 0.1;

% Set numSamplesToTest to numInputs to evaluate the model on the entire
% test data set.
numSamplesToTest = 50;
detectionResults = table('Size',[numSamplesToTest 3],...
                        'VariableTypes',{'cell','cell','cell'},...
                        'VariableNames',{'Boxes','Scores','Labels'});

for num = 1:numSamplesToTest
    ptCloud = testData{num,1};
    
    [box,score,labels] = generatePointPillarDetections(net,ptCloud,anchorBoxes,gridParams,classNames,confidenceThreshold,...
                                            overlapThreshold,P,N,executionEnvironment);
 
    % Convert the detected boxes to rotated rectangles format.
    if ~isempty(box)
        detectionResults.Boxes{num} = box(:,[1,2,4,5,7]);
    else
        detectionResults.Boxes{num} = box;
    end
    detectionResults.Scores{num} = score;
    detectionResults.Labels{num} = labels;
end

metrics = evaluateDetectionAOS(detectionResults,groundTruthData,nmsPositiveIoUThreshold)
metrics=1×5 table
             AOS        AP       OrientationSimilarity      Precision           Recall    
           _______    _______    _____________________    ______________    ______________

    car    0.69396    0.69396       {125×1 double}        {125×1 double}    {125×1 double}

Обнаружьте объекты Используя столбы точки

Используйте сеть для обнаружения объектов:

  • Считайте облако точек из тестовых данных.

  • Используйте generatePointPillarDetections функция помощника, присоединенная к этому примеру как вспомогательный файл, чтобы получить предсказанные ограничительные рамки и оценки достоверности.

  • Отобразите облако точек с ограничительными рамками.

ptCloud = testData{3,1};
gtLabels = testLabels{3,1}{1};

% Display the point cloud.
figure;
ax4 = pcshow(ptCloud.Location);

% The generatePointPillarDetections function detects the bounding boxes, scores for a
% given point cloud.
confidenceThreshold = 0.5;
overlapThreshold = 0.1;
[box,score,labels] = generatePointPillarDetections(net,ptCloud,anchorBoxes,gridParams,classNames,confidenceThreshold,...
                      overlapThreshold,P,N,executionEnvironment);

% Display the detections on the point cloud.
showShape('cuboid', box, 'Parent', ax4, 'Opacity', 0.1, 'Color', 'red','LineWidth',0.5);hold on;
showShape('cuboid', gtLabels, 'Parent', ax4, 'Opacity', 0.1, 'Color', 'green','LineWidth',0.5);
zoom(ax4,2);

Функции помощника

Градиенты модели

Функциональный modelGradients берет в качестве входа dlnetwork объект net и мини-пакет входных данных pillarFeautures и pillarIndices с соответствующими основными блоками истинности, полями привязки и параметрами сетки. Функция возвращает градиенты потери относительно настраиваемых параметров в net, соответствующая мини-пакетная потеря и состояние текущего пакета.

Функция градиентов модели вычисляет общую сумму убытков и градиенты путем выполнения этих операций.

  • Извлеките предсказания из сети с помощью forward функция.

  • Сгенерируйте цели для расчета потерь при помощи достоверных данных, параметров сетки и полей привязки.

  • Вычислите функцию потерь для всех шести предсказаний от сети.

  • Вычислите общую сумму убытков как сумму всех потерь.

  • Вычислите градиенты learnables относительно общей суммы убытков.

function [gradients, loss, state] = modelGradients(net, pillarFeatures, pillarIndices, boxLabels, gridParams, anchorBoxes,...
                                                   executionEnvironment)
      
    % Extract the predictions from the network.
    YPredictions = cell(size(net.OutputNames));
    [YPredictions{:}, state] = forward(net,pillarIndices,pillarFeatures);
    
    % Generate target for predictions from the ground truth data.
    YTargets = generatePointPillarTargets(YPredictions, boxLabels, pillarIndices, gridParams, anchorBoxes);
    YTargets = cellfun(@ dlarray,YTargets,'UniformOutput', false);
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        YTargets = cellfun(@ gpuArray,YTargets,'UniformOutput', false);
    end
     
    [angLoss, occLoss, locLoss, szLoss, hdLoss, clfLoss] = computePointPillarLoss(YPredictions, YTargets);
    
    % Compute the total loss.
    loss = angLoss + occLoss + locLoss + szLoss + hdLoss + clfLoss;
    
    % Compute the gradients of the learnables with regard to the loss.
    gradients = dlgradient(loss,net.Learnables);
 
end

function [pillarFeatures, pillarIndices, labels] = createBatchData(features, indices, groundTruthBoxes, groundTruthClasses, classNames)
% Returns pillar features and indices combined along the batch dimension
% and bounding boxes concatenated along batch dimension in labels.
    
    % Concatenate features and indices along batch dimension.
    pillarFeatures = cat(4, features{:,1});
    pillarIndices = cat(4, indices{:,1});
    
    % Get class IDs from the class names.
    classNames = repmat({categorical(classNames')}, size(groundTruthClasses));
    [~, classIndices] = cellfun(@(a,b)ismember(a,b), groundTruthClasses, classNames, 'UniformOutput', false);
    
    % Append the class indices and create a single array of responses.
    combinedResponses = cellfun(@(bbox, classid)[bbox, classid], groundTruthBoxes, classIndices, 'UniformOutput', false);
    len = max(cellfun(@(x)size(x,1), combinedResponses));
    paddedBBoxes = cellfun( @(v) padarray(v,[len-size(v,1),0],0,'post'), combinedResponses, 'UniformOutput',false);
    labels = cat(4, paddedBBoxes{:,1});
end

function lidarData = downloadWPIData(outputFolder, lidarURL)
% Download the data set from the given URL into the output folder.

    lidarDataTarFile = fullfile(outputFolder,'WPI_LidarData.tar.gz');
    if ~exist(lidarDataTarFile, 'file')
        mkdir(outputFolder);
        
        disp('Downloading WPI Lidar driving data (760 MB)...');
        websave(lidarDataTarFile, lidarURL);
        untar(lidarDataTarFile,outputFolder);
    end
    
    % Extract the file.
    if ~exist(fullfile(outputFolder, 'WPI_LidarData.mat'), 'file')
        untar(lidarDataTarFile,outputFolder);
    end
    load(fullfile(outputFolder, 'WPI_LidarData.mat'),'lidarData');
    lidarData = reshape(lidarData,size(lidarData,2),1);
end

function net = downloadPretrainedPointPillarsNet(outputFolder, pretrainedNetURL)
% Download the pretrained PointPillars detector.

    preTrainedMATFile = fullfile(outputFolder,'trainedPointPillarsNet.mat');
    preTrainedZipFile = fullfile(outputFolder,'trainedPointPillars.zip');
    
    if ~exist(preTrainedMATFile,'file')
        if ~exist(preTrainedZipFile,'file')
            disp('Downloading pretrained detector (8.3 MB)...');
            websave(preTrainedZipFile, pretrainedNetURL);
        end
        unzip(preTrainedZipFile, outputFolder);   
    end
    pretrainedNet = load(preTrainedMATFile);
    net = pretrainedNet.net;       
end

function lossPlotter = configureTrainingProgressPlotter(f)
% The configureTrainingProgressPlotter function configures training
% progress plots for various losses.
    figure(f);
    clf
    ylabel('Total Loss');
    xlabel('Iteration');
    lossPlotter = animatedline;
end

function retValue = checkForNaN(gradients,loss)
% Based on experiments it is found that the last convolution head
% 'occupancy|conv2d' contains NaNs as the gradients. This function checks
% whether gradient values contain NaNs. Add other convolution
% head values to the condition if NaNs are present in the gradients. 
    gradValue = gradients.Value((gradients.Layer == 'occupancy|conv2d') & (gradients.Parameter == 'Bias'));
    if (sum(isnan(extractdata(loss)),'all') > 0) || (sum(isnan(extractdata(gradValue{1,1})),'all') > 0)
        retValue = true;
    else
        retValue = false;
    end
end

Ссылки

[1] Ленг, Алекс Х., Sourabh Vora, Хольгер Цезарь, Лубин Чжоу, Цзюн Ян и Оскар Бейджбом. "PointPillars: Быстрые Энкодеры для Обнаружения объектов От Облаков точек". На 2019 Конференциях IEEE/CVF по Компьютерному зрению и Распознаванию образов (CVPR), 12689-12697. Лонг-Бич, CA, США: IEEE, 2019. https://doi.org/10.1109/CVPR.2019.01298.