В этом примере показано, как обучить одиночный детектор выстрела (SSD).
Глубокое обучение является мощным методом машинного обучения, которая автоматически изучает функции изображений, необходимые для задач обнаружения. Существует несколько методов обнаружения объектов с помощью глубокого обучения, таких как Faster R-CNN, You Only Look Once (YOLO v2) и SSD. Этот пример обучает детектор транспортного средства SSD, используя trainSSDObjectDetector
функция. Для получения дополнительной информации смотрите Обнаружение объектов (Computer Vision Toolbox).
Загрузите предварительно обученный детектор, чтобы избежать необходимости ждать завершения обучения. Если вы хотите обучить детектор, установите doTraining
переменная - true.
doTraining = false; if ~doTraining && ~exist('ssdResNet50VehicleExample_20a.mat','file') disp('Downloading pretrained detector (44 MB)...'); pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/ssdResNet50VehicleExample_20a.mat'; websave('ssdResNet50VehicleExample_20a.mat',pretrainedURL); end
Downloading pretrained detector (44 MB)...
Этот пример использует небольшой набор данных о транспортном средстве, который содержит 295 изображений. Многие из этих изображений получены из наборов данных Caltech Cars 1999 и 2001, доступных на веб-сайте Caltech Computational Vision, созданном Пьетро Пероной и используемом с разрешения. Каждое изображение содержит один или два маркированных образца транспортного средства. Небольшой набор данных полезен для исследования процедуры обучения SSD, но на практике для обучения устойчивого детектора необходимо больше маркированных изображений.
unzip vehicleDatasetImages.zip data = load('vehicleDatasetGroundTruth.mat'); vehicleDataset = data.vehicleDataset;
Обучающие данные хранятся в таблице. Первый столбец содержит путь к файлам изображений. Остальные столбцы содержат метки информация только для чтения для транспортных средств. Отображение первых нескольких строк данных.
vehicleDataset(1:4,:)
ans=4×2 table
imageFilename vehicle
_________________________________ _________________
{'vehicleImages/image_00001.jpg'} {[220 136 35 28]}
{'vehicleImages/image_00002.jpg'} {[175 126 61 45]}
{'vehicleImages/image_00003.jpg'} {[108 120 45 33]}
{'vehicleImages/image_00004.jpg'} {[124 112 38 36]}
Разделите набор данных на набор обучающих данных для настройки детектора и тестовый набор для оценки детектора. Выберите 60% данных для обучения. Остальное используйте для оценки.
rng(0); shuffledIndices = randperm(height(vehicleDataset)); idx = floor(0.6 * length(shuffledIndices) ); trainingData = vehicleDataset(shuffledIndices(1:idx),:); testData = vehicleDataset(shuffledIndices(idx+1:end),:);
Использование imageDatastore
и boxLabelDatastore
загрузить изображение и маркировать данные во время обучения и оценки.
imdsTrain = imageDatastore(trainingData{:,'imageFilename'}); bldsTrain = boxLabelDatastore(trainingData(:,'vehicle')); imdsTest = imageDatastore(testData{:,'imageFilename'}); bldsTest = boxLabelDatastore(testData(:,'vehicle'));
Объедините хранилища данных меток изображений и коробок.
trainingData = combine(imdsTrain,bldsTrain); testData = combine(imdsTest, bldsTest);
Отобразите одно из обучающих изображений и коробчатых меток.
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
Сеть обнаружения объектов SSD может рассматриваться как имеющая две подсети. Сеть редукции данных, за которой следует сеть обнаружения.
Сеть редукции данных обычно является предварительно обученной CNN (для получения дополнительной информации см. Глубокие нейронные сети Pretrained). Этот пример использует ResNet-50 для редукции данных. Другие предварительно обученные сети, такие как MobileNet v2 или ResNet-18, также могут использоваться в зависимости от требований приложения. Подсеть обнаружения является небольшой CNN по сравнению с сетью редукции данных и состоит из нескольких сверточных слоев и слоев, характерных для SSD.
Используйте ssdLayers
функция для автоматического изменения предварительно обученной ResNet-50 сети в сеть обнаружения объектов SSD. ssdLayers
требуется задать несколько входов, которые параметризируют сеть SSD, включая размер входного сигнала сети и количество классов. При выборе размера входа сети учитывайте размер обучающих изображений и вычислительные затраты, связанные с обработкой данных при выбранном размере. Когда это возможно, выберите размер входа сети, который близок к размеру обучающего изображения. Однако, чтобы уменьшить вычислительные затраты на выполнение этого примера, размер входа сети выбирается равным [300 300 3]. Во время обучения, trainSSDObjectDetector
автоматически изменяет размер обучающих изображений на размер входа сети.
inputSize = [300 300 3];
Задайте количество классов объектов для обнаружения.
numClasses = width(vehicleDataset)-1;
Создайте сеть обнаружения объектов SSD.
lgraph = ssdLayers(inputSize, numClasses, 'resnet50');
Визуализировать сеть можно используя analyzeNetwork
или D eepNetworkDesigner
из Deep Learning Toolbox™. Обратите внимание, что вы также можете создать пользовательский слой сети SSD. Для получения дополнительной информации смотрите Создание сети обнаружения объектов SSD (Computer Vision Toolbox).
Увеличение количества данных используется для повышения точности сети путем случайного преобразования исходных данных во время обучения. При помощи увеличения данных можно добавить больше разнообразия в обучающие данные, не увеличивая на самом деле количество маркированных обучающих выборок. Использование transform
для увеличения обучающих данных
Случайное отражение изображения и связанных прямоугольных меток по горизонтали.
Случайным образом масштабируйте изображение, связанные прямоугольные метки.
Цвет изображения дрожания.
Обратите внимание, что увеличение количества данных не применяется к тестовым данным. В идеале тестовые данные должны быть показательными по сравнению с исходными данными и не должны быть изменены для объективной оценки.
augmentedTrainingData = transform(trainingData,@augmentData);
Визуализируйте дополненные обучающие данные путем чтения одного и того же изображения несколько раз.
augmentedData = cell(4,1); for k = 1:4 data = read(augmentedTrainingData); augmentedData{k} = insertShape(data{1},'Rectangle',data{2}); reset(augmentedTrainingData); end figure montage(augmentedData,'BorderSize',10)
Предварительно обработайте дополненные обучающие данные для подготовки к обучению.
preprocessedTrainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
Считайте предварительно обработанные обучающие данные.
data = read(preprocessedTrainingData);
Отобразите изображение и ограничительные рамки.
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
Использование trainingOptions
для определения опций обучения. Задайте 'CheckpointPath'
во временное место. Это позволяет экономить частично обученные детекторы в процессе обучения. Если обучение прервано, например, отключение степени или отказ системы, можно возобновить обучение с сохраненной контрольной точки.
options = trainingOptions('sgdm', ... 'MiniBatchSize', 16, .... 'InitialLearnRate',1e-1, ... 'LearnRateSchedule', 'piecewise', ... 'LearnRateDropPeriod', 30, ... 'LearnRateDropFactor', 0.8, ... 'MaxEpochs', 300, ... 'VerboseFrequency', 50, ... 'CheckpointPath', tempdir, ... 'Shuffle','every-epoch');
Использование trainSSDObjectDetector
(Computer Vision Toolbox) функция для обучения детектора объектов SSD, если doTraining
к true. В противном случае загружает предварительно обученную сеть
if doTraining % Train the SSD detector. [detector, info] = trainSSDObjectDetector(preprocessedTrainingData,lgraph,options); else % Load pretrained detector for the example. pretrained = load('ssdResNet50VehicleExample_20a.mat'); detector = pretrained.detector; end
Этот пример проверяется на NVIDIA™ графическом процессоре Titan X с 12 ГБ памяти. Если у вашего графического процессора меньше памяти, возможно, у вас закончится память. Если это произойдет, опустите 'MiniBatchSize
'использование trainingOptions
функция. Обучение этой сети заняло приблизительно 2 часов, используя эту настройку. Время обучения варьируется в зависимости от используемого оборудования.
В качестве быстрого теста запустите детектор на одном тестовом изображении.
data = read(testData);
I = data{1,1};
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I, 'Threshold', 0.4);
Отображение результатов.
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
figure
imshow(I)
Оцените обученный детектор объектов на большом наборе изображений, чтобы измерить эффективность. Computer Vision Toolbox™ предоставляет функции оценки детектора объектов, чтобы измерить общие метрики, такие как средняя точность (evaluateDetectionPrecision
) и средние логарифмические коэффициенты пропуска (evaluateDetectionMissRate
). В данном примере используйте среднюю метрику точности для оценки эффективности. Средняя точность обеспечивает одно число, которое включает в себя способность детектора делать правильные классификации (precision
) и способность детектора находить все релевантные объекты (recall
).
Примените к тестовым данным то же преобразование предварительной обработки, что и к обучающим данным. Обратите внимание, что увеличение количества данных не применяется к тестовым данным. Тестовые данные должны быть показательными по сравнению с исходными данными и не должны быть изменены для объективной оценки.
preprocessedTestData = transform(testData,@(data)preprocessData(data,inputSize));
Запустите детектор на всех тестовых изображениях.
detectionResults = detect(detector, preprocessedTestData, 'Threshold', 0.4);
Оцените детектор объектов с помощью средней метрики точности.
[ap,recall,precision] = evaluateDetectionPrecision(detectionResults, preprocessedTestData);
Кривая точности/отзыва (PR) подсвечивает, насколько точен детектор на меняющихся уровнях отзыва. В идеале точность будет равна 1 на всех уровнях отзыва. Использование большего количества данных может помочь улучшить среднюю точность, но может потребовать большего времени обучения. Постройте кривую PR.
figure plot(recall,precision) xlabel('Recall') ylabel('Precision') grid on title(sprintf('Average Precision = %.2f',ap))
После обучения и оценки детектора можно сгенерировать код для ssdObjectDetector
использование GPU Coder™. Для получения дополнительной информации смотрите Генерацию кода для обнаружения объектов с помощью примера Single Shot Multibox Detector (Computer Vision Toolbox).
function B = augmentData(A) % Apply random horizontal flipping, and random X/Y scaling. Boxes that get % scaled outside the bounds are clipped if the overlap is above 0.25. Also, % jitter image color. B = cell(size(A)); I = A{1}; sz = size(I); if numel(sz)==3 && sz(3) == 3 I = jitterColorHSV(I,... 'Contrast',0.2,... 'Hue',0,... 'Saturation',0.1,... 'Brightness',0.2); end % Randomly flip and scale image. tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]); rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput'); B{1} = imwarp(I,tform,'OutputView',rout); % Sanitize boxes, if needed. A{2} = helperSanitizeBoxes(A{2}, sz); % Apply same transform to boxes. [B{2},indices] = bboxwarp(A{2},tform,rout,'OverlapThreshold',0.25); B{3} = A{3}(indices); % Return original data only when all boxes are removed by warping. if isempty(indices) B = A; end end function data = preprocessData(data,targetSize) % Resize image and bounding boxes to the targetSize. sz = size(data{1},[1 2]); scale = targetSize(1:2)./sz; data{1} = imresize(data{1},targetSize(1:2)); % Sanitize boxes, if needed. data{2} = helperSanitizeBoxes(data{2}, sz); % Resize boxes. data{2} = bboxresize(data{2},scale); end
[1] Лю, Вэй, Драгомир Ангуэлов, Думитру Эрхан, Кристиан Сегеди, Скотт Рид, Чэн Ян Фу, и Александр К. Берг. SSD: Single shot multibox detector (неопр.) (недоступная ссылка). 14-я Европейская конференция по компьютерному зрению, ECCV 2016. Springer Verlag, 2016.
analyzeNetwork
| combine
| read
| trainingOptions
| transform
| detect
(Computer Vision Toolbox) | evaluateDetectionPrecision
(Computer Vision Toolbox) | ssdLayers
(Computer Vision Toolbox) | trainSSDObjectDetector
(Computer Vision Toolbox)imageDatastore
| boxLabelDatastore
(Computer Vision Toolbox) | ssdObjectDetector
(Computer Vision Toolbox)