Просмотр поведения сети с помощью tsne

В этом примере показано, как использовать tsne функция для просмотра активаций в обученной сети. Это представление может помочь вам понять, как работает сеть.

The tsne (Statistics and Machine Learning Toolbox) функция в Statistics and Machine Learning Toolbox™ реализует t-распределенное стохастическое соседнее встраивание (t-SNE) [1]. Этот метод сопоставляет высоко-размерные данные (такие как активация сети в слое) с двумя размерностями. Метод использует нелинейную карту, которая пытается сохранить расстояния. Используя t-SNE, чтобы визуализировать активацию сети, можно получить понимание того, как сеть реагирует.

Можно использовать t-SNE, чтобы визуализировать, как нейронным сетям для глубокого обучения изменяете представление входных данных, когда оно проходит через слои сети. Можно также использовать t-SNE, чтобы найти проблемы с входными данными и понять, какие наблюдения сеть классифицирует неправильно.

Для примера t-SNE может уменьшить многомерные активации слоя softmax до 2-D представления с подобной структурой. Плотные кластеры на полученном графике t-SNE соответствуют классам, которые сеть обычно классифицирует правильно. Визуализация позволяет найти точки, которые появляются в неправильном кластере, указывая на наблюдение, которое сеть классифицирует неправильно. Наблюдение может быть неправильно помечено, или сеть может предсказать, что наблюдение является образцом другого класса, потому что оно похоже на другие наблюдения этого класса. Обратите внимание, что уменьшение t-SNE активаций softmax использует только эти активации, а не базовые наблюдения.

Загрузка набора данных

Этот пример использует набор данных Example Food Images, который содержит 978 фотографий пищи в девяти классах и имеет размер приблизительно 77 МБ. Загрузите набор данных во временную директорию путем вызова downloadExampleFoodImagesData вспомогательная функция; код для этой вспомогательной функции появится в конце этого примера.

dataDir = fullfile(tempdir, "ExampleFoodImageDataset");
url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip";

if ~exist(dataDir, "dir")
    mkdir(dataDir);
end

downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset...
This can take several minutes to download...
Download finished...
Unzipping file...
Unzipping finished...
Done.

Обучите сеть классифицировать изображения продуктов питания

Измените предварительно обученную сеть SqueezeNet, чтобы классифицировать изображения пищи из набора данных. Замените конечный сверточный слой, который имеет 1000 фильтров для 1000 классов ImageNet, новым сверточным слоем, который имеет только девять фильтров. Каждый фильтр соответствует одному типу пищи.

lgraph = layerGraph(squeezenet());
lgraph = lgraph.replaceLayer("ClassificationLayer_predictions",...
    classificationLayer("Name", "ClassificationLayer_predictions"));

newConv =  convolution2dLayer([14 14], 9, "Name", "conv", "Padding", "same");
lgraph = lgraph.replaceLayer("conv10", newConv);

Создайте imageDatastore содержащие пути к данным изображения. Разделите datastore на наборы обучения и валидации, используя 65% данных для обучения и остальное для валидации. Поскольку набор данных довольно мал, сверхподбор кривой является значительной проблемой. Чтобы минимизировать сверхподбор кривой, увеличьте набор обучающих данных с помощью случайных щелчков и масштабирования.

imds = imageDatastore(dataDir, ...
    "IncludeSubfolders", true, "LabelSource", "foldernames");

aug = imageDataAugmenter("RandXReflection", true, ...
    "RandYReflection", true, ...
    "RandXScale", [0.8 1.2], ...
    "RandYScale", [0.8 1.2]);

trainingFraction = 0.65;
[trainImds,valImds] = splitEachLabel(imds, trainingFraction);

augImdsTrain = augmentedImageDatastore([227 227], trainImds, ...
    'DataAugmentation', aug);
augImdsVal = augmentedImageDatastore([227 227], valImds);

Создайте опции обучения и обучите сеть. SqueezeNet - это небольшая сеть, которую быстро обучать. Можно обучаться на графическом процессоре или центральном процессоре; этот пример обучает на центральном процессоре.

opts = trainingOptions("adam", ...
    "InitialLearnRate", 1e-4, ...
    "MaxEpochs", 30, ...
    "ValidationData", augImdsVal, ...
    "Verbose", false,...
    "Plots", "training-progress", ...
    "ExecutionEnvironment","cpu",...
    "MiniBatchSize",128);
rng default
net = trainNetwork(augImdsTrain, lgraph, opts);

Классификация данных валидации

Используйте сеть для классификации изображений в наборе валидации. Чтобы проверить, что сеть достаточно точна при классификации новых данных, постройте матрицу неточностей истинных и предсказанных меток.

figure();
YPred = classify(net,augImdsVal);
confusionchart(valImds.Labels,YPred,'ColumnSummary',"column-normalized")

Сеть хорошо классифицирует несколько изображений. Сеть, по-видимому, имеет проблемы с изображениями суши, классифицируя многих как суши, но некоторые как пицца или гамбургер. Сеть не классифицирует какие-либо изображения в класс hot dog.

Вычисление активаций для нескольких слоев

Чтобы продолжить анализ эффективности сети, вычислите активации для каждого наблюдения в наборе данных на уровне раннего максимального объединения, конечном слое свертки и конечном слое softmax. Вывод активаций как матрицы NxM, где N - количество наблюдений, а M - количество размерностей активации. M является продуктом пространственных и канальных размерностей. Каждая строка является наблюдением, а каждый столбец - размерностью. На слое softmax M = 9, потому что набор данных о пищевых продуктах имеет девять классов. Каждая строка в матрице содержит девять элементов, соответствующих вероятностям того, что наблюдение принадлежит каждому из девяти классов пищи.

earlyLayerName = "pool1";
finalConvLayerName = "conv";
softmaxLayerName = "prob";
pool1Activations = activations(net,...
    augImdsVal,earlyLayerName,"OutputAs","rows");
finalConvActivations = activations(net,...
    augImdsVal,finalConvLayerName,"OutputAs","rows");
softmaxActivations = activations(net,...
    augImdsVal,softmaxLayerName,"OutputAs","rows");

Неоднозначность классификаций

Можно использовать активации softmax, чтобы вычислить классификации изображений, которые, скорее всего, будут неправильными. Задайте неоднозначность классификации как отношение второй по величине вероятности к наибольшей вероятности. Неоднозначность классификации находится между нулем (почти определенная классификация) и 1 (почти так же вероятно, что и второй класс). Неоднозначность около 1 означает, что сеть не уверена в классе, в котором принадлежит конкретное изображение. Эта неопределенность может быть вызвана двумя классами, наблюдения которых кажутся настолько похожими на сеть, что она не может узнать различия между ними. Или может возникнуть высокая неоднозначность, потому что конкретное наблюдение содержит элементы более чем одного класса, поэтому сеть не может решить, какая классификация является правильной. Обратите внимание, что низкая неоднозначность не обязательно подразумевает правильную классификацию; даже если сеть имеет высокую вероятность для класса, классификация все равно может быть неправильной.

[R,RI] = maxk(softmaxActivations,2,2);
ambiguity = R(:,2)./R(:,1);

Найдите самые неоднозначные изображения.

[ambiguity,ambiguityIdx] = sort(ambiguity,"descend");

Просмотр наиболее вероятных классов неоднозначных изображений и истинных классов.

classList = unique(valImds.Labels);
top10Idx = ambiguityIdx(1:10);
top10Ambiguity = ambiguity(1:10);
mostLikely = classList(RI(ambiguityIdx,1));
secondLikely = classList(RI(ambiguityIdx,2));
table(top10Idx,top10Ambiguity,mostLikely(1:10),secondLikely(1:10),valImds.Labels(ambiguityIdx(1:10)),...
    'VariableNames',["Image #","Ambiguity","Likeliest","Second","True Class"])
ans=10×5 table
    Image #    Ambiguity    Likeliest       Second        True Class 
    _______    _________    _________    ____________    ____________

       94        0.9879     hamburger    pizza           hamburger   
      175       0.96311     hamburger    french_fries    hot_dog     
      179       0.94939     pizza        hamburger       hot_dog     
      337       0.93426     sushi        sashimi         sushi       
      256       0.92972     sushi        pizza           pizza       
      297       0.91776     sushi        sashimi         sashimi     
      283       0.80407     pizza        sushi           pizza       
       27       0.80278     hamburger    pizza           french_fries
      302       0.79283     sashimi      sushi           sushi       
      201       0.76034     pizza        greek_salad     pizza       

Сеть предсказывает, что изображение 27, скорее всего, является гамбургером или пиццей. Однако это изображение на самом деле фри. Просмотрите изображение, чтобы увидеть, почему может произойти эта неправильная классификация.

v = 27;
figure();
imshow(valImds.Files{v});
title(sprintf("Observation: %i\n" + ...
    "Actual: %s. Predicted: %s", v, ...
    string(valImds.Labels(v)), string(YPred(v))), ...
    'Interpreter', 'none');

Изображение содержит несколько различных областей, некоторые из которых могут запутать сеть.

Вычисление 2-D представлений данных с использованием t-SNE

Вычислите низкомерное представление данных сети для слоя раннего максимального объединения, последнего сверточного слоя и последнего слоя softmax. Используйте tsne функция для уменьшения размерности данных активации от M до 2. Чем больше размерность активаций, тем больше требуется расчетов t-SNE. Поэтому расчет для слоя раннего максимального объединения, где активации имеют 200 704 размерности, занимает больше времени, чем для конечного слоя softmax. Установите случайный seed для воспроизводимости результата t-SNE.

rng default
pool1tsne = tsne(pool1Activations);
finalConvtsne = tsne(finalConvActivations);
softmaxtsne = tsne(softmaxActivations);

Сравнение поведения сети для ранних и более поздних слоев

Метод t-SNE пытается сохранить расстояния так, чтобы точки рядом друг с другом в высокомерном представлении также находились рядом друг с другом в низкомерном представлении. Как показано в матрице неточностей, сеть эффективна при классификации в различные классы. Поэтому семантически похожие (или однотипные) изображения, такие как салат из цезаря и салат из капрезы, находятся рядом друг с другом в пространстве активаций softmax. t-SNE захватывает эту близость в 2-D представлении, которое легче понять и построить, чем девятимерные счета softmax.

Ранние слои, как правило, работают с низкоуровневыми функциями, такие как ребра и цвета. Более глубокие слои научились высокоуровневым функциям с более семантическим смыслом, такое как различие между пиццей и хот-догом. Поэтому активация из ранних слоев не показывает никакой кластеризации по классам. Два изображения, похожие по пикселям (для примера они оба содержат много зеленых пикселей), находятся рядом друг с другом в высоко-размерном пространстве активаций, независимо от их семантического содержимого. Активация из более поздних слоев имеет тенденцию к кластеризации точек из одного и того же класса вместе. Это поведение наиболее выражено в слое softmax и сохраняется в двумерном представлении t-SNE.

Постройте график данных t-SNE для слоя раннего максимального объединения, конечного сверточного слоя и конечного слоя softmax с помощью gscatter функция. Обратите внимание, что активация раннего максимального объединения не показывает никакой кластеризации между изображениями одного и того же класса. Активации конечного сверточного слоя в некоторой степени кластеризуются по классам, но меньше, чем активации softmax. Различные цвета соответствуют наблюдениям разных классов.

doLegend = 'off';
markerSize = 7;
figure;

subplot(1,3,1);
gscatter(pool1tsne(:,1),pool1tsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Max pooling activations");

subplot(1,3,2);
gscatter(finalConvtsne(:,1),finalConvtsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Final conv activations");

subplot(1,3,3);
gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Softmax activations");

Исследование наблюдений на графике t-SNE

Создайте больший график активаций softmax, включая легенду, маркирующую каждый класс. Из графика t-SNE можно понять больше о структуре апостериорного распределения вероятностей.

Например, график показывает отдельное, отдельное скопление фри наблюдений, в то время как сасими и суши кластеры разрешены не очень хорошо. Подобно матрице неточностей, график предполагает, что сеть более точна при предсказании во французский класс фри.

numClasses = length(classList);
colors = lines(numClasses);
h = figure;
gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels,colors);

l = legend;
l.Interpreter = "none";
l.Location = "bestoutside";

Можно также использовать t-SNE, чтобы определить, какие изображения неправильно классифицируются сетью и почему. Неправильные наблюдения часто являются изолированными точками неправильного цвета для их окружающего кластера. Например, неправильно классифицированное изображение гамбургера очень близко к области картофеля фри (зеленая точка, ближайшая к центру оранжевого кластера). Эта точка является наблюдением 99. Округлите это наблюдение на графике t-SNE и отобразите изображение с imshow.

obs = 99;
рисунок (h)
держаться on;
hs = рассеяние (softmaxtsne (obs, 1), softmaxtsne (obs,  2),...
    'black','LineWidth',1.5);
l.String {end}  ='Hamburger';
держаться off;
рисунок ();
imshow (valImds.Files {obs});
заголовок (sprintf ("Observation: %i\n" + ...
    "Actual: %s. Predicted: %s", obs, ...
    string Меток (obs)), string (YPred (obs)), ...
    'Interpreter', 'none');

Если изображение содержит несколько видов пищи, сеть может запутаться. В этом случае сеть классифицирует изображение как картофель фри, хотя еда на переднем плане является гамбургером. Фри-фри, видимые на краю изображения, вызывают путаницу.

Аналогично, неоднозначное изображение 27 (показанное ранее в этом примере) имеет несколько областей. Рассмотрим график t-SNE, подсвечивающий неоднозначный аспект этого изображения фри.

obs = 27;
рисунок (h)
держаться on;
h = рассеяние (softmaxtsne (obs, 1), softmaxtsne (obs,  2),...
    'k','d','LineWidth',1.5);
l.String {end}  ='French Fries';
держаться off;

Изображение находится не в четко определенном кластере на графике, что указывает на то, что классификация, вероятно, неправильна. Изображение далеко от кластера фри, и близко к кластеру гамбургеров.

Причины неправильной классификации должны быть обеспечены другой информацией, обычно гипотезой, основанной на содержимом изображения. Затем можно протестировать гипотезу с помощью других данных или с помощью инструментов, которые указывают, какие пространственные области изображения важны для классификации сети. Для примеров смотрите occlusionSensitivity и Grad-CAM раскрывает причину решений по глубокому обучению.

Ссылки

[1] ван дер Маатен, Лоренс и Джеффри Хинтон. «Визуализация данных с использованием t-SNE». Journal of Машинное Обучение Research 9, 2008, pp. 2579-2605.

Функция помощника

function downloadExampleFoodImagesData(url, dataDir)
% Download the Example Food Image data set, containing 978 images of
% different types of food split into 9 classes.

% Copyright 2019 The MathWorks, Inc.

fileName = "ExampleFoodImageDataset.zip";
fileFullPath = fullfile(dataDir, fileName);

% Download the .zip file into a temporary directory.
if ~exist(fileFullPath, "file")
    fprintf("Downloading MathWorks Example Food Image dataset...\n");
    fprintf("This can take several minutes to download...\n");
    websave(fileFullPath, url);
    fprintf("Download finished...\n");
else
    fprintf("Skipping download, file already exists...\n");
end

% Unzip the file.
%
% Check if the file has already been unzipped by checking for the presence
% of one of the class directories.
exampleFolderFullPath = fullfile(dataDir, "pizza");
if ~exist(exampleFolderFullPath, "dir")
    fprintf("Unzipping file...\n");
    unzip(fileFullPath, dataDir);
    fprintf("Unzipping finished...\n");
else
    fprintf("Skipping unzipping, file already unzipped...\n");
end
fprintf("Done.\n");

end

См. также

| | | | | | | (Statistics and Machine Learning Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте