Классификация цифр с рассеиванием вейвлета

В этом примере показано, как использовать вейвлет, рассеивающийся в классификации изображений. Этот пример требует Wavelet Toolbox™, Deep Learning Toolbox™ и Parallel Computing Toolbox™.

Для проблем классификации часто полезно сопоставить данные в некоторое альтернативное представление, которое отбрасывает несоответствующую информацию при сохранении отличительных свойств каждого класса. Изображение вейвлета, рассеивающее представления низкого отклонения построений изображений, которые нечувствительны к переводам и маленьким деформациям. Поскольку переводы и маленькие деформации в изображении не влияют на членство в классе, рассеивающийся преобразовывают коэффициенты, обеспечивают функции, от которых можно создать устойчивые модели классификации.

Рассеивание вейвлета работает путем расположения каскадом изображения через серию вейвлета, преобразовывает, нелинейность и усреднение [1] [3] [4]. Результат этого глубокого извлечения признаков состоит в том, что изображения в том же классе подвинулись поближе друг к другу в рассеивающемся, преобразовывают представление, в то время как изображения, принадлежащие различным классам, перемещены дальше независимо. В то время как рассеивание вейвлета преобразовывает, имеет много архитектурных общих черт с глубокими сверточными нейронными сетями, включая операторы свертки, нелинейность и усреднение, фильтры в рассеивающемся преобразовании предопределены и зафиксированы.

Изображения цифры

Набор данных, используемый в этом примере, содержит 10 000 синтетических изображений цифр от 0 до 9. Изображения сгенерированы путем применения случайных преобразований к изображениям тех цифр, созданных с различными шрифтами. Каждое изображение цифры является 28 28 пикселями. Набор данных содержит равное количество изображений на категорию. Используйте imageDataStore считать изображения.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
Imds = imageDatastore(digitDatasetPath,'IncludeSubfolders',true, 'LabelSource','foldernames');

Случайным образом выберите и постройте 20 изображений от набора данных.

figure
numImages = 10000;
rng(100);
perm = randperm(numImages,20);
for np = 1:20
    subplot(4,5,np);
    imshow(Imds.Files{perm(np)});
end

Вы видите, что 8 с показывают значительную изменчивость в то время как весь являющийся идентифицирующимся как 8. То же самое верно для других повторных цифр в выборке. Это сопоставимо с естественным почерком, где любая цифра отличается нетривиально между индивидуумами и даже в почерке того же индивидуума относительно перевода, вращения и других маленьких деформаций. Используя рассеивание вейвлета, мы надеемся создать представления этих цифр, которые затеняют эту несоответствующую изменчивость.

Извлечение признаков рассеивания вейвлета изображений

Синтетические изображения 28 28. Создайте среду рассеивания вейвлета изображений и установите шкалу инвариантности равняться размеру изображения. Определите номер вращений к 8 в каждом двух вейвлетах, рассеивающих наборы фильтров. Конструкция среды рассеивания вейвлета требует, чтобы мы установили только два гиперпараметра: InvarianceScale и NumRotations.

sf = waveletScattering2('ImageSize',[28 28],'InvarianceScale',28, ...
    'NumRotations',[8 8]);

Этот пример использует возможность параллельной обработки MATLAB™ через tall интерфейс массивов. Если параллельный пул в настоящее время не запускается, можно запустить один путем выполнения следующего кода. В качестве альтернативы в первый раз вы создаете tall массив, параллельный пул создается.

if isempty(gcp)
    parpool;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

Для воспроизводимости установите генератор случайных чисел. Переставьте файлы imageDatastore и разделение 10 000 изображений в два набора, один для обучения и один протянуло набор для тестирования. Выделите 80% данных или 8 000 изображений, к набору обучающих данных и протяните остающиеся 2 000 изображений для тестирования. Создайте tall массивы от обучения и тестовых наборов данных. Используйте функцию помощника helperScatImages чтобы создать характеристические векторы из рассеивающегося преобразовывают коэффициенты. helperScatImages получает журнал рассеивающегося, преобразовывают матрицу функции, а также среднее значение вдоль обоих размерности строки и столбца каждого изображения. Код для helperScatImages в конце этого примера. Для каждого изображения в этом примере функция помощника создает 217 1 характеристический вектор.

rng(10);
Imds = shuffle(Imds);
[trainImds,testImds] = splitEachLabel(Imds,0.8);
Ttrain = tall(trainImds);
Ttest = tall(testImds);
trainfeatures = cellfun(@(x)helperScatImages(sf,x),Ttrain,'UniformOutput',false);
testfeatures = cellfun(@(x)helperScatImages(sf,x),Ttest,'UniformOutput',false);

Используйте tallgather возможность конкатенировать все обучение и тестовые функции.

Trainf = gather(trainfeatures);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 3 min 51 sec
Evaluation completed in 3 min 51 sec
trainfeatures = cat(2,Trainf{:});
Testf = gather(testfeatures);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 49 sec
Evaluation completed in 49 sec
testfeatures = cat(2,Testf{:});

Предыдущий код приводит к двум матрицам с размерностями строки 217 и размерностью столбца, равной количеству изображений в наборах обучающих данных и наборах тестов соответственно. Соответственно, каждый столбец является характеристическим вектором для своего соответствующего изображения. Оригинальные изображения содержали 784 элемента. Рассеивающиеся коэффициенты представляют аппроксимированное 4-кратное сокращение размера каждого изображения.

Модель PCA и предсказание

Этот пример создает простой классификатор на основе основных компонентов рассеивающихся характеристических векторов для каждого класса. Классификатор реализован в функциях helperPCAModel и helperPCAClassifier. helperPCAModel определяет основные компоненты для каждого класса цифры на основе рассеивающихся функций. Код для helperPCAModel в конце этого примера. helperPCAClassifier классифицирует протянутые тестовые данные путем нахождения самого близкого соответствия (лучшая проекция) между основными компонентами каждого тестового характеристического вектора с набором обучающих данных и присвоения класса соответственно. Код для helperPCAClassifier в конце этого примера.

model = helperPCAModel(trainfeatures,30,trainImds.Labels);
predlabels = helperPCAClassifier(testfeatures,model);

После построения модели и классификации набора тестов, определите точность классификации наборов тестов.

accuracy = sum(testImds.Labels == predlabels)./numel(testImds.Labels)*100
accuracy = 99.6000

Мы достигли правильной классификации на 99,6% тестовых данных. Чтобы видеть, как 2 000 тестовых изображений были классифицированы, постройте матрицу беспорядка. Существует 200 примеров в наборе тестов для каждого из этих 10 классов.

figure;
confusionchart(testImds.Labels,predlabels)
title('Test-Set Confusion Matrix -- Wavelet Scattering')

CNN

В этом разделе мы обучаем простую сверточную нейронную сеть (CNN) распознавать цифры. Создайте CNN, чтобы состоять из слоя свертки с 20 фильтрами 5 на 5 шагами 1 на 1. Следуйте за слоем свертки с активацией RELU и макс. объединением слоя. Используйте полносвязный слой, сопровождаемый softmax слоем, чтобы нормировать выход полносвязного слоя к вероятностям. Используйте функцию потери перекрестной энтропии в изучении.

imageSize = [28 28 1];
layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Используйте стохастический градиентный спуск с импульсом и скоростью обучения 0,0001 для обучения. Определите максимальный номер эпох к 20. Для воспроизводимости установите ExecutionEnvironment к 'cpu'.

options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'Plots','training-progress','ExecutionEnvironment','cpu');

Обучите сеть. Для обучения и тестирования мы используем те же наборы данных, используемые в рассеивающемся преобразовании.

reset(trainImds);
reset(testImds);
net = trainNetwork(trainImds,layers,options);

К концу обучения CNN выполняет близкие 100% на наборе обучающих данных. Используйте обучивший сеть, чтобы сделать предсказания на протянутом наборе тестов.

YPred = classify(net,testImds,'ExecutionEnvironment','cpu');
DCNNaccuracy = sum(YPred == testImds.Labels)/numel(YPred)*100
DCNNaccuracy = 95.5000

Простой CNN достиг правильной классификации на 95,5% на протянутом наборе тестов. Постройте график беспорядка для CNN.

figure;
confusionchart(testImds.Labels,YPred)
title('Test-Set Confusion Chart -- CNN')

Сводные данные

Этот пример использовал изображение вейвлета, рассеивающееся, чтобы создать представления низкого отклонения изображений цифры для классификации. Используя рассеивающееся преобразование с фиксированными весами фильтра и простым классификатором основных компонентов, мы достигли правильной классификации на 99,6% на протянутом наборе тестов. С простым CNN, в котором изучены фильтры, мы достигли правильных 95,5%. Этот пример не предназначается как прямое сравнение рассеивающегося преобразования и CNNs. Существуют несколько гиперпараметр и изменения в архитектуре, которые можно сделать в каждом случае, которые значительно влияют на результаты. Цель этого примера состояла в том, чтобы просто продемонстрировать, что потенциал глубоких экстракторов функции как рассеивание вейвлета преобразовывает, чтобы произвести устойчивые представления данных для изучения.

Ссылки

[1] Бруна, J. и С. Маллэт. "Инвариантные Сети Свертки Рассеивания". Транзакции IEEE согласно Анализу Шаблона и Искусственному интеллекту. Издание 35, Номер 8, 2013, стр 1872–1886.

[2] Mallat, S. "Рассеивание Инварианта группы". Коммуникации в Чистой и Прикладной математике. Издание 65, Номер 10, 2012, стр 1331–1398.

[3] Sifre, L. и С. Маллэт. "Вращение, масштабирование и инвариант деформации, рассеивающийся для дискриминации структуры". 2 013 Конференций по IEEE по Компьютерному зрению и Распознаванию образов. 2013, стр 1233–1240. 10.1109/CVPR.2013.163.

Приложение — поддерживание функций

helperScatImages

function features = helperScatImages(sf,x)
% This function is only to support examples in the Wavelet Toolbox.
% It may change or be removed in a future release.

% Copyright 2018 MathWorks

smat = featureMatrix(sf,x,'transform','log');
features = mean(mean(smat,2),3);
end

helperPCAModel

function model = helperPCAModel(features,M,Labels)
% This function is only to support wavelet image scattering examples in 
% Wavelet Toolbox. It may change or be removed in a future release.
% model = helperPCAModel(features,M,Labels)

% Copyright 2018 MathWorks

% Initialize structure array to hold the affine model
model = struct('Dim',[],'mu',[],'U',[],'Labels',categorical([]),'s',[]);
model.Dim = M;
% Obtain the number of classes
LabelCategories = categories(Labels);
Nclasses = numel(categories(Labels));
for kk = 1:Nclasses
    Class = LabelCategories{kk};
    % Find indices corresponding to each class
    idxClass = Labels == Class;
    % Extract feature vectors for each class
    tmpFeatures = features(:,idxClass);
    % Determine the mean for each class
    model.mu{kk} = mean(tmpFeatures,2);
    [model.U{kk},model.S{kk}] = scatPCA(tmpFeatures);
    if size(model.U{kk},2) > M
        model.U{kk} = model.U{kk}(:,1:M);
        model.S{kk} = model.S{kk}(1:M);
        
    end
    model.Labels(kk) = Class;
end
    
function [u,s,v] = scatPCA(x,M)
	% Calculate the principal components of x along the second dimension.

	if nargin > 1 && M > 0
		% If M is non-zero, calculate the first M principal components.
	    [u,s,v] = svds(x-sig_mean(x),M);
	    s = abs(diag(s)/sqrt(size(x,2)-1)).^2;
	else
		% Otherwise, calculate all the principal components.
        % Each row is an observation, i.e. the number of scattering paths
        % Each column is a class observation
		[u,d] = eig(cov(x'));
		[s,ind] = sort(diag(d),'descend');
		u = u(:,ind);
	end
end
end

helperPCAClassifier

function labels = helperPCAClassifier(features,model)
% This function is only to support wavelet image scattering examples in 
% Wavelet Toolbox. It may change or be removed in a future release.
% model is a structure array with fields, M, mu, v, and Labels
% features is the matrix of test data which is Ns-by-L, Ns is the number of
% scattering paths and L is the number of test examples. Each column of
% features is a test example.

% Copyright 2018 MathWorks

labelIdx = determineClass(features,model); 
labels = model.Labels(labelIdx); 
% Returns as column vector to agree with imageDatastore Labels
labels = labels(:);


%--------------------------------------------------------------------------
function labelIdx = determineClass(features,model)
% Determine number of classes
Nclasses = numel(model.Labels);
% Initialize error matrix
errMatrix = Inf(Nclasses,size(features,2));
for nc = 1:Nclasses
    % class centroid
    mu = model.mu{nc};
    u = model.U{nc};
    % 1-by-L
    errMatrix(nc,:) = projectionError(features,mu,u);
end
% Determine minimum along class dimension
[~,labelIdx] = min(errMatrix,[],1);   


%--------------------------------------------------------------------------
function totalerr = projectionError(features,mu,u)
    %
    Npc = size(u,2);
    L = size(features,2);
    % Subtract class mean: Ns-by-L minus Ns-by-1
    s = features-mu;
    % 1-by-L
    normSqX = sum(abs(s).^2,1)';
    err = Inf(Npc+1,L);
	err(1,:) = normSqX;
    err(2:end,:) = -abs(u'*s).^2;
    % 1-by-L
    totalerr = sqrt(sum(err,1));
end
end
end
	

Смотрите также

Похожие темы