Классификация цифр с рассеиванием вейвлета

В этом примере показано, как использовать вейвлет, рассеивающийся в классификации изображений. Этот пример требует Wavelet Toolbox™, Deep Learning Toolbox™ и Parallel Computing Toolbox™.

Для проблем классификации часто полезно сопоставить данные в некоторое альтернативное представление, которое отбрасывает несоответствующую информацию при сохранении отличительных свойств каждого класса. Изображение вейвлета, рассеивающее представления низкого отклонения построений изображений, которые нечувствительны к переводам и маленьким деформациям. Поскольку переводы и маленькие деформации в изображении не влияют на членство в классе, рассеивающийся преобразовывают коэффициенты, обеспечивают функции, от которых можно создать устойчивые модели классификации.

Рассеивание вейвлета работает путем расположения каскадом изображения через серию вейвлета, преобразовывает, нелинейность и усреднение [1] [3] [4]. Результат этого глубокого извлечения признаков состоит в том, что изображения в том же классе подвинулись поближе друг к другу в рассеивающемся, преобразовывают представление, в то время как изображения, принадлежащие различным классам, перемещены дальше независимо. В то время как рассеивание вейвлета преобразовывает, имеет много архитектурных общих черт с глубокими сверточными нейронными сетями, включая операторы свертки, нелинейность и усреднение, фильтры в рассеивающемся преобразовании предопределены и зафиксированы.

Изображения цифры

Набор данных, используемый в этом примере, содержит 10 000 синтетических изображений цифр от 0 до 9. Изображения сгенерированы путем применения случайных преобразований к изображениям тех цифр, созданных с различными шрифтами. Каждое изображение цифры является 28 28 пикселями. Набор данных содержит равное количество изображений на категорию. Используйте imageDataStore считать изображения.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
Imds = imageDatastore(digitDatasetPath,'IncludeSubfolders',true, 'LabelSource','foldernames');

Случайным образом выберите и постройте 20 изображений от набора данных.

figure
numImages = 10000;
rng(100);
perm = randperm(numImages,20);
for np = 1:20
    subplot(4,5,np);
    imshow(Imds.Files{perm(np)});
end

Вы видите, что 8's в строках 1, 2, 3, и 4 показывают значительную изменчивость в то время как весь являющийся идентифицирующимся как 8. То же самое верно для других повторных цифр в выборке. Это сопоставимо с естественным почерком, где любая цифра отличается нетривиально между индивидуумами и даже в почерке того же индивидуума относительно перевода, вращения и других маленьких деформаций. Используя рассеивание вейвлета, мы надеемся создать представления этих цифр, которые затеняют эту несоответствующую изменчивость.

Извлечение признаков рассеивания вейвлета изображений

Синтетические изображения 28 28. Создайте среду рассеивания вейвлета изображений и установите шкалу инвариантности равняться размеру изображения. Определите номер вращений к 8 в каждом двух вейвлетах, рассеивающих наборы фильтров. Конструкция среды рассеивания вейвлета требует, чтобы мы установили только два гиперпараметра: InvarianceScale и NumRotations.

sf = waveletScattering2('ImageSize',[28 28],'InvarianceScale',28, ...
    'NumRotations',[8 8]);

Этот пример использует возможность параллельной обработки MATLAB™ через tall массив interface. Можно запустить параллельный пул, если вы в настоящее время не запускаетесь со следующим кодом. В качестве альтернативы в первый раз вы создаете tall массив, параллельный пул создается.

if isempty(gcp)
    parpool;
end

Для воспроизводимости, набор генератор случайных чисел. Переставьте файлы imageDatastore и разделение 10 000 изображений в два набора, один для обучения и один протянуло набор для тестирования. Выделите 80% данных или 8 000 изображений, к набору обучающих данных и протяните остающиеся 2 000 изображений для тестирования. Создайте tall массивы от обучения и тестовых наборов данных. Используйте функцию помощника, helperScatImages, чтобы создать характеристические векторы из рассеивающегося преобразовывают коэффициенты. helperScatImages получает журнал рассеивающегося, преобразовывают матрицу функции, а также среднее значение вдоль обоих размерности строки и столбца каждого изображения. Для каждого изображения в этом примере это приводит к 217 1 характеристическому вектору.

rng(10);
Imds = shuffle(Imds);
[trainImds,testImds] = splitEachLabel(Imds,0.8);
Ttrain = tall(trainImds);
Ttest = tall(testImds);
trainfeatures = cellfun(@(x)helperScatImages(sf,x),Ttrain,'UniformOutput',false);
testfeatures = cellfun(@(x)helperScatImages(sf,x),Ttest,'UniformOutput',false);

Используйте tallgather возможность конкатенировать все обучение и тестовые функции.

Trainf = gather(trainfeatures);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 5 min 5 sec
Evaluation completed in 5 min 5 sec
trainfeatures = cat(2,Trainf{:});
Testf = gather(testfeatures);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 1 min 13 sec
Evaluation completed in 1 min 13 sec
testfeatures = cat(2,Testf{:});

Предыдущий код приводит к двум матрицам с размерностями строки 217 и размерностью столбца, равной количеству изображений в наборах обучающих данных и наборах тестов соответственно. Соответственно, каждый столбец является характеристическим вектором для своего соответствующего изображения. Обратите внимание на то, что оригинальные изображения содержали 784 элемента. Соответственно, рассеивающиеся коэффициенты представляют аппроксимированное 4-кратное сокращение размера каждого изображения.

Модель PCA и прогноз

Этот пример создает простой классификатор на основе основных компонентов рассеивающихся характеристических векторов для каждого класса. Классификатор реализован в функциях helperPCAModel и helperPCAClassifier. helperPCAModel определяет основные компоненты для каждого класса цифры на основе рассеивающихся функций. helperPCAClassifier классифицирует протянутые тестовые данные путем нахождения самого близкого соответствия (лучшая проекция) между основными компонентами каждого тестового характеристического вектора с набором обучающих данных и присвоения класса соответственно.

model = helperPCAModel(trainfeatures,30,trainImds.Labels);
predlabels = helperPCAClassifier(testfeatures,model);

После построения модели и классификации набора тестов, определите точность классификации наборов тестов.

accuracy = sum(testImds.Labels == predlabels)./numel(testImds.Labels)*100
accuracy = 99.6000

Мы достигли правильной классификации на 99,6% тестовых данных. Постройте матрицу беспорядка, чтобы видеть, как 2 000 тестовых изображений были классифицированы. Существует 200 примеров в наборе тестов для каждого из этих 10 классов.

figure;
confusionchart(testImds.Labels,predlabels);
title('Test-Set Confusion Matrix -- Wavelet Scattering')

CNN

В этом разделе мы обучаем простую сверточную нейронную сеть (CNN) распознавать цифры. Создайте CNN, чтобы состоять из слоя свертки с 20 фильтрами 5 на 5 шагами 1 на 1. Следуйте за слоем свертки с активацией RELU и макс. объединением слоя. Используйте полносвязный слой, сопровождаемый softmax слоем, чтобы нормировать выход полносвязного слоя к вероятностям. Используйте перекрестную энтропийную функцию потерь в изучении.

imageSize = [28 28 1];
layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Используйте стохастический градиентный спуск с импульсом и темпом обучения 0,0001 для обучения. Определите максимальный номер эпох к 20. Для воспроизводимости, набор ExecutionEnvironment к 'cpu'.

options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'Plots','training-progress','ExecutionEnvironment','cpu');

Обучите сеть. Для обучения и тестирования мы используем те же наборы данных, используемые в рассеивающемся преобразовании.

reset(trainImds);
reset(testImds);
net = trainNetwork(trainImds,layers,options);

К концу обучения CNN выполняет близкие 100% на наборе обучающих данных. Используйте обучивший сеть, чтобы сделать прогнозы на протянутом наборе тестов.

YPred = classify(net,testImds,'ExecutionEnvironment','cpu');
DCNNaccuracy = sum(YPred == testImds.Labels)/numel(YPred)*100
DCNNaccuracy = 95.1500

Простой CNN достиг правильной классификации на 95,15% на протянутом наборе тестов. Постройте график беспорядка для CNN.

figure;
confusionchart(testImds.Labels,YPred);
title('Test-Set Confusion Chart -- CNN')

Сводные данные

Этот пример использовал изображение вейвлета, рассеивающееся, чтобы создать представления низкого отклонения изображений цифры для классификации. Используя рассеивающееся преобразование с фиксированными весами фильтра и простым классификатором основных компонентов, мы достигли правильной классификации на 99,6% на протянутом наборе тестов. С простым CNN, в котором изучены фильтры, мы достигли правильных 95,15%. Этот пример не предназначается как прямое сравнение рассеивающегося преобразования и CNNs. Существуют несколько гиперпараметр и изменения в архитектуре, которые можно сделать в каждом случае, которые значительно влияют на результаты. Цель этого примера состояла в том, чтобы просто продемонстрировать, что потенциал глубоких экстракторов функции как рассеивание вейвлета преобразовывает, чтобы произвести устойчивые представления данных для изучения.

Ссылки

[1] Бруна, J. и Mallat, S. (2013) "Инвариантные рассеивающиеся сверточные сети", Транзакции IEEE согласно Анализу Шаблона и Искусственному интеллекту, 35 (8), стр 1872-1886.

[2] Mallat, S. (2012) "Рассеивание инварианта группы", Коммуникации на Чистой и Прикладной математике, 65, стр 1331-1398.

[3] Sifre, L. и Mallat, S. (2013). "Вращение, масштабирование и инвариант деформации, рассеивающийся для дискриминации структуры", Продолжения / CVPR, Конференция Общества эпохи компьютеризации IEEE по Компьютерному зрению и Распознаванию образов. Конференция Общества эпохи компьютеризации IEEE по Компьютерному зрению и Распознаванию образов. 1233-1240. 10.1109/CVPR.2013.163.

Смотрите также

Похожие темы