Классифицируйте временные ряды Используя анализ вейвлета и глубокое обучение

В этом примере показано, как классифицировать человеческую электрокардиограмму (ECG) сигналы с помощью непрерывного вейвлета преобразовывает (CWT) и глубокой сверточной нейронной сети (CNN).

Обучение глубокий CNN с нуля является в вычислительном отношении дорогим и требует большой суммы обучающих данных. В различных приложениях достаточная сумма обучающих данных не доступна, и синтезирование новых реалистических учебных примеров не выполнимо. В этих случаях, усиливая существующие нейронные сети, которые были обучены на больших наборах данных для концептуально подобных задач, желательно. Это усиление существующих нейронных сетей называется передачей обучения. В этом примере мы адаптируем два глубоких CNNs, GoogLeNet и SqueezeNet, предварительно обученный распознаванию изображений классифицировать формы волны ECG на основе представления частоты времени.

GoogLeNet и SqueezeNet являются глубоким CNNs, первоначально спроектированным, чтобы классифицировать изображения на 1 000 категорий. Мы снова используем сетевую архитектуру CNN, чтобы классифицировать сигналы ECG на основе изображений от CWT данных временных рядов. Данные, используемые в этом примере, общедоступны от PhysioNet.

Описание данных

В этом примере вы используете данные о ECG, полученные из трех групп людей: люди с сердечной аритмией (ARR), люди с застойной сердечной недостаточностью (CHF) и люди с нормальными ритмами пазухи (NSR). Всего вы используете 162 записи ECG от трех Физиосетевых баз данных: База данных Аритмии MIT-BIH [3] [7], MIT-BIH Нормальная База данных Ритма Пазухи [3] и База данных Застойной сердечной недостаточности BIDMC [1] [3]. А именно, 96 записей от людей с аритмией, 30 записей от людей с застойной сердечной недостаточностью и 36 записей от людей с нормальными ритмами пазухи. Цель состоит в том, чтобы обучить классификатор различать ARR, швейцарский франк и NSR.

Загрузите данные

Первый шаг должен загрузить данные из репозитория GitHub. Чтобы загрузить данные из веб-сайта, нажмите Code и выберите Download ZIP. Сохраните файл physionet_ECG_data-master.zip в папке, где у вас есть разрешение записи. Инструкции для этого примера принимают, что вы загрузили файл на свою временную директорию, tempdir, в MATLAB. Измените последующие инструкции для того, чтобы разархивировать и загрузить данные, если вы принимаете решение загрузить данные в папке, отличающейся от tempdir.

После загрузки данных GitHub разархивируйте файл в своей временной директории.

unzip(fullfile(tempdir,'physionet_ECG_data-master.zip'),tempdir)

Разархивация создает папку physionet-ECG_data-master в вашей временной директории. Эта папка содержит текстовый файл README.md и ECGData.zip. ECGData.zip файл содержит

  • ECGData.mat

  • Modified_physionet_data.txt

  • License.txt

ECGData.mat содержит данные, используемые в этом примере. Текстовый файл, Modified_physionet_data.txt, требуется копированием PhysioNet политики и обеспечивает исходные приписывания для данных, а также описание шагов предварительной обработки применилось к каждой записи ECG.

Разархивируйте ECGData.zip в physionet-ECG_data-master. Загрузите файл данных в свое рабочее пространство MATLAB.

unzip(fullfile(tempdir,'physionet_ECG_data-master','ECGData.zip'),...
    fullfile(tempdir,'physionet_ECG_data-master'))
load(fullfile(tempdir,'physionet_ECG_data-master','ECGData.mat'))

ECGData массив структур с двумя полями: Data и Labels. Data поле 162 65536 матрица, где каждая строка является записью ECG, произведенной на уровне 128 герц. Labels 162 1 массив ячеек диагностических меток, один для каждой строки Data. Три диагностических категории: 'ARR', 'CHF', и 'NSR'.

Чтобы хранить предварительно обработанные данные каждой категории, сначала создайте директорию dataDir данных о ECG в tempdir. Затем создайте три подкаталога в 'data' названный в честь каждой категории ECG. Функция помощника helperCreateECGDirectories делает это. helperCreateECGDirectories принимает ECGData, имя директории данных о ECG и имя родительского каталога как входные параметры. Можно заменить tempdir с другой директорией, где у вас есть разрешение записи. Можно найти исходный код для этой функции помощника в разделе Supporting Functions в конце этого примера.

parentDir = tempdir;
dataDir = 'data';
helperCreateECGDirectories(ECGData,parentDir,dataDir)

Постройте представителя каждой категории ECG. Функция помощника helperPlotReps делает это. helperPlotReps принимает ECGData как введено. Можно найти исходный код для этой функции помощника в разделе Supporting Functions в конце этого примера.

helperPlotReps(ECGData)

Создайте представления частоты времени

После создания папок создайте представления частоты времени сигналов ECG. Эти представления называются scalograms. scalogram является абсолютным значением коэффициентов CWT сигнала.

Чтобы создать scalograms, предварительно вычислите набор фильтров CWT. Предварительное вычисление набора фильтров CWT является предпочтительным методом при получении CWT многих сигналов с помощью тех же параметров.

Прежде, чем сгенерировать scalograms, исследуйте одного из них. Создайте набор фильтров CWT с помощью cwtfilterbank для сигнала с 1 000 выборок. Используйте набор фильтров, чтобы взять CWT первых 1 000 выборок сигнала и получить scalogram из коэффициентов.

Fs = 128;
fb = cwtfilterbank('SignalLength',1000,...
    'SamplingFrequency',Fs,...
    'VoicesPerOctave',12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;figure;pcolor(t,frq,abs(cfs))
set(gca,'yscale','log');shading interp;axis tight;
title('Scalogram');xlabel('Time (s)');ylabel('Frequency (Hz)')

Используйте функцию помощника helperCreateRGBfromTF чтобы создать scalograms как, RGB отображает, и запишите им в соответствующий подкаталог в dataDir. Исходный код для этой функции помощника находится в разделе Supporting Functions в конце этого примера. Чтобы быть совместимым с архитектурой GoogLeNet, каждое изображение RGB является массивом размера 224 224 3.

helperCreateRGBfromTF(ECGData,parentDir,dataDir)

Разделитесь на данные об обучении и валидации

Загрузите изображения scalogram как datastore изображений. imageDatastore функционируйте автоматически помечает изображения на основе имен папок и хранит данные как объект ImageDatastore. Datastore изображений позволяет вам сохранить большие данные изображения, включая данные, которые не умещаются в памяти, и эффективно считать пакеты изображений во время обучения CNN.

allImages = imageDatastore(fullfile(parentDir,dataDir),...
    'IncludeSubfolders',true,...
    'LabelSource','foldernames');

Случайным образом разделите изображения на две группы, один для обучения и другого для валидации. Используйте 80% изображений для обучения и остатка для валидации. В целях воспроизводимости мы устанавливаем случайный seed на значение по умолчанию.

rng default
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized');
disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
Number of training images: 130
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);
Number of validation images: 32

GoogLeNet

Загрузка

Загрузите предварительно обученную нейронную сеть GoogLeNet. Если Модель Deep Learning Toolbox™ для пакета Сетевой поддержки GoogLeNet не установлена, программное обеспечение обеспечивает ссылку на необходимый пакет поддержки в Add-On Explorer. Чтобы установить пакет поддержки, щелкните по ссылке, и затем нажмите Install.

net = googlenet;

Извлеките и отобразите график слоев от сети.

lgraph = layerGraph(net);
numberOfLayers = numel(lgraph.Layers);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)
title(['GoogLeNet Layer Graph: ',num2str(numberOfLayers),' Layers']);

Смотрите первый элемент свойства слоев сети. Подтвердите, что GoogLeNet требует изображений RGB размера 224 224 3.

net.Layers(1)
ans = 
  ImageInputLayer with properties:

                Name: 'data'
           InputSize: [224 224 3]

   Hyperparameters
    DataAugmentation: 'none'
       Normalization: 'zerocenter'
                Mean: [224×224×3 single]

Измените параметры сети GoogLeNet

Каждый слой в сетевой архитектуре может быть рассмотрен фильтром. Более ранние слои идентифицируют более типичные функции изображений, такие как блобы, ребра и цвета. Последующие слои фокусируются на более определенных функциях для того, чтобы дифференцировать категории. GoogLeNet предварительно обучен классифицировать изображения в 1 000 категорий объектов. Необходимо переобучить GoogLeNet для нашей проблемы классификации ECG.

Чтобы предотвратить сверхподбор кривой, слой уволенного используется. Слой уволенного случайным образом обнуляет входные элементы с данной вероятностью. Смотрите dropoutLayer (Deep Learning Toolbox) для получения дополнительной информации. Вероятность по умолчанию 0.5. Замените итоговый слой уволенного в сети, 'pool5-drop_7x7_s1', со слоем уволенного вероятности 0.6.

newDropoutLayer = dropoutLayer(0.6,'Name','new_Dropout');
lgraph = replaceLayer(lgraph,'pool5-drop_7x7_s1',newDropoutLayer);

Сверточные слои сетевого извлечения отображают функции что последний learnable слой и итоговое использование слоя классификации, чтобы классифицировать входное изображение. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержите информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы переобучить GoogLeNet, чтобы классифицировать изображения RGB, замените эти два слоя на новые слои, адаптированные к данным.

Замените полносвязный слой 'loss3-classifier' с новым полносвязным слоем с количеством фильтров равняются количеству классов. Чтобы учиться быстрее в новых слоях, чем в переданных слоях, увеличьте факторы скорости обучения полносвязного слоя.

numClasses = numel(categories(imgsTrain.Labels));
newConnectedLayer = fullyConnectedLayer(numClasses,'Name','new_fc',...
    'WeightLearnRateFactor',5,'BiasLearnRateFactor',5);
lgraph = replaceLayer(lgraph,'loss3-classifier',newConnectedLayer);

Слой классификации задает выходные классы сети. Замените слой классификации на новый без меток класса. trainNetwork автоматически устанавливает выходные классы слоя в учебное время.

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);

Установите опции обучения и обучите GoogLeNet

Обучение нейронной сети является итеративным процессом, который включает минимизацию функции потерь. Чтобы минимизировать функцию потерь, алгоритм градиентного спуска используется. В каждой итерации оценен градиент функции потерь, и веса алгоритма спуска обновляются.

Обучение может быть настроено путем установки различных вариантов. InitialLearnRate задает начальный размер шага в направлении отрицательного градиента функции потерь. MiniBatchSize задает как большой из подмножества набора обучающих данных, чтобы использовать в каждой итерации. Одна эпоха является всей передачей алгоритма настройки по целому набору обучающих данных. MaxEpochs задает максимальное количество эпох, чтобы использовать для обучения. Выбор правильного номера эпох не является тривиальной задачей. Сокращение числа эпох оказывает влияние underfitting модель и увеличение числа результатов эпох в сверхподборе кривой.

Используйте trainingOptions (Deep Learning Toolbox) функция, чтобы задать опции обучения. Установите MiniBatchSize к 10, MaxEpochs к 10, и InitialLearnRate к 0,0001. Визуализируйте процесс обучения установкой Plots к training-progress. Используйте стохастический градиентный спуск с оптимизатором импульса. По умолчанию обучение сделано на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™. Чтобы видеть, который поддерживаются графические процессоры, смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). В целях воспроизводимости, набор ExecutionEnvironment к cpu так, чтобы trainNetwork используемый центральный процессор. Установите случайный seed на значение по умолчанию. Время выполнения будет быстрее, если вы сможете использовать графический процессор.

options = trainingOptions('sgdm',...
    'MiniBatchSize',15,...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4,...
    'ValidationData',imgsValidation,...
    'ValidationFrequency',10,...
    'Verbose',1,...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');
rng default

Обучите сеть. Учебный процесс обычно занимает 1-5 минут на настольном центральном процессоре. Командное окно отображает учебную информацию во время запуска. Результаты включают номер эпохи, номер итерации, время протекло, мини-пакетная точность, точность валидации и значение функции потерь для данных о валидации.

trainedGN = trainNetwork(imgsTrain,lgraph,options);

Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |        6.67% |       18.75% |       4.9207 |       2.4141 |      1.0000e-04 |
|       2 |          10 |       00:00:23 |       66.67% |       62.50% |       0.9589 |       1.3191 |      1.0000e-04 |
|       3 |          20 |       00:00:43 |       46.67% |       75.00% |       1.2973 |       0.5928 |      1.0000e-04 |
|       4 |          30 |       00:01:04 |       60.00% |       78.13% |       0.7219 |       0.4576 |      1.0000e-04 |
|       5 |          40 |       00:01:25 |       73.33% |       84.38% |       0.4750 |       0.3367 |      1.0000e-04 |
|       7 |          50 |       00:01:46 |       93.33% |       84.38% |       0.2714 |       0.2892 |      1.0000e-04 |
|       8 |          60 |       00:02:07 |       80.00% |       87.50% |       0.3617 |       0.2433 |      1.0000e-04 |
|       9 |          70 |       00:02:29 |       86.67% |       87.50% |       0.3246 |       0.2526 |      1.0000e-04 |
|      10 |          80 |       00:02:50 |      100.00% |       96.88% |       0.0701 |       0.1876 |      1.0000e-04 |
|      12 |          90 |       00:03:11 |       86.67% |      100.00% |       0.2836 |       0.1681 |      1.0000e-04 |
|      13 |         100 |       00:03:32 |       86.67% |       96.88% |       0.4160 |       0.1607 |      1.0000e-04 |
|      14 |         110 |       00:03:53 |       86.67% |       96.88% |       0.3237 |       0.1565 |      1.0000e-04 |
|      15 |         120 |       00:04:14 |       93.33% |       96.88% |       0.1646 |       0.1476 |      1.0000e-04 |
|      17 |         130 |       00:04:35 |      100.00% |       96.88% |       0.0551 |       0.1330 |      1.0000e-04 |
|      18 |         140 |       00:04:57 |       93.33% |       96.88% |       0.0927 |       0.1347 |      1.0000e-04 |
|      19 |         150 |       00:05:18 |       93.33% |       93.75% |       0.1666 |       0.1325 |      1.0000e-04 |
|      20 |         160 |       00:05:39 |       93.33% |       96.88% |       0.0873 |       0.1164 |      1.0000e-04 |
|======================================================================================================================|

Смотрите последний слой обучившего сеть. Подтвердите, что слой Classification Output включает эти три класса.

trainedGN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

Оцените точность GoogLeNet

Оцените сеть с помощью данных о валидации.

[YPred,probs] = classify(trainedGN,imgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp(['GoogLeNet Accuracy: ',num2str(100*accuracy),'%'])
GoogLeNet Accuracy: 96.875%

Точность идентична точности валидации, о которой сообщают относительно учебной фигуры визуализации. scalograms были разделены в наборы обучения и валидации. Оба набора использовались, чтобы обучить GoogLeNet. Идеальный способ оценить результат обучения состоит в том, чтобы иметь сеть, классифицируют данные, которые это не видело. С тех пор существует недостаточный объем данных, чтобы разделиться на обучение, валидацию и тестирование, мы обрабатываем вычисленную точность валидации как сетевую точность.

Исследуйте активации GoogLeNet

Каждый слой CNN производит ответ или активацию, к входному изображению. Однако существует только несколько слоев в CNN, которые подходят для извлечения признаков изображений. Слои в начале сетевого получения основные функции изображений, такие как ребра и блобы. Чтобы видеть это, визуализируйте сетевые веса фильтра из первого сверточного слоя. Существует 64 отдельных набора весов в первом слое.

wghts = trainedGN.Layers(2).Weights;
wghts = rescale(wghts);
wghts = imresize(wghts,5);
figure
montage(wghts)
title('First Convolutional Layer Weights')

Можно исследовать активации и обнаружить, какие функции GoogLeNet изучает путем сравнения областей активации с оригинальным изображением. Для получения дополнительной информации смотрите, Визуализируют Активации Сверточной нейронной сети (Deep Learning Toolbox) и Визуализируют Функции Сверточной нейронной сети (Deep Learning Toolbox).

Исследуйте, какие области в сверточных слоях активируются на изображении от ARR класс. Сравните с соответствующими областями в оригинальном изображении. Каждый слой сверточной нейронной сети состоит из названных каналов многих 2D массивов. Передайте изображение через сеть и исследуйте выходные активации первого сверточного слоя, 'conv1-7x7_s2'.

convLayer = 'conv1-7x7_s2';

imgClass = 'ARR';
imgName = 'ARR_10.jpg';
imarr = imread(fullfile(parentDir,dataDir,imgClass,imgName));

trainingFeaturesARR = activations(trainedGN,imarr,convLayer);
sz = size(trainingFeaturesARR);
trainingFeaturesARR = reshape(trainingFeaturesARR,[sz(1) sz(2) 1 sz(3)]);
figure
montage(rescale(trainingFeaturesARR),'Size',[8 8])
title([imgClass,' Activations'])

Найдите самый сильный канал для этого изображения. Сравните самый сильный канал с оригинальным изображением.

imgSize = size(imarr);
imgSize = imgSize(1:2);
[~,maxValueIndex] = max(max(max(trainingFeaturesARR)));
arrMax = trainingFeaturesARR(:,:,:,maxValueIndex);
arrMax = rescale(arrMax);
arrMax = imresize(arrMax,imgSize);
figure;
imshowpair(imarr,arrMax,'montage')
title(['Strongest ',imgClass,' Channel: ',num2str(maxValueIndex)])

SqueezeNet

SqueezeNet является глубоким CNN, архитектура которого поддерживает изображения размера 227 227 3. Даже при том, что размеры изображения отличаются для GoogLeNet, вы не должны генерировать новые изображения RGB в размерностях SqueezeNet. Можно использовать исходные изображения RGB.

Загрузка

Загрузите предварительно обученную нейронную сеть SqueezeNet. Если Модель Deep Learning Toolbox™ для пакета Сетевой поддержки SqueezeNet не установлена, программное обеспечение обеспечивает ссылку на необходимый пакет поддержки в Add-On Explorer. Чтобы установить пакет поддержки, щелкните по ссылке, и затем нажмите Install.

sqz = squeezenet;

Извлеките график слоев из сети. Подтвердите, что SqueezeNet имеет меньше слоев, чем GoogLeNet. Также подтвердите, что SqueezeNet сконфигурирован для изображений размера 227 227 3

lgraphSqz = layerGraph(sqz);
disp(['Number of Layers: ',num2str(numel(lgraphSqz.Layers))])
Number of Layers: 68
disp(lgraphSqz.Layers(1).InputSize)
   227   227     3

Измените параметры сети SqueezeNet

Чтобы переобучить SqueezeNet, чтобы классифицировать новые изображения, делайте изменения похожими на сделанных для GoogLeNet.

Смотрите последние шесть слоев сети.

lgraphSqz.Layers(end-5:end)
ans = 
  6x1 Layer array with layers:

     1   'drop9'                             Dropout                 50% dropout
     2   'conv10'                            Convolution             1000 1x1x512 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'                       ReLU                    ReLU
     4   'pool10'                            Average Pooling         14x14 average pooling with stride [1  1] and padding [0  0  0  0]
     5   'prob'                              Softmax                 softmax
     6   'ClassificationLayer_predictions'   Classification Output   crossentropyex with 'tench' and 999 other classes

Замените 'drop9' слой, последний слой уволенного в сети, со слоем уволенного вероятности 0.6.

tmpLayer = lgraphSqz.Layers(end-5);
newDropoutLayer = dropoutLayer(0.6,'Name','new_dropout');
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newDropoutLayer);

В отличие от GoogLeNet, последний learnable слой в SqueezeNet является сверточным слоем 1 на 1, 'conv10', и не полносвязный слой. Замените 'conv10' слой с новым сверточным слоем с количеством фильтров равняется количеству классов. Как был сделан с GoogLeNet, увеличьте факторы скорости обучения нового слоя.

numClasses = numel(categories(imgsTrain.Labels));
tmpLayer = lgraphSqz.Layers(end-4);
newLearnableLayer = convolution2dLayer(1,numClasses, ...
        'Name','new_conv', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newLearnableLayer);

Замените слой классификации на новый без меток класса.

tmpLayer = lgraphSqz.Layers(end);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newClassLayer);

Смотрите последние шесть слоев сети. Подтвердите уволенного, сверточного, и выведите слои, были изменены.

lgraphSqz.Layers(63:68)
ans = 
  6x1 Layer array with layers:

     1   'new_dropout'       Dropout                 60% dropout
     2   'new_conv'          Convolution             3 1x1 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'       ReLU                    ReLU
     4   'pool10'            Average Pooling         14x14 average pooling with stride [1  1] and padding [0  0  0  0]
     5   'prob'              Softmax                 softmax
     6   'new_classoutput'   Classification Output   crossentropyex

Подготовьте данные о RGB к SqueezeNet

Изображения RGB имеют размерности, подходящие для архитектуры GoogLeNet. Создайте увеличенные хранилища данных изображений, которые автоматически изменяют размер существующих изображений RGB для архитектуры SqueezeNet. Для получения дополнительной информации смотрите augmentedImageDatastore (Deep Learning Toolbox).

augimgsTrain = augmentedImageDatastore([227 227],imgsTrain);
augimgsValidation = augmentedImageDatastore([227 227],imgsValidation);

Установите опции обучения и обучите SqueezeNet

Создайте новый набор опций обучения, чтобы использовать с SqueezeNet. Установите случайный seed на значение по умолчанию и обучите сеть. Учебный процесс обычно занимает 1-5 минут на настольном центральном процессоре.

ilr = 3e-4;
miniBatchSize = 10;
maxEpochs = 15;
valFreq = floor(numel(augimgsTrain.Files)/miniBatchSize);
opts = trainingOptions('sgdm',...
    'MiniBatchSize',miniBatchSize,...
    'MaxEpochs',maxEpochs,...
    'InitialLearnRate',ilr,...
    'ValidationData',augimgsValidation,...
    'ValidationFrequency',valFreq,...
    'Verbose',1,...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');

rng default
trainedSN = trainNetwork(augimgsTrain,lgraphSqz,opts);

Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:01 |       20.00% |       43.75% |       5.2508 |       1.2540 |          0.0003 |
|       1 |          13 |       00:00:11 |       60.00% |       50.00% |       0.9912 |       1.0519 |          0.0003 |
|       2 |          26 |       00:00:20 |       60.00% |       59.38% |       0.8554 |       0.8497 |          0.0003 |
|       3 |          39 |       00:00:30 |       60.00% |       59.38% |       0.8120 |       0.8328 |          0.0003 |
|       4 |          50 |       00:00:38 |       50.00% |              |       0.7885 |              |          0.0003 |
|       4 |          52 |       00:00:40 |       60.00% |       65.63% |       0.7091 |       0.7314 |          0.0003 |
|       5 |          65 |       00:00:49 |       90.00% |       87.50% |       0.4639 |       0.5893 |          0.0003 |
|       6 |          78 |       00:00:59 |       70.00% |       87.50% |       0.6021 |       0.4355 |          0.0003 |
|       7 |          91 |       00:01:08 |       90.00% |       90.63% |       0.2307 |       0.2945 |          0.0003 |
|       8 |         100 |       00:01:15 |       90.00% |              |       0.1827 |              |          0.0003 |
|       8 |         104 |       00:01:18 |       90.00% |       93.75% |       0.2139 |       0.2153 |          0.0003 |
|       9 |         117 |       00:01:28 |      100.00% |       90.63% |       0.0521 |       0.1964 |          0.0003 |
|      10 |         130 |       00:01:38 |       90.00% |       90.63% |       0.1134 |       0.2214 |          0.0003 |
|      11 |         143 |       00:01:47 |      100.00% |       90.63% |       0.0855 |       0.2095 |          0.0003 |
|      12 |         150 |       00:01:52 |       90.00% |              |       0.2394 |              |          0.0003 |
|      12 |         156 |       00:01:57 |      100.00% |       90.63% |       0.0606 |       0.1849 |          0.0003 |
|      13 |         169 |       00:02:06 |      100.00% |       90.63% |       0.0090 |       0.2071 |          0.0003 |
|      14 |         182 |       00:02:16 |      100.00% |       93.75% |       0.0127 |       0.3597 |          0.0003 |
|      15 |         195 |       00:02:25 |      100.00% |       93.75% |       0.0016 |       0.3414 |          0.0003 |
|======================================================================================================================|

Смотрите последний слой сети. Подтвердите, что слой Classification Output включает эти три класса.

trainedSN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

Оцените точность SqueezeNet

Оцените сеть с помощью данных о валидации.

[YPred,probs] = classify(trainedSN,augimgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp(['SqueezeNet Accuracy: ',num2str(100*accuracy),'%'])
SqueezeNet Accuracy: 93.75%

Заключение

В этом примере показано, как использовать передачу обучения, и непрерывный анализ вейвлета, чтобы классифицировать три класса ECG сигнализирует путем усиления предварительно обученного CNNs GoogLeNet и SqueezeNet. Основанные на вейвлете представления частоты времени сигналов ECG используются, чтобы создать scalograms. Изображения RGB scalograms сгенерированы. Изображения используются, чтобы подстроить обоих глубоко CNNs. Активации различных слоев сети также исследовались.

Этот пример иллюстрирует один возможный рабочий процесс, который можно использовать для классификации сигналов с помощью предварительно обученных моделей CNN. Другие рабочие процессы возможны. Разверните Классификатор Сигнала на NVIDIA Джетсон Используя Анализ Вейвлета и Глубокое обучение и Разверните Классификатор Сигнала Используя Вейвлеты, и Глубокое обучение на Raspberry Pi показывают, как развернуть код на оборудование для классификации сигнала. GoogLeNet и SqueezeNet являются моделями, предварительно обученными на подмножестве базы данных ImageNet [10], который используется в ImageNet крупномасштабной визуальной проблеме распознавания (ILSVRC) [8]. Набор ImageNet содержит изображения реальных объектов, такие как рыба, птицы, устройства и грибы. Scalograms выходят за пределы класса реальных объектов. Для того, чтобы поместиться в архитектуру GoogLeNet и SqueezeNet, scalograms также подвергся снижению объема данных. Вместо того, чтобы подстроить предварительно обучил CNNs отличать различные классы scalograms, обучение, CNN с нуля в исходных scalogram размерностях является опцией.

Ссылки

  1. Baim, D. S. В. С. Колуччи, Э. С. Монрэд, Х. С. Смит, Р. Ф. Райт, А. Лэноу, Д. Ф. Готье, Б. Дж. Рэнсил, В. Гроссман и Э. Бронвалд. "Выживание пациентов с тяжелой застойной сердечной недостаточностью отнеслось с устным milrinone". Журнал американского Колледжа Кардиологии. Издание 7, Номер 3, 1986, стр 661–670.

  2. Engin, M. "ECG разбил классификацию с помощью нейронечеткой сети". Буквы Распознавания образов. Издание 25, Номер 15, 2004, pp.1715-1722.

  3. Голдбергер А. Л., Л. А. Н. Амарал, L. Стекло, Дж. М. Гаусдорф, P. Ch. Иванов, Р. Г. Марк, Дж. Э. Митус, Г. Б. Муди, C.-K. Пенг и Х. Э. Стэнли. "PhysioBank, PhysioToolkit и PhysioNet: Компоненты Нового Ресурса Исследования для Комплексных Физиологических Сигналов". Циркуляция. Издание 101, Номер 23: e215–e220. [Циркуляция Электронные Страницы; http://circ.ahajournals.org/content/101/23/e215.full]; 2000 (13 июня). doi: 10.1161/01. CIR.101.23.e215.

  4. Leonarduzzi, R. F. Г. Шлоттоер и М. Э. Торрес. "Вейвлет основанный на лидере мультифрактальный анализ изменчивости сердечного ритма во время миокардиальной ишемии". В Разработке в Обществе Медицины и Биологии (EMBC), Ежегодной Международной конференции IEEE, 110–113. Буэнос-Айрес, Аргентина: IEEE, 2010.

  5. Литий, T. и М. Чжоу. "Классификация ECG с помощью пакета вейвлета энтропийные и случайные леса". Энтропия. Издание 18, Номер 8, 2016, p.285.

  6. Махарадж, E. A., и утра Алонсо. "Дискриминантный анализ многомерных временных рядов: Приложение к диагнозу на основе сигналов ECG". Вычислительная Статистика и Анализ данных. Издание 70, 2014, стр 67–87.

  7. Капризный, G. B. и Р. Г. Марк. "Удар Базы данных Аритмии MIT-BIH". Разработка IEEE в Журнале Медицины и Биологии. Издание 20. Номер 3, мочь-июнь 2001, стр 45–50. (PMID: 11446209)

  8. Russakovsky, O., Цз. Дэн и Х. Су и др. "Крупный масштаб ImageNet Визуальная проблема Распознавания". Международный журнал Компьютерного зрения. Издание 115, Номер 3, 2015, стр 211–252.

  9. Чжао, Q. и Л. Чжан. "Извлечение признаков ECG и классификация с помощью вейвлета преобразовывают и машины опорных векторов". На Международной конференции IEEE по вопросам Нейронных сетей и Мозга, 1089–1092. Пекин, Китай: IEEE, 2005.

  10. ImageNet. http://www.image-net.org

Вспомогательные Функции

helperCreateECGDataDirectories создает директорию данных в родительском каталоге, затем создает три подкаталога в директории данных. Подкаталоги называют в честь каждого класса сигнала ECG, найденного в ECGData.

function helperCreateECGDirectories(ECGData,parentFolder,dataFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

rootFolder = parentFolder;
localFolder = dataFolder;
mkdir(fullfile(rootFolder,localFolder))

folderLabels = unique(ECGData.Labels);
for i = 1:numel(folderLabels)
    mkdir(fullfile(rootFolder,localFolder,char(folderLabels(i))));
end
end

helperPlotReps строит первую тысячу выборок представителя каждого класса сигнала ECG, найденного в ECGData.

function helperPlotReps(ECGData)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

folderLabels = unique(ECGData.Labels);

for k=1:3
    ecgType = folderLabels{k};
    ind = find(ismember(ECGData.Labels,ecgType));
    subplot(3,1,k)
    plot(ECGData.Data(ind(1),1:1000));
    grid on
    title(ecgType)
end
end

helperCreateRGBfromTF использует cwtfilterbank получить непрерывное преобразование вейвлета ECG сигнализирует и генерирует scalograms от коэффициентов вейвлета. Функция помощника изменяет размер scalograms и пишет им в диск как jpeg изображения.

function helperCreateRGBfromTF(ECGData,parentFolder,childFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

imageRoot = fullfile(parentFolder,childFolder);

data = ECGData.Data;
labels = ECGData.Labels;

[~,signalLength] = size(data);

fb = cwtfilterbank('SignalLength',signalLength,'VoicesPerOctave',12);
r = size(data,1);

for ii = 1:r
    cfs = abs(fb.wt(data(ii,:)));
    im = ind2rgb(im2uint8(rescale(cfs)),jet(128));
    
    imgLoc = fullfile(imageRoot,char(labels(ii)));
    imFileName = strcat(char(labels(ii)),'_',num2str(ii),'.jpg');
    imwrite(imresize(im,[224 224]),fullfile(imgLoc,imFileName));
end
end

Смотрите также

| | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Похожие темы