Этот пример показывает, как анализировать эстетическое качество изображений с помощью сверточной нейронной сети Neural Image Assessment (NIMA) (CNN).
Метрики качества изображений обеспечивают объективную меру качества изображений. Эффективная метрика обеспечивает количественные счета, которые хорошо коррелируют с субъективным восприятием качества человеческим наблюдателем. Метрики качества позволяют сравнить алгоритмы обработки изображений.
NIMA [1] является методом без ссылки, который предсказывает качество изображения, не полагаясь на первозданное эталонное изображение, которое часто недоступно. NIMA использует CNN, чтобы предсказать распределение счетов качества для каждого изображения.
Загрузите предварительно обученную нейронную сеть NIMA с помощью функции helper downloadTrainedNIMANet
. Функция helper присоединена к примеру как вспомогательный файл. Эта модель предсказывает распределение счетов качества для каждого изображения в области значений [1, 10], где 1 и 10 являются самыми низкими и самыми высокими возможными значениями для счета, соответственно. Высокий счет указывает на хорошее качество изображения.
imageDir = fullfile(tempdir,"LIVEInTheWild"); if ~exist(imageDir,'dir') mkdir(imageDir); end trainedNIMA_url = 'https://ssd.mathworks.com/supportfiles/image/data/trainedNIMA.zip'; downloadTrainedNIMANet(trainedNIMA_url,imageDir); load(fullfile(imageDir,'trainedNIMA.mat'));
Можно оценить эффективность модели NIMA, сравнив предсказанные счета для качественного и более низкого качества изображения.
Прочтите качественное изображение в рабочую область.
imOriginal = imread('kobi.png');
Уменьшите эстетическое качество изображения путем применения Гауссова размытия. Отобразите оригинальное изображение и размытое изображение в монтаже. Субъективно эстетическое качество размытого изображения хуже, чем качество оригинального изображения.
imBlur = imgaussfilt(imOriginal,5); montage({imOriginal,imBlur})
Спрогнозируйте распределение счетов качества NIMA для двух изображений, используя predictNIMAScore
вспомогательная функция. Эта функция присоединена к примеру как вспомогательный файл.
The predictNIMAScore
функция возвращает среднее и стандартное отклонение распределения счетов NIMA для изображения. Прогнозируемый средний счет является мерой качества изображения. Стандартное отклонение счетов может быть рассмотрено как мера доверительного уровня предсказанного среднего счета.
[meanOriginal,stdOriginal] = predictNIMAScore(dlnet,imOriginal); [meanBlur,stdBlur] = predictNIMAScore(dlnet,imBlur);
Отображение изображений вместе со средним и стандартным отклонением распределений счетов, предсказанных моделью NIMA. The
Модель NIMA правильно предсказывает счета для этих изображений, которые согласуются с субъективной визуальной оценкой.
figure t = tiledlayout(1,2); displayImageAndScoresForNIMA(t,imOriginal,meanOriginal,stdOriginal,"Original Image") displayImageAndScoresForNIMA(t,imBlur,meanBlur,stdBlur,"Blurred Image")
Остальная часть этого примера показывает, как обучить и оценить модель NIMA.
Этот пример использует набор данных LIVE In the Wild [2], который является базой данных субъективного качества изображений в открытом доступе. Набор данных содержит 1162 фотографии, захваченные мобильными устройствами, с 7 дополнительными изображениями, предоставленными для обучения людей-бомбардиров. Каждое изображение оценивается в среднем 175 индивидуумы по шкале [1, 100]. Набор данных обеспечивает среднее и стандартное отклонение субъективных счетов для каждого изображения.
Загрузите набор данных, следуя инструкциям, описанным в LIVE In the Wild Image Quality Challenge Database. Извлеките данные в директорию, заданную imageDir
переменная. Когда экстракция успешна, imageDir
содержит две директории: Data
и Images
.
Получите файл пути к изображениям.
imageData = load(fullfile(imageDir,'Data','AllImages_release.mat')); imageData = imageData.AllImages_release; nImg = length(imageData); imageList(1:7) = fullfile(imageDir,'Images','trainingImages',imageData(1:7)); imageList(8:nImg) = fullfile(imageDir,'Images',imageData(8:end));
Создайте datastore, которое управляет данными изображений.
imds = imageDatastore(imageList);
Загрузите средние и стандартные данные об отклонениях, соответствующие изображениям.
meanData = load(fullfile(imageDir,'Data','AllMOS_release.mat')); meanData = meanData.AllMOS_release; stdData = load(fullfile(imageDir,'Data','AllStdDev_release.mat')); stdData = stdData.AllStdDev_release;
При необходимости отобразите несколько выборочных изображений из набора данных с соответствующими средним и стандартными значениями отклонения.
figure t = tiledlayout(1,3); idx1 = 785; displayImageAndScoresForNIMA(t,readimage(imds,idx1), ... meanData(idx1),stdData(idx1),"Image "+imageData(idx1)) idx2 = 203; displayImageAndScoresForNIMA(t,readimage(imds,idx2), ... meanData(idx2),stdData(idx2),"Image "+imageData(idx2)) idx3 = 777; displayImageAndScoresForNIMA(t,readimage(imds,idx3), ... meanData(idx3),stdData(idx3),"Image "+imageData(idx3))
Предварительно обработайте изображения, изменив размеры их к 256 на 256 пикселей.
rescaleSize = [256 256]; imds = transform(imds,@(x)imresize(x,rescaleSize));
Модель NIMA требует распределения человеческих счетов, но набор данных LIVE предоставляет только среднее и стандартное отклонение распределения. Приблизите базовое распределение для каждого изображения в наборе данных LIVE с помощью createNIMAScoreDistribution
вспомогательная функция. Эта функция присоединена к примеру как вспомогательный файл.
The createNIMAScoreDistribution
пересчитывает счета в область значений [1, 10], затем генерирует максимальное энтропийное распределение счетов из среднего и стандартного значений отклонения.
newMaxScore = 10; prob = createNIMAScoreDistribution(meanData,stdData); cumProb = cumsum(prob,2);
Создайте arrayDatastore
который управляет распределениями счетов.
probDS = arrayDatastore(cumProb','IterationDimension',2);
Объедините хранилища данных, содержащие данные изображения и данные распределения счетов.
dsCombined = combine(imds,probDS);
Предварительный просмотр выхода чтения из комбинированного datastore.
sampleRead = preview(dsCombined)
sampleRead=1×2 cell array
{256×256×3 uint8} {10×1 double}
figure tiledlayout(1,2) nexttile imshow(sampleRead{1}) title("Sample Image from Data Set") nexttile plot(sampleRead{2}) title("Cumulative Score Distribution")
Разделите данные на наборы для обучения, валидации и тестирования. Выделите 70% данных для обучения, 15% для валидации и оставшуюся часть для проверки.
numTrain = floor(0.70 * nImg); numVal = floor(0.15 * nImg); Idx = randperm(nImg); idxTrain = Idx(1:numTrain); idxVal = Idx(numTrain+1:numTrain+numVal); idxTest = Idx(numTrain+numVal+1:nImg); dsTrain = subset(dsCombined,idxTrain); dsVal = subset(dsCombined,idxVal); dsTest = subset(dsCombined,idxTest);
Увеличение обучающих данных с помощью augmentImageTest
вспомогательная функция. Эта функция присоединена к примеру как вспомогательный файл. The augmentDataForNIMA
функция выполняет эти операции увеличения для каждого обучающего изображения:
Обрезать изображение до 224 на 244 пикселей, чтобы уменьшить сверхподбор кривой.
Разверните изображение горизонтально с вероятностью 50%.
inputSize = [224 224]; dsTrain = transform(dsTrain,@(x)augmentDataForNIMA(x,inputSize));
Слой входа сети выполняет z-балльную нормализацию обучающих изображений. Вычислите среднее и стандартное отклонение обучающих изображений для использования в нормализации z-балла.
meanImage = zeros([inputSize 3]); meanImageSq = zeros([inputSize 3]); while hasdata(dsTrain) dat = read(dsTrain); img = double(dat{1}); meanImage = meanImage + img; meanImageSq = meanImageSq + img.^2; end meanImage = meanImage/numTrain; meanImageSq = meanImageSq/numTrain; varImage = meanImageSq - meanImage.^2; stdImage = sqrt(varImage);
Установите datastore в начальное состояние.
reset(dsTrain);
Этот пример начинается с MobileNet-v2 [3] CNN, обученного на ImageNet [4]. Пример модифицирует сеть путем замены последнего слоя MobileNet-v2 сети полносвязным слоем с 10 нейронами, каждый из которых представляет дискретный счет от 1 до 10. Сеть предсказывает вероятность каждого счета для каждого изображения. Пример нормализует выходы полностью соединенного слоя с помощью слоя активации softmax.
The mobilenetv2
функция возвращает предварительно обученную MobileNet-v2 сеть. Эта функция требует пакета Deep Learning Toolbox™ Model for MobileNet-v2 Network поддержки. Если этот пакет поддержки не установлен, то функция предоставляет ссылку на загрузку.
net = mobilenetv2;
Преобразуйте сеть в layerGraph
объект.
lgraph = layerGraph(net);
Сеть имеет размер входа изображения 224 на 224 пикселя. Замените слой входа слоем входа изображений, который выполняет нормализацию z-балла по данным изображения с помощью среднего и стандартного отклонения обучающих изображений.
inLayer = imageInputLayer([inputSize 3],'Name','input','Normalization','zscore','Mean',meanImage,'StandardDeviation',stdImage); lgraph = replaceLayer(lgraph,'input_1',inLayer);
Замените исходный конечный классификационный слой полносвязным слоем с 10 нейронами. Добавьте слой softmax, чтобы нормализовать выходы. Установите скорость обучения для полносвязного слоя в 10 раз выше скорости обучения для базовых слоев CNN. Нанесите выпадение 75%.
lgraph = removeLayers(lgraph,{'ClassificationLayer_Logits','Logits_softmax','Logits'}); newFinalLayers = [ dropoutLayer(0.75,'Name','drop') fullyConnectedLayer(newMaxScore,'Name','fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10) softmaxLayer('Name','prob')]; lgraph = addLayers(lgraph,newFinalLayers); lgraph = connectLayers(lgraph,'global_average_pooling2d_1','drop'); dlnet = dlnetwork(lgraph);
Визуализируйте сеть с помощью приложения Deep Network Designer.
deepNetworkDesigner(lgraph)
The modelGradients
Функция helper вычисляет градиенты и потери для каждой итерации настройки сети. Эта функция определяется в разделе Вспомогательные функции этого примера.
Цель сети NIMA состоит в том, чтобы минимизировать расстояние земного движителя (EMD) между основной истиной и предсказанным распределениями счетом. Потеря EMD учитывает расстояние между классами при наказании за неправильную классификацию. Поэтому потеря EMD работает лучше, чем типичная потеря перекрестной энтропии softmax, используемая в задачах классификации [5]. В этом примере вычисляются потери EMD с помощью earthMoverDistance
вспомогательная функция, которая задана в разделе «Вспомогательные функции» этого примера.
Для функции потерь EMD используйте расстояние r-нормы с r = 2. Это расстояние позволяет легко оптимизировать, когда вы работаете с градиентным спуском.
Задайте опции для оптимизации SGDM. Обучите сеть на 150 эпох с мини-пакетом размером 128.
numEpochs = 150; miniBatchSize = 128; momentum = 0.9; initialLearnRate = 3e-3; decay = 0.95;
Создайте minibatchqueue
объект, который управляет мини-пакетированием наблюдений в пользовательском цикле обучения. The minibatchqueue
объект также переводит данные в dlarray
объект, который позволяет проводить автоматическую дифференциацию в применениях глубокого обучения.
Задайте мини-формат извлечения данных пакета как 'SSCB'
(пространственный, пространственный, канальный, пакетный). Установите значение 'DispatchInBackground'
аргумент имя-значение логического элемента, возвращенный canUseGPU
. Если поддерживаемый графический процессор доступен для расчетов, то minibatchqueue
объект обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.
mbqTrain = minibatchqueue(dsTrain,'MiniBatchSize',miniBatchSize, ... 'PartialMiniBatch','discard','MiniBatchFormat',{'SSCB',''}, ... 'DispatchInBackground',canUseGPU); mbqVal = minibatchqueue(dsVal,'MiniBatchSize',miniBatchSize, ... 'MiniBatchFormat',{'SSCB',''},'DispatchInBackground',canUseGPU);
По умолчанию пример загружает предварительно обученную версию сети NIMA. Предварительно обученная сеть позволяет запускать весь пример, не дожидаясь завершения обучения.
Чтобы обучить сеть, установите doTraining
переменная в следующем коде, для true
. Обучите модель в пользовательском цикле обучения. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next
функция.
Оцените градиенты модели с помощью dlfeval
функции и modelGradients
вспомогательная функция.
Обновляйте параметры сети с помощью sgdmupdate
функция.
Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения дополнительной информации смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
doTraining = false; if doTraining iteration = 0; velocity = []; start = tic; [hFig,lineLossTrain,lineLossVal] = initializeTrainingPlotNIMA; for epoch = 1:numEpochs shuffle (mbqTrain); learnRate = initialLearnRate/(1+decay*floor(epoch/10)); while hasdata(mbqTrain) iteration = iteration + 1; [dlX,cdfY] = next(mbqTrain); [grad,loss] = dlfeval(@modelGradients,dlnet,dlX,cdfY); [dlnet,velocity] = sgdmupdate(dlnet,grad,velocity,learnRate,momentum); updateTrainingPlotNIMA(lineLossTrain,loss,epoch,iteration,start) end % Add validation data to plot [~,lossVal,~] = modelPredictions(dlnet,mbqVal); updateTrainingPlotNIMA(lineLossVal,lossVal,epoch,iteration,start) end % Save the trained network modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss")); save(strcat("trainedNIMA-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'dlnet'); else load(fullfile(imageDir,'trainedNIMA.mat')); end
Оцените эффективность модели на наборе тестовых данных с помощью трех метрик: EMD, бинарной точности классификации и коэффициентов корреляции. Производительность сети NIMA на наборе тестовых данных согласуется с эффективностью эталонной модели NIMA, сообщенной Талеби и Миланфаром [1].
Создайте minibatchqueue
объект, который управляет мини-пакетом тестовых данных.
mbqTest = minibatchqueue(dsTest,'MiniBatchSize',miniBatchSize,'MiniBatchFormat',{'SSCB',''});
Вычислите предсказанные вероятности и основную истину совокупные вероятности мини-пакетов тестовых данных с помощью modelPredictions
функция. Эта функция определяется в разделе Вспомогательные функции этого примера.
[YPredTest,~,cdfYTest] = modelPredictions(dlnet,mbqTest);
Вычислите среднее и стандартные значения отклонений основной истины и предсказанных распределений.
meanPred = extractdata(YPredTest)' * (1:10)'; stdPred = sqrt(extractdata(YPredTest)'*((1:10).^2)' - meanPred.^2); origCdf = extractdata(cdfYTest); origPdf = [origCdf(1,:); diff(origCdf)]; meanOrig = origPdf' * (1:10)'; stdOrig = sqrt(origPdf'*((1:10).^2)' - meanOrig.^2);
Вычислите EMD основной истины и предсказанных распределений счета. Для предсказания используйте расстояние r-нормы с r = 1. Значение EMD указывает на близость предсказанного и основная истина рейтинговых распределений.
EMDTest = earthMoverDistance(YPredTest,cdfYTest,1)
EMDTest = 1×1 single gpuArray dlarray 0.1158
Для бинарной точности классификации преобразуйте распределения в две классификации: качественную и низкокачественную. Классифицируйте изображения со средним счетом, большим порога, как высококачественные.
qualityThreshold = 5; binaryPred = meanPred > qualityThreshold; binaryOrig = meanOrig > qualityThreshold;
Вычислите точность двоичной классификации.
binaryAccuracy = 100 * sum(binaryPred==binaryOrig)/length(binaryPred)
binaryAccuracy = 84.6591
Большие значения корреляции указывают на большую положительную корреляцию между основной истиной и предсказанными счетами. Вычислите коэффициент линейной корреляции (LCC) и коэффициент ранговой корреляции Спирмана (SRCC) для средних счетов.
meanLCC = corr(meanOrig,meanPred)
meanLCC = gpuArray single 0.7265
meanSRCC = corr(meanOrig,meanPred,'type','Spearman')
meanSRCC = gpuArray single 0.6451
The modelGradients
функция принимает как вход dlnetwork
dlnet объекта
и мини-пакет входных данных dlX
с соответствующими целевыми совокупными вероятностями cdfY
. Функция возвращает градиенты потерь относительно настраиваемых параметров в dlnet
а также потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
function [gradients,loss] = modelGradients(dlnet,dlX,cdfY) dlYPred = forward(dlnet,dlX); loss = earthMoverDistance(dlYPred,cdfY,2); gradients = dlgradient(loss,dlnet.Learnables); end
The earthMoverDistance
функция вычисляет EMD между основной истиной и предсказанным распределениями для заданного значения r-нормы. The earthMoverDistance
использует computeCDF
вспомогательная функция для вычисления совокупных вероятностей предсказанного распределения.
function loss = earthMoverDistance(YPred,cdfY,r) N = size(cdfY,1); cdfYPred = computeCDF(YPred); cdfDiff = (1/N) * (abs(cdfY - cdfYPred).^r); lossArray = sum(cdfDiff,1).^(1/r); loss = mean(lossArray); end function cdfY = computeCDF(Y) % Given a probability mass function Y, compute the cumulative probabilities [N,miniBatchSize] = size(Y); L = repmat(triu(ones(N)),1,1,miniBatchSize); L3d = permute(L,[1 3 2]); prod = Y.*L3d; prodSum = sum(prod,1); cdfY = reshape(prodSum(:)',miniBatchSize,N)'; end
The modelPredictions
функция вычисляет предполагаемые вероятности, потери и основную истину совокупные вероятности мини-пакетов данных.
function [dlYPred,loss,cdfYOrig] = modelPredictions(dlnet,mbq) reset(mbq); loss = 0; numObservations = 0; dlYPred = []; cdfYOrig = []; while hasdata(mbq) [dlX,cdfY] = next(mbq); miniBatchSize = size(dlX,4); dlY = predict(dlnet,dlX); loss = loss + earthMoverDistance(dlY,cdfY,2)*miniBatchSize; dlYPred = [dlYPred dlY]; cdfYOrig = [cdfYOrig cdfY]; numObservations = numObservations + miniBatchSize; end loss = loss / numObservations; end
[1] Талеби, Хоссейн и Пейман Миланфар. NIMA: оценка нейронного изображения. Транзакции IEEE по обработке изображений 27, № 8 (август 2018): 3998-4011. https://doi.org/10.1109/TIP.2018.2831899.
[2] LIVE: Лаборатория инженерии изображений и видео. «LIVE In The Wild Image Quality Challenge Database». https://live.ece.utexas.edu/research/ChallengeDB/index.html.
[3] Сэндлер, Марк, Эндрю Говард, Менглун Чжу, Андрей Жмогинов и Лян-Чих Чен. «MobileNetV2: перевёрнутые невязки и линейные узкие места». В 2018 году IEEE/CVF Conference on Компьютерное Зрение and Pattern Recognition, 4510-20. Солт-Лейк-Сити, UT: IEEE, 2018. https://doi.org/10.1109/CVPR.2018.00474.
[4] ImageNet. http://www.image-net.org.
[5] Hou, Le, Chen-Ping Yu и Димитрис Самарас. Squared Earth Mover's Distance-Based Loss for Training Deep Neural Networks (неопр.) (недоступная ссылка). Препринт, представленный 30 ноября 2016 года. https://arxiv.org/abs/1611.05916.
dlfeval
| dlnetwork
| layerGraph
| minibatchqueue
| mobilenetv2
| predict
| sgdmupdate
| transform