Исследование классификаций спектрограмм с использованием LIME

Этот пример показывает, как использовать локально интерпретируемые модели-агностические объяснения (LIME) для исследования робастности глубокой сверточной нейронной сети, обученной классифицировать спектрограммы. LIME является методом визуализации, какие части наблюдения способствуют классификационному решению сети. Этот пример использует imageLIME функция, чтобы понять, какие функции в данных спектрограммы наиболее важны для классификации.

В этом примере вы создаете и обучаете нейронную сеть для классификации четырех видов моделируемых данных временных рядов:

  • Синусоиды одной частоты

  • Суперпозиция трёх синусоид

  • Широкие Гауссовы peaks во временных рядах

  • Гауссовы импульсы во временных рядах

Чтобы сделать эту задачу более реалистичной, временные ряды включают добавленные смешивающие сигналы: постоянную низкочастотную фоновую синусоиду и большое количество высокочастотного шума. Зашумленные данные временных рядов являются сложной задачей классификации последовательностей. Можно подойти к проблеме, сначала преобразовав данные временных рядов в частотно-временную спектрограмму, чтобы выявить базовые функции в данных временных рядов. Затем можно ввести спектрограммы в сеть классификации изображений.

Сгенерируйте формы волны и спектрограммы

Сгенерируйте данные временных рядов для четырех классов. Этот пример использует функцию helper generateSpectrogramData для генерации временных рядов и соответствующих данных спектрограммы. Вспомогательные функции, используемые в этом примере, присоединены как вспомогательные файлы.

numObsPerClass = 500;

classes = categorical(["SingleFrequency","ThreeFrequency","Gaussian","Pulse"]);
numClasses = length(classes);

[noisyTimeSeries,spectrograms,labels] = generateSpectrogramData(numObsPerClass,classes);

Вычислите размер спектрограммы изображений и количество наблюдений.

inputSize = size(spectrograms, [1 2]);
numObs = size(spectrograms,4);

Построение сгенерированных данных

Постройте подмножество данных временных рядов с добавлением шума. Поскольку шум имеет сопоставимую амплитуду с сигналом, данные кажутся шумными во временном интервале. Эта функция делает классификацию сложной задачей.

figure
numPlots = 12;

for i=1:numPlots
    subplot(3,4,i)
    plot(noisyTimeSeries(i,:))
    title(labels(i))
end

Постройте график частотно-временных спектрограмм зашумленных данных в том же порядке, что и временные ряды графиков. Горизонтальная ось является временем, а вертикальная - частотой.

figure
for i=1:12
    subplot(3,4,i)
    imshow(spectrograms(:,:,1,i))
    hold on
    colormap parula
    title(labels(i))
    hold off
end

Функции от каждого класса хорошо видны, демонстрируя, почему преобразование из временного интервала в спектрограммные изображения может быть выгодно для этого типа задачи. Для примера, SingleFrequency класс имеет один пик на основной частоте, видимый как горизонтальная полоса в спектрограмме. Для ThreeFrequency класс, три частоты видны.

Все классы отображают слабую полосу на низкой частоте (около верхней части изображения), соответствующую синусоиде фона.

Разделение данных

Используйте splitlabels функция для деления данных на обучающие и валидационные данные. Используйте 80% данных для обучения и 20% для валидации.

splitIndices = splitlabels(labels,0.8);

trainLabels = labels(splitIndices{1});
trainSpectrograms = spectrograms(:,:,:,splitIndices{1});

valLabels = labels(splitIndices{2});
valSpectrograms = spectrograms(:,:,:,splitIndices{2});

Определите архитектуру нейронной сети

Создайте сверточную нейронную сеть с блоками свертки, нормализации партии . и слоями ReLU.

dropoutProb = 0.2;
numFilters = 8;

layers = [
    imageInputLayer(inputSize)
    
    convolution2dLayer(3,numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(3,'Stride',2,'Padding','same')
    
    convolution2dLayer(3,2*numFilters,'Padding','same')
    batchNormalizationLayer
    
    convolution2dLayer(3,4*numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    globalMaxPooling2dLayer
    
    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Задайте опции обучения

Задайте опции для обучения с помощью оптимизатора SGDM. Перетасовывайте данные каждую эпоху, задавая 'Shuffle' опция для 'every-epoch'. Отслеживайте процесс обучения путем установки 'Plots' опция для 'training-progress'. Чтобы подавить подробный выход, установите 'Verbose' на false.

options = trainingOptions('sgdm', ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{valSpectrograms,valLabels});

Обучите сеть

Обучите сеть классифицировать спектрограммные изображения.

net = trainNetwork(trainSpectrograms,trainLabels,layers,options);

Точность

Классифицируйте наблюдения валидации с помощью обученной сети.

predLabels = classify(net,valSpectrograms);

Исследуйте эффективность сети, построив график матрицы неточностей с confusionchart.

figure
confusionchart(valLabels,predLabels,'Normalization','row-normalized')

Сеть точно классифицирует спектрограммы валидации с почти 100% точностью для большинства классов.

Исследуйте сетевые предсказания

Используйте imageLIME функция, чтобы понять, какие функции в данных изображения являются наиболее важными для классификации.

Метод LIME сегментирует изображение на несколько функции и генерирует синтетические наблюдения путем случайного включения или исключения функций. Каждый пиксель исключённой функции заменяется значением среднего пикселя изображения. Сеть классифицирует эти синтетические наблюдения и использует результирующие счета для предсказанного класса, наряду с наличием или отсутствием функции, как ответы и предикторы, чтобы обучить регрессионую задачу с более простой моделью - в этом примере - регрессионым деревом. Дерево регрессии пытается аппроксимировать поведение сети при одном наблюдении. Он узнает, какие функции важны и значительно влияют на счет класса.

Задайте пользовательскую карту сегментации

По умолчанию imageLIME использует сегментацию суперпикселей, чтобы разделить изображение на функции. Эта опция хорошо работает для естественных изображений, но менее эффективна для данных спектрограммы. Можно задать пользовательскую карту сегментации путем установки 'Segmentation' аргумент имя-значение в числовой массив того же размера, что и изображение, где каждый элемент является целым числом, соответствующим индексу функции, в которой находится пиксель.

Для данных спектрограммы изображения спектрограммы имеют гораздо более мелкие функции в измерении y (частота), чем измерение x (время). Сгенерируйте карту сегментации с 240 сегментами в сетке 40 на 6, чтобы обеспечить более высокое разрешение частоты. Увеличьте сетку до размера изображения при помощи imresize функция, задающая метод повышающей дискретизации следующим 'nearest'.

featureIdx = 1:240;
segmentationMap = reshape(featureIdx,6,40)';
segmentationMap = imresize(segmentationMap,inputSize,'nearest');

Вычислительная карта LIME

Постройте график спектрограммы и вычислите карту LIME для двух наблюдений от каждого класса.

obsToShowPerClass = 2;

for j=1:obsToShowPerClass
    figure

    for i=1:length(classes)

        idx = find(valLabels == classes(i),obsToShowPerClass);

        % Read the test image and label.
        testSpectrogram = valSpectrograms(:,:,:,idx(j));
        testLabel = valLabels(idx(j));

        % Compute the LIME importance map.
        map = imageLIME(net,testSpectrogram,testLabel, ...
            'NumSamples',4096, ...
            'Segmentation',segmentationMap);

        % Rescale the map to the size of the image.
        mapRescale = uint8(255*rescale(map));

        % Plot the spectrogram image next to the LIME map.
        subplot(2,2,i)
        imshow(imtile({testSpectrogram,mapRescale}))
        title(string(testLabel))
        colormap parula
    end
end

Карты LIME показывают, что для большинства классов сеть ориентирована на соответствующие функции для классификации. Для примера, для SingleFrequency Класс сеть особого внимания на частоте, соответствующей степени спектру синусоиды, а не на ложных фоновых деталях или шуме.

Для SingleFrequency класс, сеть использует частоту для классификации. Для Pulse и Gaussian классы, сеть дополнительно фокусируется на правильной частотной части спектрограммы. Для этих трех классов сеть не путается с частотой фона, видимой около верхней части всех спектрограмм. Эта информация не полезна для различения этих классов (так как она присутствует во всех классах), поэтому сеть игнорирует ее. Напротив, для ThreeFrequency класс, постоянная фоновая частота релевантна классификационному решению сети. Для этого класса сеть не игнорирует эту частоту, но относится к ней с той же важностью, что и к трем фактическим частотам.

The imageLIME результаты показывают, что сеть правильно использует peaks в частотно-временных спектрограммах и не путается с ложной фоновой синусоидой для всех классов, кроме ThreeFrequency класс, где сеть не различает три частоты в сигнале и низкочастотный фон.

См. также

| | (Signal Processing Toolbox)

Похожие темы