Количественная оценка качества изображения с помощью оценки нейронного изображения

Этот пример показывает, как анализировать эстетическое качество изображений с помощью сверточной нейронной сети Neural Image Assessment (NIMA) (CNN).

Метрики качества изображений обеспечивают объективную меру качества изображений. Эффективная метрика обеспечивает количественные счета, которые хорошо коррелируют с субъективным восприятием качества человеческим наблюдателем. Метрики качества позволяют сравнить алгоритмы обработки изображений.

NIMA [1] является методом без ссылки, который предсказывает качество изображения, не полагаясь на первозданное эталонное изображение, которое часто недоступно. NIMA использует CNN, чтобы предсказать распределение счетов качества для каждого изображения.

Оцените качество изображения с помощью обученной модели NIMA

Загрузите предварительно обученную нейронную сеть NIMA с помощью функции helper downloadTrainedNIMANet. Функция helper присоединена к примеру как вспомогательный файл. Эта модель предсказывает распределение счетов качества для каждого изображения в области значений [1, 10], где 1 и 10 являются самыми низкими и самыми высокими возможными значениями для счета, соответственно. Высокий счет указывает на хорошее качество изображения.

imageDir = fullfile(tempdir,"LIVEInTheWild");
if ~exist(imageDir,'dir')
    mkdir(imageDir);
end
trainedNIMA_url = 'https://ssd.mathworks.com/supportfiles/image/data/trainedNIMA.zip';
downloadTrainedNIMANet(trainedNIMA_url,imageDir);
load(fullfile(imageDir,'trainedNIMA.mat'));

Можно оценить эффективность модели NIMA, сравнив предсказанные счета для качественного и более низкого качества изображения.

Прочтите качественное изображение в рабочую область.

imOriginal = imread('kobi.png'); 

Уменьшите эстетическое качество изображения путем применения Гауссова размытия. Отобразите оригинальное изображение и размытое изображение в монтаже. Субъективно эстетическое качество размытого изображения хуже, чем качество оригинального изображения.

imBlur = imgaussfilt(imOriginal,5);
montage({imOriginal,imBlur})

Спрогнозируйте распределение счетов качества NIMA для двух изображений, используя predictNIMAScore вспомогательная функция. Эта функция присоединена к примеру как вспомогательный файл.

The predictNIMAScore функция возвращает среднее и стандартное отклонение распределения счетов NIMA для изображения. Прогнозируемый средний счет является мерой качества изображения. Стандартное отклонение счетов может быть рассмотрено как мера доверительного уровня предсказанного среднего счета.

[meanOriginal,stdOriginal] = predictNIMAScore(dlnet,imOriginal);
[meanBlur,stdBlur] = predictNIMAScore(dlnet,imBlur);

Отображение изображений вместе со средним и стандартным отклонением распределений счетов, предсказанных моделью NIMA. The Модель NIMA правильно предсказывает счета для этих изображений, которые согласуются с субъективной визуальной оценкой.

figure
t = tiledlayout(1,2);
displayImageAndScoresForNIMA(t,imOriginal,meanOriginal,stdOriginal,"Original Image")
displayImageAndScoresForNIMA(t,imBlur,meanBlur,stdBlur,"Blurred Image")

Остальная часть этого примера показывает, как обучить и оценить модель NIMA.

Загрузить LIVE в диком наборе данных

Этот пример использует набор данных LIVE In the Wild [2], который является базой данных субъективного качества изображений в открытом доступе. Набор данных содержит 1162 фотографии, захваченные мобильными устройствами, с 7 дополнительными изображениями, предоставленными для обучения людей-бомбардиров. Каждое изображение оценивается в среднем 175 индивидуумы по шкале [1, 100]. Набор данных обеспечивает среднее и стандартное отклонение субъективных счетов для каждого изображения.

Загрузите набор данных, следуя инструкциям, описанным в LIVE In the Wild Image Quality Challenge Database. Извлеките данные в директорию, заданную imageDir переменная. Когда экстракция успешна, imageDir содержит две директории: Data и Images.

Загрузка LIVE в дикие данные

Получите файл пути к изображениям.

imageData = load(fullfile(imageDir,'Data','AllImages_release.mat'));
imageData = imageData.AllImages_release;
nImg = length(imageData);
imageList(1:7) = fullfile(imageDir,'Images','trainingImages',imageData(1:7));
imageList(8:nImg) = fullfile(imageDir,'Images',imageData(8:end));

Создайте datastore, которое управляет данными изображений.

imds = imageDatastore(imageList);

Загрузите средние и стандартные данные об отклонениях, соответствующие изображениям.

meanData = load(fullfile(imageDir,'Data','AllMOS_release.mat'));
meanData = meanData.AllMOS_release;
stdData = load(fullfile(imageDir,'Data','AllStdDev_release.mat'));
stdData = stdData.AllStdDev_release;

При необходимости отобразите несколько выборочных изображений из набора данных с соответствующими средним и стандартными значениями отклонения.

figure
t = tiledlayout(1,3);
idx1 = 785;
displayImageAndScoresForNIMA(t,readimage(imds,idx1), ...
    meanData(idx1),stdData(idx1),"Image "+imageData(idx1))
idx2 = 203;
displayImageAndScoresForNIMA(t,readimage(imds,idx2), ...
    meanData(idx2),stdData(idx2),"Image "+imageData(idx2))
idx3 = 777;
displayImageAndScoresForNIMA(t,readimage(imds,idx3), ...
    meanData(idx3),stdData(idx3),"Image "+imageData(idx3))

Предварительная обработка и увеличение данных

Предварительно обработайте изображения, изменив размеры их к 256 на 256 пикселей.

rescaleSize = [256 256];
imds = transform(imds,@(x)imresize(x,rescaleSize));

Модель NIMA требует распределения человеческих счетов, но набор данных LIVE предоставляет только среднее и стандартное отклонение распределения. Приблизите базовое распределение для каждого изображения в наборе данных LIVE с помощью createNIMAScoreDistribution вспомогательная функция. Эта функция присоединена к примеру как вспомогательный файл.

The createNIMAScoreDistribution пересчитывает счета в область значений [1, 10], затем генерирует максимальное энтропийное распределение счетов из среднего и стандартного значений отклонения.

newMaxScore = 10;
prob = createNIMAScoreDistribution(meanData,stdData);
cumProb = cumsum(prob,2);

Создайте arrayDatastore который управляет распределениями счетов.

probDS = arrayDatastore(cumProb','IterationDimension',2); 

Объедините хранилища данных, содержащие данные изображения и данные распределения счетов.

dsCombined = combine(imds,probDS);

Предварительный просмотр выхода чтения из комбинированного datastore.

sampleRead = preview(dsCombined)
sampleRead=1×2 cell array
    {256×256×3 uint8}    {10×1 double}

figure
tiledlayout(1,2)
nexttile
imshow(sampleRead{1})
title("Sample Image from Data Set")
nexttile
plot(sampleRead{2})
title("Cumulative Score Distribution")

Разделение данных для обучения, валидации и проверки

Разделите данные на наборы для обучения, валидации и тестирования. Выделите 70% данных для обучения, 15% для валидации и оставшуюся часть для проверки.

numTrain = floor(0.70 * nImg);
numVal = floor(0.15 * nImg);

Idx = randperm(nImg);
idxTrain = Idx(1:numTrain);
idxVal = Idx(numTrain+1:numTrain+numVal);
idxTest = Idx(numTrain+numVal+1:nImg);

dsTrain = subset(dsCombined,idxTrain);
dsVal = subset(dsCombined,idxVal);
dsTest = subset(dsCombined,idxTest);

Увеличение обучающих данных

Увеличение обучающих данных с помощью augmentImageTest вспомогательная функция. Эта функция присоединена к примеру как вспомогательный файл. The augmentDataForNIMA функция выполняет эти операции увеличения для каждого обучающего изображения:

  • Обрезать изображение до 224 на 244 пикселей, чтобы уменьшить сверхподбор кривой.

  • Разверните изображение горизонтально с вероятностью 50%.

inputSize = [224 224];
dsTrain = transform(dsTrain,@(x)augmentDataForNIMA(x,inputSize));

Вычислите Набор обучающих данных статистику для нормализации Входа

Слой входа сети выполняет z-балльную нормализацию обучающих изображений. Вычислите среднее и стандартное отклонение обучающих изображений для использования в нормализации z-балла.

meanImage = zeros([inputSize 3]);
meanImageSq = zeros([inputSize 3]);
while hasdata(dsTrain)
    dat = read(dsTrain);
    img = double(dat{1});
    meanImage = meanImage + img;
    meanImageSq = meanImageSq + img.^2;
end
meanImage = meanImage/numTrain;
meanImageSq = meanImageSq/numTrain;
varImage = meanImageSq - meanImage.^2;
stdImage = sqrt(varImage);

Установите datastore в начальное состояние.

reset(dsTrain);

Загрузка и изменение MobileNet-v2 сети

Этот пример начинается с MobileNet-v2 [3] CNN, обученного на ImageNet [4]. Пример модифицирует сеть путем замены последнего слоя MobileNet-v2 сети полносвязным слоем с 10 нейронами, каждый из которых представляет дискретный счет от 1 до 10. Сеть предсказывает вероятность каждого счета для каждого изображения. Пример нормализует выходы полностью соединенного слоя с помощью слоя активации softmax.

The mobilenetv2 (Deep Learning Toolbox) функция возвращает предварительно обученную MobileNet-v2 сеть. Эта функция требует пакета Deep Learning Toolbox™ Model for MobileNet-v2 Network поддержки. Если этот пакет поддержки не установлен, то функция предоставляет ссылку на загрузку.

net = mobilenetv2;

Преобразуйте сеть в layerGraph (Deep Learning Toolbox) объект.

lgraph = layerGraph(net);

Сеть имеет размер входа изображения 224 на 224 пикселя. Замените слой входа слоем входа изображений, который выполняет нормализацию z-балла по данным изображения с помощью среднего и стандартного отклонения обучающих изображений.

inLayer = imageInputLayer([inputSize 3],'Name','input','Normalization','zscore','Mean',meanImage,'StandardDeviation',stdImage);
lgraph = replaceLayer(lgraph,'input_1',inLayer);

Замените исходный конечный классификационный слой полносвязным слоем с 10 нейронами. Добавьте слой softmax, чтобы нормализовать выходы. Установите скорость обучения для полносвязного слоя в 10 раз выше скорости обучения для базовых слоев CNN. Нанесите выпадение 75%.

lgraph = removeLayers(lgraph,{'ClassificationLayer_Logits','Logits_softmax','Logits'});
newFinalLayers = [
    dropoutLayer(0.75,'Name','drop')
    fullyConnectedLayer(newMaxScore,'Name','fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10)
    softmaxLayer('Name','prob')];    
lgraph = addLayers(lgraph,newFinalLayers);
lgraph = connectLayers(lgraph,'global_average_pooling2d_1','drop');
dlnet = dlnetwork(lgraph);

Визуализируйте сеть с помощью приложения Deep Network Designer (Deep Learning Toolbox).

deepNetworkDesigner(lgraph)

Задайте градиенты модели и функции потерь

The modelGradients Функция helper вычисляет градиенты и потери для каждой итерации настройки сети. Эта функция определяется в разделе Вспомогательные функции этого примера.

Цель сети NIMA состоит в том, чтобы минимизировать расстояние земного движителя (EMD) между основной истиной и предсказанным распределениями счетом. Потеря EMD учитывает расстояние между классами при наказании за неправильную классификацию. Поэтому потеря EMD работает лучше, чем типичная потеря перекрестной энтропии softmax, используемая в задачах классификации [5]. В этом примере вычисляются потери EMD с помощью earthMoverDistance вспомогательная функция, которая задана в разделе «Вспомогательные функции» этого примера.

Для функции потерь EMD используйте расстояние r-нормы с r = 2. Это расстояние позволяет легко оптимизировать, когда вы работаете с градиентным спуском.

Настройка опций обучения

Задайте опции для оптимизации SGDM. Обучите сеть на 150 эпох с мини-пакетом размером 128.

numEpochs = 150;
miniBatchSize = 128;
momentum = 0.9;
initialLearnRate = 3e-3;
decay = 0.95;

Пакетные обучающие данные

Создайте minibatchqueue (Deep Learning Toolbox) объект, который управляет мини-пакетированием наблюдений в пользовательском цикле обучения. The minibatchqueue объект также переводит данные в dlarray (Deep Learning Toolbox) объект, который включает автоматическую дифференциацию в применениях глубокого обучения.

Задайте мини-формат извлечения данных пакета как 'SSCB' (пространственный, пространственный, канальный, пакетный). Установите значение 'DispatchInBackground' аргумент имя-значение логического элемента, возвращенный canUseGPU. Если поддерживаемый графический процессор доступен для расчетов, то minibatchqueue объект обрабатывает мини-пакеты в фоновом режиме в параллельном пуле во время обучения.

mbqTrain = minibatchqueue(dsTrain,'MiniBatchSize',miniBatchSize, ...
    'PartialMiniBatch','discard','MiniBatchFormat',{'SSCB',''}, ...
    'DispatchInBackground',canUseGPU);
mbqVal = minibatchqueue(dsVal,'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFormat',{'SSCB',''},'DispatchInBackground',canUseGPU);

Обучите сеть

По умолчанию пример загружает предварительно обученную версию сети NIMA. Предварительно обученная сеть позволяет запускать весь пример, не дожидаясь завершения обучения.

Чтобы обучить сеть, установите doTraining переменная в следующем коде, для true. Обучите модель в пользовательском цикле обучения. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next (Deep Learning Toolbox) функция.

  • Оцените градиенты модели с помощью dlfeval (Deep Learning Toolbox) функцию и modelGradients вспомогательная функция.

  • Обновляйте параметры сети с помощью sgdmupdate (Deep Learning Toolbox) функция.

Обучите на графическом процессоре, если он доступен. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения дополнительной информации смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

doTraining = false;
if doTraining
    iteration = 0;
    velocity = [];
    start = tic;
    
    [hFig,lineLossTrain,lineLossVal] = initializeTrainingPlotNIMA;
    
    for epoch = 1:numEpochs
        
        shuffle (mbqTrain);    
        learnRate = initialLearnRate/(1+decay*floor(epoch/10));
        
        while hasdata(mbqTrain)
            iteration = iteration + 1;        
            [dlX,cdfY] = next(mbqTrain);
            [grad,loss] = dlfeval(@modelGradients,dlnet,dlX,cdfY);        
            [dlnet,velocity] = sgdmupdate(dlnet,grad,velocity,learnRate,momentum);
            
            updateTrainingPlotNIMA(lineLossTrain,loss,epoch,iteration,start)      
            
        end
        
        % Add validation data to plot
        [~,lossVal,~] = modelPredictions(dlnet,mbqVal);
        updateTrainingPlotNIMA(lineLossVal,lossVal,epoch,iteration,start) 
        
    end
    
    % Save the trained network
    modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss"));
    save(strcat("trainedNIMA-",modelDateTime,"-Epoch-",num2str(numEpochs),".mat"),'dlnet');

else
    load(fullfile(imageDir,'trainedNIMA.mat'));    
end

Оценка модели NIMA

Оцените эффективность модели на наборе тестовых данных с помощью трех метрик: EMD, бинарной точности классификации и коэффициентов корреляции. Производительность сети NIMA на наборе тестовых данных согласуется с эффективностью эталонной модели NIMA, сообщенной Талеби и Миланфаром [1].

Создайте minibatchqueue (Deep Learning Toolbox) объект, который управляет мини-пакетированием тестовых данных.

mbqTest = minibatchqueue(dsTest,'MiniBatchSize',miniBatchSize,'MiniBatchFormat',{'SSCB',''});

Вычислите предсказанные вероятности и основную истину совокупные вероятности мини-пакетов тестовых данных с помощью modelPredictions функция. Эта функция определяется в разделе Вспомогательные функции этого примера.

[YPredTest,~,cdfYTest] = modelPredictions(dlnet,mbqTest);

Вычислите среднее и стандартные значения отклонений основной истины и предсказанных распределений.

meanPred = extractdata(YPredTest)' * (1:10)';
stdPred = sqrt(extractdata(YPredTest)'*((1:10).^2)' - meanPred.^2);
origCdf = extractdata(cdfYTest);
origPdf = [origCdf(1,:); diff(origCdf)];
meanOrig = origPdf' * (1:10)';
stdOrig = sqrt(origPdf'*((1:10).^2)' - meanOrig.^2);

Вычисление EMD

Вычислите EMD основной истины и предсказанных распределений счета. Для предсказания используйте расстояние r-нормы с r = 1. Значение EMD указывает на близость предсказанного и основная истина рейтинговых распределений.

EMDTest = earthMoverDistance(YPredTest,cdfYTest,1)
EMDTest = 
  1×1 single gpuArray dlarray

    0.1158

Вычисление точности двоичной классификации

Для бинарной точности классификации преобразуйте распределения в две классификации: качественную и низкокачественную. Классифицируйте изображения со средним счетом, большим порога, как высококачественные.

qualityThreshold = 5;
binaryPred = meanPred > qualityThreshold;    
binaryOrig = meanOrig > qualityThreshold;

Вычислите точность двоичной классификации.

binaryAccuracy = 100 * sum(binaryPred==binaryOrig)/length(binaryPred)
binaryAccuracy =

   84.6591

Вычислите коэффициенты корреляции

Большие значения корреляции указывают на большую положительную корреляцию между основной истиной и предсказанными счетами. Вычислите коэффициент линейной корреляции (LCC) и коэффициент ранговой корреляции Спирмана (SRCC) для средних счетов.

meanLCC = corr(meanOrig,meanPred)
meanLCC =

  gpuArray single

    0.7265
meanSRCC = corr(meanOrig,meanPred,'type','Spearman')
meanSRCC =

  gpuArray single

    0.6451

Вспомогательные функции

Функция градиентов модели

The modelGradients функция принимает как вход dlnetwork dlnet объекта и мини-пакет входных данных dlX с соответствующими целевыми совокупными вероятностями cdfY. Функция возвращает градиенты потерь относительно настраиваемых параметров в dlnet а также потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,cdfY)
    dlYPred = forward(dlnet,dlX);    
    loss = earthMoverDistance(dlYPred,cdfY,2);    
    gradients = dlgradient(loss,dlnet.Learnables);    
end

Функция потерь

The earthMoverDistance функция вычисляет EMD между основной истиной и предсказанным распределениями для заданного значения r-нормы. The earthMoverDistance использует computeCDF вспомогательная функция для вычисления совокупных вероятностей предсказанного распределения.

function loss = earthMoverDistance(YPred,cdfY,r)
    N = size(cdfY,1);
    cdfYPred = computeCDF(YPred);
    cdfDiff = (1/N) * (abs(cdfY - cdfYPred).^r);
    lossArray = sum(cdfDiff,1).^(1/r);
    loss = mean(lossArray);
    
end
function cdfY = computeCDF(Y)
% Given a probability mass function Y, compute the cumulative probabilities
    [N,miniBatchSize] = size(Y);
    L = repmat(triu(ones(N)),1,1,miniBatchSize);
    L3d = permute(L,[1 3 2]);
    prod = Y.*L3d;
    prodSum = sum(prod,1);
    cdfY = reshape(prodSum(:)',miniBatchSize,N)';
end

Функция предсказаний модели

The modelPredictions функция вычисляет предполагаемые вероятности, потери и основную истину совокупные вероятности мини-пакетов данных.

function [dlYPred,loss,cdfYOrig] = modelPredictions(dlnet,mbq)
    reset(mbq);
    loss = 0;
    numObservations = 0;
    dlYPred = [];    
    cdfYOrig = [];
    
    while hasdata(mbq)     
        [dlX,cdfY] = next(mbq);
        miniBatchSize = size(dlX,4);
        
        dlY = predict(dlnet,dlX);
        loss = loss + earthMoverDistance(dlY,cdfY,2)*miniBatchSize;
        dlYPred = [dlYPred dlY];
        cdfYOrig = [cdfYOrig cdfY];
        
        numObservations = numObservations + miniBatchSize;
        
    end
    loss = loss / numObservations;
end

Ссылки

[1] Талеби, Хоссейн и Пейман Миланфар. NIMA: оценка нейронного изображения. Транзакции IEEE по обработке изображений 27, № 8 (август 2018): 3998-4011. https://doi.org/10.1109/TIP.2018.2831899.

[2] LIVE: Лаборатория инженерии изображений и видео. «LIVE In The Wild Image Quality Challenge Database». https://live.ece.utexas.edu/research/ChallengeDB/index.html.

[3] Сэндлер, Марк, Эндрю Говард, Менглун Чжу, Андрей Жмогинов и Лян-Чих Чен. «MobileNetV2: перевёрнутые невязки и линейные узкие места». В 2018 году IEEE/CVF Conference on Компьютерное Зрение and Pattern Recognition, 4510-20. Солт-Лейк-Сити, UT: IEEE, 2018. https://doi.org/10.1109/CVPR.2018.00474.

[4] ImageNet. http://www.image-net.org.

[5] Hou, Le, Chen-Ping Yu и Димитрис Самарас. Squared Earth Mover's Distance-Based Loss for Training Deep Neural Networks (неопр.) (недоступная ссылка). Препринт, представленный 30 ноября 2016 года. https://arxiv.org/abs/1611.05916.

См. также

| (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте