В этом примере показано, как использовать локально поддающиеся толкованию объяснения модели агностические (LIME), чтобы исследовать робастность глубокой сверточной нейронной сети, обученной классифицировать спектрограммы. LIME является методом для визуализации, которую части наблюдения вносят в решение классификации о сети. Этот пример использует imageLIME
функция, чтобы понять, какие функции в данных о спектрограмме являются самыми важными для классификации.
В этом примере вы создаете и обучаете нейронную сеть, чтобы классифицировать четыре вида симулированных данных временных рядов:
Sine wave одной частоты
Суперпозиция трех синусоид
Широкий Гауссов peaks во временных рядах
Гауссовы импульсы во временных рядах
Чтобы сделать эту проблему более реалистичной, временные ряды включают добавленные сигналы соединения: постоянная низкочастотная фоновая синусоида и большая сумма высокочастотного шума. Шумные данные временных рядов являются сложной проблемой классификации последовательностей. Можно приблизиться к проблеме первым преобразованием данных временных рядов в спектрограмму частоты времени, чтобы показать базовые функции в данных временных рядов. Можно затем ввести спектрограммы к сети классификации изображений.
Сгенерируйте данные временных рядов для этих четырех классов. Этот пример использует функцию помощника generateSpectrogramData
сгенерировать временные ряды и соответствующие данные о спектрограмме. Функции помощника, используемые в этом примере, присоединяются как вспомогательные файлы.
numObsPerClass = 500; classes = categorical(["SingleFrequency","ThreeFrequency","Gaussian","Pulse"]); numClasses = length(classes); [noisyTimeSeries,spectrograms,labels] = generateSpectrogramData(numObsPerClass,classes);
Вычислите размер изображений спектрограммы и количество наблюдений.
inputSize = size(spectrograms, [1 2]); numObs = size(spectrograms,4);
Постройте подмножество данных временных рядов с добавленным шумом. Поскольку шум имеет сопоставимую амплитуду к сигналу, данные кажутся шумными во временном интервале. Эта функция делает классификацию сложной проблемой.
figure numPlots = 12; for i=1:numPlots subplot(3,4,i) plot(noisyTimeSeries(i,:)) title(labels(i)) end
Постройте спектрограммы частоты времени зашумленных данных в том же порядке как графики временных рядов. Горизонтальная ось время, и вертикальная ось является частотой.
figure for i=1:12 subplot(3,4,i) imshow(spectrograms(:,:,1,i)) hold on colormap parula title(labels(i)) hold off end
Функции от каждого класса ясно отображаются, демонстрируя, почему преобразование от временного интервала до изображений спектрограммы может быть выгодным для этого типа проблемы. Например, SingleFrequency
класс имеет один пик на основной частоте, видимой как горизонтальная планка в спектрограмме. Для ThreeFrequency
класс, эти три частоты отображаются.
Все классы отображают слабую полосу в низкой частоте (около верхней части изображения), соответствуя фоновой синусоиде.
Используйте splitlabels
функционируйте, чтобы разделить данные на данные об обучении и валидации. Используйте 80% данных для обучения и 20% для валидации.
splitIndices = splitlabels(labels,0.8); trainLabels = labels(splitIndices{1}); trainSpectrograms = spectrograms(:,:,:,splitIndices{1}); valLabels = labels(splitIndices{2}); valSpectrograms = spectrograms(:,:,:,splitIndices{2});
Создайте сверточную нейронную сеть с блоками свертки, нормализации партии. и слоев ReLU.
dropoutProb = 0.2; numFilters = 8; layers = [ imageInputLayer(inputSize) convolution2dLayer(3,numFilters,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,2*numFilters,'Padding','same') batchNormalizationLayer convolution2dLayer(3,4*numFilters,'Padding','same') batchNormalizationLayer reluLayer globalMaxPooling2dLayer dropoutLayer(dropoutProb) fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Задайте опции для обучения с помощью оптимизатора SGDM. Переставьте данные каждая эпоха путем установки 'Shuffle'
опция к 'every-epoch'
. Контролируйте процесс обучения путем установки 'Plots'
опция к 'training-progress'
. Чтобы подавить многословный выход, установите 'Verbose'
к false
.
options = trainingOptions('sgdm', ... 'Shuffle','every-epoch', ... 'Plots','training-progress', ... 'Verbose',false, ... 'ValidationData',{valSpectrograms,valLabels});
Обучите сеть, чтобы классифицировать изображения спектрограммы.
net = trainNetwork(trainSpectrograms,trainLabels,layers,options);
Классифицируйте наблюдения валидации с помощью обучившего сеть.
predLabels = classify(net,valSpectrograms);
Исследуйте производительность сети путем графического вывода матрицы беспорядка с confusionchart
.
figure confusionchart(valLabels,predLabels,'Normalization','row-normalized')
Сеть точно классифицирует спектрограммы валидации, с близко к 100%-й точности для большинства классов.
Используйте imageLIME
функция, чтобы понять, какие функции в данных изображения являются самыми важными для классификации.
Метод LIME сегментирует изображение на несколько функций и генерирует синтетические наблюдения случайным образом включая или, исключая функции. Каждый пиксель в исключенной функции заменяется значением среднего пикселя изображения. Сеть классифицирует эти синтетические наблюдения и использует получившуюся музыку к предсказанному классу, наряду с присутствием или отсутствием функции, как ответы и предикторы, чтобы обучить проблему регрессии с более простой моделью — в этом примере, дереве регрессии. Дерево регрессии пытается аппроксимировать поведение сети на одном наблюдении. Это учится, какие функции важны и значительно влияют на счет класса.
По умолчанию, imageLIME
суперпиксельная сегментация использования, чтобы разделить изображение на функции. Эта опция работает хорошо на естественные изображения, но является менее эффективной для данных о спектрограмме. Можно задать пользовательскую карту сегментации путем установки 'Segmentation'
аргумент значения имени к числовому массиву тот же размер как изображение, где каждым элементом является целое число, соответствующее индексу функции, в которой находится пиксель.
Для данных о спектрограмме изображения спектрограммы имеют намного более прекрасные функции в y-размерности (частота), чем x-размерность (время). Сгенерируйте карту сегментации с 240 сегментами, в 40 6 сетка, чтобы обеспечить более высокое разрешение частоты. Сверхдискретизируйте сетку к размеру изображения при помощи imresize
функция, задавая метод повышающей дискретизации как 'nearest'
.
featureIdx = 1:240;
segmentationMap = reshape(featureIdx,6,40)';
segmentationMap = imresize(segmentationMap,inputSize,'nearest');
Постройте спектрограмму и вычислите карту LIME для двух наблюдений от каждого класса.
obsToShowPerClass = 2; for j=1:obsToShowPerClass figure for i=1:length(classes) idx = find(valLabels == classes(i),obsToShowPerClass); % Read the test image and label. testSpectrogram = valSpectrograms(:,:,:,idx(j)); testLabel = valLabels(idx(j)); % Compute the LIME importance map. map = imageLIME(net,testSpectrogram,testLabel, ... 'NumSamples',4096, ... 'Segmentation',segmentationMap); % Rescale the map to the size of the image. mapRescale = uint8(255*rescale(map)); % Plot the spectrogram image next to the LIME map. subplot(2,2,i) imshow(imtile({testSpectrogram,mapRescale})) title(string(testLabel)) colormap parula end end
Карты LIME демонстрируют, что для большинства классов, сеть фокусируется на соответствующих функциях классификации. Например, для SingleFrequency
класс, сеть фокусируется на частоте, соответствующей спектру мощности синусоиды а не на побочных фоновых деталях или шуме.
Для SingleFrequency
класс, сеть использует частоту, чтобы классифицировать. Для Pulse
и Gaussian
классы, сеть дополнительно фокусируется на правильной части частоты спектрограммы. Для этих трех классов сеть не перепутана фоновой частотой, видимой около верхней части всех спектрограмм. Эта информация не полезна для различения этих классов (когда это присутствует во всех классах), таким образом, сеть игнорирует его. В отличие от этого для ThreeFrequency
класс, постоянная фоновая частота относится к решению классификации о сети. Для этого класса сеть не игнорирует эту частоту, но обрабатывает его с подобной важностью для трех фактических частот.
imageLIME
результаты демонстрируют, что сеть правильно использует peaks в спектрограммах частоты времени и не перепутана побочной фоновой синусоидой для всех классов за исключением ThreeFrequency
класс, где сеть не различает эти три частоты в сигнале и низкочастотном фоне.
imageLIME
| trainNetwork
| pspectrum
(Signal Processing Toolbox)