Паразитная классификация Используя рассеивание вейвлета и глубокое обучение

В этом примере показано, как классифицировать паразитарные инфекции на изображения окраски Giemsa с помощью рассеивания вейвлета изображений и глубокого обучения. Набор данных сложен для глубоких сетей, потому что он содержит только 48 изображений. Изображения разделены равномерно в три категории паразитарных инфекций: babesiosis, плазмодий-gametocyte и трипаносомоз.

Данные

Получите данные из обмена файлами MATLAB®: Развертывание Глубоких нейронных сетей к Встроенным графическим процессорам и разархивировало файл. Файл находится в той же папке как этот пример.

url = "https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/" + ...
    "5918495a-0009-419e-8e10-77b06e3fe553/844e43fa-7c50-4f88-a435-f0afe04fc3a3/" + ...
    "packages/zip";
websave("classifyBloodSmearImages.zip",url);
unzip('classifyBloodSmearImages.zip')

Создайте ImageDatastore чтобы управлять доступом Giemsa окрашивают изображения. Изображения находятся в формате RGB с общим размером 300 300 3.

imagedir = fullfile('classifyBloodSmearImages','BloodSmearImages');
Imds = imageDatastore(imagedir,'IncludeSubFolders',true,'FileExtensions',...
    '.jpg','LabelSource','foldernames');
summary(Imds.Labels)
     babesiosis                 16 
     plasmodium-gametocyte      16 
     trypanosomiasis            16 

Существует 16 изображений для каждого из трех паразитных типов. Разделите данные в обучение и наборы тестов затяжки с 70 процентами изображений в наборе обучающих данных и 30 процентами в наборе тестов. Установите генератор случайных чисел для воспроизводимости.

rng default
[trainImds,testImds] = splitEachLabel(Imds,0.7);

Проверьте, что равные количества каждого паразитного класса содержатся и в наборах обучающих данных и в наборах тестов.

summary(trainImds.Labels)
     babesiosis                 11 
     plasmodium-gametocyte      11 
     trypanosomiasis            11 
% Perform the same for the test set.
summary(testImds.Labels)
     babesiosis                 5 
     plasmodium-gametocyte      5 
     trypanosomiasis            5 

Поскольку это - маленький набор данных, целые наборы обучающих данных и наборы тестов умещаются в памяти. Считайте все изображения для обоих наборов.

trainImages = readall(trainImds);
testImages = readall(testImds);

Постройте некоторые демонстрационные изображения от обучающих данных.

idx = randperm(33,6);
figure
for ii = 1:length(idx)
    im = trainImages{idx(ii)};
    subplot(3,2,ii)
    imshow(im,[])
    title(string(trainImds.Labels(idx(ii))));
end

Сеть рассеивания вейвлета

В этом примере вы используете рассеивание вейвлета, преобразовывают как экстрактор функции для подходов машинного обучения. Рассеивание вейвлета преобразовывает, помогает уменьшать размерность данных и увеличить несходство межкласса. Создайте сеть рассеивания 2D слоя изображений с 40 40 пиксельная шкала инвариантности. Используйте два вейвлета на октаву в первом слое и один вейвлет на октаву во втором слое. Используйте два вращения вейвлетов на слой.

sn = waveletScattering2('ImageSize',[300 300],'InvarianceScale',40,...
    'QualityFactors',[2 1],'NumRotations',[2 2]);
[~,npaths] = paths(sn);
sum(npaths)
ans = 27
coefficientSize(sn)
ans = 1×2

    38    38

Заданная сеть рассеивания вейвлета имеет 27 путей. Изображение на каждом пути к рассеиванию уменьшается до 38 на 38 на 3. Даже без дальнейшего усреднения рассеивающихся коэффициентов, это - сокращение размера памяти каждого изображения больше, чем фактор 2. Однако для классификации мы формируем характеристический вектор, который составляет в среднем рассеивающиеся коэффициенты по пространственным размерностям и размерностям канала. Это приводит к характеристическим векторам только с 27 элементами, скаляром с действительным знаком для каждого пути к рассеиванию. Это представляет сокращение числа элементов на коэффициент 10 000 для каждого изображения.

Следующий код вычисляет вейвлет, рассеивающий характеристические векторы и для наборов обучающих данных и для наборов тестов. Конкатенация характеристических векторов так, чтобы у вас были матрицы N-27, где N является количеством примеров в наборе обучающих данных или наборе тестов и каждой строке, является вейвлетом, рассеивающим характеристический вектор для примера.

trainfeatures = cellfun(@(x)helperScatImages_mean(sn,x),trainImages,'Uni',0);
testfeatures = cellfun(@(x)helperScatImages_mean(sn,x),testImages,'Uni',0);
trainfeatures = cat(1,trainfeatures{:});
testfeatures = cat(1,testfeatures{:});

Классификация SVM

Используйте классификатор SVM с рассеивающимися функциями. Выберите ядро кубического полинома. Используйте схему кодирования one-all.

template = templateSVM(...
    'KernelFunction', 'polynomial', ...
    'PolynomialOrder', 3, ...
    'KernelScale', 1, ...
    'BoxConstraint', 314, ...
    'Standardize', true);
classificationSVM = fitcecoc(trainfeatures,trainImds.Labels,...
    'Learners', template, 'Coding', 'onevsall');

Оцените точность на наборе обучающих данных с помощью перекрестной проверки с 5 сгибами.

kfoldmodel = crossval(classificationSVM, 'KFold', 5);
loss = kfoldLoss(kfoldmodel)*100;
crossvalAccuracy = 100-loss
crossvalAccuracy = single
    81.8182

Точность перекрестной проверки составляет приблизительно 80 процентов. Теперь исследуйте точность на протянутом наборе тестов и постройте график беспорядка.

[predLabels,scores] = predict(classificationSVM,testfeatures);
testAccuracy = ...
    sum(categorical(predLabels)== testImds.Labels)/numel(testImds.Labels)*100
testAccuracy = 80
figure
cchart = confusionchart(testImds.Labels,predLabels);
cchart.Title = ...
    {'Confusion Chart for Wavelet' ; 'Scattering Features using SVM'};
cchart.RowSummary = 'row-normalized';
cchart.ColumnSummary = 'column-normalized';

Полная тестовая точность составляет 80 процентов с моделью SVM. Отзыв для каждого класса составляет 80%. Точность также хороша для плазмодия-gametocyte и паразитов трипаносомоза, но хуже для babesiosis. Исследуйте музыку F1 к каждому классу.

f1SVM = f1score(cchart.NormalizedValues);
disp(f1SVM)
                               F1   
                             _______

    babesiosis               0.72727
    plasmodium-gametocyte    0.88889
    trypanosomiasis              0.8

Все баллы F1 между приблизительно 0,7 и 0.9.

Классификатор PCA с рассеиванием функций

Машины опорных векторов являются мощными методами для функций, которые не линейно отделимы, но они спроектированы для бинарной классификации и могут быть субоптимальными для проблем мультикласса. Здесь вы дополняете анализ SVM при помощи простого PCA (линейный) классификатор с теми же функциями рассеивания вейвлета. helperPCAModel функция определяет numcomp собственные вектора, соответствующие самым большим собственным значениям ковариационной матрицы рассеивания вейвлета, показывают для каждого патогена в наборе обучающих данных наряду со средними значениями класса.

helperPCAClassifier классифицирует каждую тестовую выборку. Это делает это путем вычитания средних значений класса модели из каждого вейвлета, рассеивающего характеристический вектор в тестовом наборе данных и проецирующего характеристические векторы в центре на собственные вектора ковариационной матрицы для каждого класса в модели. helperPCAClassifier присвоения каждый тестовый пример патогену с самой маленькой ошибкой или невязка. Это - классификатор анализа основных компонентов (PCA).

Удалите 0-th функции рассеивания порядка из каждого характеристического вектора. Определите номер основных компонентов (собственные вектора) к 6.

numcomp = 6;
model = helperPCAModel(trainfeatures(:,2:end)',numcomp,trainImds.Labels);
PCALabels = helperPCAClassifier(testfeatures(:,2:end)',model);
testPCAacc = sum(PCALabels==testImds.Labels)/numel(testImds.Labels)*100
testPCAacc = 86.6667

Тестовая точность составляет приблизительно 87% с классификатором PCA. Постройте график беспорядка и вычислите музыку F1 к каждому классу.

figure
cchart = confusionchart(testImds.Labels,PCALabels);
cchart.Title = {'Confusion Chart for Wavelet Scattering Features' ; ...
    'using PCA Classifier'};
cchart.RowSummary = 'row-normalized';
cchart.ColumnSummary = 'column-normalized';

f1PCA = f1score(cchart.NormalizedValues);
disp(f1PCA)
                               F1   
                             _______

    babesiosis               0.90909
    plasmodium-gametocyte    0.88889
    trypanosomiasis              0.8

Музыка F1 к классификатору PCA с функциями рассеивания вейвлета довольно сильна со всеми баллами между 0,8 и 1.

Сверточная глубокая сеть

В этом разделе вы делаете попытку той же классификации с помощью глубоко сверточные сети. Глубокие сети обеспечивают современные результаты для проблем классификации с большими наборами данных и способны к изучению сложных нелинейных отображений, но их эффективность часто страдает в маленьких наборах данных. Чтобы смягчить эту проблему, используйте увеличение изображений. imageDataAugmenter тревожит данные в каждую эпоху, в действительности создавая новые учебные примеры.

augmenter = imageDataAugmenter('RandRotation',[0 180],'RandXTranslation', [-5 5], ...
    'RandYTranslation',[-5 5]);
augimds = augmentedImageDatastore([300 300 3],trainImds,'DataAugmentation',augmenter);

Задайте маленький CNN, состоящий из двух слоев свертки, сопровождаемых слоями нормализации партии. и активациями RELU. Следуйте за итоговой активацией RELU с макс. объединением, полностью соединенными, и softmax слоями.

layers = [
    imageInputLayer([300 300 3])
    convolution2dLayer(7,16)
    batchNormalizationLayer
    reluLayer    
    convolution2dLayer(3,20)
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(4)
    fullyConnectedLayer(3)
    softmaxLayer
    classificationLayer];

Используйте стохастический градиентный спуск с мини-пакетным размером 10. Переставьте данные каждая эпоха. Запустите обучение в течение 100 эпох.

opts = trainingOptions('sgdm',...
    'InitialLearnRate',  0.0001, ...
    'MaxEpochs', 100, ...
    'MiniBatchSize',10,...
    'Shuffle','every-epoch',...
    'Plots', 'training-progress',...
    'Verbose',false,...
    'ExecutionEnvironment','cpu');

Обучите сеть.

trainedNet = trainNetwork(augimds,layers,opts);

Исследуйте эффективность сети на протянутом наборе тестов.

ypred = trainedNet.classify(testImds);
cnnAccuracy = sum(ypred == testImds.Labels)/numel(testImds.Labels)*100
cnnAccuracy = 66.6667
figure
cchart = confusionchart(testImds.Labels,ypred);
cchart.Title = 'Confusion Chart for Deep CNN';
cchart.RowSummary = 'row-normalized';
cchart.ColumnSummary = 'column-normalized';

f1CNN = f1score(cchart.NormalizedValues);
disp(f1CNN)
                               F1   
                             _______

    babesiosis                  0.75
    plasmodium-gametocyte    0.76923
    trypanosomiasis          0.44444

Несмотря на использование увеличенного набора данных для обучения, CNN имеет сверхподгонку, набор обучающих данных и баллы F1 значительно хуже, чем любой модель SVM или PCA с функциями рассеивания вейвлета.

Затем используйте передачу обучения с SqueezeNet. Измените итоговый сверточный слой, чтобы вместить то, что у вас есть три класса патогенов. SqueezeNet был создан, чтобы распознать 1 000 классов.

net = squeezenet;
lgraphSQZ = layerGraph(net);
numClasses = numel(categories(trainImds.Labels));
oldFinalConv = lgraphSQZ.Layers(end-4);
newFinalConv = convolution2dLayer(1,numClasses, ...
        'Name','new_conv');
setLearnRateFactor(newFinalConv,'Weights',10);
setLearnRateFactor(newFinalConv,'Bias',10)
ans = 
  Convolution2DLayer with properties:

              Name: 'new_conv'

   Hyperparameters
        FilterSize: [1 1]
       NumChannels: 'auto'
        NumFilters: 3
            Stride: [1 1]
    DilationFactor: [1 1]
       PaddingMode: 'manual'
       PaddingSize: [0 0 0 0]
      PaddingValue: 0

   Learnable Parameters
           Weights: []
              Bias: []

  Show all properties

lgraphSQZ = replaceLayer(lgraphSQZ,oldFinalConv.Name,newFinalConv);
oldClassLayer= lgraphSQZ.Layers(end);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraphSQZ = replaceLayer(lgraphSQZ,oldClassLayer.Name,newClassLayer);

Сбросьте обучение и протестируйте хранилища данных. Измените функцию чтения datastore, чтобы изменить размер изображений, чтобы быть совместимыми с SqueezeNet, который ожидает 227 227 3 изображениями. Настройте увеличение изображений и обучите сеть.

reset(trainImds);
reset(testImds);
trainImds.ReadFcn = @(x)imresize(imread(x),'OutputSize',[227 227]);
testImds.ReadFcn = @(x)imresize(imread(x),'OutputSize',[227 227]);
augmenter = imageDataAugmenter('RandRotation',[0 180],'RandXTranslation', [-5 5], ...
    'RandYTranslation',[-5 5]);
augimds = augmentedImageDatastore([227 227 3],trainImds,...
    'DataAugmentation',augmenter);
trainedNet = trainNetwork(augimds,lgraphSQZ,opts);

Получите точность SqueezeNet, постройте график беспорядка и вычислите баллы F1.

ypred = trainedNet.classify(testImds);
sqznetAccuracy = sum(ypred == testImds.Labels)/numel(testImds.Labels)*100
sqznetAccuracy = 73.3333
figure
cchart = confusionchart(testImds.Labels,ypred);
cchart.Title = {'Confusion Chart for Transfer Learning' ; 'with SqueezeNet'};
cchart.RowSummary = 'row-normalized';
cchart.ColumnSummary = 'column-normalized';

f1SqueezeNet = f1score(cchart.NormalizedValues);
disp(f1SqueezeNet)
                               F1   
                             _______

    babesiosis               0.72727
    plasmodium-gametocyte        0.8
    trypanosomiasis          0.66667

SqueezeNet выполняет лучше, чем более простой CNN, особенно в терминах счета F1 к трипаносомозу, но эффективность не совпадает с точностью более простого классификатора PCA с функциями рассеивания вейвлета.

Сводные данные

В этом примере рассеивание вейвлета преобразовывает, и среды глубокого обучения использовались, чтобы классифицировать патогены на изображения окраски Giemsa. Ограниченный размер набора данных обеспечивает проблемы для обучения классификатор глубокого обучения, даже когда увеличение данных используется. Пример проиллюстрировал, что рассеивание вейвлета преобразовывает, может обеспечить полезную альтернативу глубоким сетям в таких случаях. В формировании характеристических векторов от рассеивания вейвлета преобразовывают, мы уменьшали, каждый преобразовывает выход от 27 38 38 3 тензорами к вектору с 27 элементами. Соответственно, мы использовали глобальное объединение рассеивающихся коэффициентов. Возможно использовать другие схемы объединения, которые могли привести к лучшим результатам.

Приложение — поддерживание функций

function features = helperScatImages_mean(sn,x)
smat = featureMatrix(sn,x);
features = mean(smat,2:4);
features = features';
end
function F1scores = f1score(cchartVal)
N = sum(cchartVal,'all');
probT = sum(cchartVal)./N;
classProbEst = diag(cchartVal)./N;
Prec = classProbEst'./probT;
probC = [5/15 5/15 5/15];
Recall = classProbEst'./probC;
F1scores = harmmean([Prec ; Recall]);
F1scores = F1scores';
F1scores = table(F1scores,'VariableNames',{'F1'},...
    'RowNames', {'babesiosis','plasmodium-gametocyte', 'trypanosomiasis'});
end

function labels = helperPCAClassifier(features,model)
% This function is only to support wavelet image scattering examples in 
% Wavelet Toolbox. It may change or be removed in a future release.
% model is a structure array with fields, M, mu, v, and Labels
% features is the matrix of test data which is Ns-by-L, Ns is the number of
% scattering paths and L is the number of test examples. Each column of
% features is a test example.

% Copyright 2018-2021 MathWorks

labelIdx = determineClass(features,model); 
labels = model.Labels(labelIdx); 
% Returns as column vector to agree with imageDatastore Labels
labels = labels(:);


%--------------------------------------------------------------------------
function labelIdx = determineClass(features,model)
% Determine number of classes
Nclasses = numel(model.Labels);
% Initialize error matrix
errMatrix = Inf(Nclasses,size(features,2));
for nc = 1:Nclasses
    % class centroid
    mu = model.mu{nc};
    u = model.U{nc};
    % 1-by-L
    errMatrix(nc,:) = projectionError(features,mu,u);
end
% Determine minimum along class dimension
[~,labelIdx] = min(errMatrix,[],1);   


%--------------------------------------------------------------------------
function totalerr = projectionError(features,mu,u)
    %
    Npc = size(u,2);
    L = size(features,2);
    % Subtract class mean: Ns-by-L minus Ns-by-1
    s = features-mu;
    % 1-by-L
    normSqX = sum(abs(s).^2,1)';
    err = Inf(Npc+1,L);
	err(1,:) = normSqX;
    err(2:end,:) = -abs(u'*s).^2;
    % 1-by-L
    totalerr = sqrt(sum(err,1));
end
end
end

function model = helperPCAModel(features,M,Labels)
% This function is only to support wavelet image scattering examples in
% Wavelet Toolbox. It may change or be removed in a future release.
% model = helperPCAModel(features,M,Labels)

% Copyright 2018-2021 MathWorks

% Initialize structure array to hold the affine model
model = struct('Dim',[],'mu',[],'U',[],'Labels',categorical([]),'S',[]);
model.Dim = M;
% Obtain the number of classes
LabelCategories = categories(Labels);
Nclasses = numel(categories(Labels));
for kk = 1:Nclasses
    Class = LabelCategories{kk};
    % Find indices corresponding to each class
    idxClass = Labels == Class;
    % Extract feature vectors for each class
    tmpFeatures = features(:,idxClass);
    % Determine the mean for each class
    model.mu{kk} = mean(tmpFeatures,2);
    [model.U{kk},model.S{kk}] = scatPCA(tmpFeatures);
    if size(model.U{kk},2) > M
        model.U{kk} = model.U{kk}(:,1:M);
        model.S{kk} = model.S{kk}(1:M);
        
    end
    model.Labels(kk) = Class;
end

    function [u,s,v] = scatPCA(x)
        % Calculate the principal components of x along the second dimension.
        [u,d] = eig(cov(x'));
        % Sort eigenvalues of covariance matrix in descending order
        [s,ind] = sort(diag(d),'descend');
        % sort eigenvector matrix accordingly
        u = u(:,ind);
    end
end

Смотрите также

(Wavelet Toolbox)

Связанные примеры

Больше о