Генерация кода для глубокого обучения модель Simulink, чтобы классифицировать сигналы ECG

Этот пример демонстрирует, как можно использовать мощные методы обработки сигналов и Сверточные нейронные сети вместе, чтобы классифицировать сигналы ECG. Мы также продемонстрируем, как код CUDA® может быть сгенерирован из модели Simulink®. Этот пример использует предварительно обученную сеть CNN от Классифицировать Временных рядов Используя пример Анализа и Глубокого обучения Вейвлета Wavelet Toolbox™, чтобы классифицировать сигналы ECG на основе изображений от CWT данных временных рядов. Для получения информации об обучении смотрите, Классифицируют Временные ряды Используя Анализ Вейвлета и Глубокое обучение (Wavelet Toolbox).

Этот пример иллюстрирует следующие концепции:

  • Смоделируйте приложение классификации в Simulink путем предварительной обработки и выполнения преобразований вейвлета данных о ECG и Image Classifier блок из Deep Learning Toolbox™ для загрузки предварительно обученной сети и выполнения классификации данных о ECG.

  • Сконфигурируйте модель для генерации кода.

  • Сгенерируйте исполняемый файл CUDA для модели Simulink.

Сторонние необходимые условия

Проверьте среду графического процессора

Чтобы проверить, что компиляторы и библиотеки, необходимые для выполнения этого примера, настраиваются правильно, используйте coder.checkGpuInstall функция.

envCfg = coder.gpuEnvConfig('host');
envCfg.DeepLibTarget = 'cudnn';
envCfg.DeepCodegen = 1;
envCfg.Quiet = 1;
coder.checkGpuInstall(envCfg);

Описание данных о ECG

Этот пример использует данные о ECG из Физиосетевой базы данных. Это содержит данные из трех групп людей:

  1. Люди с сердечной аритмией (ARR)

  2. Люди с застойной сердечной недостаточностью (CHF)

  3. Люди с нормальными ритмами пазухи (NSR)

Это включает 96 записей от людей с ARR, 30 записей от людей со швейцарским франком и 36 записей от людей с NSR. ecg_signals MAT-файл содержит тестовые данные о ECG в формате временных рядов. Классификатор изображений в этом примере различает ARR, швейцарский франк и NSR.

Алгоритмический рабочий процесс

Блок-схему для алгоритмического рабочего процесса модели Simulink показывают.

Глубокое обучение ECG модель Simulink

Модель Simulink для классификации сигналов ECG показывают. Когда модель запускается, Video Viewer блок отображает классифицированный сигнал ECG.

open_system('ecg_dl_cwt');

Подсистема предварительной обработки ECG

ECG Preprocessing подсистема содержит MATLAB Function блокируйтесь, который выполняет CWT, чтобы получить scalogram сигнала ECG и затем обрабатывает scalogram, чтобы получить изображение и Image Classifier блокируйтесь, который загружает предварительно обученную сеть от trainedNet.mat и выполняет предсказание для классификации изображений на основе CNN глубокого обучения SqueezeNet.

open_system('ecg_dl_cwt/ECG Preprocessing');

ScalogramFromECG функциональный блок задает функцию под названием ecg_to_scalogram это:

  • Использование 65 536 выборок данных о ECG с двойной точностью, как введено.

  • Создайте представление частоты времени от данных о ECG путем применяния преобразования Вейвлета.

  • Получите scalogram из коэффициентов вейвлета.

  • Преобразуйте scalogram в изображение размера (227x227x3).

Функциональная подпись ecg_to_scalogram показан.

type ecg_to_scalogram
function ecg_image  = ecg_to_scalogram(ecg_signal)

% Copyright 2020 The MathWorks, Inc.

persistent jetdata;
if(isempty(jetdata))
    jetdata = colourmap(128,'single');
end
% Obtain wavelet coefficients from ECG signal
cfs = cwt_ecg(ecg_signal);  
% Obtain scalogram from wavelet coefficients
image = ind2rgb(im2uint8(rescale(cfs)),jetdata);
ecg_image = im2uint8(imresize(image,[227,227]));

end

Постобработка ECG

ECG Postprocessing Блок MATLAB function задает label_prob_image функция, которая находит метку для изображения scalogram на основе самого высокого счета от баллов выведенной классификатором изображений. Это выводит изображение scalogram с меткой и доверием, распечатанным на нем.

type label_prob_image
function final_image = label_prob_image(ecg_image, scores, labels)

% Copyright 2020 The MathWorks, Inc.

scores = double(scores);
% Obtain maximum confidence 
[prob,index] = max(scores);
confidence = prob*100;
% Obtain label corresponding to maximum confidence
label = erase(char(labels(index)),'_label');
text = cell(2,1);
text{1} = ['Classification: ' label];
text{2} = ['Confidence: ' sprintf('%0.2f',confidence) '%'];
position = [135 20 0 0; 130 40 0 0];
final_image = insertObjectAnnotation(ecg_image,'rectangle',position,text,'TextBoxOpacity',0.9,'FontSize',9);

end

Запустите симуляцию

Открытое диалоговое окно Configuration Parameters.

В Целевой панели Симуляции выберите ускорение GPU. В группе Глубокого обучения выберите целевую библиотеку как cuDNN.

Чтобы проверить алгоритм и отобразить метки и оценку достоверности тестового сигнала ECG, загруженного в рабочей области, запустите симуляцию.

set_param('ecg_dl_cwt', 'SimulationMode', 'Normal');
sim('ecg_dl_cwt');

Сгенерируйте и создайте модель Simulink

В панели Генерации кода выберите Language как C++ и включите, Генерируют код графического процессора.

Генерация Открытого кода> панель графического процессора Кода. В Библиотеках подкатегории включите cuBLAS, cuSOLVER и cuFFT.

Сгенерируйте и создайте модель Simulink на хосте графический процессор при помощи rtwbuild команда. Генератор кода помещает файлы в папку сборки, подпапку под названием ecg_dl_cwt_ert_rtw под вашей текущей рабочей папкой.

status = evalc("rtwbuild('ecg_dl_cwt')");

Сгенерированный код CUDA®

Подпапка под названием ecg_dl_cwt_ert_rtw содержит сгенерированные Коды С++, соответствующие различным блокам в модели Simulink и определенных операциях, выполняемых в тех блоках. Например, файл trainedNet0_ecg_dl_cwt0.h содержит класс C++, который содержит определенные атрибуты, такие как numLayers и функции членства, такие как getBatchSize(), predict(). Этот класс представляет предварительно обученный SqueezeNet который загрузился в модели Simulink.

Очистка

Закройте модель Simulink.

close_system('ecg_dl_cwt/ECG Preprocessing');
close_system('ecg_dl_cwt');
Для просмотра документации необходимо авторизоваться на сайте