Исследуйте классификации спектрограмм Используя LIME

В этом примере показано, как использовать локально поддающиеся толкованию объяснения модели агностические (LIME), чтобы исследовать робастность глубокой сверточной нейронной сети, обученной классифицировать спектрограммы. LIME является методом для визуализации, которую части наблюдения вносят в решение классификации о сети. Этот пример использует imageLIME функция, чтобы понять, какие функции в данных о спектрограмме являются самыми важными для классификации.

В этом примере вы создаете и обучаете нейронную сеть, чтобы классифицировать четыре вида симулированных данных временных рядов:

  • Sine wave одной частоты

  • Суперпозиция трех синусоид

  • Широкий Гауссов peaks во временных рядах

  • Гауссовы импульсы во временных рядах

Чтобы сделать эту проблему более реалистичной, временные ряды включают добавленные сигналы соединения: постоянная низкочастотная фоновая синусоида и большая сумма высокочастотного шума. Шумные данные временных рядов являются сложной проблемой классификации последовательностей. Можно приблизиться к проблеме первым преобразованием данных временных рядов в спектрограмму частоты времени, чтобы показать базовые функции в данных временных рядов. Можно затем ввести спектрограммы к сети классификации изображений.

Сгенерируйте формы волны и спектрограммы

Сгенерируйте данные временных рядов для этих четырех классов. Этот пример использует функцию помощника generateSpectrogramData сгенерировать временные ряды и соответствующие данные о спектрограмме. Функции помощника, используемые в этом примере, присоединяются как вспомогательные файлы.

numObsPerClass = 500;

classes = categorical(["SingleFrequency","ThreeFrequency","Gaussian","Pulse"]);
numClasses = length(classes);

[noisyTimeSeries,spectrograms,labels] = generateSpectrogramData(numObsPerClass,classes);

Вычислите размер изображений спектрограммы и количество наблюдений.

inputSize = size(spectrograms, [1 2]);
numObs = size(spectrograms,4);

Отобразите сгенерированные данные на графике

Постройте подмножество данных временных рядов с добавленным шумом. Поскольку шум имеет сопоставимую амплитуду к сигналу, данные кажутся шумными во временном интервале. Эта функция делает классификацию сложной проблемой.

figure
numPlots = 12;

for i=1:numPlots
    subplot(3,4,i)
    plot(noisyTimeSeries(i,:))
    title(labels(i))
end

Постройте спектрограммы частоты времени зашумленных данных в том же порядке как графики временных рядов. Горизонтальная ось время, и вертикальная ось является частотой.

figure
for i=1:12
    subplot(3,4,i)
    imshow(spectrograms(:,:,1,i))
    hold on
    colormap parula
    title(labels(i))
    hold off
end

Функции от каждого класса ясно отображаются, демонстрируя, почему преобразование от временного интервала до изображений спектрограммы может быть выгодным для этого типа проблемы. Например, SingleFrequency класс имеет один пик на основной частоте, видимой как горизонтальная планка в спектрограмме. Для ThreeFrequency класс, эти три частоты отображаются.

Все классы отображают слабую полосу в низкой частоте (около верхней части изображения), соответствуя фоновой синусоиде.

Разделение данных

Используйте splitlabels функционируйте, чтобы разделить данные на данные об обучении и валидации. Используйте 80% данных для обучения и 20% для валидации.

splitIndices = splitlabels(labels,0.8);

trainLabels = labels(splitIndices{1});
trainSpectrograms = spectrograms(:,:,:,splitIndices{1});

valLabels = labels(splitIndices{2});
valSpectrograms = spectrograms(:,:,:,splitIndices{2});

Задайте архитектуру нейронной сети

Создайте сверточную нейронную сеть с блоками свертки, нормализации партии. и слоев ReLU.

dropoutProb = 0.2;
numFilters = 8;

layers = [
    imageInputLayer(inputSize)
    
    convolution2dLayer(3,numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(3,'Stride',2,'Padding','same')
    
    convolution2dLayer(3,2*numFilters,'Padding','same')
    batchNormalizationLayer
    
    convolution2dLayer(3,4*numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    globalMaxPooling2dLayer
    
    dropoutLayer(dropoutProb)
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Задайте опции обучения

Задайте опции для обучения с помощью оптимизатора SGDM. Переставьте данные каждая эпоха путем установки 'Shuffle' опция к 'every-epoch'. Контролируйте процесс обучения путем установки 'Plots' опция к 'training-progress'. Чтобы подавить многословный выход, установите 'Verbose' к false.

options = trainingOptions('sgdm', ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{valSpectrograms,valLabels});

Обучение сети

Обучите сеть, чтобы классифицировать изображения спектрограммы.

net = trainNetwork(trainSpectrograms,trainLabels,layers,options);

Точность

Классифицируйте наблюдения валидации с помощью обучившего сеть.

predLabels = classify(net,valSpectrograms);

Исследуйте производительность сети путем графического вывода матрицы беспорядка с confusionchart.

figure
confusionchart(valLabels,predLabels,'Normalization','row-normalized')

Сеть точно классифицирует спектрограммы валидации, с близко к 100%-й точности для большинства классов.

Исследуйте сетевые предсказания

Используйте imageLIME функция, чтобы понять, какие функции в данных изображения являются самыми важными для классификации.

Метод LIME сегментирует изображение на несколько функций и генерирует синтетические наблюдения случайным образом включая или, исключая функции. Каждый пиксель в исключенной функции заменяется значением среднего пикселя изображения. Сеть классифицирует эти синтетические наблюдения и использует получившуюся музыку к предсказанному классу, наряду с присутствием или отсутствием функции, как ответы и предикторы, чтобы обучить проблему регрессии с более простой моделью — в этом примере, дереве регрессии. Дерево регрессии пытается аппроксимировать поведение сети на одном наблюдении. Это учится, какие функции важны и значительно влияют на счет класса.

Задайте пользовательскую карту сегментации

По умолчанию, imageLIME суперпиксельная сегментация использования, чтобы разделить изображение на функции. Эта опция работает хорошо на естественные изображения, но является менее эффективной для данных о спектрограмме. Можно задать пользовательскую карту сегментации путем установки 'Segmentation' аргумент значения имени к числовому массиву тот же размер как изображение, где каждым элементом является целое число, соответствующее индексу функции, в которой находится пиксель.

Для данных о спектрограмме изображения спектрограммы имеют намного более прекрасные функции в y-размерности (частота), чем x-размерность (время). Сгенерируйте карту сегментации с 240 сегментами, в 40 6 сетка, чтобы обеспечить более высокое разрешение частоты. Сверхдискретизируйте сетку к размеру изображения при помощи imresize функция, задавая метод повышающей дискретизации как 'nearest'.

featureIdx = 1:240;
segmentationMap = reshape(featureIdx,6,40)';
segmentationMap = imresize(segmentationMap,inputSize,'nearest');

Вычислите карту LIME

Постройте спектрограмму и вычислите карту LIME для двух наблюдений от каждого класса.

obsToShowPerClass = 2;

for j=1:obsToShowPerClass
    figure

    for i=1:length(classes)

        idx = find(valLabels == classes(i),obsToShowPerClass);

        % Read the test image and label.
        testSpectrogram = valSpectrograms(:,:,:,idx(j));
        testLabel = valLabels(idx(j));

        % Compute the LIME importance map.
        map = imageLIME(net,testSpectrogram,testLabel, ...
            'NumSamples',4096, ...
            'Segmentation',segmentationMap);

        % Rescale the map to the size of the image.
        mapRescale = uint8(255*rescale(map));

        % Plot the spectrogram image next to the LIME map.
        subplot(2,2,i)
        imshow(imtile({testSpectrogram,mapRescale}))
        title(string(testLabel))
        colormap parula
    end
end

Карты LIME демонстрируют, что для большинства классов, сеть фокусируется на соответствующих функциях классификации. Например, для SingleFrequency класс, сеть фокусируется на частоте, соответствующей спектру мощности синусоиды а не на побочных фоновых деталях или шуме.

Для SingleFrequency класс, сеть использует частоту, чтобы классифицировать. Для Pulse и Gaussian классы, сеть дополнительно фокусируется на правильной части частоты спектрограммы. Для этих трех классов сеть не перепутана фоновой частотой, видимой около верхней части всех спектрограмм. Эта информация не полезна для различения этих классов (когда это присутствует во всех классах), таким образом, сеть игнорирует его. В отличие от этого для ThreeFrequency класс, постоянная фоновая частота относится к решению классификации о сети. Для этого класса сеть не игнорирует эту частоту, но обрабатывает его с подобной важностью для трех фактических частот.

imageLIME результаты демонстрируют, что сеть правильно использует peaks в спектрограммах частоты времени и не перепутана побочной фоновой синусоидой для всех классов за исключением ThreeFrequency класс, где сеть не различает эти три частоты в сигнале и низкочастотном фоне.

Смотрите также

| (Signal Processing Toolbox) |

Похожие темы