tsne
В этом примере показано, как использовать tsne
функция для просмотра активаций в обученной сети. Это представление может помочь вам понять, как работает сеть.
The tsne
(Statistics and Machine Learning Toolbox) функция в Statistics and Machine Learning Toolbox™ реализует t-распределенное стохастическое соседнее встраивание (t-SNE) [1]. Этот метод сопоставляет высоко-размерные данные (такие как активация сети в слое) с двумя размерностями. Метод использует нелинейную карту, которая пытается сохранить расстояния. Используя t-SNE, чтобы визуализировать активацию сети, можно получить понимание того, как сеть реагирует.
Можно использовать t-SNE, чтобы визуализировать, как нейронным сетям для глубокого обучения изменяете представление входных данных, когда оно проходит через слои сети. Можно также использовать t-SNE, чтобы найти проблемы с входными данными и понять, какие наблюдения сеть классифицирует неправильно.
Для примера t-SNE может уменьшить многомерные активации слоя softmax до 2-D представления с подобной структурой. Плотные кластеры на полученном графике t-SNE соответствуют классам, которые сеть обычно классифицирует правильно. Визуализация позволяет найти точки, которые появляются в неправильном кластере, указывая на наблюдение, которое сеть классифицирует неправильно. Наблюдение может быть неправильно помечено, или сеть может предсказать, что наблюдение является образцом другого класса, потому что оно похоже на другие наблюдения этого класса. Обратите внимание, что уменьшение t-SNE активаций softmax использует только эти активации, а не базовые наблюдения.
Этот пример использует набор данных Example Food Images, который содержит 978 фотографий пищи в девяти классах и имеет размер приблизительно 77 МБ. Загрузите набор данных во временную директорию путем вызова downloadExampleFoodImagesData
вспомогательная функция; код для этой вспомогательной функции появится в конце этого примера.
dataDir = fullfile(tempdir, "ExampleFoodImageDataset"); url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip"; if ~exist(dataDir, "dir") mkdir(dataDir); end downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset... This can take several minutes to download... Download finished... Unzipping file... Unzipping finished... Done.
Измените предварительно обученную сеть SqueezeNet, чтобы классифицировать изображения пищи из набора данных. Замените конечный сверточный слой, который имеет 1000 фильтров для 1000 классов ImageNet, новым сверточным слоем, который имеет только девять фильтров. Каждый фильтр соответствует одному типу пищи.
lgraph = layerGraph(squeezenet()); lgraph = lgraph.replaceLayer("ClassificationLayer_predictions",... classificationLayer("Name", "ClassificationLayer_predictions")); newConv = convolution2dLayer([14 14], 9, "Name", "conv", "Padding", "same"); lgraph = lgraph.replaceLayer("conv10", newConv);
Создайте imageDatastore
содержащие пути к данным изображения. Разделите datastore на наборы обучения и валидации, используя 65% данных для обучения и остальное для валидации. Поскольку набор данных довольно мал, сверхподбор кривой является значительной проблемой. Чтобы минимизировать сверхподбор кривой, увеличьте набор обучающих данных с помощью случайных щелчков и масштабирования.
imds = imageDatastore(dataDir, ... "IncludeSubfolders", true, "LabelSource", "foldernames"); aug = imageDataAugmenter("RandXReflection", true, ... "RandYReflection", true, ... "RandXScale", [0.8 1.2], ... "RandYScale", [0.8 1.2]); trainingFraction = 0.65; [trainImds,valImds] = splitEachLabel(imds, trainingFraction); augImdsTrain = augmentedImageDatastore([227 227], trainImds, ... 'DataAugmentation', aug); augImdsVal = augmentedImageDatastore([227 227], valImds);
Создайте опции обучения и обучите сеть. SqueezeNet - это небольшая сеть, которую быстро обучать. Можно обучаться на графическом процессоре или центральном процессоре; этот пример обучает на центральном процессоре.
opts = trainingOptions("adam", ... "InitialLearnRate", 1e-4, ... "MaxEpochs", 30, ... "ValidationData", augImdsVal, ... "Verbose", false,... "Plots", "training-progress", ... "ExecutionEnvironment","cpu",... "MiniBatchSize",128); rng default net = trainNetwork(augImdsTrain, lgraph, opts);
Используйте сеть для классификации изображений в наборе валидации. Чтобы проверить, что сеть достаточно точна при классификации новых данных, постройте матрицу неточностей истинных и предсказанных меток.
figure(); YPred = classify(net,augImdsVal); confusionchart(valImds.Labels,YPred,'ColumnSummary',"column-normalized")
Сеть хорошо классифицирует несколько изображений. Сеть, по-видимому, имеет проблемы с изображениями суши, классифицируя многих как суши, но некоторые как пицца или гамбургер. Сеть не классифицирует какие-либо изображения в класс hot dog.
Чтобы продолжить анализ эффективности сети, вычислите активации для каждого наблюдения в наборе данных на уровне раннего максимального объединения, конечном слое свертки и конечном слое softmax. Вывод активаций как матрицы NxM, где N - количество наблюдений, а M - количество размерностей активации. M является продуктом пространственных и канальных размерностей. Каждая строка является наблюдением, а каждый столбец - размерностью. На слое softmax M = 9, потому что набор данных о пищевых продуктах имеет девять классов. Каждая строка в матрице содержит девять элементов, соответствующих вероятностям того, что наблюдение принадлежит каждому из девяти классов пищи.
earlyLayerName = "pool1"; finalConvLayerName = "conv"; softmaxLayerName = "prob"; pool1Activations = activations(net,... augImdsVal,earlyLayerName,"OutputAs","rows"); finalConvActivations = activations(net,... augImdsVal,finalConvLayerName,"OutputAs","rows"); softmaxActivations = activations(net,... augImdsVal,softmaxLayerName,"OutputAs","rows");
Можно использовать активации softmax, чтобы вычислить классификации изображений, которые, скорее всего, будут неправильными. Задайте неоднозначность классификации как отношение второй по величине вероятности к наибольшей вероятности. Неоднозначность классификации находится между нулем (почти определенная классификация) и 1 (почти так же вероятно, что и второй класс). Неоднозначность около 1 означает, что сеть не уверена в классе, в котором принадлежит конкретное изображение. Эта неопределенность может быть вызвана двумя классами, наблюдения которых кажутся настолько похожими на сеть, что она не может узнать различия между ними. Или может возникнуть высокая неоднозначность, потому что конкретное наблюдение содержит элементы более чем одного класса, поэтому сеть не может решить, какая классификация является правильной. Обратите внимание, что низкая неоднозначность не обязательно подразумевает правильную классификацию; даже если сеть имеет высокую вероятность для класса, классификация все равно может быть неправильной.
[R,RI] = maxk(softmaxActivations,2,2); ambiguity = R(:,2)./R(:,1);
Найдите самые неоднозначные изображения.
[ambiguity,ambiguityIdx] = sort(ambiguity,"descend");
Просмотр наиболее вероятных классов неоднозначных изображений и истинных классов.
classList = unique(valImds.Labels); top10Idx = ambiguityIdx(1:10); top10Ambiguity = ambiguity(1:10); mostLikely = classList(RI(ambiguityIdx,1)); secondLikely = classList(RI(ambiguityIdx,2)); table(top10Idx,top10Ambiguity,mostLikely(1:10),secondLikely(1:10),valImds.Labels(ambiguityIdx(1:10)),... 'VariableNames',["Image #","Ambiguity","Likeliest","Second","True Class"])
ans=10×5 table
Image # Ambiguity Likeliest Second True Class
_______ _________ _________ ____________ ____________
94 0.9879 hamburger pizza hamburger
175 0.96311 hamburger french_fries hot_dog
179 0.94939 pizza hamburger hot_dog
337 0.93426 sushi sashimi sushi
256 0.92972 sushi pizza pizza
297 0.91776 sushi sashimi sashimi
283 0.80407 pizza sushi pizza
27 0.80278 hamburger pizza french_fries
302 0.79283 sashimi sushi sushi
201 0.76034 pizza greek_salad pizza
Сеть предсказывает, что изображение 27, скорее всего, является гамбургером или пиццей. Однако это изображение на самом деле фри. Просмотрите изображение, чтобы увидеть, почему может произойти эта неправильная классификация.
v = 27; figure(); imshow(valImds.Files{v}); title(sprintf("Observation: %i\n" + ... "Actual: %s. Predicted: %s", v, ... string(valImds.Labels(v)), string(YPred(v))), ... 'Interpreter', 'none');
Изображение содержит несколько различных областей, некоторые из которых могут запутать сеть.
Вычислите низкомерное представление данных сети для слоя раннего максимального объединения, последнего сверточного слоя и последнего слоя softmax. Используйте tsne
функция для уменьшения размерности данных активации от M до 2. Чем больше размерность активаций, тем больше требуется расчетов t-SNE. Поэтому расчет для слоя раннего максимального объединения, где активации имеют 200 704 размерности, занимает больше времени, чем для конечного слоя softmax. Установите случайный seed для воспроизводимости результата t-SNE.
rng default
pool1tsne = tsne(pool1Activations);
finalConvtsne = tsne(finalConvActivations);
softmaxtsne = tsne(softmaxActivations);
Метод t-SNE пытается сохранить расстояния так, чтобы точки рядом друг с другом в высокомерном представлении также находились рядом друг с другом в низкомерном представлении. Как показано в матрице неточностей, сеть эффективна при классификации в различные классы. Поэтому семантически похожие (или однотипные) изображения, такие как салат из цезаря и салат из капрезы, находятся рядом друг с другом в пространстве активаций softmax. t-SNE захватывает эту близость в 2-D представлении, которое легче понять и построить, чем девятимерные счета softmax.
Ранние слои, как правило, работают с низкоуровневыми функциями, такие как ребра и цвета. Более глубокие слои научились высокоуровневым функциям с более семантическим смыслом, такое как различие между пиццей и хот-догом. Поэтому активация из ранних слоев не показывает никакой кластеризации по классам. Два изображения, похожие по пикселям (для примера они оба содержат много зеленых пикселей), находятся рядом друг с другом в высоко-размерном пространстве активаций, независимо от их семантического содержимого. Активация из более поздних слоев имеет тенденцию к кластеризации точек из одного и того же класса вместе. Это поведение наиболее выражено в слое softmax и сохраняется в двумерном представлении t-SNE.
Постройте график данных t-SNE для слоя раннего максимального объединения, конечного сверточного слоя и конечного слоя softmax с помощью gscatter
функция. Обратите внимание, что активация раннего максимального объединения не показывает никакой кластеризации между изображениями одного и того же класса. Активации конечного сверточного слоя в некоторой степени кластеризуются по классам, но меньше, чем активации softmax. Различные цвета соответствуют наблюдениям разных классов.
doLegend = 'off'; markerSize = 7; figure; subplot(1,3,1); gscatter(pool1tsne(:,1),pool1tsne(:,2),valImds.Labels, ... [],'.',markerSize,doLegend); title("Max pooling activations"); subplot(1,3,2); gscatter(finalConvtsne(:,1),finalConvtsne(:,2),valImds.Labels, ... [],'.',markerSize,doLegend); title("Final conv activations"); subplot(1,3,3); gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels, ... [],'.',markerSize,doLegend); title("Softmax activations");
Создайте больший график активаций softmax, включая легенду, маркирующую каждый класс. Из графика t-SNE можно понять больше о структуре апостериорного распределения вероятностей.
Например, график показывает отдельное, отдельное скопление фри наблюдений, в то время как сасими и суши кластеры разрешены не очень хорошо. Подобно матрице неточностей, график предполагает, что сеть более точна при предсказании во французский класс фри.
numClasses = length(classList); colors = lines(numClasses); h = figure; gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels,colors); l = legend; l.Interpreter = "none"; l.Location = "bestoutside";
Можно также использовать t-SNE, чтобы определить, какие изображения неправильно классифицируются сетью и почему. Неправильные наблюдения часто являются изолированными точками неправильного цвета для их окружающего кластера. Например, неправильно классифицированное изображение гамбургера очень близко к области картофеля фри (зеленая точка, ближайшая к центру оранжевого кластера). Эта точка является наблюдением 99. Округлите это наблюдение на графике t-SNE и отобразите изображение с imshow
.
obs = 99; рисунок (h) держаться on; hs = рассеяние (softmaxtsne (obs, 1), softmaxtsne (obs, 2),... 'black','LineWidth',1.5); l.String {end} ='Hamburger'; держаться off; рисунок (); imshow (valImds.Files {obs}); заголовок (sprintf ("Observation: %i\n" + ... "Actual: %s. Predicted: %s", obs, ... string Меток (obs)), string (YPred (obs)), ... 'Interpreter', 'none');
Если изображение содержит несколько видов пищи, сеть может запутаться. В этом случае сеть классифицирует изображение как картофель фри, хотя еда на переднем плане является гамбургером. Фри-фри, видимые на краю изображения, вызывают путаницу.
Аналогично, неоднозначное изображение 27 (показанное ранее в этом примере) имеет несколько областей. Рассмотрим график t-SNE, подсвечивающий неоднозначный аспект этого изображения фри.
obs = 27; рисунок (h) держаться on; h = рассеяние (softmaxtsne (obs, 1), softmaxtsne (obs, 2),... 'k','d','LineWidth',1.5); l.String {end} ='French Fries'; держаться off;
Изображение находится не в четко определенном кластере на графике, что указывает на то, что классификация, вероятно, неправильна. Изображение далеко от кластера фри, и близко к кластеру гамбургеров.
Причины неправильной классификации должны быть обеспечены другой информацией, обычно гипотезой, основанной на содержимом изображения. Затем можно протестировать гипотезу с помощью других данных или с помощью инструментов, которые указывают, какие пространственные области изображения важны для классификации сети. Для примеров смотрите occlusionSensitivity
и Grad-CAM раскрывает причину решений по глубокому обучению.
[1] ван дер Маатен, Лоренс и Джеффри Хинтон. «Визуализация данных с использованием t-SNE». Journal of Машинное Обучение Research 9, 2008, pp. 2579-2605.
function downloadExampleFoodImagesData(url, dataDir) % Download the Example Food Image data set, containing 978 images of % different types of food split into 9 classes. % Copyright 2019 The MathWorks, Inc. fileName = "ExampleFoodImageDataset.zip"; fileFullPath = fullfile(dataDir, fileName); % Download the .zip file into a temporary directory. if ~exist(fileFullPath, "file") fprintf("Downloading MathWorks Example Food Image dataset...\n"); fprintf("This can take several minutes to download...\n"); websave(fileFullPath, url); fprintf("Download finished...\n"); else fprintf("Skipping download, file already exists...\n"); end % Unzip the file. % % Check if the file has already been unzipped by checking for the presence % of one of the class directories. exampleFolderFullPath = fullfile(dataDir, "pizza"); if ~exist(exampleFolderFullPath, "dir") fprintf("Unzipping file...\n"); unzip(fileFullPath, dataDir); fprintf("Unzipping finished...\n"); else fprintf("Skipping unzipping, file already unzipped...\n"); end fprintf("Done.\n"); end
activations
| classify
| layerGraph
| occlusionSensitivity
| squeezenet
| trainingOptions
| trainNetwork
| tsne
(Statistics and Machine Learning Toolbox)