В этом примере показано, как классифицировать паразитарные инфекции на изображения окраски Giemsa с помощью рассеивания вейвлета изображений и глубокого обучения. Набор данных сложен для глубоких сетей, потому что он содержит только 48 изображений. Изображения разделены равномерно в три категории паразитарных инфекций: babesiosis, плазмодий-gametocyte и трипаносомоз.
Получите данные из обмена файлами MATLAB®: Развертывание Глубоких нейронных сетей к Встроенным графическим процессорам и разархивировало файл. Файл находится в той же папке как этот пример.
url = "https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/" + ... "5918495a-0009-419e-8e10-77b06e3fe553/844e43fa-7c50-4f88-a435-f0afe04fc3a3/" + ... "packages/zip"; websave("classifyBloodSmearImages.zip",url); unzip('classifyBloodSmearImages.zip')
Создайте ImageDatastore
чтобы управлять доступом Giemsa окрашивают изображения. Изображения находятся в формате RGB с общим размером 300 300 3.
imagedir = fullfile('classifyBloodSmearImages','BloodSmearImages'); Imds = imageDatastore(imagedir,'IncludeSubFolders',true,'FileExtensions',... '.jpg','LabelSource','foldernames'); summary(Imds.Labels)
babesiosis 16 plasmodium-gametocyte 16 trypanosomiasis 16
Существует 16 изображений для каждого из трех паразитных типов. Разделите данные в обучение и наборы тестов затяжки с 70 процентами изображений в наборе обучающих данных и 30 процентами в наборе тестов. Установите генератор случайных чисел для воспроизводимости.
rng default
[trainImds,testImds] = splitEachLabel(Imds,0.7);
Проверьте, что равные количества каждого паразитного класса содержатся и в наборах обучающих данных и в наборах тестов.
summary(trainImds.Labels)
babesiosis 11 plasmodium-gametocyte 11 trypanosomiasis 11
% Perform the same for the test set.
summary(testImds.Labels)
babesiosis 5 plasmodium-gametocyte 5 trypanosomiasis 5
Поскольку это - маленький набор данных, целые наборы обучающих данных и наборы тестов умещаются в памяти. Считайте все изображения для обоих наборов.
trainImages = readall(trainImds); testImages = readall(testImds);
Постройте некоторые демонстрационные изображения от обучающих данных.
idx = randperm(33,6); figure for ii = 1:length(idx) im = trainImages{idx(ii)}; subplot(3,2,ii) imshow(im,[]) title(string(trainImds.Labels(idx(ii)))); end
В этом примере вы используете рассеивание вейвлета, преобразовывают как экстрактор функции для подходов машинного обучения. Рассеивание вейвлета преобразовывает, помогает уменьшать размерность данных и увеличить несходство межкласса. Создайте сеть рассеивания 2D слоя изображений с 40 40 пиксельная шкала инвариантности. Используйте два вейвлета на октаву в первом слое и один вейвлет на октаву во втором слое. Используйте два вращения вейвлетов на слой.
sn = waveletScattering2('ImageSize',[300 300],'InvarianceScale',40,... 'QualityFactors',[2 1],'NumRotations',[2 2]); [~,npaths] = paths(sn); sum(npaths)
ans = 27
coefficientSize(sn)
ans = 1×2
38 38
Заданная сеть рассеивания вейвлета имеет 27 путей. Изображение на каждом пути к рассеиванию уменьшается до 38 на 38 на 3. Даже без дальнейшего усреднения рассеивающихся коэффициентов, это - сокращение размера памяти каждого изображения больше, чем фактор 2. Однако для классификации мы формируем характеристический вектор, который составляет в среднем рассеивающиеся коэффициенты по пространственным размерностям и размерностям канала. Это приводит к характеристическим векторам только с 27 элементами, скаляром с действительным знаком для каждого пути к рассеиванию. Это представляет сокращение числа элементов на коэффициент 10 000 для каждого изображения.
Следующий код вычисляет вейвлет, рассеивающий характеристические векторы и для наборов обучающих данных и для наборов тестов. Конкатенация характеристических векторов так, чтобы у вас были матрицы N-27, где N является количеством примеров в наборе обучающих данных или наборе тестов и каждой строке, является вейвлетом, рассеивающим характеристический вектор для примера.
trainfeatures = cellfun(@(x)helperScatImages_mean(sn,x),trainImages,'Uni',0); testfeatures = cellfun(@(x)helperScatImages_mean(sn,x),testImages,'Uni',0); trainfeatures = cat(1,trainfeatures{:}); testfeatures = cat(1,testfeatures{:});
Используйте классификатор SVM с рассеивающимися функциями. Выберите ядро кубического полинома. Используйте схему кодирования one-all.
template = templateSVM(... 'KernelFunction', 'polynomial', ... 'PolynomialOrder', 3, ... 'KernelScale', 1, ... 'BoxConstraint', 314, ... 'Standardize', true); classificationSVM = fitcecoc(trainfeatures,trainImds.Labels,... 'Learners', template, 'Coding', 'onevsall');
Оцените точность на наборе обучающих данных с помощью перекрестной проверки с 5 сгибами.
kfoldmodel = crossval(classificationSVM, 'KFold', 5);
loss = kfoldLoss(kfoldmodel)*100;
crossvalAccuracy = 100-loss
crossvalAccuracy = single
81.8182
Точность перекрестной проверки составляет приблизительно 80 процентов. Теперь исследуйте точность на протянутом наборе тестов и постройте график беспорядка.
[predLabels,scores] = predict(classificationSVM,testfeatures);
testAccuracy = ...
sum(categorical(predLabels)== testImds.Labels)/numel(testImds.Labels)*100
testAccuracy = 80
figure cchart = confusionchart(testImds.Labels,predLabels); cchart.Title = ... {'Confusion Chart for Wavelet' ; 'Scattering Features using SVM'}; cchart.RowSummary = 'row-normalized'; cchart.ColumnSummary = 'column-normalized';
Полная тестовая точность составляет 80 процентов с моделью SVM. Отзыв для каждого класса составляет 80%. Точность также хороша для плазмодия-gametocyte и паразитов трипаносомоза, но хуже для babesiosis. Исследуйте музыку F1 к каждому классу.
f1SVM = f1score(cchart.NormalizedValues); disp(f1SVM)
F1 _______ babesiosis 0.72727 plasmodium-gametocyte 0.88889 trypanosomiasis 0.8
Все баллы F1 между приблизительно 0,7 и 0.9.
Машины опорных векторов являются мощными методами для функций, которые не линейно отделимы, но они спроектированы для бинарной классификации и могут быть субоптимальными для проблем мультикласса. Здесь вы дополняете анализ SVM при помощи простого PCA (линейный) классификатор с теми же функциями рассеивания вейвлета. helperPCAModel
функция определяет numcomp
собственные вектора, соответствующие самым большим собственным значениям ковариационной матрицы рассеивания вейвлета, показывают для каждого патогена в наборе обучающих данных наряду со средними значениями класса.
helperPCAClassifier
классифицирует каждую тестовую выборку. Это делает это путем вычитания средних значений класса модели из каждого вейвлета, рассеивающего характеристический вектор в тестовом наборе данных и проецирующего характеристические векторы в центре на собственные вектора ковариационной матрицы для каждого класса в модели. helperPCAClassifier
присвоения каждый тестовый пример патогену с самой маленькой ошибкой или невязка. Это - классификатор анализа основных компонентов (PCA).
Удалите 0-th функции рассеивания порядка из каждого характеристического вектора. Определите номер основных компонентов (собственные вектора) к 6.
numcomp = 6; model = helperPCAModel(trainfeatures(:,2:end)',numcomp,trainImds.Labels); PCALabels = helperPCAClassifier(testfeatures(:,2:end)',model); testPCAacc = sum(PCALabels==testImds.Labels)/numel(testImds.Labels)*100
testPCAacc = 86.6667
Тестовая точность составляет приблизительно 87% с классификатором PCA. Постройте график беспорядка и вычислите музыку F1 к каждому классу.
figure cchart = confusionchart(testImds.Labels,PCALabels); cchart.Title = {'Confusion Chart for Wavelet Scattering Features' ; ... 'using PCA Classifier'}; cchart.RowSummary = 'row-normalized'; cchart.ColumnSummary = 'column-normalized';
f1PCA = f1score(cchart.NormalizedValues); disp(f1PCA)
F1 _______ babesiosis 0.90909 plasmodium-gametocyte 0.88889 trypanosomiasis 0.8
Музыка F1 к классификатору PCA с функциями рассеивания вейвлета довольно сильна со всеми баллами между 0,8 и 1.
В этом разделе вы делаете попытку той же классификации с помощью глубоко сверточные сети. Глубокие сети обеспечивают современные результаты для проблем классификации с большими наборами данных и способны к изучению сложных нелинейных отображений, но их эффективность часто страдает в маленьких наборах данных. Чтобы смягчить эту проблему, используйте увеличение изображений. imageDataAugmenter
тревожит данные в каждую эпоху, в действительности создавая новые учебные примеры.
augmenter = imageDataAugmenter('RandRotation',[0 180],'RandXTranslation', [-5 5], ... 'RandYTranslation',[-5 5]); augimds = augmentedImageDatastore([300 300 3],trainImds,'DataAugmentation',augmenter);
Задайте маленький CNN, состоящий из двух слоев свертки, сопровождаемых слоями нормализации партии. и активациями RELU. Следуйте за итоговой активацией RELU с макс. объединением, полностью соединенными, и softmax слоями.
layers = [ imageInputLayer([300 300 3]) convolution2dLayer(7,16) batchNormalizationLayer reluLayer convolution2dLayer(3,20) batchNormalizationLayer reluLayer maxPooling2dLayer(4) fullyConnectedLayer(3) softmaxLayer classificationLayer];
Используйте стохастический градиентный спуск с мини-пакетным размером 10. Переставьте данные каждая эпоха. Запустите обучение в течение 100 эпох.
opts = trainingOptions('sgdm',... 'InitialLearnRate', 0.0001, ... 'MaxEpochs', 100, ... 'MiniBatchSize',10,... 'Shuffle','every-epoch',... 'Plots', 'training-progress',... 'Verbose',false,... 'ExecutionEnvironment','cpu');
Обучите сеть.
trainedNet = trainNetwork(augimds,layers,opts);
Исследуйте эффективность сети на протянутом наборе тестов.
ypred = trainedNet.classify(testImds); cnnAccuracy = sum(ypred == testImds.Labels)/numel(testImds.Labels)*100
cnnAccuracy = 66.6667
figure cchart = confusionchart(testImds.Labels,ypred); cchart.Title = 'Confusion Chart for Deep CNN'; cchart.RowSummary = 'row-normalized'; cchart.ColumnSummary = 'column-normalized';
f1CNN = f1score(cchart.NormalizedValues); disp(f1CNN)
F1 _______ babesiosis 0.75 plasmodium-gametocyte 0.76923 trypanosomiasis 0.44444
Несмотря на использование увеличенного набора данных для обучения, CNN имеет сверхподгонку, набор обучающих данных и баллы F1 значительно хуже, чем любой модель SVM или PCA с функциями рассеивания вейвлета.
Затем используйте передачу обучения с SqueezeNet. Измените итоговый сверточный слой, чтобы вместить то, что у вас есть три класса патогенов. SqueezeNet был создан, чтобы распознать 1 000 классов.
net = squeezenet; lgraphSQZ = layerGraph(net); numClasses = numel(categories(trainImds.Labels)); oldFinalConv = lgraphSQZ.Layers(end-4); newFinalConv = convolution2dLayer(1,numClasses, ... 'Name','new_conv'); setLearnRateFactor(newFinalConv,'Weights',10); setLearnRateFactor(newFinalConv,'Bias',10)
ans = Convolution2DLayer with properties: Name: 'new_conv' Hyperparameters FilterSize: [1 1] NumChannels: 'auto' NumFilters: 3 Stride: [1 1] DilationFactor: [1 1] PaddingMode: 'manual' PaddingSize: [0 0 0 0] PaddingValue: 0 Learnable Parameters Weights: [] Bias: [] Show all properties
lgraphSQZ = replaceLayer(lgraphSQZ,oldFinalConv.Name,newFinalConv); oldClassLayer= lgraphSQZ.Layers(end); newClassLayer = classificationLayer('Name','new_classoutput'); lgraphSQZ = replaceLayer(lgraphSQZ,oldClassLayer.Name,newClassLayer);
Сбросьте обучение и протестируйте хранилища данных. Измените функцию чтения datastore, чтобы изменить размер изображений, чтобы быть совместимыми с SqueezeNet, который ожидает 227 227 3 изображениями. Настройте увеличение изображений и обучите сеть.
reset(trainImds); reset(testImds); trainImds.ReadFcn = @(x)imresize(imread(x),'OutputSize',[227 227]); testImds.ReadFcn = @(x)imresize(imread(x),'OutputSize',[227 227]); augmenter = imageDataAugmenter('RandRotation',[0 180],'RandXTranslation', [-5 5], ... 'RandYTranslation',[-5 5]); augimds = augmentedImageDatastore([227 227 3],trainImds,... 'DataAugmentation',augmenter); trainedNet = trainNetwork(augimds,lgraphSQZ,opts);
Получите точность SqueezeNet, постройте график беспорядка и вычислите баллы F1.
ypred = trainedNet.classify(testImds); sqznetAccuracy = sum(ypred == testImds.Labels)/numel(testImds.Labels)*100
sqznetAccuracy = 73.3333
figure cchart = confusionchart(testImds.Labels,ypred); cchart.Title = {'Confusion Chart for Transfer Learning' ; 'with SqueezeNet'}; cchart.RowSummary = 'row-normalized'; cchart.ColumnSummary = 'column-normalized';
f1SqueezeNet = f1score(cchart.NormalizedValues); disp(f1SqueezeNet)
F1 _______ babesiosis 0.72727 plasmodium-gametocyte 0.8 trypanosomiasis 0.66667
SqueezeNet выполняет лучше, чем более простой CNN, особенно в терминах счета F1 к трипаносомозу, но эффективность не совпадает с точностью более простого классификатора PCA с функциями рассеивания вейвлета.
В этом примере рассеивание вейвлета преобразовывает, и среды глубокого обучения использовались, чтобы классифицировать патогены на изображения окраски Giemsa. Ограниченный размер набора данных обеспечивает проблемы для обучения классификатор глубокого обучения, даже когда увеличение данных используется. Пример проиллюстрировал, что рассеивание вейвлета преобразовывает, может обеспечить полезную альтернативу глубоким сетям в таких случаях. В формировании характеристических векторов от рассеивания вейвлета преобразовывают, мы уменьшали, каждый преобразовывает выход от 27 38 38 3 тензорами к вектору с 27 элементами. Соответственно, мы использовали глобальное объединение рассеивающихся коэффициентов. Возможно использовать другие схемы объединения, которые могли привести к лучшим результатам.
function features = helperScatImages_mean(sn,x) smat = featureMatrix(sn,x); features = mean(smat,2:4); features = features'; end function F1scores = f1score(cchartVal) N = sum(cchartVal,'all'); probT = sum(cchartVal)./N; classProbEst = diag(cchartVal)./N; Prec = classProbEst'./probT; probC = [5/15 5/15 5/15]; Recall = classProbEst'./probC; F1scores = harmmean([Prec ; Recall]); F1scores = F1scores'; F1scores = table(F1scores,'VariableNames',{'F1'},... 'RowNames', {'babesiosis','plasmodium-gametocyte', 'trypanosomiasis'}); end function labels = helperPCAClassifier(features,model) % This function is only to support wavelet image scattering examples in % Wavelet Toolbox. It may change or be removed in a future release. % model is a structure array with fields, M, mu, v, and Labels % features is the matrix of test data which is Ns-by-L, Ns is the number of % scattering paths and L is the number of test examples. Each column of % features is a test example. % Copyright 2018-2021 MathWorks labelIdx = determineClass(features,model); labels = model.Labels(labelIdx); % Returns as column vector to agree with imageDatastore Labels labels = labels(:); %-------------------------------------------------------------------------- function labelIdx = determineClass(features,model) % Determine number of classes Nclasses = numel(model.Labels); % Initialize error matrix errMatrix = Inf(Nclasses,size(features,2)); for nc = 1:Nclasses % class centroid mu = model.mu{nc}; u = model.U{nc}; % 1-by-L errMatrix(nc,:) = projectionError(features,mu,u); end % Determine minimum along class dimension [~,labelIdx] = min(errMatrix,[],1); %-------------------------------------------------------------------------- function totalerr = projectionError(features,mu,u) % Npc = size(u,2); L = size(features,2); % Subtract class mean: Ns-by-L minus Ns-by-1 s = features-mu; % 1-by-L normSqX = sum(abs(s).^2,1)'; err = Inf(Npc+1,L); err(1,:) = normSqX; err(2:end,:) = -abs(u'*s).^2; % 1-by-L totalerr = sqrt(sum(err,1)); end end end function model = helperPCAModel(features,M,Labels) % This function is only to support wavelet image scattering examples in % Wavelet Toolbox. It may change or be removed in a future release. % model = helperPCAModel(features,M,Labels) % Copyright 2018-2021 MathWorks % Initialize structure array to hold the affine model model = struct('Dim',[],'mu',[],'U',[],'Labels',categorical([]),'S',[]); model.Dim = M; % Obtain the number of classes LabelCategories = categories(Labels); Nclasses = numel(categories(Labels)); for kk = 1:Nclasses Class = LabelCategories{kk}; % Find indices corresponding to each class idxClass = Labels == Class; % Extract feature vectors for each class tmpFeatures = features(:,idxClass); % Determine the mean for each class model.mu{kk} = mean(tmpFeatures,2); [model.U{kk},model.S{kk}] = scatPCA(tmpFeatures); if size(model.U{kk},2) > M model.U{kk} = model.U{kk}(:,1:M); model.S{kk} = model.S{kk}(1:M); end model.Labels(kk) = Class; end function [u,s,v] = scatPCA(x) % Calculate the principal components of x along the second dimension. [u,d] = eig(cov(x')); % Sort eigenvalues of covariance matrix in descending order [s,ind] = sort(diag(d),'descend'); % sort eigenvector matrix accordingly u = u(:,ind); end end
waveletScattering2
(Wavelet Toolbox)