Этот пример показывает, как использовать локально интерпретируемые модели-агностические объяснения (LIME) для исследования робастности глубокой сверточной нейронной сети, обученной классифицировать спектрограммы. LIME является методом визуализации, какие части наблюдения способствуют классификационному решению сети. Этот пример использует imageLIME
функция, чтобы понять, какие функции в данных спектрограммы наиболее важны для классификации.
В этом примере вы создаете и обучаете нейронную сеть для классификации четырех видов моделируемых данных временных рядов:
Синусоиды одной частоты
Суперпозиция трёх синусоид
Широкие Гауссовы peaks во временных рядах
Гауссовы импульсы во временных рядах
Чтобы сделать эту задачу более реалистичной, временные ряды включают добавленные смешивающие сигналы: постоянную низкочастотную фоновую синусоиду и большое количество высокочастотного шума. Зашумленные данные временных рядов являются сложной задачей классификации последовательностей. Можно подойти к проблеме, сначала преобразовав данные временных рядов в частотно-временную спектрограмму, чтобы выявить базовые функции в данных временных рядов. Затем можно ввести спектрограммы в сеть классификации изображений.
Сгенерируйте данные временных рядов для четырех классов. Этот пример использует функцию helper generateSpectrogramData
для генерации временных рядов и соответствующих данных спектрограммы. Вспомогательные функции, используемые в этом примере, присоединены как вспомогательные файлы.
numObsPerClass = 500; classes = categorical(["SingleFrequency","ThreeFrequency","Gaussian","Pulse"]); numClasses = length(classes); [noisyTimeSeries,spectrograms,labels] = generateSpectrogramData(numObsPerClass,classes);
Вычислите размер спектрограммы изображений и количество наблюдений.
inputSize = size(spectrograms, [1 2]); numObs = size(spectrograms,4);
Постройте подмножество данных временных рядов с добавлением шума. Поскольку шум имеет сопоставимую амплитуду с сигналом, данные кажутся шумными во временном интервале. Эта функция делает классификацию сложной задачей.
figure numPlots = 12; for i=1:numPlots subplot(3,4,i) plot(noisyTimeSeries(i,:)) title(labels(i)) end
Постройте график частотно-временных спектрограмм зашумленных данных в том же порядке, что и временные ряды графиков. Горизонтальная ось является временем, а вертикальная - частотой.
figure for i=1:12 subplot(3,4,i) imshow(spectrograms(:,:,1,i)) hold on colormap parula title(labels(i)) hold off end
Функции от каждого класса хорошо видны, демонстрируя, почему преобразование из временного интервала в спектрограммные изображения может быть выгодно для этого типа задачи. Для примера, SingleFrequency
класс имеет один пик на основной частоте, видимый как горизонтальная полоса в спектрограмме. Для ThreeFrequency
класс, три частоты видны.
Все классы отображают слабую полосу на низкой частоте (около верхней части изображения), соответствующую синусоиде фона.
Используйте splitlabels
функция для деления данных на обучающие и валидационные данные. Используйте 80% данных для обучения и 20% для валидации.
splitIndices = splitlabels(labels,0.8); trainLabels = labels(splitIndices{1}); trainSpectrograms = spectrograms(:,:,:,splitIndices{1}); valLabels = labels(splitIndices{2}); valSpectrograms = spectrograms(:,:,:,splitIndices{2});
Создайте сверточную нейронную сеть с блоками свертки, нормализации партии . и слоями ReLU.
dropoutProb = 0.2; numFilters = 8; layers = [ imageInputLayer(inputSize) convolution2dLayer(3,numFilters,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2,'Padding','same') convolution2dLayer(3,2*numFilters,'Padding','same') batchNormalizationLayer convolution2dLayer(3,4*numFilters,'Padding','same') batchNormalizationLayer reluLayer globalMaxPooling2dLayer dropoutLayer(dropoutProb) fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
Задайте опции для обучения с помощью оптимизатора SGDM. Перетасовывайте данные каждую эпоху, задавая 'Shuffle'
опция для 'every-epoch'
. Отслеживайте процесс обучения путем установки 'Plots'
опция для 'training-progress'
. Чтобы подавить подробный выход, установите 'Verbose'
на false
.
options = trainingOptions('sgdm', ... 'Shuffle','every-epoch', ... 'Plots','training-progress', ... 'Verbose',false, ... 'ValidationData',{valSpectrograms,valLabels});
Обучите сеть классифицировать спектрограммные изображения.
net = trainNetwork(trainSpectrograms,trainLabels,layers,options);
Классифицируйте наблюдения валидации с помощью обученной сети.
predLabels = classify(net,valSpectrograms);
Исследуйте эффективность сети, построив график матрицы неточностей с confusionchart
.
figure confusionchart(valLabels,predLabels,'Normalization','row-normalized')
Сеть точно классифицирует спектрограммы валидации с почти 100% точностью для большинства классов.
Используйте imageLIME
функция, чтобы понять, какие функции в данных изображения являются наиболее важными для классификации.
Метод LIME сегментирует изображение на несколько функции и генерирует синтетические наблюдения путем случайного включения или исключения функций. Каждый пиксель исключённой функции заменяется значением среднего пикселя изображения. Сеть классифицирует эти синтетические наблюдения и использует результирующие счета для предсказанного класса, наряду с наличием или отсутствием функции, как ответы и предикторы, чтобы обучить регрессионую задачу с более простой моделью - в этом примере - регрессионым деревом. Дерево регрессии пытается аппроксимировать поведение сети при одном наблюдении. Он узнает, какие функции важны и значительно влияют на счет класса.
По умолчанию imageLIME
использует сегментацию суперпикселей, чтобы разделить изображение на функции. Эта опция хорошо работает для естественных изображений, но менее эффективна для данных спектрограммы. Можно задать пользовательскую карту сегментации путем установки 'Segmentation'
аргумент имя-значение в числовой массив того же размера, что и изображение, где каждый элемент является целым числом, соответствующим индексу функции, в которой находится пиксель.
Для данных спектрограммы изображения спектрограммы имеют гораздо более мелкие функции в измерении y (частота), чем измерение x (время). Сгенерируйте карту сегментации с 240 сегментами в сетке 40 на 6, чтобы обеспечить более высокое разрешение частоты. Увеличьте сетку до размера изображения при помощи imresize
функция, задающая метод повышающей дискретизации следующим 'nearest'
.
featureIdx = 1:240;
segmentationMap = reshape(featureIdx,6,40)';
segmentationMap = imresize(segmentationMap,inputSize,'nearest');
Постройте график спектрограммы и вычислите карту LIME для двух наблюдений от каждого класса.
obsToShowPerClass = 2; for j=1:obsToShowPerClass figure for i=1:length(classes) idx = find(valLabels == classes(i),obsToShowPerClass); % Read the test image and label. testSpectrogram = valSpectrograms(:,:,:,idx(j)); testLabel = valLabels(idx(j)); % Compute the LIME importance map. map = imageLIME(net,testSpectrogram,testLabel, ... 'NumSamples',4096, ... 'Segmentation',segmentationMap); % Rescale the map to the size of the image. mapRescale = uint8(255*rescale(map)); % Plot the spectrogram image next to the LIME map. subplot(2,2,i) imshow(imtile({testSpectrogram,mapRescale})) title(string(testLabel)) colormap parula end end
Карты LIME показывают, что для большинства классов сеть ориентирована на соответствующие функции для классификации. Для примера, для SingleFrequency
Класс сеть особого внимания на частоте, соответствующей степени спектру синусоиды, а не на ложных фоновых деталях или шуме.
Для SingleFrequency
класс, сеть использует частоту для классификации. Для Pulse
и Gaussian
классы, сеть дополнительно фокусируется на правильной частотной части спектрограммы. Для этих трех классов сеть не путается с частотой фона, видимой около верхней части всех спектрограмм. Эта информация не полезна для различения этих классов (так как она присутствует во всех классах), поэтому сеть игнорирует ее. Напротив, для ThreeFrequency
класс, постоянная фоновая частота релевантна классификационному решению сети. Для этого класса сеть не игнорирует эту частоту, но относится к ней с той же важностью, что и к трем фактическим частотам.
The imageLIME
результаты показывают, что сеть правильно использует peaks в частотно-временных спектрограммах и не путается с ложной фоновой синусоидой для всех классов, кроме ThreeFrequency
класс, где сеть не различает три частоты в сигнале и низкочастотный фон.
imageLIME
| trainNetwork
| pspectrum
(Signal Processing Toolbox)