В этом примере показано, как обучить находящиеся в области Сверточные нейронные сети (R-CNN) для целевого распознавания в больших изображениях Радара с синтезированной апертурой (SAR) сцены с помощью Deep Learning Toolbox™ и Parallel Computing Toolbox™.
Deep Learning Toolbox служит основой для разработки и реализации глубоких нейронных сетей с алгоритмами, предварительно обученными моделями и приложениями.
Parallel Computing Toolbox позволяет вам решить в вычислительном отношении и информационно емкие проблемы с помощью многоядерных процессоров, графических процессоров и компьютерных кластеров. Это позволяет вам использовать графические процессоры непосредственно из MATLAB и ускорить возможности расчета, необходимые в алгоритмах глубокого обучения.
Основанные на нейронной сети алгоритмы, показали замечательное достижение в разнообразных областях в пределах от естественного обнаружения сцены к медицинской обработке изображений. Это показало огромное улучшение по сравнению со стандартными алгоритмами обнаружения. Вдохновленный этими продвижениями, исследователи приложили усилия, чтобы применить основанные на глубоком обучении решения поля обработки изображений РСА. В этом примере решение было применено, чтобы решить задачу целевого обнаружения и распознавания. Сеть R-CNN, используемая здесь не только, решает задачу объединяющегося обнаружения и распознавания, но также и обеспечьте эффективное и эффективное производительное решение, которое масштабируется к большим изображениям РСА сцены также.
Этот пример демонстрирует как к:
Загрузите набор данных и предварительно обученную модель.
Загрузите и анализируйте данные изображения.
Определить сетевую архитектуру.
Задайте опции обучения.
Обучите сеть.
Оценка сети.
Чтобы проиллюстрировать этот рабочий процесс, Перемещаясь и Стационарный Целевой Захват и Распознавание набор данных помехи (MSTAR), опубликованный Научно-исследовательской лабораторией Военно-воздушных сил, используется. Набор данных доступен для скачивания здесь. В качестве альтернативы подмножество данных, используемых, чтобы продемонстрировать рабочий процесс, обеспечивается. Цель состоит в том, чтобы разработать модель, которая может обнаружить и распознать цели.
Этот пример использует подмножество набора данных помехи MSTAR, который содержит 300 обучения и 50 изображений помехи тестирования с 5 различными целями. Данные были собраны с помощью датчика X-полосы в режиме центра внимания с 1-футовым разрешением. Данные содержат сельские и городские типы помех. Тип используемой цели является BTR-60 (бронируемый автомобиль), BRDM-2 (боевая машина), ZSU-23/4 (бак), T62 (бак) и SLICY (несколько простая статическая цель геометрической формы). Изображения были получены под углом депрессии 15 градусов. Данные о помехе хранятся в формате изображения PNG, и соответствующие достоверные данные хранится в groundTruthMSTARClutterDataset.mat
файл. Файл содержит 2D информацию об ограничительной рамке для пяти классов, которые являются SLICY, BTR-60, BRDM-2, ZSU-23/4 и T62 для данных об обучении и тестировании соответственно. Размер набора данных составляет 1,6 Гбайт.
Загрузите набор данных с данного URL с помощью helperDownloadMSTARClutterData
функция помощника, заданная в конце этого примера.
outputFolder = pwd;
dataURL = ('https://ssd.mathworks.com/supportfiles/radar/data/MSTAR_ClutterDataset.tar.gz');
helperDownloadMSTARClutterData(outputFolder,dataURL);
В зависимости от вашего Интернет-соединения может занять время процесс загрузки. Код приостанавливает выполнение MATLAB®, пока процесс загрузки не завершен. В качестве альтернативы загрузите набор данных на локальный диск с помощью веб-браузера и извлеките файл. При использовании альтернативного подхода замените outputFolder переменную в примере к местоположению загруженного файла.
Загрузите предварительно обученную сеть с данного URL с помощью helperDownloadPretrainedSARDetectorNet
функция помощника, заданная в конце этого примера. Предварительно обученная модель позволяет вам запускать целый пример, не имея необходимость ожидать обучения завершиться. Чтобы обучить сеть, установите doTrain
переменная к истине.
pretrainedNetURL = ('https://ssd.mathworks.com/supportfiles/radar/data/TrainedSARDetectorNet.tar.gz'); doTrain = false; if ~doTrain helperDownloadPretrainedSARDetectorNet (outputFolder, pretrainedNetURL); end
Загрузите достоверные данные (набор обучающих данных и набор тестов). Эти изображения сгенерированы таким способом, которым это помещает целевые микросхемы наугад местоположение на фоновом изображении помехи. Изображение помехи создается из загруженных необработанных данных. Сгенерированная цель будет использоваться в качестве целей основной истины, чтобы обучить и протестировать сеть.
load('groundTruthMSTARClutterDataset.mat', "trainingData", "testData");
Достоверные данные хранятся в таблице с шестью столбцами, где первый столбец содержит пути к файлу изображения, и второе к шестому столбцу содержит различные целевые ограничительные рамки.
% Display the first few rows of the data set
trainingData(1:4,:)
ans=4×6 table
imageFilename SLICY BTR_60 BRDM_2 ZSU_23_4 T62
______________________________ __________________ __________________ __________________ ___________________ ___________________
"./TrainingImages/Img0001.png" {[ 285 468 28 28]} {[ 135 331 65 65]} {[ 597 739 65 65]} {[ 810 1107 80 80]} {[1228 1089 87 87]}
"./TrainingImages/Img0002.png" {[595 1585 28 28]} {[ 880 162 65 65]} {[308 1683 65 65]} {[1275 1098 80 80]} {[1274 1099 87 87]}
"./TrainingImages/Img0003.png" {[200 1140 28 28]} {[961 1055 65 65]} {[306 1256 65 65]} {[ 661 1412 80 80]} {[ 699 886 87 87]}
"./TrainingImages/Img0004.png" {[ 623 186 28 28]} {[ 536 946 65 65]} {[ 131 245 65 65]} {[1030 1266 80 80]} {[ 151 924 87 87]}
Отобразите одно из учебных изображений и меток поля, чтобы визуализировать данные.
img = imread(trainingData.imageFilename(1)); bbox = reshape(cell2mat(trainingData{1,2:end}),[4,5])'; labels = {'SLICY', 'BTR_60', 'BRDM_2', 'ZSU_23_4', 'T62'}; annotatedImage = insertObjectAnnotation(img,'rectangle',bbox,labels,... 'TextBoxOpacity',0.9,'FontSize',50); figure imshow(annotatedImage); title('Sample Training image with bounding boxes and labels')
Создайте детектор объектов R-CNN для пяти целей: 'SLICY', 'BTR_60', 'BRDM_2', 'ZSU_23_4', 'T62'.
objectClasses = {'SLICY', 'BTR_60', 'BRDM_2', 'ZSU_23_4', 'T62'};
Сеть должна смочь классифицировать 5 целей, заданных выше и фоновый класс для того, чтобы быть обученной с помощью trainRCNNObjectDetector
доступный в Deep Learning Toolbox™. 1
добавляется в коде ниже, чтобы включать фоновый класс.
numClassesPlusBackground = numel(objectClasses) + 1;
Итоговый полносвязный слой сети задает количество классов, которые это может классифицировать. Установите итоговый полносвязный слой иметь выходной размер, равный numClassesPlusBackground
.
% Define input size inputSize = [128,128,1]; % Define network layers = createNetwork(inputSize,numClassesPlusBackground);
Теперь эти слоя сети могут использоваться, чтобы обучить основанный на R-CNN детектор с 5 объектами класса.
Используйте trainingOptions
задавать сетевые опции обучения. trainingOptions
использованием значения по умолчанию графический процессор, если вы доступны (требует, Parallel Computing Toolbox™ и CUDA® включили графический процессор с, вычисляют возможность 3.0 или выше). В противном случае это использует центральный процессор. Можно также задать среду выполнения при помощи 'ExecutionEnvironment'
аргумент пары "имя-значение" trainingOptions
. Чтобы автоматически обнаружить, если вы имеете графический процессор в наличии, установите ExecutionEnvironment
к 'auto
'. Если вы не имеете графического процессора или не хотите использовать один для обучения, устанавливать ExecutionEnvironment
к 'cpu
'. Чтобы гарантировать использование графического процессора для обучения, установите ExecutionEnvironment
к 'gpu
'.
% Set training options options = trainingOptions('sgdm', ... 'MiniBatchSize', 128, ... 'InitialLearnRate', 1e-3, ... 'LearnRateSchedule', 'piecewise', ... 'LearnRateDropFactor', 0.1, ... 'LearnRateDropPeriod', 100, ... 'MaxEpochs', 10, ... 'Verbose', true, ... 'CheckpointPath',tempdir,... 'ExecutionEnvironment','auto');
Используйте trainRCNNObjectDetector
обучать детектор объектов R-CNN если doTrain
верно. В противном случае загрузите предварительно обученную сеть. Если обучение, настройте 'NegativeOverlapRange
'и 'PositiveOverlapRange
'чтобы гарантировать, что обучающие выборки плотно перекрываются с основной истиной,
if doTrain % Train an R-CNN object detector. This will take several minutes detector = trainRCNNObjectDetector(trainingData, layers, options,'PositiveOverlapRange',[0.5 1], 'NegativeOverlapRange', [0.1 0.5]); else % Load a previously trained detector preTrainedMATFile = fullfile(outputFolder,'TrainedSARDetectorNet.mat'); load(preTrainedMATFile); end
Чтобы получить качественную идею функционирования детектора, выберите случайное изображение от набора тестов и запустите его через детектор. Детектор, как ожидают, возвратит набор ограничительных рамок, где он думает, что обнаруженные цели, наряду с баллами, указывающими на доверие к каждому обнаружению.
% Read test image imgIdx = randi(height(testData)); testImage = imread(testData.imageFilename(imgIdx)); % Detect SAR targets in the test image [bboxes,score,label] = detect(detector,testImage,'MiniBatchSize',16);
Чтобы изучить достигнутые результаты, наложите результаты детектора с тестовым изображением. Основной параметр является порогом обнаружения, счетом, выше которого детектор "обнаружил" цель. Более высокий порог приведет к меньшему количеству ложных положительных сторон однако, он также приведет к более ложным отрицательным сторонам.
scoreThreshold = 0.8; % Display the detection results outputImage = testImage; for idx = 1:length (счет) bbox = поля (idx, :); thisScore = счет (idx); if thisScore> scoreThreshold аннотация = sprintf ('%s: (Confidence = %0.2f)', пометьте (idx),... вокруг (thisScore, 2)); outputImage = insertObjectAnnotation (outputImage, 'rectangle', bbox,... аннотация,'TextBoxOpacity',0.9,'FontSize',45,'LineWidth',2); end end f = число; f.Position (3:4) = [860,740]; imshow (outputImage) заголовок'Predicted boxes and labels on test image')
Путем рассмотрения изображений последовательно, может быть изучена эффективность детектора. Чтобы выполнить более строгий анализ с помощью целого набора тестов, запустите набор тестов через детектор.
% Create a table to hold the bounding boxes, scores and labels output by the detector numImages = height(testData); results = table('Size',[numImages 3],... 'VariableTypes',{'cell','cell','cell'},... 'VariableNames',{'Boxes','Scores','Labels'}); % Run detector on each image in the test set and collect results for i = 1:numImages imgFilename = testData.imageFilename{i}; % Read the image I = imread(imgFilename); % Run the detector [bboxes, scores, labels] = detect(detector, I,'MiniBatchSize',16); % Collect the results results.Boxes{i} = bboxes; results.Scores{i} = scores; results.Labels{i} = labels; end
Возможные обнаружения и их ограничительные рамки для всех изображений в наборе тестов могут использоваться, чтобы вычислить Среднюю точность (AP) детектора для каждого класса. AP является средним значением точности детектора на разных уровнях отзыва, таким образом давайте зададим точность и отзыв.
где
- количество истинных положительных сторон (детектор предсказывает цель, когда это присутствует),
- количество ложных положительных сторон (детектор предсказывает цель, когда это не присутствует),
- количество ложных отрицательных сторон (детектору не удается обнаружить цель, когда это присутствует),
Детектор с точностью 1 рассматривается хорошим в обнаружении целей, которые присутствуют, в то время как детектор с отзывом 1 способен избегать ложных обнаружений. Точность и отзыв имеют обратную связь.
Постройте отношение между точностью и отзывом для каждого класса. Среднее значение каждой кривой является AP. Кривые для 0,5 порогов обнаружения построены.
Для получения дополнительной информации см. документацию для evaluateDetectionPrecision
.
% Extract expected bounding box locations from test data expectedResults = testData(:, 2:end); threshold = 0.5; % Evaluate the object detector using average precision metric [ap, recall, precision] = evaluateDetectionPrecision(results, expectedResults,threshold); % Plot precision recall curve f = figure; ax = gca; f.Position(3:4) = [860,740]; xlabel('Recall') ylabel('Precision') grid on; hold on; legend('Location', 'southeast'); title('Precision Vs Recall curve for threshold value 0.5 for different classes'); for i = 1:length(ap) % Plot precision/recall curve plot(ax,recall{i},precision{i},'DisplayName',['Average Precision for class ' trainingData.Properties.VariableNames{i+1} ' is ' num2str(round(ap(i),3))]) end
AP для большинства классов - больше чем 0,9. Из них обученная модель, кажется, борется больше всего в обнаружении целей 'SLICY'. Однако это все еще может достигнуть AP 0,7 для класса.
Этот пример демонстрирует, как обучить R-CNN целевому распознаванию в изображениях РСА. Предварительно обученная сеть достигла точности AP больше чем 0,9.
Функциональный createNetwork
берет в качестве входа размер изображения inputSize
и количество классов numClassesPlusBackground
. Функция возвращает архитектуру нейронной сети свертки.
function layers = createNetwork(inputSize,numClassesPlusBackground) layers = [ imageInputLayer(inputSize) % Input Layer convolution2dLayer(3,32,'Padding','same') % Convolution Layer reluLayer % Relu Layer convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer % Batch normalization Layer reluLayer maxPooling2dLayer(2,'Stride',2) % Max Pooling Layer convolution2dLayer(3,64,'Padding','same') reluLayer convolution2dLayer(3,64,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,128,'Padding','same') reluLayer convolution2dLayer(3,128,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,256,'Padding','same') reluLayer convolution2dLayer(3,256,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(6,512) reluLayer dropoutLayer(0.5) % Dropout Layer fullyConnectedLayer(512) % Fully connected Layer. reluLayer fullyConnectedLayer(numClassesPlusBackground) softmaxLayer % Softmax Layer classificationLayer % Classification Layer ]; end function helperDownloadMSTARClutterData(outputFolder,DataURL) % Download the data set from the given URL to the output folder. radarDataTarFile = fullfile(outputFolder,'MSTAR_ClutterDataset.tar.gz'); if ~exist(radarDataTarFile,'file') disp('Downloading MSTAR Clutter data (1.6 GB)...'); websave(radarDataTarFile,DataURL); untar(radarDataTarFile,outputFolder); end end function helperDownloadPretrainedSARDetectorNet(outputFolder,pretrainedNetURL) % Download the pretrained network. preTrainedMATFile = fullfile(outputFolder,'TrainedSARDetectorNet.mat'); preTrainedZipFile = fullfile(outputFolder,'TrainedSARDetectorNet.tar.gz'); if ~exist(preTrainedMATFile,'file') if ~exist(preTrainedZipFile,'file') disp('Downloading pretrained detector (29.4 MB)...'); websave(preTrainedZipFile,pretrainedNetURL); end untar(preTrainedZipFile,outputFolder); end end
[1] Набор данных MSTAR. https://www.sdms.afrl.af.mil/index.php? collection=mstar