exponenta event banner

Классификация временных рядов с использованием вейвлет-анализа и глубокого обучения

Этот пример показывает, как классифицировать сигналы электрокардиограммы человека (ЭКГ) с использованием непрерывного вейвлет-преобразования (CWT) и глубокой сверточной нейронной сети (CNN).

Обучение глубокому CNN с нуля является дорогостоящим с точки зрения вычислений и требует большого количества обучающих данных. В различных применениях достаточный объем обучающих данных недоступен, и синтезировать новые реалистичные обучающие примеры невозможно. В этих случаях желательно использовать существующие нейронные сети, которые были обучены большим наборам данных для концептуально аналогичных задач. Это использование существующих нейронных сетей называется обучением передаче. В этом примере мы адаптируем два глубоких CNN, GoogLeNet и SqueeNet, предварительно подготовленных для распознавания изображения, для классификации форм сигналов ЭКГ на основе частотно-временного представления.

GoogLeNet и SqueeeNet являются глубокими CNN, изначально предназначенными для классификации изображений в 1000 категорий. Мы повторно используем сетевую архитектуру CNN для классификации сигналов ЭКГ на основе изображений из CWT данных временных рядов. Данные, используемые в этом примере, являются общедоступными из PhysioNet.

Описание данных

В этом примере используются данные ЭКГ, полученные от трех групп людей: лиц с сердечной аритмией (ARR), лиц с застойной сердечной недостаточностью (CHF) и лиц с нормальными синусовыми ритмами (СМП). Всего используется 162 записей ЭКГ из трех баз данных PhysioNet: база данных аритмии MIT-BIH [3] [7], база данных нормального синусового ритма MIT-BIH [3] и база данных застойной сердечной недостаточности BIDMC [1] [3]. Более конкретно, 96 записей от людей с аритмией, 30 записей от людей с застойной сердечной недостаточностью и 36 записей от людей с нормальными синусовыми ритмами. Цель состоит в том, чтобы обучить классификатора различать ARR, CHF и NSR.

Загрузить данные

Первым шагом является загрузка данных из репозитория GitHub. Чтобы загрузить данные с веб-сайта, нажмите Code и выбрать Download ZIP. Сохранить файл physionet_ECG_data-main.zip в папке, в которой имеется разрешение на запись. В инструкциях для этого примера предполагается, что файл загружен во временный каталог. tempdir, в MATLAB. Измените последующие инструкции по распаковке и загрузке данных, если вы решили загрузить данные в папку, отличную от tempdir.

После загрузки данных из GitHub распакуйте файл во временном каталоге.

unzip(fullfile(tempdir,'physionet_ECG_data-main.zip'),tempdir)

Распаковка создает папку physionet-ECG_data-main во временном каталоге. Эта папка содержит текстовый файл README.md и ECGData.zip. ECGData.zip файл содержит

  • ECGData.mat

  • Modified_physionet_data.txt

  • License.txt

ECGData.mat содержит данные, используемые в этом примере. Текстовый файл, Modified_physionet_data.txt, требуется политикой копирования PhysioNet и предоставляет исходные атрибуты для данных, а также описание этапов предварительной обработки, применяемых к каждой записи ЭКГ.

Расстегнуть молнию ECGData.zip в physionet-ECG_data-main. Загрузите файл данных в рабочую область MATLAB.

unzip(fullfile(tempdir,'physionet_ECG_data-main','ECGData.zip'),...
    fullfile(tempdir,'physionet_ECG_data-main'))
load(fullfile(tempdir,'physionet_ECG_data-main','ECGData.mat'))

ECGData - структурный массив с двумя полями: Data и Labels. Data поле представляет собой матрицу 162 на 65536, где каждая строка представляет собой запись ЭКГ, дискретизированную при 128 герц. Labels представляет собой массив диагностических меток типа 162 на 1, по одной для каждой строки Data. Тремя диагностическими категориями являются: 'ARR', 'CHF', и 'NSR'.

Для хранения предварительно обработанных данных каждой категории сначала создайте каталог данных ЭКГ. dataDir внутри tempdir. Затем создайте три подкаталога в 'data' по каждой категории ЭКГ. Вспомогательная функция helperCreateECGDirectories делает это. helperCreateECGDirectories принимает ECGDataимя каталога данных ЭКГ и имя родительского каталога в качестве входных аргументов. Вы можете заменить tempdir с другим каталогом, в котором имеется разрешение на запись. Исходный код для этой вспомогательной функции можно найти в разделе «Вспомогательные функции» в конце этого примера.

parentDir = tempdir;
dataDir = 'data';
helperCreateECGDirectories(ECGData,parentDir,dataDir)

Постройте график представителя каждой категории ЭКГ. Вспомогательная функция helperPlotReps делает это. helperPlotReps принимает ECGData в качестве входных данных. Исходный код для этой вспомогательной функции можно найти в разделе «Вспомогательные функции» в конце этого примера.

helperPlotReps(ECGData)

Создание частотно-временных представлений

После создания папок создайте частотно-временные представления сигналов ЭКГ. Эти представления называются скалограммами. Скалограмма - это абсолютное значение коэффициентов CWT сигнала.

Чтобы создать скалограммы, предварительно вычислите банк фильтров CWT. Предварительное вычисление набора фильтров CWT является предпочтительным способом при получении CWT многих сигналов с использованием одних и тех же параметров.

Перед формированием скалограмм осмотрите одну из них. Создание банка фильтров CWT с помощью cwtfilterbank (Vavelet Toolbox) для сигнала с 1000 выборок. Используйте набор фильтров, чтобы взять CWT первых 1000 выборок сигнала и получить скалограмму из коэффициентов.

Fs = 128;
fb = cwtfilterbank('SignalLength',1000,...
    'SamplingFrequency',Fs,...
    'VoicesPerOctave',12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;figure;pcolor(t,frq,abs(cfs))
set(gca,'yscale','log');shading interp;axis tight;
title('Scalogram');xlabel('Time (s)');ylabel('Frequency (Hz)')

Использовать функцию помощника helperCreateRGBfromTF для создания скалограмм в виде изображений RGB и записи их в соответствующий подкаталоги в dataDir. Исходный код для этой вспомогательной функции находится в разделе «Вспомогательные функции» в конце этого примера. Чтобы быть совместимым с архитектурой GoogLeNet, каждое изображение RGB - множество размера 224 на 224 на 3.

helperCreateRGBfromTF(ECGData,parentDir,dataDir)

Разделить на данные обучения и проверки

Загрузите изображения скалограммы как хранилище данных изображения. imageDatastore функция автоматически помечает изображения на основе имен папок и сохраняет данные как объект ImageDatastore. Хранилище данных изображения позволяет хранить большие данные изображения, включая данные, которые не помещаются в память, и эффективно считывать пакеты изображений во время обучения CNN.

allImages = imageDatastore(fullfile(parentDir,dataDir),...
    'IncludeSubfolders',true,...
    'LabelSource','foldernames');

Случайное разделение изображений на две группы, одна для обучения, а другая для проверки. Используйте 80% изображений для обучения, а остальные - для проверки. В целях воспроизводимости для случайного начального числа устанавливается значение по умолчанию.

rng default
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized');
disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
Number of training images: 130
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);
Number of validation images: 32

GoogLeNet

Груз

Загрузите предварительно обученную нейронную сеть GoogLeNet. Если Deep Learning Toolbox™ Model для пакета поддержки сети GoogLeNet не установлен, программное обеспечение предоставляет ссылку на требуемый пакет поддержки в проводнике Add-On. Чтобы установить пакет поддержки, щелкните ссылку и нажмите кнопку Установить.

net = googlenet;

Извлеките и отобразите график слоев из сети.

lgraph = layerGraph(net);
numberOfLayers = numel(lgraph.Layers);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)
title(['GoogLeNet Layer Graph: ',num2str(numberOfLayers),' Layers']);

Проверьте первый элемент свойства network Layers. Убедитесь, что GoogLeNet требует RGB-изображений размером 224-на-224-на-3.

net.Layers(1)
ans = 
  ImageInputLayer with properties:

                Name: 'data'
           InputSize: [224 224 3]

   Hyperparameters
    DataAugmentation: 'none'
       Normalization: 'zerocenter'
                Mean: [224×224×3 single]

Изменение параметров сети GoogLeNet

Каждый уровень сетевой архитектуры может считаться фильтром. Более ранние слои идентифицируют более распространенные элементы изображений, такие как блобы, ребра и цвета. Последующие слои фокусируются на более специфических элементах, чтобы дифференцировать категории. Приложение GoogLeNet предварительно подготовлено для классификации изображений по 1000 категориям объектов. Вы должны переподготовить GoogLeNet для нашей проблемы классификации ЭКГ.

Для предотвращения переоборудования используется слой отсева. Уровень отсева случайным образом устанавливает входные элементы в ноль с заданной вероятностью. Посмотрите dropoutLayer для получения дополнительной информации. Вероятность по умолчанию равна 0,5. Замените окончательный уровень отсева в сети, 'pool5-drop_7x7_s1', с уровнем вероятности отсева 0,6.

newDropoutLayer = dropoutLayer(0.6,'Name','new_Dropout');
lgraph = replaceLayer(lgraph,'pool5-drop_7x7_s1',newDropoutLayer);

Сверточные уровни сетевого извлеченного изображения отличаются тем, что последний обучаемый уровень и конечный уровень классификации используют для классификации входного изображения. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержат информацию о том, как объединить функции, извлекаемые сетью, в вероятности классов, значение потерь и прогнозируемые метки. Чтобы переобучить GoogLeNet для классификации изображений RGB, замените эти два слоя на новые, адаптированные к данным.

Замена полностью подключенного слоя 'loss3-classifier' с новым полностью соединенным слоем с числом фильтров, равным числу классов. Чтобы быстрее учиться в новых слоях, чем в перенесенных слоях, увеличьте коэффициенты скорости обучения полностью подключенного уровня.

numClasses = numel(categories(imgsTrain.Labels));
newConnectedLayer = fullyConnectedLayer(numClasses,'Name','new_fc',...
    'WeightLearnRateFactor',5,'BiasLearnRateFactor',5);
lgraph = replaceLayer(lgraph,'loss3-classifier',newConnectedLayer);

Уровень классификации определяет выходные классы сети. Замените классификационный слой новым без меток классов. trainNetwork автоматически устанавливает выходные классы слоя во время обучения.

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);

Настройка параметров обучения и обучение GoogLeNet

Обучение нейронной сети - это итеративный процесс, который включает минимизацию функции потерь. Для минимизации функции потерь используется алгоритм градиентного спуска. В каждой итерации оценивается градиент функции потерь и обновляются веса алгоритма спуска.

Обучение можно настроить, задав различные параметры. InitialLearnRate задает начальный размер шага в направлении отрицательного градиента функции потерь. MiniBatchSize определяет размер подмножества обучающего набора, используемого в каждой итерации. Одна эпоха - это полный проход обучающего алгоритма по всему обучающему набору. MaxEpochs указывает максимальное количество эпох, используемых для обучения. Выбор правильного количества эпох не является тривиальной задачей. Уменьшение числа эпох приводит к недооснащению модели, а увеличение числа эпох приводит к переоснащению.

Используйте trainingOptions для определения вариантов обучения. Набор MiniBatchSize к 10, MaxEpochs до 10, и InitialLearnRate до 0,0001. Визуализация хода обучения с помощью настройки Plots кому training-progress. Используйте стохастический градиентный спуск с оптимизатором импульса. По умолчанию обучение выполняется на графическом процессоре, если он доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений. Сведения о том, какие графические процессоры поддерживаются, см. в разделе Поддержка графических процессоров по выпуску (Панель инструментов параллельных вычислений). Для воспроизводимости, комплект ExecutionEnvironment кому cpu чтобы trainNetwork использовал ЦП. Задайте для случайного начального значения значение по умолчанию. Время выполнения будет быстрее, если вы сможете использовать графический процессор.

options = trainingOptions('sgdm',...
    'MiniBatchSize',15,...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4,...
    'ValidationData',imgsValidation,...
    'ValidationFrequency',10,...
    'Verbose',1,...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');
rng default

Обучение сети. Процесс обучения обычно занимает 1-5 минут на настольном процессоре. В окне команд отображается информация об обучении во время выполнения. Результаты включают номер эпохи, номер итерации, прошедшее время, точность мини-партии, точность проверки и значение функции потери для данных проверки.

trainedGN = trainNetwork(imgsTrain,lgraph,options);

Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |        6.67% |       18.75% |       4.9207 |       2.4141 |      1.0000e-04 |
|       2 |          10 |       00:00:23 |       66.67% |       62.50% |       0.9589 |       1.3191 |      1.0000e-04 |
|       3 |          20 |       00:00:43 |       46.67% |       75.00% |       1.2973 |       0.5928 |      1.0000e-04 |
|       4 |          30 |       00:01:04 |       60.00% |       78.13% |       0.7219 |       0.4576 |      1.0000e-04 |
|       5 |          40 |       00:01:25 |       73.33% |       84.38% |       0.4750 |       0.3367 |      1.0000e-04 |
|       7 |          50 |       00:01:46 |       93.33% |       84.38% |       0.2714 |       0.2892 |      1.0000e-04 |
|       8 |          60 |       00:02:07 |       80.00% |       87.50% |       0.3617 |       0.2433 |      1.0000e-04 |
|       9 |          70 |       00:02:29 |       86.67% |       87.50% |       0.3246 |       0.2526 |      1.0000e-04 |
|      10 |          80 |       00:02:50 |      100.00% |       96.88% |       0.0701 |       0.1876 |      1.0000e-04 |
|      12 |          90 |       00:03:11 |       86.67% |      100.00% |       0.2836 |       0.1681 |      1.0000e-04 |
|      13 |         100 |       00:03:32 |       86.67% |       96.88% |       0.4160 |       0.1607 |      1.0000e-04 |
|      14 |         110 |       00:03:53 |       86.67% |       96.88% |       0.3237 |       0.1565 |      1.0000e-04 |
|      15 |         120 |       00:04:14 |       93.33% |       96.88% |       0.1646 |       0.1476 |      1.0000e-04 |
|      17 |         130 |       00:04:35 |      100.00% |       96.88% |       0.0551 |       0.1330 |      1.0000e-04 |
|      18 |         140 |       00:04:57 |       93.33% |       96.88% |       0.0927 |       0.1347 |      1.0000e-04 |
|      19 |         150 |       00:05:18 |       93.33% |       93.75% |       0.1666 |       0.1325 |      1.0000e-04 |
|      20 |         160 |       00:05:39 |       93.33% |       96.88% |       0.0873 |       0.1164 |      1.0000e-04 |
|======================================================================================================================|

Проверьте последний уровень обученной сети. Подтвердите, что уровень «Classification Output» включает три класса.

trainedGN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

Оценка точности GoogLeNet

Анализ сети с использованием данных проверки.

[YPred,probs] = classify(trainedGN,imgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp(['GoogLeNet Accuracy: ',num2str(100*accuracy),'%'])
GoogLeNet Accuracy: 96.875%

Точность идентична точности проверки, сообщенной на рисунке визуализации обучения. Скалограммы были разделены на учебные и валидационные коллекции. Обе коллекции использовались для обучения GoogLeNet. Идеальным способом оценки результатов обучения является отсутствие данных классификации сети. Поскольку недостаточно данных для разделения на обучение, проверку и тестирование, мы рассматриваем вычисленную точность проверки как сетевую точность.

Обзор активаций GoogLeNet

Каждый уровень CNN создает отклик или активацию на входное изображение. Однако в CNN имеется только несколько слоев, пригодных для извлечения признаков изображения. Слои в начале сети захватывают основные элементы изображения, такие как края и блобы. Чтобы увидеть это, визуализируйте веса сетевого фильтра из первого сверточного уровня. В первом слое 64 отдельных набора весов.

wghts = trainedGN.Layers(2).Weights;
wghts = rescale(wghts);
wghts = imresize(wghts,5);
figure
montage(wghts)
title('First Convolutional Layer Weights')

Вы можете проверить активации и обнаружить, какие функции узнает GoogLeNet, сравнивая области активации с исходным изображением. Дополнительные сведения см. в разделах Визуализация активизаций сверточной нейронной сети и Визуализация особенностей сверточной нейронной сети.

Проверьте, какие области в сверточных слоях активизируются на изображении из ARR класс. Сравните с соответствующими областями на исходном изображении. Каждый уровень сверточной нейронной сети состоит из множества 2-D массивов, называемых каналами. Пропустите изображение через сеть и проверьте выходные активации первого сверточного уровня, 'conv1-7x7_s2'.

convLayer = 'conv1-7x7_s2';

imgClass = 'ARR';
imgName = 'ARR_10.jpg';
imarr = imread(fullfile(parentDir,dataDir,imgClass,imgName));

trainingFeaturesARR = activations(trainedGN,imarr,convLayer);
sz = size(trainingFeaturesARR);
trainingFeaturesARR = reshape(trainingFeaturesARR,[sz(1) sz(2) 1 sz(3)]);
figure
montage(rescale(trainingFeaturesARR),'Size',[8 8])
title([imgClass,' Activations'])

Найдите самый сильный канал для этого изображения. Сравните самый сильный канал с исходным изображением.

imgSize = size(imarr);
imgSize = imgSize(1:2);
[~,maxValueIndex] = max(max(max(trainingFeaturesARR)));
arrMax = trainingFeaturesARR(:,:,:,maxValueIndex);
arrMax = rescale(arrMax);
arrMax = imresize(arrMax,imgSize);
figure;
imshowpair(imarr,arrMax,'montage')
title(['Strongest ',imgClass,' Channel: ',num2str(maxValueIndex)])

SqueezeNet

SqueezeNet - глубокий CNN, архитектура которого поддерживает изображения размера 227 на 227 на 3. Несмотря на то, что размеры изображения отличаются для GoogLeNet, нет необходимости создавать новые изображения RGB для размеров SqueeeNet. Можно использовать исходные изображения RGB.

Груз

Загрузите предварительно обученную нейронную сеть SqueeExNet. Если Deep Learning Toolbox™ Model для пакета поддержки сети SqueeeNet не установлен, программное обеспечение предоставляет ссылку на требуемый пакет поддержки в проводнике Add-On. Чтобы установить пакет поддержки, щелкните ссылку и нажмите кнопку Установить.

sqz = squeezenet;

Извлеките график слоев из сети. Подтвердите, что SqueeENet имеет меньше слоев, чем GoogLeNet. Также убедитесь, что SqueeEcNet настроен для изображений размером 227-на-227-на-3

lgraphSqz = layerGraph(sqz);
disp(['Number of Layers: ',num2str(numel(lgraphSqz.Layers))])
Number of Layers: 68
disp(lgraphSqz.Layers(1).InputSize)
   227   227     3

Изменение параметров сети SqueeENet

Чтобы переобучить SqueeEcNet для классификации новых изображений, внесите изменения, аналогичные тем, которые сделаны для GoogLeNet.

Проверьте последние шесть сетевых уровней.

lgraphSqz.Layers(end-5:end)
ans = 
  6x1 Layer array with layers:

     1   'drop9'                             Dropout                 50% dropout
     2   'conv10'                            Convolution             1000 1x1x512 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'                       ReLU                    ReLU
     4   'pool10'                            Average Pooling         14x14 average pooling with stride [1  1] and padding [0  0  0  0]
     5   'prob'                              Softmax                 softmax
     6   'ClassificationLayer_predictions'   Classification Output   crossentropyex with 'tench' and 999 other classes

Замените 'drop9' уровень, последний уровень отсева в сети, с уровнем отсева вероятности 0,6.

tmpLayer = lgraphSqz.Layers(end-5);
newDropoutLayer = dropoutLayer(0.6,'Name','new_dropout');
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newDropoutLayer);

В отличие от GoogLeNet, последний обучаемый уровень в SqueeEcNet является сверточным уровнем 1 на 1, 'conv10', а не полностью подключенного слоя. Замените 'conv10' слой с новым сверточным слоем с числом фильтров, равным числу классов. Как это было сделано с GoogLeNet, увеличить коэффициенты скорости обучения нового слоя.

numClasses = numel(categories(imgsTrain.Labels));
tmpLayer = lgraphSqz.Layers(end-4);
newLearnableLayer = convolution2dLayer(1,numClasses, ...
        'Name','new_conv', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newLearnableLayer);

Замените классификационный слой новым без меток классов.

tmpLayer = lgraphSqz.Layers(end);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newClassLayer);

Проверьте последние шесть уровней сети. Подтвердите изменение уровней отсева, свертки и вывода.

lgraphSqz.Layers(63:68)
ans = 
  6x1 Layer array with layers:

     1   'new_dropout'       Dropout                 60% dropout
     2   'new_conv'          Convolution             3 1x1 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'       ReLU                    ReLU
     4   'pool10'            Average Pooling         14x14 average pooling with stride [1  1] and padding [0  0  0  0]
     5   'prob'              Softmax                 softmax
     6   'new_classoutput'   Classification Output   crossentropyex

Подготовка данных RGB для SqueeNet

Образы RGB имеют размеры, соответствующие архитектуре GoogLeNet. Создайте хранилища данных дополненных изображений, которые автоматически изменяют размер существующих образов RGB для архитектуры SqueeNet. Дополнительные сведения см. в разделе augmentedImageDatastore.

augimgsTrain = augmentedImageDatastore([227 227],imgsTrain);
augimgsValidation = augmentedImageDatastore([227 227],imgsValidation);

Настройка параметров обучения и SqueeNet поезда

Создайте новый набор параметров обучения для использования с SqueeEcNet. Установите для случайного начального значения значение по умолчанию и выполните обучение сети. Процесс обучения обычно занимает 1-5 минут на настольном процессоре.

ilr = 3e-4;
miniBatchSize = 10;
maxEpochs = 15;
valFreq = floor(numel(augimgsTrain.Files)/miniBatchSize);
opts = trainingOptions('sgdm',...
    'MiniBatchSize',miniBatchSize,...
    'MaxEpochs',maxEpochs,...
    'InitialLearnRate',ilr,...
    'ValidationData',augimgsValidation,...
    'ValidationFrequency',valFreq,...
    'Verbose',1,...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');

rng default
trainedSN = trainNetwork(augimgsTrain,lgraphSqz,opts);

Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:01 |       20.00% |       43.75% |       5.2508 |       1.2540 |          0.0003 |
|       1 |          13 |       00:00:11 |       60.00% |       50.00% |       0.9912 |       1.0519 |          0.0003 |
|       2 |          26 |       00:00:20 |       60.00% |       59.38% |       0.8554 |       0.8497 |          0.0003 |
|       3 |          39 |       00:00:30 |       60.00% |       59.38% |       0.8120 |       0.8328 |          0.0003 |
|       4 |          50 |       00:00:38 |       50.00% |              |       0.7885 |              |          0.0003 |
|       4 |          52 |       00:00:40 |       60.00% |       65.63% |       0.7091 |       0.7314 |          0.0003 |
|       5 |          65 |       00:00:49 |       90.00% |       87.50% |       0.4639 |       0.5893 |          0.0003 |
|       6 |          78 |       00:00:59 |       70.00% |       87.50% |       0.6021 |       0.4355 |          0.0003 |
|       7 |          91 |       00:01:08 |       90.00% |       90.63% |       0.2307 |       0.2945 |          0.0003 |
|       8 |         100 |       00:01:15 |       90.00% |              |       0.1827 |              |          0.0003 |
|       8 |         104 |       00:01:18 |       90.00% |       93.75% |       0.2139 |       0.2153 |          0.0003 |
|       9 |         117 |       00:01:28 |      100.00% |       90.63% |       0.0521 |       0.1964 |          0.0003 |
|      10 |         130 |       00:01:38 |       90.00% |       90.63% |       0.1134 |       0.2214 |          0.0003 |
|      11 |         143 |       00:01:47 |      100.00% |       90.63% |       0.0855 |       0.2095 |          0.0003 |
|      12 |         150 |       00:01:52 |       90.00% |              |       0.2394 |              |          0.0003 |
|      12 |         156 |       00:01:57 |      100.00% |       90.63% |       0.0606 |       0.1849 |          0.0003 |
|      13 |         169 |       00:02:06 |      100.00% |       90.63% |       0.0090 |       0.2071 |          0.0003 |
|      14 |         182 |       00:02:16 |      100.00% |       93.75% |       0.0127 |       0.3597 |          0.0003 |
|      15 |         195 |       00:02:25 |      100.00% |       93.75% |       0.0016 |       0.3414 |          0.0003 |
|======================================================================================================================|

Проверьте последний уровень сети. Подтвердите, что уровень «Classification Output» включает три класса.

trainedSN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

Оценка точности SqueeENet

Анализ сети с использованием данных проверки.

[YPred,probs] = classify(trainedSN,augimgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp(['SqueezeNet Accuracy: ',num2str(100*accuracy),'%'])
SqueezeNet Accuracy: 93.75%

Заключение

Этот пример показывает, как использовать обучение передаче и непрерывный вейвлет-анализ для классификации трех классов сигналов ЭКГ, используя предварительно обученные CNN GoogLeNet и SqueeEcNet. Для создания скалограмм используют вейвлет-основанные частотно-временные представления ЭКГ-сигналов. Формируют изображения скалограмм в формате RGB. Изображения используются для точной настройки обоих глубоких CNN. Были также изучены возможности активации различных сетевых уровней.

Этот пример иллюстрирует один возможный рабочий процесс, который можно использовать для классификации сигналов с использованием предварительно подготовленных моделей CNN. Возможны и другие рабочие процессы. Развертывание классификатора сигналов на NVIDIA Jetson Использование вейвлет-анализа и глубокого обучения (Wavelet Toolbox) и Развертывание классификатора сигналов с помощью вейвлетов и глубокого обучения на Raspberry Pi (Wavelet Toolbox) показывают, как развернуть код на оборудовании для классификации сигналов. GoogLeNet и SqueeNet являются моделями, предварительно подготовленными в подмножестве базы данных ImageNet [10], которое используется в вызове масштабного визуального распознавания ImageNet (ILSVRC) [8]. Коллекция ImageNet содержит изображения реальных объектов, таких как рыбы, птицы, приборы и грибы. Скалограммы выходят за пределы класса реальных объектов. Для того, чтобы вписаться в архитектуру GoogLeNet и SqueEcNet, скалограммы также подверглись сокращению данных. Вместо точной настройки предварительно обученных CNN для различения различных классов скалограмм, обучение CNN с нуля при исходных размерах скалограммы является опцией.

Ссылки

  1. Баим, Д. С., В. С. Колуччи, Э. С. Монрад, Х. С. Смит, Р. Ф. Райт, А. Лануэ, Д. Ф. Готье, Б. Дж. Рансил, В. Гроссман и Э. Браунвальд. «Выживаемость пациентов с тяжелой застойной сердечной недостаточностью, получавших пероральный милринон». Журнал Американского колледжа кардиологов. Том 7, номер 3, 1986, стр. 661-670.

  2. Энгин, М. «Классификация биений ЭКГ с использованием нейро-нечеткой сети». Буквы распознавания образов. Том 25, номер 15, 2004, ст.1715-1722.

  3. Гольдбергер А. Л., Л. А. Н. Амарал, Л. Гласс, Ж. М. Хаусдорф, П. Ч. Иванов, Р. Г. Марк, Ж. Е. Миетус, Г. Б. Муди, К. -К. Пэн и Х. Э. Стэнли. «PhysioBank, PhysioToolkit и PhysioNet: компоненты нового исследовательского ресурса для сложных физиологических сигналов». Циркуляция. Том 101, номер 23: e215-e220. [Электронные страницы тиража ;http://circ.ahajournals.org/content/101/23/e215.full]; 2000 (13 июня). дои: 10.1161/01.CIR.101.23.e215.

  4. Леонардуцци, Р. Ф., Г. Шлоттауэр и М. Э. Торрес. «Многофактальный анализ вариабельности сердечного ритма во время ишемии миокарда на основе вейвлет-лидеров». В области инженерии в Обществе медицины и биологии (EMBC), Ежегодная международная конференция IEEE, 110-113. Буэнос-Айрес, Аргентина: IEEE, 2010.

  5. Ли, Т. и М. Чжоу. «Классификация ЭКГ с использованием вейвлет-пакетной энтропии и случайных лесов». Энтропия. Том 18, номер 8, 2016, стр. 285.

  6. Махарадж, Э. А. и А. М. Алонсо. «Дискриминантный анализ многомерных временных рядов: Применение к диагностике на основе сигналов ЭКГ». Вычислительная статистика и анализ данных. Том 70, 2014, стр. 67-87.

  7. Муди, Г. Б. и Р. Г. Марк. «Влияние базы данных аритмии MIT-BIH». IEEE Engineering in Medicine and Biology Magazine. Том 20. Номер 3, май-июнь 2001 года, стр. 45-50. (PMID: 11446209)

  8. Руссаковский, О., Дж. Денг и Х. Су и др. «Задача масштабного визуального распознавания ImageNet». Международный журнал компьютерного зрения. Том 115, номер 3, 2015, стр. 211-252.

  9. Чжао, К. и Л. Чжан. «Извлечение и классификация элементов ЭКГ с использованием вейвлет-преобразования и поддержки векторных машин». В IEEE Международная конференция по нейронным сетям и мозгу, 1089-1092. Пекин, Китай: IEEE, 2005.

  10. ImageNet. http://www.image-net.org

Вспомогательные функции

helperCreateECGDataDirectory создает каталог данных внутри родительского каталога, затем создает три подкаталога внутри каталога данных. Подкаталоги именуются после каждого класса сигнала ЭКГ, найденного в ECGData.

function helperCreateECGDirectories(ECGData,parentFolder,dataFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

rootFolder = parentFolder;
localFolder = dataFolder;
mkdir(fullfile(rootFolder,localFolder))

folderLabels = unique(ECGData.Labels);
for i = 1:numel(folderLabels)
    mkdir(fullfile(rootFolder,localFolder,char(folderLabels(i))));
end
end

helperPlotReps строит графики первой тысячи образцов представителя каждого класса сигнала ЭКГ, обнаруженного в ECGData.

function helperPlotReps(ECGData)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

folderLabels = unique(ECGData.Labels);

for k=1:3
    ecgType = folderLabels{k};
    ind = find(ismember(ECGData.Labels,ecgType));
    subplot(3,1,k)
    plot(ECGData.Data(ind(1),1:1000));
    grid on
    title(ecgType)
end
end

helperCreateRUREstartTF использует cwtfilterbank (Vavelet Toolbox) для получения непрерывного вейвлет-преобразования ЭКГ-сигналов и генерирует скалограммы из вейвлет-коэффициентов. Вспомогательная функция изменяет размер скалограмм и записывает их на диск в виде изображений jpeg.

function helperCreateRGBfromTF(ECGData,parentFolder,childFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

imageRoot = fullfile(parentFolder,childFolder);

data = ECGData.Data;
labels = ECGData.Labels;

[~,signalLength] = size(data);

fb = cwtfilterbank('SignalLength',signalLength,'VoicesPerOctave',12);
r = size(data,1);

for ii = 1:r
    cfs = abs(fb.wt(data(ii,:)));
    im = ind2rgb(im2uint8(rescale(cfs)),jet(128));
    
    imgLoc = fullfile(imageRoot,char(labels(ii)));
    imFileName = strcat(char(labels(ii)),'_',num2str(ii),'.jpg');
    imwrite(imresize(im,[224 224]),fullfile(imgLoc,imFileName));
end
end

См. также

| | | | | | (инструментарий вейвлета)

Связанные темы