Обнаружение объектов Используя глубокое обучение YOLO v3

В этом примере показано, как обучить детектор объектов YOLO v3.

Глубокое обучение является мощным методом машинного обучения, который можно использовать, чтобы обучить устойчивые детекторы объектов. Несколько методов для обнаружения объектов существуют, включая Faster R-CNN, вы только смотрите однажды (YOLO) v2, и один детектор выстрела (SSD). В этом примере показано, как обучить детектор объектов YOLO v3. YOLO v3 улучшает YOLO v2 путем добавления обнаружения в нескольких шкалах, чтобы помочь обнаружить меньшие объекты. Функция потерь, используемая для обучения, разделена на среднеквадратическую ошибку для регрессии ограничительной рамки и бинарную перекрестную энтропию для предметной классификации, чтобы помочь улучшить точность обнаружения.

Примечание: Этот пример требует Модели Computer Vision Toolbox™ для обнаружения объектов YOLO v3. Можно установить Модель Computer Vision Toolbox для обнаружения объектов YOLO v3 из Add-On Explorer. Для получения дополнительной информации об установке дополнений, смотрите, Получают и Управляют Дополнениями.

Загрузите предварительно обученную сеть

Загрузите предварительно обученную сеть с помощью функции помощника downloadPretrainedYOLOv3Detector избегать необходимости ожидать обучения завершиться. Если вы хотите обучить сеть, установите doTraining переменная к true.

doTraining = false;

if ~doTraining
    preTrainedDetector = downloadPretrainedYOLOv3Detector();    
end

Загрузка данных

Этот пример использует маленький набор маркированных данных, который содержит 295 изображений. Многие из этих изображений прибывают из Автомобилей Калифорнийского технологического института 1 999 и 2 001 набор данных, доступный в Калифорнийском технологическом институте Вычислительный веб-сайт Видения, созданный Пьетро Пероной и используемый с разрешением. Каждое изображение содержит один или два помеченных экземпляра транспортного средства. Небольшой набор данных полезен для исследования метода обучения YOLO v3, но на практике, более помеченные изображения необходимы, чтобы обучить устойчивую сеть.

Разархивируйте изображения транспортного средства и загрузите достоверные данные транспортного средства.

unzip vehicleDatasetImages.zip
data = load('vehicleDatasetGroundTruth.mat');
vehicleDataset = data.vehicleDataset;

% Add the full path to the local vehicle data folder.
vehicleDataset.imageFilename = fullfile(pwd, vehicleDataset.imageFilename);

Примечание: В случае нескольких классов, данные могут также организованный как три столбца, где первый столбец содержит имена файла образа с путями, второй столбец содержит ограничительные рамки, и третий столбец должен быть вектором ячейки, который содержит имена метки, соответствующие каждой ограничительной рамке. Для получения дополнительной информации о том, как расположить ограничительные рамки и метки, смотрите boxLabelDatastore (Computer Vision Toolbox).

Все ограничительные рамки должны быть в форме [x y width height]. Этот вектор задает левый верхний угол и размер ограничительной рамки в пикселях.

Разделите набор данных в набор обучающих данных для того, чтобы обучить сеть и набор тестов для оценки сети. Используйте 60% данных для набора обучающих данных и остальных для набора тестов.

rng(0);
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * length(shuffledIndices));
trainingDataTbl = vehicleDataset(shuffledIndices(1:idx), :);
testDataTbl = vehicleDataset(shuffledIndices(idx+1:end), :);

Создайте datastore изображений для загрузки изображений.

imdsTrain = imageDatastore(trainingDataTbl.imageFilename);
imdsTest = imageDatastore(testDataTbl.imageFilename);

Создайте datastore для ограничительных рамок основной истины.

bldsTrain = boxLabelDatastore(trainingDataTbl(:, 2:end));
bldsTest = boxLabelDatastore(testDataTbl(:, 2:end));

Объедините изображение и хранилища данных метки поля.

trainingData = combine(imdsTrain, bldsTrain);
testData = combine(imdsTest, bldsTest);

Используйте validateInputData обнаружить недопустимые изображения, ограничительные рамки или маркирует i.e.,

  • Выборки с недопустимым форматом изображения или содержащий NaNs

  • Ограничительные рамки, содержащие zeros/NaNs/Infs/empty

  • Метки Missing/non-categorical.

Значения ограничительных рамок должны быть конечными, положительными, недробными, non-NaN и должны быть в границе изображения с положительной высотой и шириной. Любые недопустимые выборки должны или быть отброшены или зафиксированы для соответствующего обучения.

validateInputData(trainingData);
validateInputData(testData);

Увеличение данных

Увеличение данных используется, чтобы улучшить сетевую точность путем случайного преобразования исходных данных во время обучения. При помощи увеличения данных можно добавить больше разнообразия в обучающие данные, на самом деле не имея необходимость увеличить число помеченных обучающих выборок.

Используйте transform функция, чтобы применить пользовательские увеличения данных к обучающим данным. augmentData функция помощника, перечисленная в конце примера, применяет следующие увеличения к входным данным.

  • Увеличение колебания цвета на HSV-пробеле

  • Случайный горизонтальный щелчок

  • Случайное масштабирование на 10 процентов

augmentedTrainingData = transform(trainingData, @augmentData);

Считайте то же изображение четыре раза и отобразите увеличенные обучающие данные.

% Visualize the augmented images.
augmentedData = cell(4,1);
for k = 1:4
    data = read(augmentedTrainingData);
    augmentedData{k} = insertShape(data{1,1}, 'Rectangle', data{1,2});
    reset(augmentedTrainingData);
end
figure
montage(augmentedData, 'BorderSize', 10)

Задайте детектор объектов YOLO v3

Детектор YOLO v3 в этом примере основан на SqueezeNet и использует сеть извлечения признаков в SqueezeNet со сложением двух голов обнаружения в конце. Вторая голова обнаружения является дважды размером первой головы обнаружения, таким образом, это лучше способно обнаружить маленькие объекты. Обратите внимание на то, что можно задать любое количество глав обнаружения различных размеров на основе размера объектов, которые вы хотите обнаружить. Использование детектора YOLO v3 поля привязки, оцененные с помощью обучающих данных, чтобы иметь лучшее начальное уголовное прошлое, соответствующее типу набора данных и помочь детектору учиться предсказывать поля точно. Для получения информации о полях привязки смотрите Поля Привязки для Обнаружения объектов (Computer Vision Toolbox).

Сеть YOLO v3, существующая в детекторе YOLO v3, проиллюстрирована в следующей схеме.

Можно использовать Deep Network Designer, чтобы создать сеть, показанную в схеме.

Задайте сетевой входной размер. При выборе сетевого входного размера считайте минимальный размер требуемым запустить саму сеть, размер учебных изображений и вычислительную стоимость, понесенную путем обработки данных в выбранном размере. Когда это возможно, выберите размер входного сигнала сети, который близок к размеру обучающего изображения и больше, чем размер входного сигнала, необходимый для сети. Чтобы уменьшать вычислительную стоимость выполнения примера, задайте сетевой входной размер [227 227 3].

networkInputSize = [227 227 3];

Во-первых, используйте transform предварительно обрабатывать обучающие данные для вычисления полей привязки, когда учебные изображения, используемые в этом примере, больше, чем 227 227 и отличаются по размеру. Задайте количество привязок как 6, чтобы достигнуть хорошего компромисса между количеством привязок и означать IoU. Используйте estimateAnchorBoxes функционируйте, чтобы оценить поля привязки. Для получения дополнительной информации при оценке полей привязки, смотрите Оценочные Поля Привязки От Обучающих данных (Computer Vision Toolbox). В случае использования предварительно обученного детектора объектов YOLOv3 должны быть заданы поля привязки, вычисленные на тот конкретный обучающий набор данных. Обратите внимание на то, что процесс оценки не детерминирован. Чтобы препятствовать тому, чтобы предполагаемые поля привязки изменились при настройке других гиперпараметров, устанавливает случайный seed до оценки с помощью rng.

rng(0)
trainingDataForEstimation = transform(trainingData, @(data)preprocessData(data, networkInputSize));
numAnchors = 6;
[anchors, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation, numAnchors)
anchors = 6×2

    41    34
   163   130
    98    93
   144   125
    33    24
    69    66

meanIoU = 0.8507

Задайте anchorBoxes использовать в обоих головы обнаружения. anchorBoxes массив ячеек [Mx1], где M обозначает количество голов обнаружения. Каждая голова обнаружения состоит из матрицы [Nx2] anchors, где N является количеством привязок, чтобы использовать. Выберите anchorBoxes поскольку каждое обнаружение направляется на основе размера карты функции. Используйте больший anchors в более низкой шкале и меньшем anchors в более высокой шкале. Для этого отсортируйте anchors с большими полями привязки сначала и присвоением первые три к первому обнаружению направляются и следующие три во вторую голову обнаружения.

area = anchors(:, 1).*anchors(:, 2);
[~, idx] = sort(area, 'descend');
anchors = anchors(idx, :);
anchorBoxes = {anchors(1:3,:)
    anchors(4:6,:)
    };

Загрузите сеть SqueezeNet, предварительно обученную на наборе данных Imagenet, и затем задайте имена классов. Можно также принять решение загрузить различную предварительно обученную сеть, обученную на наборе данных COCO, таком как tiny-yolov3-coco или darknet53-coco или набор данных Imagenet, такой как MobileNet-v2 или ResNet-18. YOLO v3 выполняет лучше и обучается быстрее, когда вы используете предварительно обученную сеть.

baseNetwork = squeezenet;
classNames = trainingDataTbl.Properties.VariableNames(2:end);

Затем создайте yolov3ObjectDetector объект путем добавления источника сети обнаружения. Выбор оптимального источника сети обнаружения требует метода проб и ошибок, и можно использовать analyzeNetwork найти имена потенциального источника сети обнаружения в сети. В данном примере используйте fire9-concat и fire5-concat слои как DetectionNetworkSource.

yolov3Detector = yolov3ObjectDetector(baseNetwork, classNames, anchorBoxes, 'DetectionNetworkSource', {'fire9-concat', 'fire5-concat'});

В качестве альтернативы вместо сети, созданной выше использования SqueezeNet, другие предварительно обученные архитектуры YOLOv3, обученные с помощью больших наборов данных как MS-COCO, могут использоваться, чтобы передать, изучают детектор на задаче обнаружения пользовательского объекта. Передача обучения может быть понята путем изменения имен классов и anchorBoxes.

Предварительно обработайте обучающие данные

Предварительно обработайте увеличенные обучающие данные, чтобы подготовиться к обучению. preprocess (Computer Vision Toolbox) метод в yolov3ObjectDetector (Computer Vision Toolbox), применяет следующие операции предварительной обработки к входным данным.

  • Измените размер изображений к сетевому входному размеру путем поддержания соотношения сторон.

  • Масштабируйте пиксели изображения в области значений [0 1].

preprocessedTrainingData = transform(augmentedTrainingData, @(data)preprocess(yolov3Detector, data));

Считайте предварительно обработанные обучающие данные.

data = read(preprocessedTrainingData);

Отобразите изображение с ограничительными рамками.

I = data{1,1};
bbox = data{1,2};
annotatedImage = insertShape(I, 'Rectangle', bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)

Сбросьте datastore.

reset(preprocessedTrainingData);

Задайте опции обучения

Задайте эти опции обучения.

  • Определите номер эпох, чтобы быть 80.

  • Установите мини-пакетный размер как 8. Устойчивое обучение может быть возможным с уровнями высшего образования, когда выше мини-пакетный размер является used. Несмотря на то, что, это должно быть установлено в зависимости от доступной памяти.

  • Установите скорость обучения на 0,001.

  • Установите период прогрева как 1000 итерации. Этот параметр обозначает количество итераций, чтобы увеличить скорость обучения экспоненциально на основе формулы learningRate×(iterationwarmupPeriod)4. Это помогает в стабилизации градиентов на уровнях высшего образования.

  • Установитесь коэффициент регуляризации L2 на 0,0005.

  • Задайте порог штрафа как 0,5. Оштрафованы обнаружения, которые перекрывают меньше чем 0,5 с основной истиной.

  • Инициализируйте скорость градиента как []. Это используется SGDM, чтобы сохранить скорость градиентов.

numEpochs = 80;
miniBatchSize = 8;
learningRate = 0.001;
warmupPeriod = 1000;
l2Regularization = 0.0005;
penaltyThreshold = 0.5;
velocity = [];

Обучите модель

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения информации о поддерживаемом вычислите возможности, смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

Используйте minibatchqueue функционируйте, чтобы разделить предварительно обработанные обучающие данные в пакеты с функцией поддержки createBatchData который возвращает пакетные изображения и ограничительные рамки, объединенные с соответствующими идентификаторами класса. Для более быстрой экстракции пакетных данных для обучения, dispatchInBackground должен быть установлен в "истинный", который гарантирует использование параллельного пула.

minibatchqueue автоматически обнаруживает доступность графического процессора. Если вы не имеете графического процессора или не хотите использовать один для обучения, устанавливать OutputEnvironment параметр к "cpu".

if canUseParallelPool
   dispatchInBackground = true;
else
   dispatchInBackground = false;
end

mbqTrain = minibatchqueue(preprocessedTrainingData, 2,...
        "MiniBatchSize", miniBatchSize,...
        "MiniBatchFcn", @(images, boxes, labels) createBatchData(images, boxes, labels, classNames), ...
        "MiniBatchFormat", ["SSCB", ""],...
        "DispatchInBackground", dispatchInBackground,...
        "OutputCast", ["", "double"]);

Создайте использование плоттера процесса обучения, поддерживающее функциональный configureTrainingProgressPlotter видеть график, в то время как обучение детектор возражают с пользовательским учебным циклом.

Наконец, задайте пользовательский учебный цикл. Для каждой итерации:

  • Считайте данные из minibatchqueue. Если это больше не имеет данных, сбросьте minibatchqueue и перестановка.

  • Оцените градиенты модели с помощью dlfeval и modelGradients функция. Функциональный modelGradients, перечисленный как функция поддержки, возвращает градиенты потери относительно настраиваемых параметров в net, соответствующая мини-пакетная потеря и состояние текущего пакета.

  • Примените фактор затухания веса к градиентам к регуляризации для большего количества устойчивого обучения.

  • Определите скорость обучения на основе итераций с помощью piecewiseLearningRateWithWarmup поддерживание функции.

  • Обновите параметры детектора с помощью sgdmupdate функция.

  • Обновите state параметры детектора со скользящим средним значением.

  • Отобразите скорость обучения, общую сумму убытков и отдельные потери (потеря поля, объектная потеря и потеря класса) для каждой итерации. Они могут использоваться, чтобы интерпретировать, как соответствующие потери изменяются в каждой итерации. Например, внезапный скачок в потере поля после немногих итераций подразумевает, что существует Inf или NaNs в предсказаниях.

  • Обновите график процесса обучения.

Обучение может также быть отключено, если потеря насыщала в течение нескольких эпох.

if doTraining
    
    % Create subplots for the learning rate and mini-batch loss.
    fig = figure;
    [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(fig);

    iteration = 0;
    % Custom training loop.
    for epoch = 1:numEpochs
          
        reset(mbqTrain);
        shuffle(mbqTrain);
        
        while(hasdata(mbqTrain))
            iteration = iteration + 1;
           
            [XTrain, YTrain] = next(mbqTrain);
            
            % Evaluate the model gradients and loss using dlfeval and the
            % modelGradients function.
            [gradients, state, lossInfo] = dlfeval(@modelGradients, yolov3Detector, XTrain, YTrain, penaltyThreshold);
    
            % Apply L2 regularization.
            gradients = dlupdate(@(g,w) g + l2Regularization*w, gradients, yolov3Detector.Learnables);
    
            % Determine the current learning rate value.
            currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs);
    
            % Update the detector learnable parameters using the SGDM optimizer.
            [yolov3Detector.Learnables, velocity] = sgdmupdate(yolov3Detector.Learnables, gradients, velocity, currentLR);
    
            % Update the state parameters of dlnetwork.
            yolov3Detector.State = state;
              
            % Display progress.
            displayLossInfo(epoch, iteration, currentLR, lossInfo);  
                
            % Update training plot with new points.
            updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, lossInfo.totalLoss);
        end        
    end
else
    yolov3Detector = preTrainedDetector;
end

Оцените модель

Computer Vision Toolbox™ обеспечивает функции оценки детектора объектов, чтобы измерить общие метрики, такие как средняя точность (evaluateDetectionPrecision) и средние журналом коэффициенты непопаданий (evaluateDetectionMissRate). В этом примере используется средняя метрика точности. Средняя точность обеспечивает один номер, который включает способность детектора сделать правильные классификации (точность) и способность детектора найти все соответствующие объекты (отзыв).

results = detect(yolov3Detector,testData,'MiniBatchSize',8);

% Evaluate the object detector using Average Precision metric.
[ap,recall,precision] = evaluateDetectionPrecision(results,testData);

Кривая отзыва точности (PR) показывает, насколько точный детектор на различных уровнях отзыва. Идеально, точность 1 на всех уровнях отзыва.

% Plot precision-recall curve.
figure
plot(recall,precision)
xlabel('Recall')
ylabel('Precision')
grid on
title(sprintf('Average Precision = %.2f', ap))

Обнаружьте Объекты Используя YOLO v3

Используйте детектор для обнаружения объектов.

% Read the datastore.
data = read(testData);

% Get the image.
I = data{1};

[bboxes,scores,labels] = detect(yolov3Detector,I);

% Display the detections on image.
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);

figure
imshow(I)

Вспомогательные Функции

Функция градиентов модели

Функциональный modelGradients берет yolov3ObjectDetector объект, мини-пакет входных данных XTrain с соответствующими основными блоками истинности YTrain, заданный порог штрафа как входные параметры и возвращает градиенты потери относительно настраиваемых параметров в yolov3ObjectDetector, соответствующая мини-пакетная информация о потере и состояние текущего пакета.

Функция градиентов модели вычисляет общую сумму убытков и градиенты путем выполнения этих операций.

  • Сгенерируйте предсказания от входного пакета изображений с помощью forward метод.

  • Соберите предсказания на центральном процессоре для постобработки.

  • Преобразуйте предсказания от координат ячейки сетки YOLO v3 до координат ограничительной рамки, чтобы позволить легкое сравнение с достоверными данными при помощи anchorBoxGenerator метод yolov3ObjectDetector.

  • Сгенерируйте цели для расчета потерь при помощи конвертированных предсказаний и достоверных данных. Эти цели сгенерированы для положений ограничительной рамки (x, y, ширина, высота), объектное доверие и вероятности класса. Смотрите функцию поддержки generateTargets.

  • Вычисляет среднеквадратическую ошибку предсказанных координат ограничительной рамки с целевыми полями. Смотрите функцию поддержки bboxOffsetLoss.

  • Определяет бинарную перекрестную энтропию предсказанной объектной оценки достоверности с оценкой достоверности целевого объекта. Смотрите функцию поддержки objectnessLoss.

  • Определяет бинарную перекрестную энтропию предсказанного класса объекта с целью. Смотрите функцию поддержки classConfidenceLoss.

  • Вычисляет общую сумму убытков как сумму всех потерь.

  • Вычисляет градиенты learnables относительно общей суммы убытков.

function [gradients, state, info] = modelGradients(detector, XTrain, YTrain, penaltyThreshold)
inputImageSize = size(XTrain,1:2);

% Gather the ground truths in the CPU for post processing
YTrain = gather(extractdata(YTrain));

% Extract the predictions from the detector.
[gatheredPredictions, YPredCell, state] = forward(detector, XTrain);

% Generate target for predictions from the ground truth data.
[boxTarget, objectnessTarget, classTarget, objectMaskTarget, boxErrorScale] = generateTargets(gatheredPredictions,...
    YTrain, inputImageSize, detector.AnchorBoxes, penaltyThreshold);

% Compute the loss.
boxLoss = bboxOffsetLoss(YPredCell(:,[2 3 7 8]),boxTarget,objectMaskTarget,boxErrorScale);
objLoss = objectnessLoss(YPredCell(:,1),objectnessTarget,objectMaskTarget);
clsLoss = classConfidenceLoss(YPredCell(:,6),classTarget,objectMaskTarget);
totalLoss = boxLoss + objLoss + clsLoss;

info.boxLoss = boxLoss;
info.objLoss = objLoss;
info.clsLoss = clsLoss;
info.totalLoss = totalLoss;

% Compute gradients of learnables with regard to loss.
gradients = dlgradient(totalLoss, detector.Learnables);
end

function boxLoss = bboxOffsetLoss(boxPredCell, boxDeltaTarget, boxMaskTarget, boxErrorScaleTarget)
% Mean squared error for bounding box position.
lossX = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,1),boxDeltaTarget(:,1),boxMaskTarget(:,1),boxErrorScaleTarget));
lossY = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,2),boxDeltaTarget(:,2),boxMaskTarget(:,1),boxErrorScaleTarget));
lossW = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,3),boxDeltaTarget(:,3),boxMaskTarget(:,1),boxErrorScaleTarget));
lossH = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,4),boxDeltaTarget(:,4),boxMaskTarget(:,1),boxErrorScaleTarget));
boxLoss = lossX+lossY+lossW+lossH;
end

function objLoss = objectnessLoss(objectnessPredCell, objectnessDeltaTarget, boxMaskTarget)
% Binary cross-entropy loss for objectness score.
objLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),objectnessPredCell,objectnessDeltaTarget,boxMaskTarget(:,2)));
end

function clsLoss = classConfidenceLoss(classPredCell, classTarget, boxMaskTarget)
% Binary cross-entropy loss for class confidence score.
clsLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),classPredCell,classTarget,boxMaskTarget(:,3)));
end

Увеличение и функции обработки данных

function data = augmentData(A)
% Apply random horizontal flipping, and random X/Y scaling. Boxes that get
% scaled outside the bounds are clipped if the overlap is above 0.25. Also,
% jitter image color.

data = cell(size(A));
for ii = 1:size(A,1)
    I = A{ii,1};
    bboxes = A{ii,2};
    labels = A{ii,3};
    sz = size(I);

    if numel(sz) == 3 && sz(3) == 3
        I = jitterColorHSV(I,...
            'Contrast',0.0,...
            'Hue',0.1,...
            'Saturation',0.2,...
            'Brightness',0.2);
    end
    
    % Randomly flip image.
    tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
    rout = affineOutputView(sz,tform,'BoundsStyle','centerOutput');
    I = imwarp(I,tform,'OutputView',rout);
    
    % Apply same transform to boxes.
    [bboxes,indices] = bboxwarp(bboxes,tform,rout,'OverlapThreshold',0.25);
    labels = labels(indices);
    
    % Return original data only when all boxes are removed by warping.
    if isempty(indices)
        data(ii,:) = A(ii,:);
    else
        data(ii,:) = {I, bboxes, labels};
    end
end
end


function data = preprocessData(data, targetSize)
% Resize the images and scale the pixels to between 0 and 1. Also scale the
% corresponding bounding boxes.

for ii = 1:size(data,1)
    I = data{ii,1};
    imgSize = size(I);
    
    % Convert an input image with single channel to 3 channels.
    if numel(imgSize) < 3 
        I = repmat(I,1,1,3);
    end
    bboxes = data{ii,2};

    I = im2single(imresize(I,targetSize(1:2)));
    scale = targetSize(1:2)./imgSize(1:2);
    bboxes = bboxresize(bboxes,scale);
    
    data(ii, 1:2) = {I, bboxes};
end
end

function [XTrain, YTrain] = createBatchData(data, groundTruthBoxes, groundTruthClasses, classNames)
% Returns images combined along the batch dimension in XTrain and
% normalized bounding boxes concatenated with classIDs in YTrain

% Concatenate images along the batch dimension.
XTrain = cat(4, data{:,1});

% Get class IDs from the class names.
classNames = repmat({categorical(classNames')}, size(groundTruthClasses));
[~, classIndices] = cellfun(@(a,b)ismember(a,b), groundTruthClasses, classNames, 'UniformOutput', false);

% Append the label indexes and training image size to scaled bounding boxes
% and create a single cell array of responses.
combinedResponses = cellfun(@(bbox, classid)[bbox, classid], groundTruthBoxes, classIndices, 'UniformOutput', false);
len = max( cellfun(@(x)size(x,1), combinedResponses ) );
paddedBBoxes = cellfun( @(v) padarray(v,[len-size(v,1),0],0,'post'), combinedResponses, 'UniformOutput',false);
YTrain = cat(4, paddedBBoxes{:,1});
end

Функция расписания скорости обучения

function currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs)
% The piecewiseLearningRateWithWarmup function computes the current
% learning rate based on the iteration number.
persistent warmUpEpoch;

if iteration <= warmupPeriod
    % Increase the learning rate for number of iterations in warmup period.
    currentLR = learningRate * ((iteration/warmupPeriod)^4);
    warmUpEpoch = epoch;
elseif iteration >= warmupPeriod && epoch < warmUpEpoch+floor(0.6*(numEpochs-warmUpEpoch))
    % After warm up period, keep the learning rate constant if the remaining number of epochs is less than 60 percent. 
    currentLR = learningRate;
    
elseif epoch >= warmUpEpoch + floor(0.6*(numEpochs-warmUpEpoch)) && epoch < warmUpEpoch+floor(0.9*(numEpochs-warmUpEpoch))
    % If the remaining number of epochs is more than 60 percent but less
    % than 90 percent multiply the learning rate by 0.1.
    currentLR = learningRate*0.1;
    
else
    % If remaining epochs are more than 90 percent multiply the learning
    % rate by 0.01.
    currentLR = learningRate*0.01;
end

end

Служебные функции

function [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(f)
% Create the subplots to display the loss and learning rate.
figure(f);
clf
subplot(2,1,1);
ylabel('Learning Rate');
xlabel('Iteration');
learningRatePlotter = animatedline;
subplot(2,1,2);
ylabel('Total Loss');
xlabel('Iteration');
lossPlotter = animatedline;
end

function displayLossInfo(epoch, iteration, currentLR, lossInfo)
% Display loss information for each iteration.
disp("Epoch : " + epoch + " | Iteration : " + iteration + " | Learning Rate : " + currentLR + ...
   " | Total Loss : " + double(gather(extractdata(lossInfo.totalLoss))) + ...
   " | Box Loss : " + double(gather(extractdata(lossInfo.boxLoss))) + ...
   " | Object Loss : " + double(gather(extractdata(lossInfo.objLoss))) + ...
   " | Class Loss : " + double(gather(extractdata(lossInfo.clsLoss))));
end

function updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, totalLoss)
% Update loss and learning rate plots.
addpoints(lossPlotter, iteration, double(extractdata(gather(totalLoss))));
addpoints(learningRatePlotter, iteration, currentLR);
drawnow
end

function detector = downloadPretrainedYOLOv3Detector()
% Download a pretrained yolov3 detector.
if ~exist('yolov3SqueezeNetVehicleExample_21aSPKG.mat', 'file')
    if ~exist('yolov3SqueezeNetVehicleExample_21aSPKG.zip', 'file')
        disp('Downloading pretrained detector...');
        pretrainedURL = 'https://ssd.mathworks.com/supportfiles/vision/data/yolov3SqueezeNetVehicleExample_21aSPKG.zip';
        websave('yolov3SqueezeNetVehicleExample_21aSPKG.zip', pretrainedURL);
    end
    unzip('yolov3SqueezeNetVehicleExample_21aSPKG.zip');
end
pretrained = load("yolov3SqueezeNetVehicleExample_21aSPKG.mat");
detector = pretrained.detector;
end

Ссылки

[1] Redmon, Джозеф и Али Фархади. “YOLOv3: Инкрементное Улучшение”. Предварительно распечатайте, представленный 8 апреля 2018. https://arxiv.org/abs/1804.02767.

Смотрите также

Приложения

Функции

Объекты

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте