exponenta event banner

Классификация цифр с вейвлет-рассеянием

В этом примере показано, как использовать вейвлет-рассеяние для классификации изображений. В этом примере требуются Wavelet Toolbox™, Deep Learning Toolbox™ и Parallel Computing Toolbox™.

Для проблем классификации часто полезно отобразить данные в некоторое альтернативное представление, которое отбрасывает неактуальную информацию, сохраняя при этом дискриминационные свойства каждого класса. Вейвлет-рассеяние изображений создает малодисперсионные представления изображений, которые нечувствительны к трансляциям и малым деформациям. Так как трансляции и небольшие деформации в изображении не влияют на принадлежность к классу, коэффициенты преобразования рассеяния обеспечивают функции, из которых можно построить надежные модели классификации.

Вейвлет-рассеяние работает путем каскадирования изображения посредством серии вейвлет-преобразований, нелинейностей и усреднения [1] [3] [4]. Результатом этого глубокого извлечения признаков является то, что изображения в одном классе перемещаются ближе друг к другу в представлении преобразования рассеяния, в то время как изображения, принадлежащие различным классам, перемещаются дальше друг от друга. В то время как вейвлет-рассеивающее преобразование имеет ряд архитектурных сходств с глубокими сверточными нейронными сетями, включая операторы свертки, нелинейности и усреднения, фильтры в рассеивающем преобразовании заранее определены и фиксированы.

Цифровые изображения

Набор данных, используемый в этом примере, содержит 10000 синтетических изображений цифр от 0 до 9. Изображения генерируются путем применения случайных преобразований к изображениям цифр, созданных различными шрифтами. Каждая цифра изображения составляет 28 на 28 пикселей. Набор данных содержит равное количество изображений для каждой категории. Используйте imageDataStore для чтения изображений.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
Imds = imageDatastore(digitDatasetPath,'IncludeSubfolders',true, 'LabelSource','foldernames');

Случайный выбор и печать 20 изображений из набора данных.

figure
numImages = 10000;
rng(100);
perm = randperm(numImages,20);
for np = 1:20
    subplot(4,5,np);
    imshow(Imds.Files{perm(np)});
end

Вы можете видеть, что 8 проявляют значительную изменчивость, в то время как все они идентифицируются как 8. То же самое относится и к другим повторяющимся цифрам в выборке. Это согласуется с естественным почерком, где любая цифра отличается нетривиально между индивидуумами и даже в пределах почерка одного и того же индивидуума в отношении перевода, вращения и других небольших деформаций. Используя вейвлет-рассеяние, мы надеемся построить представления этих цифр, которые скрывают эту неактуальную изменчивость.

Извлечение признаков рассеяния вейвлет-изображений

Синтетические изображения 28 на 28. Создайте структуру рассеяния вейвлет-изображения и установите шкалу инвариантности равной размеру изображения. Установите число поворотов равным 8 в каждом из двух блоков фильтров вейвлет-рассеяния. Построение структуры вейвлет-рассеяния требует, чтобы мы установили только два гиперпараметра: InvarianceScale и NumRotations.

sf = waveletScattering2('ImageSize',[28 28],'InvarianceScale',28, ...
    'NumRotations',[8 8]);

В этом примере используется MATLAB™'s параллельная обработка через tall интерфейс массива. Если параллельный пул в данный момент не запущен, его можно запустить, выполнив следующий код. Кроме того, при первом создании tall создается параллельный пул.

if isempty(gcp)
    parpool;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

Для воспроизводимости установите генератор случайных чисел. Перетасовка файлов imageDatastore и разбить 10 000 изображений на два набора: один для обучения и один удерживаемый набор для тестирования. Выделите 80% данных, или 8000 изображений, обучающему набору и удерживайте оставшиеся 2000 изображений для тестирования. Создать tall массивы из обучающих и тестовых наборов данных. Использовать функцию помощника helperScatImages для создания векторов признаков из коэффициентов преобразования рассеяния. helperScatImages получает логарифм матрицы признаков преобразования рассеяния, а также среднее значение по размерам строк и столбцов каждого изображения. Код для helperScatImages находится в конце этого примера. Для каждого изображения в этом примере вспомогательная функция создает вектор элемента 217 на 1.

rng(10);
Imds = shuffle(Imds);
[trainImds,testImds] = splitEachLabel(Imds,0.8);
Ttrain = tall(trainImds);
Ttest = tall(testImds);
trainfeatures = cellfun(@(x)helperScatImages(sf,x),Ttrain,'UniformOutput',false);
testfeatures = cellfun(@(x)helperScatImages(sf,x),Ttest,'UniformOutput',false);

Использовать tall gather возможность объединения всех функций обучения и тестирования.

Trainf = gather(trainfeatures);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 3 min 51 sec
Evaluation completed in 3 min 51 sec
trainfeatures = cat(2,Trainf{:});
Testf = gather(testfeatures);
Evaluating tall expression using the Parallel Pool 'local':
- Pass 1 of 1: Completed in 49 sec
Evaluation completed in 49 sec
testfeatures = cat(2,Testf{:});

Предыдущий код дает две матрицы с размерами 217 строки и размером столбца, равными количеству изображений в обучающем и тестовом наборах соответственно. Соответственно, каждый столбец является вектором признаков для соответствующего изображения. Оригинальные изображения содержали 784 элемента. Коэффициенты рассеяния представляют собой приблизительно 4-кратное уменьшение размера каждого изображения.

Модель и прогноз PCA

В этом примере создается простой классификатор на основе основных компонентов векторов элементов рассеяния для каждого класса. Классификатор реализован в функциях helperPCAModel и helperPCAClassifier. helperPCAModel определяет главные компоненты для каждого класса цифр на основе элементов рассеяния. Код для helperPCAModel находится в конце этого примера. helperPCAClassifier классифицирует задержанные тестовые данные путем нахождения ближайшего совпадения (наилучшая проекция) между основными компонентами каждого вектора признаков теста с обучающим набором и соответствующего назначения класса. Код для helperPCAClassifier находится в конце этого примера.

model = helperPCAModel(trainfeatures,30,trainImds.Labels);
predlabels = helperPCAClassifier(testfeatures,model);

После построения модели и классификации тестового набора определите точность классификации тестового набора.

accuracy = sum(testImds.Labels == predlabels)./numel(testImds.Labels)*100
accuracy = 99.6000

Мы достигли 99,6% правильной классификации данных теста. Чтобы увидеть, как были классифицированы 2000 тестовых изображений, постройте график матрицы путаницы. В наборе тестов имеется 200 примеров для каждого из 10 классов.

figure;
confusionchart(testImds.Labels,predlabels)
title('Test-Set Confusion Matrix -- Wavelet Scattering')

CNN

В этом разделе мы обучаем простую сверточную нейронную сеть (CNN) распознавать цифры. Создайте CNN, состоящий из слоя свертки с 20 фильтрами 5 на 5 с шагом 1 на 1. Следуйте за уровнем свертки с уровнем активации и максимального объединения RELU. Используйте полностью подключенный уровень, а затем уровень softmax, чтобы нормализовать выход полностью подключенного уровня до вероятностей. Используйте функцию потери перекрестной энтропии для обучения.

imageSize = [28 28 1];
layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Используйте стохастический градиентный спуск с импульсом и скоростью обучения 0,0001 для обучения. Установите максимальное количество эпох равным 20. Для воспроизводимости установите значение ExecutionEnvironment кому 'cpu'.

options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'Plots','training-progress','ExecutionEnvironment','cpu');

Обучение сети. Для обучения и тестирования мы используем те же наборы данных, что и при преобразовании рассеяния.

reset(trainImds);
reset(testImds);
net = trainNetwork(trainImds,layers,options);

К концу обучения CNN выполняет почти 100% на тренировочном наборе. Используйте обученную сеть, чтобы сделать прогнозы на задержанном тестовом наборе.

YPred = classify(net,testImds,'ExecutionEnvironment','cpu');
DCNNaccuracy = sum(YPred == testImds.Labels)/numel(YPred)*100
DCNNaccuracy = 95.5000

Простой CNN достиг 95,5% корректной классификации в наборе тестов с удержанием. Постройте график путаницы для CNN.

figure;
confusionchart(testImds.Labels,YPred)
title('Test-Set Confusion Chart -- CNN')

Резюме

В этом примере используется вейвлет-рассеяние изображений для создания малорассеянных представлений цифровых изображений для классификации. Используя преобразование рассеяния с фиксированными весами фильтра и простой классификатор основных компонентов, мы достигли 99,6% корректной классификации в задержанном тестовом наборе. С помощью простого CNN, в котором изучаются фильтры, мы достигли 95,5% правильных. Этот пример не предназначен для прямого сравнения преобразования рассеяния и CNN. В каждом случае можно внести несколько гиперпараметров и архитектурных изменений, которые существенно влияют на результаты. Целью этого примера было просто продемонстрировать потенциал глубоких экстракторов признаков, таких как преобразование вейвлет-рассеяния, для получения надежных представлений данных для обучения.

Ссылки

[1] Бруна, J. и С. Маллэт. «Инвариантные сети свертки рассеяния». Транзакции IEEE по анализу шаблонов и машинному интеллекту. т. 35, № 8, 2013, с. 1872-1886.

[2] Маллат, С. «Инвариантное рассеяние группы». Коммуникации в чистой и прикладной математике. Том 65, номер 10, 2012, стр. 1331-1398.

[3] Сифре, Л. и С. Маллат. Конференция IEEE 2013 по компьютерному зрению и распознаванию образов. 2013, стр. 1233-1240. 10.1109/CVPR.2013.163.

Приложение - Вспомогательные функции

helperScatImages

function features = helperScatImages(sf,x)
% This function is only to support examples in the Wavelet Toolbox.
% It may change or be removed in a future release.

% Copyright 2018 MathWorks

smat = featureMatrix(sf,x,'transform','log');
features = mean(mean(smat,2),3);
end

helperPCAModel

function model = helperPCAModel(features,M,Labels)
% This function is only to support wavelet image scattering examples in 
% Wavelet Toolbox. It may change or be removed in a future release.
% model = helperPCAModel(features,M,Labels)

% Copyright 2018 MathWorks

% Initialize structure array to hold the affine model
model = struct('Dim',[],'mu',[],'U',[],'Labels',categorical([]),'s',[]);
model.Dim = M;
% Obtain the number of classes
LabelCategories = categories(Labels);
Nclasses = numel(categories(Labels));
for kk = 1:Nclasses
    Class = LabelCategories{kk};
    % Find indices corresponding to each class
    idxClass = Labels == Class;
    % Extract feature vectors for each class
    tmpFeatures = features(:,idxClass);
    % Determine the mean for each class
    model.mu{kk} = mean(tmpFeatures,2);
    [model.U{kk},model.S{kk}] = scatPCA(tmpFeatures);
    if size(model.U{kk},2) > M
        model.U{kk} = model.U{kk}(:,1:M);
        model.S{kk} = model.S{kk}(1:M);
        
    end
    model.Labels(kk) = Class;
end
    
function [u,s,v] = scatPCA(x,M)
	% Calculate the principal components of x along the second dimension.

	if nargin > 1 && M > 0
		% If M is non-zero, calculate the first M principal components.
	    [u,s,v] = svds(x-sig_mean(x),M);
	    s = abs(diag(s)/sqrt(size(x,2)-1)).^2;
	else
		% Otherwise, calculate all the principal components.
        % Each row is an observation, i.e. the number of scattering paths
        % Each column is a class observation
		[u,d] = eig(cov(x'));
		[s,ind] = sort(diag(d),'descend');
		u = u(:,ind);
	end
end
end

helperPCAClassifier

function labels = helperPCAClassifier(features,model)
% This function is only to support wavelet image scattering examples in 
% Wavelet Toolbox. It may change or be removed in a future release.
% model is a structure array with fields, M, mu, v, and Labels
% features is the matrix of test data which is Ns-by-L, Ns is the number of
% scattering paths and L is the number of test examples. Each column of
% features is a test example.

% Copyright 2018 MathWorks

labelIdx = determineClass(features,model); 
labels = model.Labels(labelIdx); 
% Returns as column vector to agree with imageDatastore Labels
labels = labels(:);


%--------------------------------------------------------------------------
function labelIdx = determineClass(features,model)
% Determine number of classes
Nclasses = numel(model.Labels);
% Initialize error matrix
errMatrix = Inf(Nclasses,size(features,2));
for nc = 1:Nclasses
    % class centroid
    mu = model.mu{nc};
    u = model.U{nc};
    % 1-by-L
    errMatrix(nc,:) = projectionError(features,mu,u);
end
% Determine minimum along class dimension
[~,labelIdx] = min(errMatrix,[],1);   


%--------------------------------------------------------------------------
function totalerr = projectionError(features,mu,u)
    %
    Npc = size(u,2);
    L = size(features,2);
    % Subtract class mean: Ns-by-L minus Ns-by-1
    s = features-mu;
    % 1-by-L
    normSqX = sum(abs(s).^2,1)';
    err = Inf(Npc+1,L);
	err(1,:) = normSqX;
    err(2:end,:) = -abs(u'*s).^2;
    % 1-by-L
    totalerr = sqrt(sum(err,1));
end
end
end
	

См. также

Связанные темы