exponenta event banner

Регрессия между изображениями в конструкторе глубоких сетей

В этом примере показано, как использовать Deep Network Designer для построения и обучения сети регрессии «изображение-изображение» для суперразрешения.

Пространственное разрешение - это количество пикселей, используемых для построения цифрового изображения. Изображение с высоким пространственным разрешением состоит из большего числа пикселей, и в результате изображение содержит большую детализацию. Суперразрешение - это процесс ввода изображения с низким разрешением и преобразования его в изображение с более высоким разрешением. При работе с данными изображения можно уменьшить пространственное разрешение для уменьшения размера данных за счет потери информации. Чтобы восстановить эту потерянную информацию, вы можете обучить сеть глубокого обучения предсказывать недостающие детали изображения. В этом примере выполняется восстановление изображений размером 28 на 28 пикселей из изображений, которые были сжаты до 7 на 7 пикселей.

Загрузить данные

В этом примере используется набор данных цифр, состоящий из 10 000 синтетических изображений рукописных цифр в градациях серого. Каждое изображение составляет 28 на 28 на 1 пиксель.

Загрузите данные и создайте хранилище данных образа.

dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset');

imds = imageDatastore(dataFolder, ...
    'IncludeSubfolders',true, ....
    'LabelSource','foldernames');

Используйте shuffle функция для перетасовки данных перед обучением.

imds = shuffle(imds);

Используйте splitEachLabel функция для разделения хранилища данных образа на три хранилища данных образа, содержащие изображения для обучения, проверки и тестирования.

[imdsTrain,imdsVal,imdsTest] = splitEachLabel(imds,0.7,0.15,0.15,'randomized');

Нормализуйте данные в каждом изображении в диапазоне [0,1]. Нормализация помогает стабилизировать и ускорить обучение сети с использованием градиентного спуска. Если ваши данные плохо масштабированы, то потеря может стать NaN и сетевые параметры могут расходиться во время обучения.

imdsTrain = transform(imdsTrain,@(x) rescale(x));
imdsVal = transform(imdsVal,@(x) rescale(x));
imdsTest = transform(imdsTest,@(x) rescale(x));

Создание данных обучения

Создайте набор обучающих данных путем генерации пар изображений, состоящих из изображений с пониженным разрешением и соответствующих изображений с высоким разрешением.

Чтобы обучить сеть выполнять регрессию изображения к изображению, изображения должны быть парами, состоящими из ввода и отклика, где оба изображения имеют одинаковый размер. Создайте обучающие данные, уменьшив выборку каждого изображения до 7 на 7 пикселей, а затем увеличив выборку до 28 на 28 пикселей. Используя пары преобразованных и оригинальных изображений, сеть может научиться сопоставлять два разных разрешения.

Создание входных данных с помощью вспомогательной функции upsampLowRes, которая использует imresize для создания изображений с более низким разрешением.

imdsInputTrain = transform(imdsTrain,@upsampLowRes);
imdsInputVal= transform(imdsVal,@upsampLowRes);
imdsInputTest = transform(imdsTest,@upsampLowRes);

Используйте combine функция для объединения изображений с низким и высоким разрешением в одном хранилище данных. Выходные данные combine функция является CombinedDatastore объект.

dsTrain = combine(imdsInputTrain,imdsTrain);
dsVal = combine(imdsInputVal,imdsVal);
dsTest = combine(imdsInputTest,imdsTest);

Создание сетевой архитектуры

Создайте сетевую архитектуру с помощью unetLayers функция от Computer Vision Toolbox™. Эта функция обеспечивает сеть, подходящую для семантической сегментации, которая может быть легко адаптирована для регрессии между изображениями.

Создайте сеть с входным размером 28 на 28 на 1 пиксель.

layers = unetLayers([28,28,1],2,'encoderDepth',2);

Отредактируйте сеть для регрессии между изображениями с помощью Deep Network Designer.

deepNetworkDesigner(layers);

На панели «Конструктор» замените слои «softmax» и «pixel classification» на слой регрессии из библиотеки слоев.

Выберите окончательный сверточный слой и установите NumFilters свойство для 1.

Сеть сейчас готова к обучению.

Импорт данных

Импортируйте данные обучения и проверки в Deep Network Designer.

На вкладке Данные щелкните Импорт данных > Импорт хранилища данных и выберите dsTrain в качестве данных обучения и dsVal в качестве данных проверки. Импортируйте оба хранилища данных, щелкнув Импорт.

Deep Network Designer отображает пары изображений в объединенном хранилище данных. Входные изображения с более высоким разрешением находятся слева, а исходные ответные изображения с высоким разрешением - справа. Сеть учится сопоставлять входные и ответные изображения.

Железнодорожная сеть

Выберите параметры обучения и обучите сеть.

На вкладке Обучение выберите Параметры обучения. В списке Решатель выберите adam. Задать для MaxEpochs значение 10. Подтвердите параметры обучения, нажав Закрыть.

Выполните обучение сети в объединенном хранилище данных, нажав кнопку Train.

По мере того, как сеть учится сопоставлять два изображения, среднеквадратичная ошибка (RMSE) уменьшается.

После завершения обучения щелкните Экспорт (Export), чтобы экспортировать обученную сеть в рабочую область. Обученная сеть хранится в переменной trainedNetwork_1.

Тестовая сеть

Оцените производительность сети с помощью тестовых данных.

Используя predict, можно проверить, может ли сеть создавать изображение с высоким разрешением из входного изображения с низким разрешением, которое не было включено в обучающий набор.

ypred = predict(trainedNetwork_1,dsTest);

for i = 1:8
    I(1:2,i) = read(dsTest);
    I(3,i) = {ypred(:,:,:,i)};
end

Сравните входные, прогнозируемые и ответные изображения.

subplot(1,3,1)
imshow(imtile(I(1,:),'GridSize',[8,1]))
title('Input')
subplot(1,3,2)
imshow(imtile(I(3,:),'GridSize',[8,1]))
title('Predict')
subplot(1,3,3)
imshow(imtile(I(2,:),'GridSize',[8,1]))
title('Response')

Сеть успешно создает изображения с высоким разрешением на входах с низким разрешением.

Сеть в этом примере очень проста и сильно адаптирована к набору данных цифр. Пример создания более сложной сети регрессии изображения к изображению для повседневных изображений см. в разделе Суперразрешение одного изображения с помощью глубокого обучения.

Вспомогательные функции

function dataOut = upsampLowRes(dataIn)
        temp = dataIn;
        temp = imresize(temp,[7,7],'method','bilinear');
        dataOut = {imresize(temp,[28,28],'method','bilinear')};
end

См. также

|

Связанные темы