Классификация временных рядов с помощью вейвлет и глубокого обучения

Этот пример показывает, как классифицировать сигналы электрокардиограммы (ЭКГ) человека с помощью непрерывного вейвлет (CWT) и глубокой сверточной нейронной сети (CNN).

Обучение глубокого CNN с нуля является в вычислительном отношении дорогим и требует большого объема обучающих данных. В различных приложениях достаточный объем обучающих данных недоступен, и синтез новых реалистичных примеров обучения невозможен. В этих случаях желательно использовать существующие нейронные сети, которые были обучены на больших наборах данных для концептуально аналогичных задач. Такое использование существующих нейронных сетей называется передачей обучения. В этом примере мы адаптируем два глубоких CNN, GoogLeNet и SqueezeNet, предварительно обученные для распознавания изображений, чтобы классифицировать формы волны ЭКГ на основе представления временной частоты.

GoogLeNet и SqueezeNet - глубокие CNN, изначально разработанные для классификации изображений в 1000 категориях. Мы повторно используем сетевую архитектуру CNN, чтобы классифицировать сигналы ЭКГ на основе изображений из CWT данных временных рядов. Данные, используемые в этом примере, являются общедоступными от PhysioNet.

Описание данных

В этом примере вы используете данные ЭКГ, полученные от трех групп людей: лиц с сердечной аритмией (ARR), лиц с застойным сердечным отказом (CHF) и лиц с нормальными синусовыми ритмами (NSR). Всего вы используете 162 записи ЭКГ из трех баз данных PhysioNet: базы данных аритмии MIT-BIH [3] [7], базы данных нормального синусового ритма MIT-BIH [3] и базы данных застойной сердечной недостаточности BIDMC [1] Более конкретно, 96 записей от лиц с аритмией, 30 записей от лиц с застойным сердечным отказом и 36 записей от лиц с нормальными синусовыми ритмами. Цель состоит в том, чтобы обучить классификатор для различения ARR, CHF и NSR.

Загрузка данных

Первый шаг - загрузка данных из репозитория GitHub. Чтобы загрузить данные с сайта, нажмите Code и выберите Download ZIP. Сохраните файл physionet_ECG_data-main.zip в папке, в которой у вас есть разрешение на запись. Инструкции для этого примера предполагают, что вы загрузили файл во временную директорию tempdir, в MATLAB. Измените последующие инструкции для распаковки и загрузки данных, если вы решите загрузить данные в папку, отличную от tempdir.

После загрузки данных с GitHub разархивируйте файл во временной директории.

unzip(fullfile(tempdir,'physionet_ECG_data-main.zip'),tempdir)

Unzipping создает папку physionet-ECG_data-main во временной директории. Эта папка содержит текстовый файл README.md и ECGData.zip. The ECGData.zip файл содержит

  • ECGData.mat

  • Modified_physionet_data.txt

  • License.txt

ECGData.mat содержит данные, используемые в этом примере. Текстовый файл, Modified_physionet_data.txt, требуется политикой копирования PhysioNet и предоставляет исходные атрибуты для данных, а также описание шагов предварительной обработки, применяемых к каждой записи ЭКГ.

Разархивирование ECGData.zip в physionet-ECG_data-main. Загрузите файл данных в рабочее рабочее пространство MATLAB.

unzip(fullfile(tempdir,'physionet_ECG_data-main','ECGData.zip'),...
    fullfile(tempdir,'physionet_ECG_data-main'))
load(fullfile(tempdir,'physionet_ECG_data-main','ECGData.mat'))

ECGData массив структур с двумя полями: Data и Labels. The Data поле представляет собой 162 на 65536 матрицу, где каждая строка является дискретизацией записи ЭКГ в 128 герц. Labels - массив ячеек 162 на 1 с диагностическими метками, по одному для каждой строки Data. Три диагностические категории: 'ARR', 'CHF', и 'NSR'.

Чтобы сохранить предварительно обработанные данные каждой категории, сначала создайте директорию данных ЭКГ dataDir внутри tempdir. Затем создайте три подкаталога в 'data' названы в честь каждой категории ЭКГ. Функция помощника helperCreateECGDirectories делает это. helperCreateECGDirectories принимает ECGData, имя директории данных ECG и имя родительской директории в качестве входных параметров. Можно заменить tempdir с другой директорией, в котором у вас есть разрешение на запись. Исходный код для этой вспомогательной функции можно найти в разделе Вспомогательные функции в конце этого примера.

parentDir = tempdir;
dataDir = 'data';
helperCreateECGDirectories(ECGData,parentDir,dataDir)

Постройте график представителя каждой категории ЭКГ. Функция помощника helperPlotReps делает это. helperPlotReps принимает ECGData как вход. Исходный код для этой вспомогательной функции можно найти в разделе Вспомогательные функции в конце этого примера.

helperPlotReps(ECGData)

Создайте представления частоты и времени

После создания папок создайте частотно-временные представления сигналов ЭКГ. Эти представления называются скалограммами. Скалограмма является абсолютным значением коэффициентов CWT сигнала.

Чтобы создать скалограммы, предварительно вычислите банк фильтров CWT. Предварительное вычисление группы фильтров CWT является предпочтительным способом при получении CWT многих сигналов с использованием тех же параметров.

Прежде чем сгенерировать скалограммы, исследуйте одну из них. Создайте банк фильтров CWT с помощью cwtfilterbank (Wavelet Toolbox) для сигнала с 1000 выборками. Используйте банк фильтров, чтобы взять CWT первых 1000 выборок сигнала и получить скалограмму из коэффициентов.

Fs = 128;
fb = cwtfilterbank('SignalLength',1000,...
    'SamplingFrequency',Fs,...
    'VoicesPerOctave',12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;figure;pcolor(t,frq,abs(cfs))
set(gca,'yscale','log');shading interp;axis tight;
title('Scalogram');xlabel('Time (s)');ylabel('Frequency (Hz)')

Используйте функцию helper helperCreateRGBfromTF чтобы создать скалограммы в виде изображений RGB и записать их в соответствующий подкаталог в dataDir. Исходный код для этой вспомогательной функции находится в разделе Вспомогательные функции в конце этого примера. Чтобы быть совместимым с архитектурой GoogLeNet, каждое изображение RGB - массив размера 224 на 224 на 3.

helperCreateRGBfromTF(ECGData,parentDir,dataDir)

Разделение на данные обучения и валидации

Загрузите скалограммные изображения как image datastore. The imageDatastore функция автоматически помечает изображения на основе имен папок и сохраняет данные как объект ImageDatastore. image datastore позволяет вам хранить большие данные об изображениях, включая данные, которые не помещаются в памяти, и эффективно считывать пакеты изображений во время обучения CNN.

allImages = imageDatastore(fullfile(parentDir,dataDir),...
    'IncludeSubfolders',true,...
    'LabelSource','foldernames');

Случайным образом разделите изображения на две группы: одну для обучения и другую для валидации. Используйте 80% изображений для обучения, а оставшуюся часть для валидации. В целях воспроизводимости мы устанавливаем случайное начальное значение на значение по умолчанию.

rng default
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized');
disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
Number of training images: 130
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);
Number of validation images: 32

GoogLeNet

Груз

Загрузите предварительно обученную нейронную сеть GoogLeNet. Если модель Deep Learning Toolbox™ для пакета поддержки GoogLeNet Network не установлена, программное обеспечение предоставляет ссылку на необходимый пакет поддержки в Add-On Explorer. Чтобы установить пакет поддержки, щелкните ссылку и выберите Установить.

net = googlenet;

Извлечение и отображение графика слоев из сети.

lgraph = layerGraph(net);
numberOfLayers = numel(lgraph.Layers);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)
title(['GoogLeNet Layer Graph: ',num2str(numberOfLayers),' Layers']);

Проверьте первый элемент свойства слоев сети. Подтвердите, что для GoogLeNet требуется изображениям RGB размера 224 224 на 3.

net.Layers(1)
ans = 
  ImageInputLayer with properties:

                Name: 'data'
           InputSize: [224 224 3]

   Hyperparameters
    DataAugmentation: 'none'
       Normalization: 'zerocenter'
                Mean: [224×224×3 single]

Изменение параметров сети GoogLeNet

Каждый слой в сетевой архитектуре может рассматриваться как фильтр. Более ранние слои идентифицируют более общие функции изображений, такие как blobs, ребра и colors. Последующие слои особого внимания на более специфических функциях в порядок для дифференциации категорий. GoogLeNet предварительно обучен, чтобы классифицировать изображения в 1000 категорий объектов. Вы должны переобучить GoogLeNet для нашей задачи классификации ЭКГ.

Чтобы предотвратить сверхподбор кривой, используется выпадающий слой. Выпадающий слой случайным образом устанавливает элементы входа для нуля с заданной вероятностью. См. dropoutLayer (Deep Learning Toolbox) для получения дополнительной информации. Вероятность по умолчанию - 0,5. Замените конечный слой отсева в сети, 'pool5-drop_7x7_s1', с выпадающим слоем вероятности 0,6.

newDropoutLayer = dropoutLayer(0.6,'Name','new_Dropout');
lgraph = replaceLayer(lgraph,'pool5-drop_7x7_s1',newDropoutLayer);

Сверточные слои сети извлекают изображение, функции последний выучиваемый слой и конечный слой классификации используют для классификации входа изображения. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержат информацию о том, как объединить функции, которые сеть извлекает в вероятности классов, значение потерь и предсказанные метки. Чтобы переобучить GoogLeNet для классификации изображений RGB, замените эти два слоя новыми слоями, адаптированными к данным.

Замените полносвязный слой 'loss3-classifier' с новым полносвязным слоем с количеством фильтров, равным количеству классов. Чтобы учиться быстрее в новых слоях, чем в переданных слоях, увеличьте коэффициенты скорости обучения полносвязного слоя.

numClasses = numel(categories(imgsTrain.Labels));
newConnectedLayer = fullyConnectedLayer(numClasses,'Name','new_fc',...
    'WeightLearnRateFactor',5,'BiasLearnRateFactor',5);
lgraph = replaceLayer(lgraph,'loss3-classifier',newConnectedLayer);

Слой классификации задает выходные классы сети. Замените слой классификации новым слоем без меток классов. trainNetwork автоматически устанавливает выходные классы слоя во время обучения.

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);

Установите опции обучения и обучите GoogLeNet

Настройка нейронной сети является итеративным процессом, который включает минимизацию функции потерь. Чтобы минимизировать функцию потерь, используется алгоритм градиентного спуска. В каждой итерации оценивается градиент функции потерь и обновляются веса алгоритма спуска.

Обучение можно настроить, установив различные опции. InitialLearnRate задает начальный размер шага в направлении отрицательного градиента функции потерь. MiniBatchSize задает размер подмножества набора обучающих данных, используемого в каждой итерации. Одна эпоха является полным проходом алгоритма настройки по всему набору обучающих данных. MaxEpochs задает максимальное количество эпох, используемых для обучения. Выбор правильного количества эпох не является тривиальной задачей. Уменьшение количества эпох имеет эффект недооценки модели и увеличения количества эпох, результатов в сверхподбор кривой.

Используйте trainingOptions (Deep Learning Toolbox) для определения опций обучения. Задайте MiniBatchSize до 10, MaxEpochs до 10, и InitialLearnRate до 0.0001. Визуализируйте процесс обучения путем настройки Plots на training-progress. Используйте стохастический градиентный спуск с оптимизатором импульса. По умолчанию обучение выполняется на графическом процессоре, если он доступен. Для использования графический процессор требуется Parallel Computing Toolbox™. Информацию о том, какие графические процессоры поддерживаются, см. в разделе Поддержка GPU Release (Parallel Computing Toolbox). В целях воспроизводимости задайте ExecutionEnvironment на cpu так что trainNetwork использовал центральный процессор. Установите значение по умолчанию для случайного начального значения. Время выполнения будет быстрее, если вы сможете использовать графический процессор.

options = trainingOptions('sgdm',...
    'MiniBatchSize',15,...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4,...
    'ValidationData',imgsValidation,...
    'ValidationFrequency',10,...
    'Verbose',1,...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');
rng default

Обучите сеть. Процесс обучения обычно занимает 1-5 минут на настольном центральном процессоре. Командное окно отображает обучающую информацию во время запуска. Результаты включают число эпох, число итерации, прошло время, точность мини-пакета, точность валидации и значение функции потерь для данных валидации.

trainedGN = trainNetwork(imgsTrain,lgraph,options);

Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |        6.67% |       18.75% |       4.9207 |       2.4141 |      1.0000e-04 |
|       2 |          10 |       00:00:23 |       66.67% |       62.50% |       0.9589 |       1.3191 |      1.0000e-04 |
|       3 |          20 |       00:00:43 |       46.67% |       75.00% |       1.2973 |       0.5928 |      1.0000e-04 |
|       4 |          30 |       00:01:04 |       60.00% |       78.13% |       0.7219 |       0.4576 |      1.0000e-04 |
|       5 |          40 |       00:01:25 |       73.33% |       84.38% |       0.4750 |       0.3367 |      1.0000e-04 |
|       7 |          50 |       00:01:46 |       93.33% |       84.38% |       0.2714 |       0.2892 |      1.0000e-04 |
|       8 |          60 |       00:02:07 |       80.00% |       87.50% |       0.3617 |       0.2433 |      1.0000e-04 |
|       9 |          70 |       00:02:29 |       86.67% |       87.50% |       0.3246 |       0.2526 |      1.0000e-04 |
|      10 |          80 |       00:02:50 |      100.00% |       96.88% |       0.0701 |       0.1876 |      1.0000e-04 |
|      12 |          90 |       00:03:11 |       86.67% |      100.00% |       0.2836 |       0.1681 |      1.0000e-04 |
|      13 |         100 |       00:03:32 |       86.67% |       96.88% |       0.4160 |       0.1607 |      1.0000e-04 |
|      14 |         110 |       00:03:53 |       86.67% |       96.88% |       0.3237 |       0.1565 |      1.0000e-04 |
|      15 |         120 |       00:04:14 |       93.33% |       96.88% |       0.1646 |       0.1476 |      1.0000e-04 |
|      17 |         130 |       00:04:35 |      100.00% |       96.88% |       0.0551 |       0.1330 |      1.0000e-04 |
|      18 |         140 |       00:04:57 |       93.33% |       96.88% |       0.0927 |       0.1347 |      1.0000e-04 |
|      19 |         150 |       00:05:18 |       93.33% |       93.75% |       0.1666 |       0.1325 |      1.0000e-04 |
|      20 |         160 |       00:05:39 |       93.33% |       96.88% |       0.0873 |       0.1164 |      1.0000e-04 |
|======================================================================================================================|

Осмотрите последний слой обученной сети. Подтвердите, что слой Classification Output включает три класса.

trainedGN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

Оценка точности GoogLeNet

Оцените сеть с помощью данных валидации.

[YPred,probs] = classify(trainedGN,imgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp(['GoogLeNet Accuracy: ',num2str(100*accuracy),'%'])
GoogLeNet Accuracy: 96.875%

Точность идентична точности валидации, сообщенной на обучающем рисунке визуализации. Скалограммы были разделены на наборы обучения и валидации. Оба наборов использовались для обучения GoogLeNet. Идеальным способом оценить результат обучения является то, чтобы сеть классифицировала данные, которые она не видела. Поскольку существует недостаточное количество данных, чтобы разделиться на обучение, валидацию и проверку, мы рассматриваем вычисленную точность валидации как точность сети.

Исследуйте активации GoogLeNet

Каждый слой CNN создает ответ или активацию на вход изображение. Однако в CNN есть только несколько слоев, которые подходят для редукции данных. Слои в начале сети захватывают основные функции изображений, такие как ребра и blobs. Чтобы увидеть это, визуализируйте веса сетевого фильтра с первого сверточного слоя. В первом слое 64 отдельных набора весов.

wghts = trainedGN.Layers(2).Weights;
wghts = rescale(wghts);
wghts = imresize(wghts,5);
figure
montage(wghts)
title('First Convolutional Layer Weights')

Можно изучить активации и узнать, какие функции учит GoogLeNet, сравнивая области активации с оригинальным изображением. Для получения дополнительной информации см. «Визуализация активаций сверточной нейронной сети» (Deep Learning Toolbox) и «Визуализация функций сверточной нейронной сети» (Deep Learning Toolbox).

Исследуйте, какие области в сверточных слоях активируются на изображении из ARR класс. Сравните с соответствующими областями в оригинальное изображение. Каждый слой сверточной нейронной сети состоит из многих 2-D массивов, называемых каналами. Передайте изображение через сеть и исследуйте выходные активации первого сверточного слоя, 'conv1-7x7_s2'.

convLayer = 'conv1-7x7_s2';

imgClass = 'ARR';
imgName = 'ARR_10.jpg';
imarr = imread(fullfile(parentDir,dataDir,imgClass,imgName));

trainingFeaturesARR = activations(trainedGN,imarr,convLayer);
sz = size(trainingFeaturesARR);
trainingFeaturesARR = reshape(trainingFeaturesARR,[sz(1) sz(2) 1 sz(3)]);
figure
montage(rescale(trainingFeaturesARR),'Size',[8 8])
title([imgClass,' Activations'])

Найдите самый сильный канал для этого изображения. Сравните самый сильный канал с оригинальным изображением.

imgSize = size(imarr);
imgSize = imgSize(1:2);
[~,maxValueIndex] = max(max(max(trainingFeaturesARR)));
arrMax = trainingFeaturesARR(:,:,:,maxValueIndex);
arrMax = rescale(arrMax);
arrMax = imresize(arrMax,imgSize);
figure;
imshowpair(imarr,arrMax,'montage')
title(['Strongest ',imgClass,' Channel: ',num2str(maxValueIndex)])

SqueezeNet

SqueezeNet - глубокий CNN, архитектура которого поддерживает изображения размера 227 227 3. Несмотря на то, что размеры изображений различны для GoogLeNet, вы не должны генерировать новые изображения RGB в размерностях SqueezeNet. Можно использовать исходные изображения RGB.

Груз

Загрузите предварительно обученную нейронную сеть SqueezeNet. Если модель Deep Learning Toolbox™ для пакета поддержки SqueezeNet Network не установлена, программное обеспечение предоставляет ссылку на необходимый пакет поддержки в Add-On Explorer. Чтобы установить пакет поддержки, щелкните ссылку и выберите Установить.

sqz = squeezenet;

Извлеките график слоев из сети. Подтвердите, что SqueezeNet имеет меньше слоев, чем GoogLeNet. Также подтвердите, что SqueezeNet сконфигурирован для изображений размера 227 227 3

lgraphSqz = layerGraph(sqz);
disp(['Number of Layers: ',num2str(numel(lgraphSqz.Layers))])
Number of Layers: 68
disp(lgraphSqz.Layers(1).InputSize)
   227   227     3

Изменение параметров сети SqueezeNet

Чтобы переобучить SqueezeNet для классификации новых изображений, внесите изменения, аналогичные тем, которые были сделаны для GoogLeNet.

Просмотрите последние шесть слои сети.

lgraphSqz.Layers(end-5:end)
ans = 
  6x1 Layer array with layers:

     1   'drop9'                             Dropout                 50% dropout
     2   'conv10'                            Convolution             1000 1x1x512 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'                       ReLU                    ReLU
     4   'pool10'                            Average Pooling         14x14 average pooling with stride [1  1] and padding [0  0  0  0]
     5   'prob'                              Softmax                 softmax
     6   'ClassificationLayer_predictions'   Classification Output   crossentropyex with 'tench' and 999 other classes

Замените 'drop9' слой, последний слой отсева в сети, с слоем отсева вероятности 0,6.

tmpLayer = lgraphSqz.Layers(end-5);
newDropoutLayer = dropoutLayer(0.6,'Name','new_dropout');
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newDropoutLayer);

В отличие от GoogLeNet, последний обучаемый слой в SqueezeNet является сверточным слоем 1 на 1, 'conv10', и не полносвязный слой. Замените 'conv10' слой с новым сверточным слоем с количеством фильтров, равным количеству классов. Как это было сделано с GoogLeNet, увеличьте коэффициенты скорости обучения нового слоя.

numClasses = numel(categories(imgsTrain.Labels));
tmpLayer = lgraphSqz.Layers(end-4);
newLearnableLayer = convolution2dLayer(1,numClasses, ...
        'Name','new_conv', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newLearnableLayer);

Замените слой классификации новым слоем без меток классов.

tmpLayer = lgraphSqz.Layers(end);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newClassLayer);

Осмотрите последние шесть слоев сети. Подтвердите, что отсева, сверточные и выходные слои были изменены.

lgraphSqz.Layers(63:68)
ans = 
  6x1 Layer array with layers:

     1   'new_dropout'       Dropout                 60% dropout
     2   'new_conv'          Convolution             3 1x1 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'       ReLU                    ReLU
     4   'pool10'            Average Pooling         14x14 average pooling with stride [1  1] and padding [0  0  0  0]
     5   'prob'              Softmax                 softmax
     6   'new_classoutput'   Classification Output   crossentropyex

Подготовка данных RGB для SqueezeNet

Изображения RGB имеют размерности, соответствующие архитектуре GoogLeNet. Создайте хранилища данных дополненных изображений, которые автоматически изменяют размер существующих изображений RGB для архитектуры SqueezeNet. Для получения дополнительной информации смотрите augmentedImageDatastore (Deep Learning Toolbox).

augimgsTrain = augmentedImageDatastore([227 227],imgsTrain);
augimgsValidation = augmentedImageDatastore([227 227],imgsValidation);

Установите опции обучения и обучите SqueezeNet

Создайте новый набор опций обучения для использования с SqueezeNet. Установите значение по умолчанию для случайного начального числа и обучите сеть. Процесс обучения обычно занимает 1-5 минут на настольном центральном процессоре.

ilr = 3e-4;
miniBatchSize = 10;
maxEpochs = 15;
valFreq = floor(numel(augimgsTrain.Files)/miniBatchSize);
opts = trainingOptions('sgdm',...
    'MiniBatchSize',miniBatchSize,...
    'MaxEpochs',maxEpochs,...
    'InitialLearnRate',ilr,...
    'ValidationData',augimgsValidation,...
    'ValidationFrequency',valFreq,...
    'Verbose',1,...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');

rng default
trainedSN = trainNetwork(augimgsTrain,lgraphSqz,opts);

Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:01 |       20.00% |       43.75% |       5.2508 |       1.2540 |          0.0003 |
|       1 |          13 |       00:00:11 |       60.00% |       50.00% |       0.9912 |       1.0519 |          0.0003 |
|       2 |          26 |       00:00:20 |       60.00% |       59.38% |       0.8554 |       0.8497 |          0.0003 |
|       3 |          39 |       00:00:30 |       60.00% |       59.38% |       0.8120 |       0.8328 |          0.0003 |
|       4 |          50 |       00:00:38 |       50.00% |              |       0.7885 |              |          0.0003 |
|       4 |          52 |       00:00:40 |       60.00% |       65.63% |       0.7091 |       0.7314 |          0.0003 |
|       5 |          65 |       00:00:49 |       90.00% |       87.50% |       0.4639 |       0.5893 |          0.0003 |
|       6 |          78 |       00:00:59 |       70.00% |       87.50% |       0.6021 |       0.4355 |          0.0003 |
|       7 |          91 |       00:01:08 |       90.00% |       90.63% |       0.2307 |       0.2945 |          0.0003 |
|       8 |         100 |       00:01:15 |       90.00% |              |       0.1827 |              |          0.0003 |
|       8 |         104 |       00:01:18 |       90.00% |       93.75% |       0.2139 |       0.2153 |          0.0003 |
|       9 |         117 |       00:01:28 |      100.00% |       90.63% |       0.0521 |       0.1964 |          0.0003 |
|      10 |         130 |       00:01:38 |       90.00% |       90.63% |       0.1134 |       0.2214 |          0.0003 |
|      11 |         143 |       00:01:47 |      100.00% |       90.63% |       0.0855 |       0.2095 |          0.0003 |
|      12 |         150 |       00:01:52 |       90.00% |              |       0.2394 |              |          0.0003 |
|      12 |         156 |       00:01:57 |      100.00% |       90.63% |       0.0606 |       0.1849 |          0.0003 |
|      13 |         169 |       00:02:06 |      100.00% |       90.63% |       0.0090 |       0.2071 |          0.0003 |
|      14 |         182 |       00:02:16 |      100.00% |       93.75% |       0.0127 |       0.3597 |          0.0003 |
|      15 |         195 |       00:02:25 |      100.00% |       93.75% |       0.0016 |       0.3414 |          0.0003 |
|======================================================================================================================|

Осмотрите последний слой сети. Подтвердите, что слой Classification Output включает три класса.

trainedSN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

Вычислите точность SqueezeNet

Оцените сеть с помощью данных валидации.

[YPred,probs] = classify(trainedSN,augimgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp(['SqueezeNet Accuracy: ',num2str(100*accuracy),'%'])
SqueezeNet Accuracy: 93.75%

Заключение

В этом примере показано, как использовать передачу обучения и непрерывный вейвлет для классификации трех классов сигналов ЭКГ путем использования предварительно обученных CNNs GoogLeNet и SqueezeNet. Основанные на вейвлете частотно-частотные представления сигналов ЭКГ используются для создания скалограмм. Формируются изображения скалограмм. Изображения используются для тонкой настройки обоих глубоких CNNs. Были также исследованы активации различных слоев сети.

Этот пример иллюстрирует один возможный рабочий процесс, который можно использовать для классификации сигналов с помощью предварительно обученных моделей CNN. Другие рабочие процессы возможны. Развертывание классификатора сигналов на NVIDIA Jetson с помощью анализа волн и глубокого обучения (Wavelet Toolbox) и развертывание классификатора сигналов с помощью вейвлетов и глубокого обучения на Raspberry Pi (Wavelet Toolbox) показывают, как развернуть GoogLeNet и SqueezeNet являются моделями, предварительно обученными на подмножестве базы данных ImageNet [10], которое используется в ILSVRC [8]. Набор ImageNet содержит изображения объектов реального мира, таких как рыбы, птицы, приборы и грибки. Скалограммы попадают вне класса объектов реального мира. В порядок вписаться в архитектуру GoogLeNet и SqueezeNet, скалограммы также подверглись сокращению данных. Вместо подстройки предварительно обученных CNN, чтобы различить различные классы скалограмм, является опцией обучения CNN с нуля при исходных размерностях скалограммы.

Ссылки

  1. Baim, D. S., В. С. Колуччи, Э. С. Монрэд, Х. С. Смит, Р. Ф. Райт, А. Лэноу, Д. Ф. Готье, Б. Дж. Рэнсил, В. Гроссман и Э. Браунвальд. «Выживание пациентов с тяжёлым застойным сердечным отказом, получавших пероральный милринон». Журнал Американского колледжа кардиологов. Том 7, № 3, 1986, стр. 661-670.

  2. Энгин, М. «классификация биений ЭКГ с помощью нейро-нечеткой сети». Распознавание Букв. Том 25, № 15, 2004, стр. 1715-1722.

  3. Гольдбергер А. Л., Л. А. Н. Амарал, Л. Гласс, Ж. М. Хаусдорф, П. Ч. Иванов, Р. Г. Марк, Ж. Э. Миетус, Г. Б. Муди, К.-К. Пэн и Х. Э. Стэнли. PhysioBank, PhysioToolkit и PhysioNet: компоненты нового исследовательского ресурса комплексных физиологических сигналов. Циркуляция. Том 101, номер 23: e215-e220. [Тиражные электронные страницы; http://circ.ahajournals.org/content/101/23/e215.full]; 2000 (13 июня). doi: 10.1161/01.CIR.101.23.e215.

  4. Леонардуцци, Р. Ф., Г. Шлоттхауэр, и М. Э. Торрес. «Вейвлет на основе мультифрактального анализа вариабельности сердечного ритма во время ишемии миокарда». В Инженерном Обществе Медицины и Биологии (EMBC), Ежегодная Международная Конференция IEEE, 110-113. Буэнос-Айрес, Аргентина: IEEE, 2010.

  5. Ли, Т. и М. Чжоу. «Классификация ЭКГ с использованием вейвлета пакетной энтропии и случайных лесов». Энтропия. Том 18, № 8, 2016, стр. 285.

  6. Махарадж, Э. А., и А. М. Алонсо. Дискриминантный анализ многомерных временных рядов: Применение к диагностике на основе сигналов ЭКГ. Вычислительная статистика и анализ данных. Том 70, 2014, с. 67-87.

  7. Moody, G. B., and R. G. Mark. «The влияния of the MIT-BIH Arrhythmia Database». IEEE Engineering in Medicine and Biology Magazine. Том 20. № 3, май-июнь 2001, с. 45-50. (PMID: 11446209)

  8. Russakovsky, O., J. Deng, and H. Su et al. «Большой масштабный вызов визуального распознавания ImageNet». Международный журнал компьютерного зрения. Том 115, № 3, 2015, стр. 211-252.

  9. Чжао, Ц., и Л. Чжан. «редукция данных и классификация ЭКГ с использованием вейвлета преобразования и машин опорных векторов». На Международной конференции IEEE по нейронным сетям и мозгу, 1089-1092. Пекин, Китай: IEEE, 2005.

  10. ImageNet. http://www.image-net.org

Вспомогательные функции

helperCreateECGDataDirectories создает директорию данных в родительской директории, затем создает три подкаталога в директории данных. Подкаталоги названы в честь каждого класса сигнала ЭКГ, найденного в ECGData.

function helperCreateECGDirectories(ECGData,parentFolder,dataFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

rootFolder = parentFolder;
localFolder = dataFolder;
mkdir(fullfile(rootFolder,localFolder))

folderLabels = unique(ECGData.Labels);
for i = 1:numel(folderLabels)
    mkdir(fullfile(rootFolder,localFolder,char(folderLabels(i))));
end
end

helperPlotReps строит графики первых тысяч выборок представителя каждого класса сигнала ЭКГ, обнаруженного в ECGData.

function helperPlotReps(ECGData)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

folderLabels = unique(ECGData.Labels);

for k=1:3
    ecgType = folderLabels{k};
    ind = find(ismember(ECGData.Labels,ecgType));
    subplot(3,1,k)
    plot(ECGData.Data(ind(1),1:1000));
    grid on
    title(ecgType)
end
end

helperCreateRGBfromTF использует cwtfilterbank (Wavelet Toolbox), чтобы получить непрерывное вейвлет сигналов ECG и генерирует скалограммы из вейвлет-коэффициентов. Функция helper изменяет размер скалограмм и записывает их на диск как изображения jpeg.

function helperCreateRGBfromTF(ECGData,parentFolder,childFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

imageRoot = fullfile(parentFolder,childFolder);

data = ECGData.Data;
labels = ECGData.Labels;

[~,signalLength] = size(data);

fb = cwtfilterbank('SignalLength',signalLength,'VoicesPerOctave',12);
r = size(data,1);

for ii = 1:r
    cfs = abs(fb.wt(data(ii,:)));
    im = ind2rgb(im2uint8(rescale(cfs)),jet(128));
    
    imgLoc = fullfile(imageRoot,char(labels(ii)));
    imFileName = strcat(char(labels(ii)),'_',num2str(ii),'.jpg');
    imwrite(imresize(im,[224 224]),fullfile(imgLoc,imFileName));
end
end