Просмотрите сетевое поведение Используя tsne

В этом примере показано, как использовать tsne функционируйте, чтобы просмотреть активации в обучившем сеть. Это представление может помочь вам изучить, как работает сеть.

tsne (Statistics and Machine Learning Toolbox) функция в Statistics and Machine Learning Toolbox™ реализует t-distributed стохастического соседа, встраивающего (t-SNE) [1]. Этот метод сопоставляет высоко-размерные данные (такие как сетевые активации в слое) к двум размерностям. Метод использует нелинейную карту, которая пытается сохранить расстояния. При помощи t-SNE, чтобы визуализировать сетевые активации, можно получить понимание того, как сеть отвечает.

Можно использовать t-SNE, чтобы визуализировать, как нейронные сети для глубокого обучения изменяют представление входных данных, когда это проходит через слоя сети. Можно также использовать t-SNE, чтобы найти проблемы с входными данными и понять, какие наблюдения сеть классифицирует неправильно.

Например, t-SNE может уменьшать многомерные активации softmax слоя к 2D представлению с подобной структурой. Высокие кластеры в получившемся графике t-SNE соответствуют классам, которые сеть обычно классифицирует правильно. Визуализация позволяет вам находить точки, которые появляются в неправильном кластере, указывая на наблюдение, что сеть классифицирует неправильно. Наблюдение может быть помечено неправильно, или сетевая сила предсказывает, что наблюдение является экземпляром различного класса, потому что это кажется похожим на другие наблюдения от того класса. Обратите внимание на то, что t-SNE сокращение softmax активаций использует только те активации, не базовые наблюдения.

Загрузите набор данных

Этот пример использует Продовольственный набор данных Изображений В качестве примера, который содержит 978 фотографий еды в девяти классах и составляет приблизительно 77 Мбайт в размере. Загрузите набор данных в свою временную директорию путем вызова downloadExampleFoodImagesData функция помощника; код для этой функции помощника появляется в конце этого примера.

dataDir = fullfile(tempdir, "ExampleFoodImageDataset");
url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip";

if ~exist(dataDir, "dir")
    mkdir(dataDir);
end

downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset...
This can take several minutes to download...
Download finished...
Unzipping file...
Unzipping finished...
Done.

Обучите сеть, чтобы классифицировать продовольственные изображения

Измените предварительно обученную сеть SqueezeNet, чтобы классифицировать изображения еды от набора данных. Замените итоговый сверточный слой, который имеет 1 000 фильтров для 1 000 классов ImageNet с новым сверточным слоем, который имеет только девять фильтров. Каждый фильтр соответствует одному типу еды.

lgraph = layerGraph(squeezenet());
lgraph = lgraph.replaceLayer("ClassificationLayer_predictions",...
    classificationLayer("Name", "ClassificationLayer_predictions"));

newConv =  convolution2dLayer([14 14], 9, "Name", "conv", "Padding", "same");
lgraph = lgraph.replaceLayer("conv10", newConv);

Создайте imageDatastore содержа пути к данным изображения. Разделите datastore в наборы обучения и валидации, с помощью 65% данных для обучения и остальных для валидации. Поскольку набор данных справедливо мал, сверхподбор кривой является значительной проблемой. Чтобы минимизировать сверхподбор кривой, увеличьте набор обучающих данных со случайными щелчками и масштабированием.

imds = imageDatastore(dataDir, ...
    "IncludeSubfolders", true, "LabelSource", "foldernames");

aug = imageDataAugmenter("RandXReflection", true, ...
    "RandYReflection", true, ...
    "RandXScale", [0.8 1.2], ...
    "RandYScale", [0.8 1.2]);

trainingFraction = 0.65;
[trainImds,valImds] = splitEachLabel(imds, trainingFraction);

augImdsTrain = augmentedImageDatastore([227 227], trainImds, ...
    'DataAugmentation', aug);
augImdsVal = augmentedImageDatastore([227 227], valImds);

Создайте опции обучения и обучите сеть. SqueezeNet является маленькой сетью, которая быстра, чтобы обучаться. Можно обучаться на графическом процессоре или центральном процессоре; этот пример обучается на центральном процессоре.

opts = trainingOptions("adam", ...
    "InitialLearnRate", 1e-4, ...
    "MaxEpochs", 30, ...
    "ValidationData", augImdsVal, ...
    "Verbose", false,...
    "Plots", "training-progress", ...
    "ExecutionEnvironment","cpu",...
    "MiniBatchSize",128);
rng default
net = trainNetwork(augImdsTrain, lgraph, opts);

Классифицируйте данные о валидации

Используйте сеть, чтобы классифицировать изображения на набор валидации. Чтобы проверить, что сеть довольно точна при классификации новых данных, постройте матрицу беспорядка истины и предсказанных меток.

figure();
YPred = classify(net,augImdsVal);
confusionchart(valImds.Labels,YPred,'ColumnSummary',"column-normalized")

Сеть классифицирует несколько изображений хорошо. Сеть, кажется, испытывает затруднения из-за изображений суши, классифицируя многих как суши, но некоторых как пицца или гамбургер. Сеть не классифицирует изображений в класс хот-дога.

Вычислите активации для нескольких слоев

Чтобы продолжить анализировать производительность сети, вычислите активации для каждого наблюдения в наборе данных на раннем макс. слое объединения, итоговом сверточном слое и финале softmax слой. Выведите активации как матрицу NxM, где N является количеством наблюдений, и M является количеством размерностей активации. M является продуктом размерностей канала и пространственных. Каждая строка является наблюдением, и каждый столбец является размерностью. На softmax слое M = 9, потому что продовольственный набор данных имеет девять классов. Каждая строка в матрице содержит девять элементов, соответствуя вероятностям, что наблюдение принадлежит каждому из девяти классов еды.

earlyLayerName = "pool1";
finalConvLayerName = "conv";
softmaxLayerName = "prob";
pool1Activations = activations(net,...
    augImdsVal,earlyLayerName,"OutputAs","rows");
finalConvActivations = activations(net,...
    augImdsVal,finalConvLayerName,"OutputAs","rows");
softmaxActivations = activations(net,...
    augImdsVal,softmaxLayerName,"OutputAs","rows");

Неоднозначность классификаций

Можно использовать softmax активации, чтобы вычислить классификации изображений, которые, скорее всего, будут неправильными. Задайте неоднозначность классификации как отношение второй по величине вероятности к самой большой вероятности. Неоднозначность классификации между нулем (почти определенная классификация) и 1 (почти настолько же, вероятно, чтобы быть классифицированной к наиболее вероятному классу как второй класс). Неоднозначность приблизительно 1 среднего значения, сеть не уверена в классе, которого принадлежит конкретное изображение. Эта неопределенность может быть вызвана двумя классами, наблюдения которых кажутся столь похожими на сеть, что это не может изучить различия между ними. Или, высокая неоднозначность может произойти, потому что конкретное наблюдение содержит элементы больше чем одного класса, таким образом, сеть не может решить, какая классификация правильна. Обратите внимание на то, что низкая неоднозначность не обязательно подразумевает правильную классификацию; даже если сеть имеет высокую вероятность для класса, классификация может все еще быть неправильной.

[R,RI] = maxk(softmaxActivations,2,2);
ambiguity = R(:,2)./R(:,1);

Найдите самые неоднозначные изображения.

[ambiguity,ambiguityIdx] = sort(ambiguity,"descend");

Просмотрите самые вероятные классы неоднозначных изображений и истинные классы.

classList = unique(valImds.Labels);
top10Idx = ambiguityIdx(1:10);
top10Ambiguity = ambiguity(1:10);
mostLikely = classList(RI(ambiguityIdx,1));
secondLikely = classList(RI(ambiguityIdx,2));
table(top10Idx,top10Ambiguity,mostLikely(1:10),secondLikely(1:10),valImds.Labels(ambiguityIdx(1:10)),...
    'VariableNames',["Image #","Ambiguity","Likeliest","Second","True Class"])
ans=10×5 table
    Image #    Ambiguity    Likeliest       Second        True Class 
    _______    _________    _________    ____________    ____________

       94        0.9879     hamburger    pizza           hamburger   
      175       0.96311     hamburger    french_fries    hot_dog     
      179       0.94939     pizza        hamburger       hot_dog     
      337       0.93426     sushi        sashimi         sushi       
      256       0.92972     sushi        pizza           pizza       
      297       0.91776     sushi        sashimi         sashimi     
      283       0.80407     pizza        sushi           pizza       
       27       0.80278     hamburger    pizza           french_fries
      302       0.79283     sashimi      sushi           sushi       
      201       0.76034     pizza        greek_salad     pizza       

Сеть предсказывает, что отображают 27, наиболее вероятный гамбургер или пицца. Однако это изображение является на самом деле картофелем фри. Просмотрите изображение, чтобы видеть, почему это misclassification сила происходит.

v = 27;
figure();
imshow(valImds.Files{v});
title(sprintf("Observation: %i\n" + ...
    "Actual: %s. Predicted: %s", v, ...
    string(valImds.Labels(v)), string(YPred(v))), ...
    'Interpreter', 'none');

Изображение содержит несколько отличных областей, некоторые из которых могут перепутать сеть.

Вычислите 2D Представления Данных Используя t-SNE

Вычислите низко-размерное представление сетевых данных для раннего макс. слоя объединения, итогового сверточного слоя и финала softmax слой. Используйте tsne функция, чтобы уменьшать размерность данных об активации от M до 2. Чем больше размерность активаций, тем дольше t-SNE расчет берет. Поэтому расчет для раннего макс. слоя объединения, где активации имеют 200 704 размерности, занимает больше времени, чем у финала softmax слой. Установите случайный seed для воспроизводимости результата t-SNE.

rng default
pool1tsne = tsne(pool1Activations);
finalConvtsne = tsne(finalConvActivations);
softmaxtsne = tsne(softmaxActivations);

Сравните сетевое поведение для ранних и более поздних слоев

t-SNE метод пытается сохранить расстояния так, чтобы точки друг около друга в высоко-размерном представлении также нашлись друг около друга в низко-размерном представлении. Как показано в матрице беспорядка, сеть является эффективной при классификации в различные классы. Поэтому изображения, которые семантически подобны (или того же типа), таковы как салат Цезарь и caprese салат, находятся друг около друга на softmax пробеле активаций. t-SNE получает эту близость в 2D представлении, которое легче изучить и построить, чем девятимерные softmax баллы.

Ранние слои имеют тенденцию работать с низкоуровневыми функциями, такими как ребра и цвета. Более глубокие слои изучили высокоуровневые функции с большим количеством семантического значения, такие как различие между пиццей и хот-догом. Поэтому активации от ранних слоев не показывают кластеризации классом. Два изображения, которые являются подобным pixelwise (например, они оба содержат много зеленых пикселей) находятся друг около друга в высоком мерном пространстве активаций, независимо от их семантического содержимого. Активации от более поздних слоев стремятся к точкам накопления от того же класса вместе. Это поведение является самым явным на softmax слое и сохраняется в двумерном t-SNE представлении.

Отобразите t-SNE данные на графике для раннего макс. слоя объединения, итогового сверточного слоя и финала softmax слой с помощью gscatter функция. Заметьте, что ранние макс. активации объединения не показывают кластеризации между изображениями того же класса. Активации итогового сверточного слоя кластеризируются классом в некоторой степени, но меньше, чем softmax активации. Различные цвета соответствуют наблюдениям за различными классами.

doLegend = 'off';
markerSize = 7;
figure;

subplot(1,3,1);
gscatter(pool1tsne(:,1),pool1tsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Max pooling activations");

subplot(1,3,2);
gscatter(finalConvtsne(:,1),finalConvtsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Final conv activations");

subplot(1,3,3);
gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Softmax activations");

Исследуйте Наблюдения в Графике t-SNE

Создайте больший график softmax активаций, включая легенду, маркирующую каждый класс. Из графика t-SNE можно понять больше о структуре распределения апостериорной вероятности.

Например, график показывает отличный, отдельный кластер наблюдений картофеля фри, тогда как кластеры сашими и суши не разрешены очень хорошо. Подобно матрице беспорядка график предполагает, что сеть более точна при предсказании в класс картофеля фри.

numClasses = length(classList);
colors = lines(numClasses);
h = figure;
gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels,colors);

l = legend;
l.Interpreter = "none";
l.Location = "bestoutside";

Можно также использовать t-SNE, чтобы определить, какие изображения неправильно классифицируются сетью и почему. Неправильные наблюдения часто являются изолируемыми точками неправильного цвета для их окружающего кластера. Например, неправильно классифицированное изображение гамбургера очень около области картофеля фри (зеленая точка, самая близкая центр оранжевого кластера). Эта точка является наблюдением 99. Окружите это наблюдение относительно графика t-SNE и отобразите изображение с imshow.

obs = 99;
фигура (h)
содержание on;
hs = рассеяние (softmaxtsne (obs, 1), softmaxtsne (obs, 2), ...
    'black','LineWidth',1.5);
l.String {конец} = 'Hamburger';
содержание off;
фигура;
imshow (valImds.Files {obs});
заголовок (sprintf ("Observation: %i\n" + ...
    "Actual: %s. Predicted: %s", obs, ...
    строка (valImds.Labels (obs)), строка (YPred (obs))), ...
    'Interpreter', 'none');

Если изображение содержит несколько типов еды, сеть может запутаться. В этом случае сеть классифицирует изображение как картофель фри даже при том, что еда на переднем плане является гамбургером. Картофель фри, видимый в ребре изображения, вызывает беспорядок.

Точно так же неоднозначное изображение 27 (показанный ранее в примере) имеет несколько областей. Исследуйте график t-SNE, подсвечивающий неоднозначный аспект этого изображения картофеля фри.

obs = 27;
фигура (h)
содержание on;
h = рассеяние (softmaxtsne (obs, 1), softmaxtsne (obs, 2), ...
    'k','d','LineWidth',1.5);
l.String {конец} = 'French Fries';
содержание off;

Изображение не находится в четко определенном кластере в графике, который указывает, что классификация является, вероятно, неправильной. Изображение далеко от кластера картофеля фри, и близко к кластеру гамбургера.

Почему из misclassification должен быть обеспечен другой информацией, обычно гипотеза на основе содержимого изображения. Можно затем протестировать гипотезу с помощью других данных, или с помощью инструментов, которые указывают, какие пространственные области изображения важны для сетевой классификации. Для примеров смотрите occlusionSensitivity и CAM градиента показывает почему позади решений глубокого обучения.

Ссылки

[1] ван дер Маатен, Лоренс и Джеффри Хинтон. "Визуализируя Данные с помощью t-SNE". Журнал Исследования Машинного обучения 9, 2008, стр 2579–2605.

Функция помощника

function downloadExampleFoodImagesData(url, dataDir)
% Download the Example Food Image data set, containing 978 images of
% different types of food split into 9 classes.

% Copyright 2019 The MathWorks, Inc.

fileName = "ExampleFoodImageDataset.zip";
fileFullPath = fullfile(dataDir, fileName);

% Download the .zip file into a temporary directory.
if ~exist(fileFullPath, "file")
    fprintf("Downloading MathWorks Example Food Image dataset...\n");
    fprintf("This can take several minutes to download...\n");
    websave(fileFullPath, url);
    fprintf("Download finished...\n");
else
    fprintf("Skipping download, file already exists...\n");
end

% Unzip the file.
%
% Check if the file has already been unzipped by checking for the presence
% of one of the class directories.
exampleFolderFullPath = fullfile(dataDir, "pizza");
if ~exist(exampleFolderFullPath, "dir")
    fprintf("Unzipping file...\n");
    unzip(fileFullPath, dataDir);
    fprintf("Unzipping finished...\n");
else
    fprintf("Skipping unzipping, file already unzipped...\n");
end
fprintf("Done.\n");

end

Смотрите также

| | | | | | | (Statistics and Machine Learning Toolbox)

Похожие темы