В этом примере показано, как обучить Один детектор выстрела (SSD).
Глубокое обучение является мощным методом машинного обучения, который автоматически изучает функции изображений, требуемые для задач обнаружения. Существует несколько методов для обнаружения объектов с помощью глубокого обучения, таких как Faster R-CNN, Вы Только Взгляд Однажды (YOLO v2) и SSD. Этот пример обучает детектор транспортного средства SSD с помощью trainSSDObjectDetector
функция. Для получения дополнительной информации смотрите Обнаружение объектов.
Загрузите предварительно обученный детектор, чтобы избежать необходимости ожидать обучения завершиться. Если вы хотите обучить детектор, установите doTraining
переменная к истине.
doTraining = false; if ~doTraining && ~exist('ssdResNet50VehicleExample_20a.mat','file') disp('Downloading pretrained detector (44 MB)...'); pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/ssdResNet50VehicleExample_20a.mat'; websave('ssdResNet50VehicleExample_20a.mat',pretrainedURL); end
Downloading pretrained detector (44 MB)...
Этот пример использует набор данных небольшого транспортного средства, который содержит 295 изображений. Многие из этих изображений прибывают из Автомобилей Калифорнийского технологического института 1 999 и 2 001 набор данных, доступный в Калифорнийском технологическом институте Вычислительный веб-сайт Видения, созданный Пьетро Пероной и используемый с разрешением. Каждое изображение содержит один или два помеченных экземпляра транспортного средства. Небольшой набор данных полезен для исследования метода обучения SSD, но на практике, более помеченные изображения необходимы, чтобы обучить устойчивый детектор.
unzip vehicleDatasetImages.zip data = load('vehicleDatasetGroundTruth.mat'); vehicleDataset = data.vehicleDataset;
Обучающие данные хранятся в таблице. Первый столбец содержит путь к файлам изображений. Остальные столбцы содержат метки ROI для транспортных средств. Отобразите первые несколько строк данных.
vehicleDataset(1:4,:)
ans=4×2 table
imageFilename vehicle
_________________________________ _________________
{'vehicleImages/image_00001.jpg'} {[220 136 35 28]}
{'vehicleImages/image_00002.jpg'} {[175 126 61 45]}
{'vehicleImages/image_00003.jpg'} {[108 120 45 33]}
{'vehicleImages/image_00004.jpg'} {[124 112 38 36]}
Разделите набор данных в набор обучающих данных для обучения детектор и набор тестов для оценки детектора. Выберите 60% данных для обучения. Используйте остальных для оценки.
rng(0); shuffledIndices = randperm(height(vehicleDataset)); idx = floor(0.6 * length(shuffledIndices) ); trainingData = vehicleDataset(shuffledIndices(1:idx),:); testData = vehicleDataset(shuffledIndices(idx+1:end),:);
Используйте imageDatastore
и boxLabelDatastore
загружать изображение и данные о метке во время обучения и оценки.
imdsTrain = imageDatastore(trainingData{:,'imageFilename'}); bldsTrain = boxLabelDatastore(trainingData(:,'vehicle')); imdsTest = imageDatastore(testData{:,'imageFilename'}); bldsTest = boxLabelDatastore(testData(:,'vehicle'));
Объедините изображение и хранилища данных метки поля.
trainingData = combine(imdsTrain,bldsTrain); testData = combine(imdsTest, bldsTest);
Отобразите одно из учебных изображений и меток поля.
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
Сеть обнаружения объектов SSD может считаться наличием двух подсетей. Сеть извлечения признаков, сопровождаемая сетью обнаружения.
Сеть извлечения признаков обычно является предварительно обученным CNN (см. Предварительно обученные Глубокие нейронные сети (Deep Learning Toolbox) для получения дополнительной информации). Этот пример использует ResNet-50 для извлечения признаков. Другие предварительно обученные сети, такие как MobileNet v2 или ResNet-18 могут также использоваться в зависимости от требований к приложению. Подсеть обнаружения является маленьким CNN по сравнению с сетью извлечения признаков и состоит из нескольких сверточных слоев и слоев, характерных для SSD.
Используйте ssdLayers
функционируйте, чтобы автоматически изменить предварительно обученную сеть ResNet-50 в сеть обнаружения объектов SSD. ssdLayers
требует, чтобы вы задали несколько входных параметров, которые параметрируют сеть SSD, включая сетевой входной размер и количество классов. При выборе сетевого входного размера считайте размер учебных изображений и вычислительную стоимость понесенными путем обработки данных в выбранном размере. Когда выполнимо, выберите сетевой входной размер, который является близко к размеру учебного изображения. Однако, чтобы уменьшать вычислительную стоимость выполнения этого примера, сетевой входной размер выбран, чтобы быть [300 300 3]. Во время обучения, trainSSDObjectDetector
автоматически изменяет размер учебных изображений к сетевому входному размеру.
inputSize = [300 300 3];
Задайте количество классов объектов, чтобы обнаружить.
numClasses = width(vehicleDataset)-1;
Создайте сеть обнаружения объектов SSD.
lgraph = ssdLayers(inputSize, numClasses, 'resnet50');
Можно визуализировать сеть с помощью analyzeNetwork
или DeepNetworkDesigner
от Deep Learning Toolbox™. Обратите внимание на то, что можно также создать пользовательский слой сети слоем SSD. Для получения дополнительной информации смотрите, Создают Сеть Обнаружения объектов SSD.
Увеличение данных используется, чтобы улучшить сетевую точность путем случайного преобразования исходных данных во время обучения. При помощи увеличения данных можно добавить больше разнообразия в обучающие данные, на самом деле не имея необходимость увеличить число помеченных обучающих выборок. Используйте transform
увеличивать обучающие данные
Случайным образом зеркальное отражение изображения и сопоставленного поля помечает горизонтально.
Случайным образом масштабируйте изображение, сопоставленные метки поля.
Цвет изображения дрожания.
Обратите внимание на то, что увеличение данных не применяется к тестовым данным. Идеально, тестовые данные должны быть представительными для исходных данных и оставлены немодифицированными для несмещенной оценки.
augmentedTrainingData = transform(trainingData,@augmentData);
Визуализируйте увеличенные обучающие данные путем чтения того же изображения многократно.
augmentedData = cell(4,1); for k = 1:4 data = read(augmentedTrainingData); augmentedData{k} = insertShape(data{1},'Rectangle',data{2}); reset(augmentedTrainingData); end figure montage(augmentedData,'BorderSize',10)
Предварительно обработайте увеличенные обучающие данные, чтобы подготовиться к обучению.
preprocessedTrainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
Считайте предварительно обработанные обучающие данные.
data = read(preprocessedTrainingData);
Отобразите поля изображения и ограничительные рамки.
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
Используйте trainingOptions
задавать сетевые опции обучения. Установите 'CheckpointPath'
к временному местоположению. Это включает сохранение частично обученных детекторов во время учебного процесса. Если обучение прервано, такой как отключением электроэнергии или системным отказом, можно возобновить обучение с сохраненной контрольной точки.
options = trainingOptions('sgdm', ... 'MiniBatchSize', 16, .... 'InitialLearnRate',1e-1, ... 'LearnRateSchedule', 'piecewise', ... 'LearnRateDropPeriod', 30, ... 'LearnRateDropFactor', 0.8, ... 'MaxEpochs', 300, ... 'VerboseFrequency', 50, ... 'CheckpointPath', tempdir, ... 'Shuffle','every-epoch');
Используйте trainSSDObjectDetector
функция, чтобы обучить детектор объектов SSD, если doTraining
к истине. В противном случае загружает предварительно обученную сеть.
if doTraining % Train the SSD detector. [detector, info] = trainSSDObjectDetector(preprocessedTrainingData,lgraph,options); else % Load pretrained detector for the example. pretrained = load('ssdResNet50VehicleExample_20a.mat'); detector = pretrained.detector; end
Этот пример проверяется на Титане NVIDIA™ X графических процессоров с 12 Гбайт памяти. Если ваш графический процессор имеет меньше памяти, у можно закончиться память. Если это происходит, понизьте 'MiniBatchSize
'использование trainingOptions
функция. Обучение этой сети заняло приблизительно 2 часа с помощью этой настройки. Учебное время варьируется в зависимости от оборудования, которое вы используете.
Как быстрый тест, запустите детектор на одном тестовом изображении.
data = read(testData);
I = data{1,1};
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I, 'Threshold', 0.4);
Отобразите результаты.
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
figure
imshow(I)
Оцените обученный детектор объектов на большом наборе изображений, чтобы измерить уровень. Computer Vision Toolbox™ обеспечивает функции оценки детектора объектов, чтобы измерить общие метрики, такие как средняя точность (evaluateDetectionPrecision
) и средние журналом коэффициенты непопаданий (evaluateDetectionMissRate
). В данном примере используйте среднюю метрику точности, чтобы оценить эффективность. Средняя точность обеспечивает один номер, который включает способность детектора сделать правильные классификации (precision
) и способность детектора найти все соответствующие объекты (recall
).
Примените к тестовым данным то же преобразование предварительной обработки, что и к обучающим данным. Обратите внимание на то, что увеличение данных не применяется к тестовым данным. Данные испытаний должны быть показательными по сравнению с исходными данными и не должны быть изменены для объективной оценки.
preprocessedTestData = transform(testData,@(data)preprocessData(data,inputSize));
Запустите детектор на всех тестовых изображениях.
detectionResults = detect(detector, preprocessedTestData, 'Threshold', 0.4);
Оцените детектор объектов с помощью средней метрики точности.
[ap,recall,precision] = evaluateDetectionPrecision(detectionResults, preprocessedTestData);
Точность/отзыв (PR), который подсвечивает кривая, насколько точный детектор на различных уровнях отзыва. Идеально, точность была бы 1 на всех уровнях отзыва. Использование большего количества данных может помочь улучшить среднюю точность, но может потребовать большего количества учебного графика временной зависимости кривая PR.
figure plot(recall,precision) xlabel('Recall') ylabel('Precision') grid on title(sprintf('Average Precision = %.2f',ap))
Если детектор обучен и оценен, можно сгенерировать код для ssdObjectDetector
использование GPU Coder™. Для получения дополнительной информации смотрите Генерацию кода для Обнаружения объектов при помощи Одного примера Детектора Мультиполя Выстрела.
function B = augmentData(A) % Apply random horizontal flipping, and random X/Y scaling. Boxes that get % scaled outside the bounds are clipped if the overlap is above 0.25. Also, % jitter image color. B = cell(size(A)); I = A{1}; sz = size(I); if numel(sz)==3 && sz(3) == 3 I = jitterColorHSV(I,... 'Contrast',0.2,... 'Hue',0,... 'Saturation',0.1,... 'Brightness',0.2); end % Randomly flip and scale image. tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]); rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput'); B{1} = imwarp(I,tform,'OutputView',rout); % Sanitize boxes, if needed. A{2} = helperSanitizeBoxes(A{2}, sz); % Apply same transform to boxes. [B{2},indices] = bboxwarp(A{2},tform,rout,'OverlapThreshold',0.25); B{3} = A{3}(indices); % Return original data only when all boxes are removed by warping. if isempty(indices) B = A; end end function data = preprocessData(data,targetSize) % Resize image and bounding boxes to the targetSize. sz = size(data{1},[1 2]); scale = targetSize(1:2)./sz; data{1} = imresize(data{1},targetSize(1:2)); % Sanitize boxes, if needed. data{2} = helperSanitizeBoxes(data{2}, sz); % Resize boxes. data{2} = bboxresize(data{2},scale); end
[1] Лю, Вэй, Драгомир Ангуелов, Думитру Эрхэн, Кристиан Сзеджеди, Скотт Рид, Ченг Янг Фу и Александр К. Берг. "SSD: Один детектор мультиполя выстрела". На 14-й европейской Конференции по Компьютерному зрению, ECCV 2016. Springer Verlag, 2016.
combine
| estimateAnchorBoxes
| evaluateDetectionPrecision
| read
| transform
| analyzeNetwork
(Deep Learning Toolbox)