В этом примере показано, как использовать 3-D данные моделирования для обучения семантической сети сегментации и точной настройки ее к реальным данным с помощью генеративных состязательных сетей (GANs).
Этот пример использует данные 3-D симуляции, сгенерированные Driving Scenario Designer и Unreal Engine ®. Для примера, показывающего, как сгенерировать такие данные моделирования, смотрите Depth and Семантическая Сегментация Visualization Using Unreal Engine Simulation (Automated Driving Toolbox). Среда симуляции 3-D генерирует изображения и соответствующие метки наземного пикселя основной истины. Использование данных моделирования избегает процесса аннотации, который является одновременно утомительным и требует большого количества человеческих усилий. Однако модели сдвига в области, обученные только на данных моделирования, плохо работают на реальных наборах данных. Для решения этой проблемы можно использовать адаптацию области, чтобы подстроить обученную модель для работы с реальным набором данных .
Этот пример использует AdaptSegNet [1], сеть, которая адаптирует структуру выходных предсказаний сегментации, которые выглядят одинаково независимо от входной области. Сеть AdaptSegNet основана на модели GAN и состоит из двух сетей, которые обучаются одновременно, чтобы максимизировать эффективность обеих:
Генератор - Сеть обучена генерировать высококачественные результаты сегментации из реальных или моделируемых входных изображений
Дискриминатор - Сеть, которая сравнивает и пытается различить, являются ли предсказания сегментации генератора реальными или моделируемыми данными
Чтобы подстроить модель AdaptSegNet для данных реального мира, этот пример использует подмножество Данные CamVid [2] и адаптирует модель, чтобы сгенерировать высококачественные предсказания сегментации на данных CamVid.
Загрузите предварительно обученную сеть. Предварительно обученная модель позволяет вам запустить весь пример, не дожидаясь завершения обучения. Если вы хотите обучить сеть, установите doTraining
переменная в true
.
doTraining = false; if ~doTraining pretrainedURL = 'https://ssd.mathworks.com/supportfiles/vision/data/trainedAdaptSegGANNet.mat'; pretrainedFolder = fullfile(tempdir,'pretrainedNetwork'); pretrainedNetwork = fullfile(pretrainedFolder,'trainedAdaptSegGANNet.mat'); if ~exist(pretrainedNetwork,'file') mkdir(pretrainedFolder); disp('Downloading pretrained network (57 MB)...'); websave(pretrainedNetwork,pretrainedURL); end pretrained = load(pretrainedNetwork); dlnetGenerator = pretrained.dlnetGenerator; end
Загрузите симуляцию и реальные наборы данных с помощью downloadDataset
function, заданная в разделе Support Functions этого примера. The downloadDataset
функция загружает весь набор данных CamVid и разбивает данные на обучающие и тестовые наборы.
Набор данных моделирования был сгенерирован Driving Scenario Designer. Сгенерированные сценарии, которые состоят из 553 фотореалистичных изображений с метками, были визуализированы Unreal Engine. Вы используете этот набор данных для обучения модели.
Реальный набор данных является подмножеством набора данных CamVid из Кембриджского университета. Чтобы адаптировать модель к реальным данным, 69 изображений CamVid. Чтобы оценить обученную модель, вы используете 368 изображений CamVid.
Время загрузки зависит от вашего подключения к Интернету.
simulationDataURL = 'https://ssd.mathworks.com/supportfiles/vision/data/SimulationDrivingDataset.zip'; realImageDataURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip'; realLabelDataURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip'; simulationDataLocation = fullfile(tempdir,'SimulationData'); realDataLocation = fullfile(tempdir,'RealData'); [simulationImagesFolder, simulationLabelsFolder, realImagesFolder, realLabelsFolder, ... realTestImagesFolder, realTestLabelsFolder] = ... downloadDataset(simulationDataLocation,simulationDataURL,realDataLocation,realImageDataURL,realLabelDataURL);
Загруженные файлы включают пиксельные метки для реальной области, но обратите внимание, что вы не используете эти пиксельные метки в процессе обучения. Этот пример использует пиксельные метки действительной области только для вычисления среднего значения пересечения по объединению (IoU), чтобы оценить эффективность обученной модели.
Использование imageDatastore
загрузить симуляцию и реальные наборы данных для обучения. При помощи image datastore можно эффективно загрузить на диск большой набор изображений.
simData = imageDatastore(simulationImagesFolder); realData = imageDatastore(realImagesFolder);
Предварительный просмотр изображений из набора данных моделирования и реального набора данных.
simImage = preview(simData); realImage = preview(realData); montage({simImage,realImage})
Реальные и моделируемые изображения выглядят очень по-другому. Следовательно, модели, обученные на моделируемых данных и оцененные на реальных данных, плохо работают из-за сдвига области.
Загрузите пиксельные данные о метке изображения симуляции при помощи pixelLabelDatastore
(Computer Vision Toolbox). Хранилище datastore метки пикселя инкапсулирует данные о пиксельных метках и идентификатор метки в отображение имен классов.
В данном примере задайте пять классов, полезных для приложения для беспилотного вождения: дорога, фон, тротуар, небо и автомобиль.
classes = [ "Road" "Background" "Pavement" "Sky" "Car" ]; numClasses = numel(classes);
Набор данных моделирования имеет восемь классов. Уменьшите количество классов с восьми до пяти путем группировки классов создания, дерева, сигнала трафика и света из исходного набора данных в один фоновый класс. Верните сгруппированные идентификаторы меток с помощью функции helper simulationPixelLabelIDs
. Эта вспомогательная функция присоединена к примеру как вспомогательный файл.
labelIDs = simulationPixelLabelIDs;
Используйте идентификаторы классов и меток, чтобы создать хранилище datastore о метках пикселей данных моделирования.
simLabels = pixelLabelDatastore(simulationLabelsFolder,classes,labelIDs);
Инициализируйте палитру для сегментированных изображений с помощью функции helper domainAdaptationColorMap
, заданный в разделе Вспомогательные функции.
dmap = domainAdaptationColorMap;
Предварительный просмотр изображения с меткой пикселя путем наложения метки поверх изображения с помощью labeloverlay
(Image Processing Toolbox) функция.
simImageLabel = preview(simLabels);
overlayImageSimulation = labeloverlay(simImage,simImageLabel,'ColorMap',dmap);
figure
imshow(overlayImageSimulation)
labelColorbar(dmap,classes);
Сдвиньте симуляцию и реальные данные, используемые для обучения, в нулевой центр, чтобы центрировать данные вокруг источника, используя transform
функции и preprocessData
вспомогательная функция, заданная в разделе Вспомогательные функции.
preprocessedSimData = transform(simData, @(simdata)preprocessData(simdata)); preprocessedRealData = transform(realData, @(realdata)preprocessData(realdata));
Используйте combine
функция для объединения преобразованного datastore изображений и хранилищ данных меток пикселей области симуляции. Процесс обучения не использует пиксельные метки реальных данных.
combinedSimData = combine(preprocessedSimData,simLabels);
Этот пример изменяет VGG-16 сеть, предварительно обученную в ImageNet, на полностью сверточную сеть. Для увеличения рецептивных полей добавляют расширенные сверточные слои с полосами 2 и 4. Это делает выход карты функций одной восьмой от размера входа. Atrous пространственное пирамидальное объединение (ASPP) используется для предоставления многомасштабной информации и сопровождается resize2dlayer
с коэффициентом повышающей дискретизации 8, чтобы изменить размер выхода на размер входа.
Сеть генератора AdaptSegNet, используемая в этом примере, проиллюстрирована на следующей схеме.
Чтобы получить предварительно обученную VGG-16 сеть, установите vgg16
. Если пакет поддержки не установлен, то программное обеспечение предоставляет ссылку для загрузки.
net = vgg16;
Чтобы сделать VGG-16 сеть подходящей для семантической сегментации, удалите все слои VGG после 'relu4_3'
.
vggLayers = net.Layers(2:24);
Создайте входной слой изображения размера 1280 на 720 на 3 для генератора.
inputSizeGenerator = [1280 720 3]; inputLayer = imageInputLayer(inputSizeGenerator,'Normalization','None','Name','inputLayer');
Создайте полностью сверточные слои сети. Используйте коэффициенты расширения 2 и 4, чтобы увеличить соответствующие поля.
fcnlayers = [ convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2],'Name','conv5_1','WeightsInitializer','narrow-normal','BiasInitializer','zeros') reluLayer('Name','relu5_1') convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2] ,'Name','conv5_2','WeightsInitializer','narrow-normal','BiasInitializer','zeros') reluLayer('Name','relu5_2') convolution2dLayer([3 3], 360,'DilationFactor',[2 2],'Padding',[2 2 2 2],'Name','conv5_3','WeightsInitializer','narrow-normal','BiasInitializer','zeros') reluLayer('Name','relu5_3') convolution2dLayer([3 3], 480,'DilationFactor',[4 4],'Padding',[4 4 4 4],'Name','conv6_1','WeightsInitializer','narrow-normal','BiasInitializer','zeros') reluLayer('Name','relu6_1') convolution2dLayer([3 3], 480,'DilationFactor',[4 4],'Padding',[4 4 4 4] ,'Name','conv6_2','WeightsInitializer','narrow-normal','BiasInitializer','zeros') reluLayer('Name','relu6_2') ];
Объедините слои и создайте график слоев.
layers = [ inputLayer vggLayers fcnlayers ]; lgraph = layerGraph(layers);
ASPP используется для предоставления многомасштабной информации. Добавьте модуль ASPP к графику слоев с размером фильтра, равным количеству каналов, при помощи addASPPToNetwork
вспомогательная функция, заданная в разделе Вспомогательные функции.
lgraph = addASPPToNetwork(lgraph, numClasses);
Применить resize2dLayer
с коэффициентом повышающей дискретизации 8, чтобы выход совпадал с размером входа.
upSampleLayer = resize2dLayer('Scale',8,'Method','bilinear','Name','resizeLayer'); lgraphGenerator = addLayers(lgraph,upSampleLayer); lgraphGenerator = connectLayers(lgraphGenerator,'additionLayer','resizeLayer');
Визуализируйте сеть генератора на графике.
plot(lgraphGenerator)
title("Generator")
Сеть дискриминатора состоит из пяти сверточных слоев с размером ядра 3 и полосой 2, где количество каналов {64, 128, 256, 512, 1}. Каждый слой сопровождается утечкой слоя ReLU, параметризованной масштабом 0,2, за исключением последнего слоя. resize2dLayer
используется для изменения размера выхода дискриминатора. Обратите внимание, что в этом примере не используется нормализация партии ., поскольку дискриминатор совместно обучается с сетью сегментации с использованием небольшого размера пакета.
Сеть дискриминатора AdaptSegNet в этом примере проиллюстрирована на следующей схеме.
Создайте слой входа изображений размера 1280 на 720 бай- numClasses
что принимает в сегментации предсказания симуляции и вещественных областей.
inputSizeDiscriminator = [1280 720 numClasses];
Создайте полностью сверточные слои и сгенерируйте график слоев дискриминатора.
% Factor for number of channels in convolution layer. numChannelsFactor = 64; % Scale factor to resize the output of the discriminator. resizeScale = 64; % Scalar multiplier for leaky ReLU layers. leakyReLUScale = 0.2; % Create the layers of the discriminator. layers = [ imageInputLayer(inputSizeDiscriminator,'Normalization','none','Name','inputLayer') convolution2dLayer(3,numChannelsFactor,'Stride',2,'Padding',1,'Name','conv1','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal') leakyReluLayer(leakyReLUScale,'Name','lrelu1') convolution2dLayer(3,numChannelsFactor*2,'Stride',2,'Padding',1,'Name','conv2','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal') leakyReluLayer(leakyReLUScale,'Name','lrelu2') convolution2dLayer(3,numChannelsFactor*4,'Stride',2,'Padding',1,'Name','conv3','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal') leakyReluLayer(leakyReLUScale,'Name','lrelu3') convolution2dLayer(3,numChannelsFactor*8,'Stride',2,'Padding',1,'Name','conv4','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal') leakyReluLayer(leakyReLUScale,'Name','lrelu4') convolution2dLayer(3,1,'Stride',2,'Padding',1,'Name','classifer','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal') resize2dLayer('Scale', resizeScale,'Method','bilinear','Name','resizeLayer'); ]; % Create the layer graph of the discriminator. lgraphDiscriminator = layerGraph(layers);
Визуализируйте сеть дискриминатора на графике.
plot(lgraphDiscriminator)
title("Discriminator")
Задайте эти опции обучения.
Установите общее количество итераций равным 5000
. Тем самым вы обучаете сеть около 10 эпох.
Установите скорость обучения для генератора равной 2.5e-4
.
Установите скорость обучения для дискриминатора равной 1e-4
.
Установите коэффициент регуляризации L2 равным 0.0005
.
Экспоненциально скорость обучения уменьшается на основе формулы . Это уменьшение помогает стабилизировать градиенты при более высоких итерациях. Установите степень 0.9
.
Установите вес состязательных потерь равным 0.001
.
Инициализируйте скорость градиента следующим [ ]
. Это значение используется SGDM, чтобы сохранить скорость градиентов.
Инициализируйте скользящее среднее значение для градиентов параметра следующим [ ]
. Это значение используется инициализатором Адама, чтобы сохранить среднее значение градиентов параметра.
Инициализируйте движущееся среднее значение градиентов квадратов параметров следующим [ ]
. Это значение используется инициализатором Адама, чтобы сохранить среднее значение градиентов квадратов параметра.
Установите размер мини-пакета равным 1
.
numIterations = 5000; learnRateGenBase = 2.5e-4; learnRateDisBase = 1e-4; l2Regularization = 0.0005; power = 0.9; lamdaAdv = 0.001; vel= []; averageGrad = []; averageSqGrad = []; miniBatchSize = 1;
Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Чтобы автоматически обнаружить, доступен ли вам графический процессор, установите executionEnvironment
на "auto"
. Если у вас нет графический процессор или вы не хотите использовать его для обучения, установите executionEnvironment
на "cpu"
. Чтобы гарантировать использование графический процессор для обучения, установите executionEnvironment
на "gpu"
. Для получения информации о поддерживаемых вычислительных возможностях смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
executionEnvironment = "auto";
Создайте minibatchqueue
объект из объединенного datastore области симуляции.
mbqTrainingDataSimulation = minibatchqueue(combinedSimData,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);
Создайте minibatchqueue
объект из входа входных изображений действительной области.
mbqTrainingDataReal = minibatchqueue(preprocessedRealData,"MiniBatchSize",miniBatchSize, ... "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);
Обучите модель с помощью пользовательского цикла обучения. Функция помощника modelGradients
, заданный в разделе Вспомогательные функции этого примера, вычислите градиенты и потери для генератора и дискриминатора. Создайте график процесса обучения с помощью configureTrainingLossPlotter
, приложенный к этому примеру как вспомогательный файл, и обновление процесса обучения с помощью updateTrainingPlots
. Закольцовывайте обучающие данные и обновляйте сетевые параметры при каждой итерации.
Для каждой итерации:
Чтение изображения и информации о метке из minibatchqueue
объект данных моделирования с помощью next
функция.
Считайте информацию об изображении из minibatchqueue
объект реальных данных с помощью next
функция.
Оцените градиенты модели с помощью dlfeval
и modelGradients
вспомогательная функция, заданная в разделе Вспомогательные функции. modelGradients
возвращает градиенты потерь относительно настраиваемых параметров.
Обновите параметры сети генератора с помощью sgdmupdate
функция.
Обновите параметры сети дискриминатора, используя adamupdate
функция.
Обновите график процесса обучения для каждой итерации и отобразите различные вычисленные потери.
if doTraining % Create the dlnetwork object of the generator. dlnetGenerator = dlnetwork(lgraphGenerator); % Create the dlnetwork object of the discriminator. dlnetDiscriminator = dlnetwork(lgraphDiscriminator); % Create the subplots for the generator and discriminator loss. fig = figure; [generatorLossPlotter, discriminatorLossPlotter] = configureTrainingLossPlotter(fig); % Loop through the data for the specified number of iterations. for iter = 1:numIterations % Reset the minibatchqueue of simulation data. if ~hasdata(mbqTrainingDataSimulation) reset(mbqTrainingDataSimulation); end % Retrieve the next mini-batch of simulation data and labels. [dlX,label] = next(mbqTrainingDataSimulation); % Reset the minibatchqueue of real data. if ~hasdata(mbqTrainingDataReal) reset(mbqTrainingDataReal); end % Retrieve the next mini-batch of real data. dlZ = next(mbqTrainingDataReal); % Evaluate the model gradients and loss using dlfeval and the modelGradients function. [gradientGenerator,gradientDiscriminator, lossSegValue, lossAdvValue, lossDisValue] = ... dlfeval(@modelGradients,dlnetGenerator,dlnetDiscriminator,dlX,dlZ,label,lamdaAdv); % Apply L2 regularization. gradientGenerator = dlupdate(@(g,w) g + l2Regularization*w, gradientGenerator, dlnetGenerator.Learnables); % Adjust the learning rate. learnRateGen = piecewiseLearningRate(iter,learnRateGenBase,numIterations,power); learnRateDis = piecewiseLearningRate(iter,learnRateDisBase,numIterations,power); % Update the generator network learnable parameters using the SGDM optimizer. [dlnetGenerator.Learnables, vel] = ... sgdmupdate(dlnetGenerator.Learnables,gradientGenerator,vel,learnRateGen); % Update the discriminator network learnable parameters using the Adam optimizer. [dlnetDiscriminator.Learnables, averageGrad, averageSqGrad] = ... adamupdate(dlnetDiscriminator.Learnables,gradientDiscriminator,averageGrad,averageSqGrad,iter,learnRateDis) ; % Update the training plot with loss values. updateTrainingPlots(generatorLossPlotter,discriminatorLossPlotter,iter, ... double(gather(extractdata(lossSegValue + lamdaAdv * lossAdvValue))),double(gather(extractdata(lossDisValue)))); end % Save the trained model. save('trainedAdaptSegGANNet.mat','dlnetGenerator'); end
Теперь дискриминатор может идентифицировать, является ли вход из симуляции или действительной области. В свою очередь, генератор теперь может генерировать предсказания сегментации, которые аналогичны в симуляции и реальных областях.
Оцените эффективность обученной сети AdaptSegNet, вычислив среднее значение IoU для предсказаний тестовых данных.
Загрузите тестовые данные с помощью imageDatastore
.
realTestData = imageDatastore(realTestImagesFolder);
Набор данных CamVid имеет 32 класса. Используйте realpixelLabelIDs
вспомогательная функция для уменьшения количества классов до пяти, как для набора данных моделирования. The realpixelLabelIDs
Функция helper присоединена к этому примеру как вспомогательный файл.
labelIDs = realPixelLabelIDs;
Использование pixelLabelDatastore
(Computer Vision Toolbox), чтобы загрузить изображения наземной метки истинности для тестовых данных.
realTestLabels = pixelLabelDatastore(realTestLabelsFolder,classes,labelIDs);
Переместите данные в нулевой центр, чтобы центрировать данные вокруг источника, как и для обучающих данных, при помощи transform
функции и preprocessData
вспомогательная функция, заданная в разделе Вспомогательные функции.
preprocessedRealTestData = transform(realTestData, @(realtestdata)preprocessData(realtestdata));
Использование combine
чтобы объединить преобразованный datastore изображений и хранилища данных меток пикселей реальных тестовых данных.
combinedRealTestData = combine(preprocessedRealTestData,realTestLabels);
Создайте minibatchqueue
объект из объединённого datastore тестовых данных .
Задайте "MiniBatchSize"
на 1
для простоты оценки метрик.
mbqimdsTest = minibatchqueue(combinedRealTestData,"MiniBatchSize",1,... "MiniBatchFormat","SSCB","OutputEnvironment",executionEnvironment);
Чтобы сгенерировать массив матричных ячеек неточностей, используйте функцию helper predictSegmentationLabelsOnTestSet
на minibatchqueue
объект тестовых данных. Функция помощника predictSegmentationLabelsOnTestSet
приведен ниже в разделе Вспомогательные функции.
imageSetConfusionMat = predictSegmentationLabelsOnTestSet(dlnetGenerator,mbqimdsTest);
Использование evaluateSemanticSegmentation
(Computer Vision Toolbox), чтобы измерить метрики семантической сегментации в матрице неточностей тестового набора.
metrics = evaluateSemanticSegmentation(imageSetConfusionMat,classes,'Verbose',false);
Чтобы увидеть метрики уровня набора данных, смотрите metrics.DataSetMetrics
.
metrics.DataSetMetrics
ans=1×4 table
GlobalAccuracy MeanAccuracy MeanIoU WeightedIoU
______________ ____________ _______ ___________
0.86883 0.769 0.64487 0.78026
Метрики набора данных обеспечивают высокоуровневый обзор эффективности сети. Чтобы увидеть влияние каждого класса на общую эффективность, смотрите метрики по классам с помощью metrics.ClassMetrics
.
metrics.ClassMetrics
ans=5×2 table
Accuracy IoU
________ _______
Road 0.9147 0.81301
Background 0.93418 0.85518
Pavement 0.33373 0.27105
Sky 0.82652 0.81109
Car 0.83586 0.47399
Эффективность набора данных хороша, но метрики классов показывают, что классы автомобиля и тротуара не сегментированы хорошо. Обучение сети с помощью дополнительных данных может привести к улучшенным результатам.
Запустите обученную сеть на одном тестовом изображении, чтобы проверить сегментированное выходное предсказание.
% Read the image from the test data. data = readimage(realTestData,350); % Perform the preprocessing step of zero shift on the image. processeddata = preprocessData(data); % Convert the data to dlarray. processeddata = dlarray(processeddata,'SSCB'); % Predict the output of the network. [genPrediction, ~] = forward(dlnetGenerator,processeddata); % Get the label, which is the index with the maximum value in the channel dimension. [~, labels] = max(genPrediction,[],3); % Overlay the predicted labels on the image. segmentedImage = labeloverlay(data,uint8(gather(extractdata(labels))),'Colormap',dmap);
Отображение результатов.
figure imshow(segmentedImage); labelColorbar(dmap,classes);
Сравните результаты метки с ожидаемыми основной истины, хранящимися в realTestLabels
. Зеленая и пурпурная области подсвечивают области, где результаты сегментации отличаются от ожидаемой основной истины.
expectedResult = readimage(realTestLabels,350); actual = uint8(gather(extractdata(labels))); expected = uint8(expectedResult); figure imshowpair(actual,expected)
Визуально результаты семантической сегментации хорошо перекрываются для классов "Дорога" , "Небо" и "Здания". Однако результаты плохо перекрываются для классов автомобиля и тротуара.
Функция помощника modelGradients
вычисляет градиенты и состязательные потери для генератора и дискриминатора. Функция также вычисляет потери сегментации для генератора и потери перекрестной энтропии для дискриминатора. Поскольку никакая информация о состоянии не требуется запоминать между итерациями для сетей генератора и дискриминатора, состояния не обновляются.
function [gradientGenerator, gradientDiscriminator, lossSegValue, lossAdvValue, lossDisValue] = modelGradients(dlnetGenerator, dlnetDiscriminator, dlX, dlZ, label, lamdaAdv) % Labels for adversarial training. simulationLabel = 0; realLabel = 1; % Extract the predictions of the simulation from the generator. [genPredictionSimulation, ~] = forward(dlnetGenerator,dlX); % Compute the generator loss. lossSegValue = segmentationLoss(genPredictionSimulation,label); % Extract the predictions of the real data from the generator. [genPredictionReal, ~] = forward(dlnetGenerator,dlZ); % Extract the softmax predictions of the real data from the discriminator. disPredictionReal = forward(dlnetDiscriminator,softmax(genPredictionReal)); % Create a matrix of simulation labels of real prediction size. Y = simulationLabel * ones(size(disPredictionReal)); % Compute the adversarial loss to make the real distribution close to the simulation label. lossAdvValue = mse(disPredictionReal,Y)/numel(Y(:)); % Compute the gradients of the generator with regard to loss. gradientGenerator = dlgradient(lossSegValue + lamdaAdv*lossAdvValue,dlnetGenerator.Learnables); % Extract the softmax predictions of the simulation from the discriminator. disPredictionSimulation = forward(dlnetDiscriminator,softmax(genPredictionSimulation)); % Create a matrix of simulation labels of simulation prediction size. Y = simulationLabel * ones(size(disPredictionSimulation)); % Compute the discriminator loss with regard to simulation class. lossDisValueSimulation = mse(disPredictionSimulation,Y)/numel(Y(:)); % Extract the softmax predictions of the real data from the discriminator. disPredictionReal = forward(dlnetDiscriminator,softmax(genPredictionReal)); % Create a matrix of real labels of real prediction size. Y = realLabel * ones(size(disPredictionReal)); % Compute the discriminator loss with regard to real class. lossDisValueReal = mse(disPredictionReal,Y)/numel(Y(:)); % Compute the total discriminator loss. lossDisValue = lossDisValueSimulation + lossDisValueReal; % Compute the gradients of the discriminator with regard to loss. gradientDiscriminator = dlgradient(lossDisValue,dlnetDiscriminator.Learnables); end
Функция помощника segmentationLoss
вычисляет потери сегментации функции, которые заданы как потери перекрестной энтропии для генератора, используя данные моделирования и его соответствующую основную истину. Функция helper вычисляет потери при помощи crossentropy
функция.
function loss = segmentationLoss(predict, target) % Generate the one-hot encodings of the ground truth. oneHotTarget = onehotencode(categorical(extractdata(target)),4); % Convert the one-hot encoded data to dlarray. oneHotTarget = dlarray(oneHotTarget,'SSBC'); % Compute the softmax output of the predictions. predictSoftmax = softmax(predict); % Compute the cross-entropy loss. loss = crossentropy(predictSoftmax,oneHotTarget,'TargetCategories','exclusive')/(numel(oneHotTarget)/2); end
Функция помощника downloadDataset
загружает как симуляцию, так и реальные наборы данных с указанных URL-адресов в указанные папки, если они не существуют. Функция возвращает пути симуляции, реальных обучающих данных и данных о проверке. Функция загружает весь набор данных CamVid и разбивает данные на обучающие и тестовые наборы с помощью subsetCamVidDatasetFileNames
файл mat, присоединенный к примеру как вспомогательный файл.
function [simulationImagesFolder, simulationLabelsFolder, realImagesFolder, realLabelsFolder,... realTestImagesFolder, realTestLabelsFolder] = ... downloadDataset(simulationDataLocation, simulationDataURL, realDataLocation, realImageDataURL, realLabelDataURL) % Build the training image and label folder location for simulation data. simulationDataZip = fullfile(simulationDataLocation,'SimulationDrivingDataset.zip'); % Get the simulation data if it does not exist. if ~exist(simulationDataZip,'file') mkdir(simulationDataLocation) disp('Downloading the simulation data'); websave(simulationDataZip,simulationDataURL); unzip(simulationDataZip,simulationDataLocation); end simulationImagesFolder = fullfile(simulationDataLocation,'SimulationDrivingDataset','images'); simulationLabelsFolder = fullfile(simulationDataLocation,'SimulationDrivingDataset','labels'); camVidLabelsZip = fullfile(realDataLocation,'CamVidLabels.zip'); camVidImagesZip = fullfile(realDataLocation,'CamVidImages.zip'); if ~exist(camVidLabelsZip,'file') || ~exist(camVidImagesZip,'file') mkdir(realDataLocation) disp('Downloading 16 MB CamVid dataset labels...'); websave(camVidLabelsZip, realLabelDataURL); unzip(camVidLabelsZip, fullfile(realDataLocation,'CamVidLabels')); disp('Downloading 587 MB CamVid dataset images...'); websave(camVidImagesZip, realImageDataURL); unzip(camVidImagesZip, fullfile(realDataLocation,'CamVidImages')); end % Build the training image and label folder location for real data. realImagesFolder = fullfile(realDataLocation,'train','images'); realLabelsFolder = fullfile(realDataLocation,'train','labels'); % Build the testing image and label folder location for real data. realTestImagesFolder = fullfile(realDataLocation,'test','images'); realTestLabelsFolder = fullfile(realDataLocation,'test','labels'); % Partition the data into training and test sets if they do not exist. if ~exist(realImagesFolder,'file') || ~exist(realLabelsFolder,'file') || ... ~exist(realTestImagesFolder,'file') || ~exist(realTestLabelsFolder,'file') mkdir(realImagesFolder); mkdir(realLabelsFolder); mkdir(realTestImagesFolder); mkdir(realTestLabelsFolder); % Load the mat file that has the names for testing and training. partitionNames = load('subsetCamVidDatasetFileNames.mat'); % Extract the test images names. imageTestNames = partitionNames.imageTestNames; % Remove the empty cells. imageTestNames = imageTestNames(~cellfun('isempty',imageTestNames)); % Extract the test labels names. labelTestNames = partitionNames.labelTestNames; % Remove the empty cells. labelTestNames = labelTestNames(~cellfun('isempty',labelTestNames)); % Copy the test images to the respective folder. for i = 1:size(imageTestNames,1) labelSource = fullfile(realDataLocation,'CamVidLabels',labelTestNames(i)); imageSource = fullfile(realDataLocation,'CamVidImages','701_StillsRaw_full',imageTestNames(i)); copyfile(imageSource{1}, realTestImagesFolder); copyfile(labelSource{1}, realTestLabelsFolder); end % Extract the train images names. imageTrainNames = partitionNames.imageTrainNames; % Remove the empty cells. imageTrainNames = imageTrainNames(~cellfun('isempty',imageTrainNames)); % Extract the train labels names. labelTrainNames = partitionNames.labelTrainNames; % Remove the empty cells. labelTrainNames = labelTrainNames(~cellfun('isempty',labelTrainNames)); % Copy the train images to the respective folder. for i = 1:size(imageTrainNames,1) labelSource = fullfile(realDataLocation,'CamVidLabels',labelTrainNames(i)); imageSource = fullfile(realDataLocation,'CamVidImages','701_StillsRaw_full',imageTrainNames(i)); copyfile(imageSource{1},realImagesFolder); copyfile(labelSource{1},realLabelsFolder); end end end
Функция помощника addASPPToNetwork
создает слои atrous spatial pyramid uling (ASPP) и добавляет их к входу графика слоев. Функция возвращает график слоев с соединенными с ним слоями ASPP.
function lgraph = addASPPToNetwork(lgraph, numClasses) % Define the ASPP dilation factors. asppDilationFactors = [6,12]; % Define the ASPP filter sizes. asppFilterSizes = [3,3]; % Extract the last layer of the layer graph. lastLayerName = lgraph.Layers(end).Name; % Define the addition layer. addLayer = additionLayer(numel(asppDilationFactors),'Name','additionLayer'); % Add the addition layer to the layer graph. lgraph = addLayers(lgraph,addLayer); % Create the ASPP layers connected to the addition layer % and connect the layer graph. for i = 1: numel(asppDilationFactors) asppConvName = "asppConv_" + string(i); branchFilterSize = asppFilterSizes(i); branchDilationFactor = asppDilationFactors(i); asspLayer = convolution2dLayer(branchFilterSize, numClasses,'DilationFactor', branchDilationFactor,... 'Padding','same','Name',asppConvName,'WeightsInitializer','narrow-normal','BiasInitializer','zeros'); lgraph = addLayers(lgraph,asspLayer); lgraph = connectLayers(lgraph,lastLayerName,asppConvName); lgraph = connectLayers(lgraph,asppConvName,strcat(addLayer.Name,'/',addLayer.InputNames{i})); end end
Функция помощника predictSegmentationLabelsOnTestSet
вычисляет матрицу неточностей предсказанных и основную истину меток с помощью segmentationConfusionMatrix
(Computer Vision Toolbox) функция.
function confusionMatrix = predictSegmentationLabelsOnTestSet(net, minbatchTestData) confusionMatrix = {}; i = 1; while hasdata(minbatchTestData) % Use next to retrieve a mini-batch from the datastore. [dlX, gtlabels] = next(minbatchTestData); % Predict the output of the network. [genPrediction, ~] = forward(net,dlX); % Get the label, which is the index with maximum value in the channel dimension. [~, labels] = max(genPrediction,[],3); % Get the confusion matrix of each image. confusionMatrix{i} = segmentationConfusionMatrix(double(gather(extractdata(labels))),double(gather(extractdata(gtlabels)))); i = i+1; end confusionMatrix = confusionMatrix'; end
Функция помощника piecewiseLearningRate
вычисляет текущую скорость обучения на основе числа итераций.
function lr = piecewiseLearningRate(i, baseLR, numIterations, power) fraction = i/numIterations; factor = (1 - fraction)^power * 1e1; lr = baseLR * factor; end
Функция помощника preprocessData
выполняет сдвиг нулевого центра путем вычитания количества каналов изображения соответствующим средним значением.
function data = preprocessData(data) % Extract respective channels. rc = data(:,:,1); gc = data(:,:,2); bc = data(:,:,3); % Compute the respective channel means. r = mean(rc(:)); g = mean(gc(:)); b = mean(bc(:)); % Shift the data by the mean of respective channel. data = single(data) - single(shiftdim([r g b],-1)); end
[1] Цай, И-Хсуань, Вэй-Чих Хунг, Самуэль Шультер, Кихюк Сон, Мин-Хсуань Ян и Манмохан Чандракер. «Обучение адаптации структурированного выходного пространства для семантической сегментации». В 2018 году IEEE/CVF Conference on Компьютерное Зрение and Pattern Recognition, 7472-81. Солт-Лейк-Сити, UT: IEEE, 2018. https://doi.org/10.1109/CVPR.2018.00780.
[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. Semantic Object Classes in Video: A High-Definition Ground Truth Database (неопр.) (недоступная ссылка). Pattern Recognition Letters 30, № 2 (январь 2009): 88-97. https://doi.org/10.1016/j.patrec.2008.04.005.