В этом примере показано, как удалить электроокулограмму (EOG) шум от электроэнцефалограммы (EEG), сигналы с помощью EEGdenoiseNet тестируют набора данных в сравнении с эталоном [1] и регрессия глубокого обучения. Набор данных EEGdenoiseNet содержит 4 514 чистых сегментов EEG и 3 400 окулярных сегментов артефакта, которые могут использоваться, чтобы синтезировать шумные сегменты EEG с основной истиной, чистой EEG (набор данных также содержит мускульные сегменты артефакта, но они не будут использоваться в этом примере).
Этот пример использует чистые и EOG-загрязненные сигналы EEG обучить модель долгой краткосрочной памяти (LSTM) удалять артефакты EOG. Модель регрессии была обучена с необработанными входными сигналами и с сигналами, преобразованными кратковременным преобразованием Фурье (STFT). Модель STFT улучшает производительность особенно в ухудшенных значениях ОСШ.
Набор данных EEGdenoiseNet содержит 4 514 чистых сегментов EEG и 3 400 сегментов EOG, которые могут использоваться, чтобы сгенерировать три набора данных для обучения, проверки и тестирования модели глубокого обучения. Частота дискретизации всех сегментов сигнала составляет 256 Гц. Для удобства набор данных был загружен на это местоположение: https://ssd.mathworks.com/supportfiles/SPT/data/EEGEOGDenoisingData.zip
Загрузите набор данных с помощью downloadSupportFile
функция.
% Download the data datasetZipFile = matlab.internal.examples.downloadSupportFile('SPT','data/EEGEOGDenoisingData.zip'); datasetFolder = fullfile(fileparts(datasetZipFile),'EEG_EOG_Denoising_Dataset'); if ~exist(datasetFolder,'dir') unzip(datasetZipFile,fileparts(datasetZipFile)); end
После загрузки данных, местоположения в datasetFolder
содержит два файла MAT:
EEG_all_epochs.mat
содержит матрицу с 4 514 чистыми сегментами EEG длины 512 выборок
EOG_all_epochs.mat
содержит матрицу с 3 400 сегментами EOG длины 512 выборок
Используйте createDataset
функция помощника, чтобы сгенерировать обучение, валидацию, и тестирующий наборы данных. Функциональные объединения чистят EEG и сигналы EOG сгенерировать пары чистых и шумных сегментов EEG с различными отношениями сигнал-шум (SNR). Для любого EEG и пары EOG можно использовать следующее уравнение комбинации, чтобы получить шумный сегмент с данным ОСШ:
Вы варьируетесь параметр управлять степенью артефакта и достигнуть особого значения ОСШ.
Создать обучающий набор данных createDataset
комбинирует первые 2 720 пар сегментов EEG и EOG десять раз каждый со случайным SNRs в [-7, 2] интервал дБ для в общей сложности 27 200 учебных пар. Каждая учебная пара хранится в файле MAT в папке под названием train
. Каждый файл MAT включает:
Чистый сегмент EEG (сохраненный под переменной под названием EEG)
Сегмент EOG (сохраненный под переменной под названием EOG
)
Шумный сегмент EEG (сохраненный под переменной под названием noisyEEG
)
ОСШ шумного сегмента (сохраненный под переменной под названием SNR
)
Значение частоты дискретизации сегментов сигнала (сохраненный под переменной под названием Fs
)
Создать набор данных createDataset
валидации комбинирует следующие 340 пар сегментов EEG и EOG десять раз каждый со случайным SNRs в [–7, 2] интервал дБ для в общей сложности 3 400 сегментов валидации. Данные о валидации хранятся в файлах MAT в папке под названием
validate
. Каждый файл MAT содержит те же переменные как те описанные для набора обучающих данных.
Наконец, чтобы создать тестовый набор данных createDataset
комбинирует следующие 340 пар сегментов EEG и EOG десять раз каждый с детерминированными значениями ОСШ –7, –6, –5, –4, –3, –2, –1, 0, 1, и 2 дБ. Тестовые данные хранятся в файлах MAT в папке под названием test
. Тестовые файлы MAT с тем же значением ОСШ сгруппированы под общей подпапкой, чтобы облегчить анализировать эффективность шумоподавления обученной модели для данного ОСШ. Например, файлы с тестовыми сигналами с ОСШ-3 дБ хранятся в папке с именем data_SNR_-3
.
Вызовите createDataset
функция, чтобы создать набор данных (это может занять несколько секунд). Установите createDatasetFlag
ко лжи, если у вас уже есть набор данных в datasetFolder
и хочу пропустить этот шаг.
createDatasetFlag = true; if createDatasetFlag createDataset (datasetFolder); end
Сгенерированный набор данных является довольно большим (~430 Мбайт), таким образом, удобно использовать хранилища данных, чтобы получить доступ к данным, не имея необходимость читать все это целиком в память. Создайте хранилища данных сигнала, чтобы получить доступ к данным об обучении и валидации. Используйте SignalVariableNames
параметр, чтобы задать переменные, которые вы хотите считать из файлов MAT (в порядке вы хотите их чтение). Также задайте ReadOutputOrientation
когда "строка", чтобы гарантировать данные совместима с сетью LSTM.
ds_Train = signalDatastore(fullfile(datasetFolder,"train"),SignalVariableNames=["noisyEEG","EEG"],ReadOutputOrientation="row")
ds_Train = signalDatastore with properties: Files:{ ' .../supportfiles/SPT/data/EEG_EOG_Denoising_Dataset/train/data_1.mat'; ' .../supportfiles/SPT/data/EEG_EOG_Denoising_Dataset/train/data_10.mat'; ' .../supportfiles/SPT/data/EEG_EOG_Denoising_Dataset/train/data_100.mat' ... and 27197 more } Folders: {'/home/fboucher/Documents/MATLAB/Examples/R2021b/supportfiles/SPT/data/EEG_EOG_Denoising_Dataset/train'} AlternateFileSystemRoots: [0×0 string] ReadSize: 1 SignalVariableNames: ["noisyEEG" "EEG"] ReadOutputOrientation: "row"
ds_Validate = signalDatastore(fullfile(datasetFolder,"validate"),SignalVariableNames=["noisyEEG","EEG"],ReadOutputOrientation="row");
Считайте данные из первого учебного файла и постройте чистые и шумные сигналы EEG. Вызов, чтобы предварительно просмотреть или считать методы datastore уступает 1x2 массив ячеек с первым элементом, содержащим шумный сегмент EEG и второй элемент, содержащий чистый сегмент EEG.
data = preview(ds_Train)
data=1×2 cell array
{[293.5459 312.8255 158.2003 54.3755 -122.9328 -245.2081 -263.6249 -231.0821 -265.4603 -326.4666 -382.5909 -399.4084 -271.8957 -35.5523 110.8435 163.8853 276.0972 411.6580 484.9692 579.9671 704.4224 715.7357 624.8036 565.6525 510.3934 346.0724 90.4347 -155.6278 -310.0931 -338.3046 -310.9277 -343.7870 -422.5112 -462.0371 -499.4786 -578.8647 -588.3763 -497.5128 -492.0669 -614.5556 -661.3109 -579.5487 -480.8576 -331.8418 -131.9308 -64.3486 -92.1508 22.4674 179.0017 131.3366 21.0167 47.0048 72.9048 8.7219 -27.1450 -37.5581 -67.9045 -21.2165 68.1191 12.8107 -151.7520 -234.3676 -217.8296 -185.2244 -164.2496 -186.3218 -259.2529 -330.6853 -405.8061 -503.9750 -501.7492 -288.2392 16.4508 253.6420 466.4391 717.8048 904.1583 974.9127 1.0648e+03 1.2244e+03 1.3335e+03 1.3712e+03 1.4139e+03 1.4144e+03 1.3391e+03 1.3401e+03 1.5020e+03 1.6821e+03 1.7543e+03 1.7582e+03 1.7833e+03 1.8473e+03 1.8685e+03 1.7850e+03 1.6640e+03 1.5618e+03 1.4678e+03 1.4627e+03 1.5648e+03 1.5608e+03 1.4272e+03 1.4523e+03 1.6068e+03 1.5567e+03 1.3658e+03 1.3158e+03 1.3204e+03 1.2457e+03 1.2205e+03 1.2400e+03 1.1601e+03 1.0654e+03 1.0684e+03 1.0682e+03 1.0135e+03 995.4703 1.0462e+03 1.1054e+03 1.0883e+03 991.5936 946.6780 964.2951 876.7881 744.0202 754.4313 725.2564 462.8552 282.7478 422.9020 566.9241 473.6554 335.4952 296.2394 260.1373 181.3557 165.5355 299.4533 452.6924 393.9423 205.3760 191.1296 311.4351 300.4170 211.6092 214.6472 217.2602 164.5827 195.2764 297.2532 330.1215 322.6244 340.3149 298.3817 183.1748 170.3522 309.5276 374.3551 265.5430 226.5715 340.0638 315.7039 136.6732 165.8833 344.5902 248.3717 -27.9482 -56.8646 109.7881 162.1398 105.1915 88.9519 110.0767 69.5323 -58.8971 -116.2790 -0.3552 46.8582 -126.7452 -229.6075 -141.9099 -154.0355 -296.3941 -354.4280 -386.1300 -419.9606 -274.0139 -66.4493 -44.1554 -85.8933 -88.3121 -162.2650 -186.2820 -37.5754 38.3777 -60.2507 -153.7257 -286.3040 -518.1376 -640.6470 -661.4028 -784.5911 -908.6936 -953.8550 -1.1159e+03 -1.3089e+03 -1.2774e+03 -1.2213e+03 -1.3786e+03 -1.5198e+03 -1.4902e+03 -1.4564e+03 -1.4443e+03 -1.3975e+03 -1.4399e+03 -1.5792e+03 -1.6399e+03 -1.6302e+03 -1.6650e+03 -1.6882e+03 -1.6382e+03 -1.5771e+03 -1.5110e+03 -1.4109e+03 -1.3619e+03 -1.4156e+03 -1.4516e+03 -1.3837e+03 -1.2984e+03 -1.2583e+03 -1.1969e+03 -1.0818e+03 -949.8370 -785.9617 -593.8683 -484.6091 -493.6924 -486.5894 -376.9322 -254.3787 -246.5404 -358.0243 -457.8937 -455.3159 -407.6992 -323.9139 -143.8645 -47.9140 -238.7738 -424.6188 -264.3816 -33.7708 -63.5278 -127.1567 -26.2163 21.3820 -79.8131 -156.4246 -169.0789 -204.3369 -265.7965 -319.7542 -345.6645 -347.2635 -346.8830 -298.6344 -193.2218 -175.2466 -263.6097 -282.8454 -246.9951 -273.7662 -255.3390 -160.0054 -156.2782 -160.6513 -21.3900 25.6993 -149.2867 -243.8079 -170.5436 -154.2558 -159.1055 -70.4898 -55.6186 -184.9485 -282.5962 -283.6852 -229.9518 -144.0285 -136.0738 -251.5972 -323.3831 -270.5175 -208.4219 -197.8861 -240.1977 -352.8216 -455.6938 -466.6105 -475.8638 -562.5412 -620.9217 -541.7099 -397.5501 -349.9958 -409.6095 -420.5214 -369.7480 -409.6981 -468.8235 -365.6404 -287.5636 -484.4784 -723.7460 -684.6297 -471.0508 -334.9867 -365.6236 -522.4700 -638.5261 -585.0912 -480.4407 -440.5247 -381.5005 -282.3495 -236.0896 -240.9890 -269.0734 -319.2842 -320.4942 -287.2268 -363.4684 -514.1044 -569.9932 -541.1071 -496.5918 -411.5588 -337.3943 -348.2844 -369.4418 -337.1498 -291.6518 -251.9061 -228.0908 -212.9140 -159.0688 -167.3810 -345.9673 -435.5995 -223.3164 -35.1722 -90.2426 -77.3739 76.3905 48.6435 -101.3630 -96.3987 -72.6505 -149.6118 -83.4186 118.1417 134.9427 11.3522 52.0115 234.5030 340.0810 350.6371 371.3420 399.1974 318.5930 109.3049 -26.2028 45.9424 146.2610 156.2677 236.4964 355.1251 247.6928 17.5845 -13.5088 51.5868 -16.7073 -56.6772 55.5756 108.1590 52.7701 78.7909 165.4370 175.4213 124.4899 25.4918 -105.4091 -121.8664 -17.4514 34.9487 52.8456 115.3902 88.9050 -18.9950 21.7647 156.9900 163.9505 119.2607 136.0510 123.1712 92.3660 90.3993 3.7770 -101.5293 -30.4743 77.8997 85.4559 200.3163 381.0650 305.4995 144.0639 243.1847 344.7070 166.2973 30.2100 181.1394 362.4509 376.9039 315.4304 277.4177 286.2399 295.2399 250.6823 252.2333 380.9921 475.7625 430.3416 383.2322 386.2697 341.6953 283.4970 299.4474 331.3131 303.5586 236.8910 185.7528 198.4428 256.0914 265.2153 197.2732 106.0651 42.6344 86.7558 248.4670 328.6696 217.0654 115.8127 134.7095 123.4219 88.0805 150.4634 172.4213 51.2537 2.9601 92.3141 124.4305 141.9237 263.2068 290.5031 176.4156 219.6978 370.3300 299.7149 172.6634 282.4104 368.8589 218.3248 102.9427 103.7194 35.4401 13.2908 120.6106 84.7856 -83.2411 -58.2985 78.9139 75.5260 48.0121 82.4030 47.5345 25.0876 139.5917 236.7479 227.8253 186.5794 90.9769 -2.7247 27.1747 18.1979 -137.4836 -170.3707 -35.5876 -33.7408 -118.0647 -38.4254 85.5209 98.7000 111.0841 147.8175 155.6366 195.0901 233.1084 198.1136 180.9826 170.8508 42.2914 -65.1522]} {[184.5071 182.3164 41.0644 -55.5457 -155.6309 -221.9838 -282.9218 -354.7277 -437.7731 -487.9534 -520.1615 -506.2143 -364.6989 -163.5820 -50.7511 26.8912 194.8870 390.9182 517.5795 593.4634 612.2013 553.5140 510.8928 534.3316 489.5771 283.2516 34.5120 -127.4337 -241.1786 -363.7183 -434.6263 -427.8548 -443.9207 -516.2359 -550.6356 -534.8068 -542.8537 -555.2995 -537.8079 -556.3362 -615.4577 -604.2526 -484.4763 -320.7209 -172.4580 -73.5406 -1.8568 99.4586 184.0984 160.4643 74.4949 10.5636 -49.4240 -109.0210 -128.1498 -128.4013 -118.2926 -59.3559 -7.2351 -44.3925 -118.2087 -159.5074 -187.7227 -207.9291 -217.5568 -252.9733 -292.6216 -298.9978 -342.3380 -450.3554 -489.6292 -382.5939 -218.6430 -74.2266 52.2691 138.1272 139.1030 70.6933 1.0145 -42.4630 -103.4616 -185.9607 -225.8533 -252.2135 -363.8540 -489.5758 -475.7173 -386.3537 -378.1933 -409.8378 -372.8535 -295.8953 -228.4253 -173.7187 -188.4879 -285.0080 -298.9211 -127.3412 35.8319 -8.3398 -118.2478 -78.7464 21.3938 24.1924 11.0743 63.1065 85.5681 67.8093 107.6461 164.9735 164.7395 168.9144 203.8057 210.9855 219.1537 276.0164 339.8352 392.1866 433.9953 414.4877 353.6295 330.8243 318.6424 277.9611 254.5348 223.2527 137.1390 95.9530 146.3726 145.0219 39.1259 -47.4125 -62.2298 -56.4713 -42.3134 4.4354 81.0353 126.2602 81.0397 30.0799 91.9103 167.9870 122.0614 58.4736 77.8752 70.7421 -2.9413 -26.6743 13.9483 41.7856 60.6159 94.5616 115.1145 107.3929 102.8801 137.3549 181.8851 180.3480 195.3455 292.6388 324.4309 217.9083 212.5276 355.4434 318.5107 82.7695 22.4755 169.3209 225.7422 174.6523 180.9261 205.8876 153.3403 75.5621 69.4078 132.9424 141.7501 24.3821 -82.4083 -65.1453 -27.1313 -50.0239 -96.2128 -162.3494 -188.0827 -70.3356 61.6508 38.1663 -23.7790 -24.5535 -55.9957 -65.6618 39.2241 115.0479 112.8966 166.4775 181.8789 34.0374 -78.0041 -51.5329 -63.2166 -109.3926 -83.1151 -97.0703 -158.8065 -84.3664 44.6814 32.8693 0.5971 78.9700 142.9246 150.2819 184.4547 175.9413 58.7671 -35.6207 -25.9139 -8.4974 -49.1827 -108.2087 -141.7154 -134.3698 -96.2263 -64.7217 -60.1042 -57.6134 -32.5431 1.9650 28.5149 55.7130 85.1667 111.3595 160.2484 226.6730 227.6562 137.4926 40.9107 -18.8243 -61.1854 -76.2282 -83.4148 -133.8459 -187.3029 -179.6742 -120.5781 -15.8016 81.0734 43.3321 -51.1707 16.9843 152.7320 125.5967 59.9723 124.7820 155.0875 57.7959 0.3257 14.7888 2.8780 -0.1661 25.2573 1.9835 -45.9279 -48.0225 -20.6491 18.3378 34.3649 -44.5040 -146.3061 -117.5063 -0.3587 53.6563 31.0214 18.0257 65.7692 124.1761 83.0233 -24.2939 -43.9572 -13.6894 -65.5889 -98.3299 -22.0335 -18.6543 -161.4866 -259.3167 -249.0026 -226.8589 -190.8240 -142.0858 -153.4440 -180.1550 -130.3496 -45.2894 12.9008 41.2162 21.2666 -25.4712 -20.7545 26.6362 39.2410 19.3302 27.9147 77.1813 112.9529 86.1450 36.0395 27.5247 27.7840 1.0752 5.4876 18.5464 -87.9264 -262.8144 -283.4608 -118.9496 6.6664 -54.1971 -179.8054 -185.7806 -107.2994 -87.5112 -78.2404 25.9276 108.8198 57.6744 -12.1159 -32.8426 -82.5263 -132.0208 -117.6162 -118.9881 -175.0695 -216.7788 -229.4338 -199.6173 -103.7726 -54.1239 -114.6663 -136.1396 -77.4445 -92.9215 -176.1955 -185.7606 -127.9888 -90.1109 -124.5307 -231.0845 -268.6486 -154.2345 -92.5942 -180.1112 -189.5059 -101.0242 -131.3268 -182.7982 -100.3321 -102.8661 -248.7238 -223.4905 -38.4470 -4.9460 -60.2417 11.5309 94.4714 73.1159 70.3551 126.3687 161.1489 162.7253 124.9805 57.2503 45.2784 76.0252 58.4429 51.2018 101.1252 83.8566 16.0189 65.7301 143.5281 73.0663 7.3477 92.5118 134.1665 42.8936 27.1547 126.4113 168.2120 144.8497 128.5777 81.5150 30.5906 49.1867 68.8646 53.4856 80.5853 87.3083 -4.4480 -26.8726 89.1834 120.0574 29.3774 19.3590 46.0898 -38.7869 -106.3669 -58.0703 -42.3876 -84.2495 -49.9379 35.0590 82.8349 113.9988 121.8874 95.8216 112.4760 144.4695 71.4178 -20.7198 50.3754 217.5505 288.2074 248.4080 198.0217 161.5141 108.6792 60.6000 79.7520 160.4816 222.2459 238.6432 242.2156 217.8590 154.8832 115.3816 121.1145 118.1581 87.5098 51.0366 34.3849 69.6186 121.3643 106.6053 50.6188 16.5712 -29.5941 -59.3618 28.4726 133.0784 88.4984 3.6516 29.6105 71.9211 49.0177 28.2880 14.6080 -9.6634 13.2494 48.8916 21.2872 17.1130 98.5594 147.5359 125.7440 132.8005 160.0092 133.7969 88.0235 70.9290 66.9099 57.3967 16.1320 -59.2551 -94.5876 -69.0312 -68.5262 -107.6983 -123.3063 -125.9344 -127.8381 -78.5704 -20.5095 -45.0608 -77.1851 -2.5161 99.9790 109.9419 60.5306 23.1728 8.3185 18.5414 20.5966 -58.5031 -182.2469 -212.0151 -135.8594 -89.1736 -113.9629 -101.7947 5.2091 129.6184 191.2036 200.8505 201.9036 194.7194 161.0028 114.6600 83.6711 40.1643 -29.6090 -84.8615]}
plot([data{2}.' data{1}.'],LineWidth=2) legend('Clean EEG','EEG with EOG artifact') axis tight
Производительность сети регрессии обычно улучшается, если сигналы ввода и вывода нормированы. Можно преобразовать хранилища данных сигнала, чтобы применить нормализацию к каждому сигналу, когда это читается из диска. normalizeData helper
функция перечислена в конце этого примера. Это просто вычитает среднее значение сигнала и делит результат на стандартное отклонение сигнала.
ds_Train_T = transform(ds_Train,@normalizeData); ds_Validate_T = transform(ds_Validate,@normalizeData);
Обучите сеть к сигналам denoise путем передачи шумных сигналов EEG в сетевой вход и запроса желаемого EEG чистые сигналы основной истины при сетевом выходе. Архитектура длинной кратковременной памяти (LSTM) выбрана, потому что это способно к изучению функций от последовательностей времени.
Определить сетевую архитектуру: номер функций определяется одной, когда одна последовательность вводится к сети, и одна последовательность выводится от сети. Используйте слой уволенного, чтобы уменьшать сверхподбор кривой модели на обучающих данных. Используйте слой регрессии в качестве выходного слоя, поскольку модель обучается выполнить регрессию. Обратите внимание на то, что нормализация должна быть применена к сигналам ввода и вывода, таким образом, более удобно использовать преобразованные хранилища данных, чем использовать Normalization
опция sequenceInputLayer
то единственное нормирует входные параметры.
numFeatures = 1; numHiddenUnits = 100; layers = [ sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits) dropoutLayer(0.2) fullyConnectedLayer(numFeatures) regressionLayer];
Задайте параметры опции обучения: используйте оптимизатор Адама и примите решение переставить данные в каждую эпоху. Кроме того, задайте datastore валидации ds_Validate_T
как источник для данных о валидации.
maxEpochs = 5; miniBatchSize = 150; options = trainingOptions('adam', ... MaxEpochs=maxEpochs, ... MiniBatchSize=miniBatchSize, ... InitialLearnRate=0.005, ... GradientThreshold=1, ... Plots="training-progress", ... Shuffle="every-epoch", ... Verbose=false,... ValidationData=ds_Validate_T ,... ValidationFrequency=100, ... OutputNetwork="best-validation-loss");
Используйте trainNetwork
функция, чтобы обучить модель. Можно непосредственно передать преобразованный, обучают datastore в функцию, потому что datastore выводит 1x2 массив ячеек, с сигналами ввода и вывода, в каждом вызове read
метод.
Учебные шаги займут несколько минут. Можно пропустить эти шаги путем загрузки предварительно обученных сетей с помощью селектора ниже. Если вы хотите обучить сеть, когда пример запускается, выберите 'Train Networks
'. Если вы хотите пропустить учебные шаги, выберите 'Download Networks
'и файл MAT, содержащий две предварительно обученных сети-rawNet
, и stftNet-
будет загружен в вашу машину.
trainingFlag = "Train networks"; if trainingFlag == "Train networks" rawNet = trainNetwork (ds_Train_T, слои, опции); else % Download the pre-trained networks modelsZipFile = matlab.internal.examples.downloadSupportFile ('SPT','data/EEGEOGDenoisingNetworks.zip'); modelsFolder = fullfile (fileparts (datasetZipFile),'EEG_EOG_Denoising_Networks'); if ~exist (modelsFolder,'dir') разархивируйте (modelsZipFile, fileparts (modelsZipFile)); end modelsFile = fullfile (modelsFolder,'trainedNetworks.mat'); загрузите (modelsFile) end
Используйте тестовый набор данных, чтобы анализировать эффективность шумоподавления rawNet
сеть. Вспомните, что тестовый набор данных содержит несколько тестовых файлов для каждого значения ОСШ в [-7,-6,-5,-4,-3,-2,-1, 0, 1, 2] дБ. Показатель производительности выбран в качестве среднеквадратической ошибки (MSE) между чистым базовым сигналом EEG и denoised сигналом EEG. MSE чистого сигнала EEG и шумного сигнала EEG также вычисляется, чтобы показать худшему случаю MSE, когда никакое шумоподавление не применяется. В каждом ОСШ вычисляют 340 значений MSE для каждого из 340 доступных тестовых сегментов EEG и получают средний MSE.
Создайте signalDatastore
использовать тестовые данные и использовать преобразованный datastore, чтобы установить нормализацию данных. Поскольку данные являются теперь внутренними подпапками тестовой папки, задайте IncludeSubfolders
как верный. Далее, используйте функцию folders2labels, чтобы получить список имен папок для каждого файла в тестовом наборе данных так, чтобы можно было получить данные для каждого ОСШ.
ds_Test = signalDatastore(fullfile(datasetFolder,"test"),SignalVariableNames=["noisyEEG","EEG"],IncludeSubfolders=true,ReadOutputOrientation="row"); ds_Test_T = transform(ds_Test,@normalizeData); % Get labels that contain the SNR value for each file in the datastore labels = folders2labels(ds_Test)
labels = 3400×1 categorical
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
data_SNR_-1
⋮
Для каждого значения ОСШ, denoise тестовые сигналы и вычисляют среднее значение MSE. Используйте subset
функция datastore, чтобы получить datastore, указывающий на данные для каждого ОСШ. К denoise сигнал вызывают predict
функция, передающая обучивший сеть и зашумленные данные как входные параметры.
SNRs = (-7:2); MSE_Denoised_rawNet = zeros(numel(SNRs),1); % Measure denoising performance MSE_No_Denoise = zeros(numel(SNRs),1); % Measure worst-case MSE when no denoising is applied for idx = 1:numel(SNRs) lblIdx = find(labels == "data_SNR_"+num2str(SNRs(idx))); ds_Test_SNR = subset(ds_Test_T,lblIdx); % New datastore pointing to files with current SNR value % Denoise the data using the predict function of the trained model pred = predict(rawNet,ds_Test_SNR); % Use an array datastore to loop over the 340 denoised signals for the % current SNR value. Transform the datastore to add the normalization % step. ds_Pred = transform(arrayDatastore(pred,OutputType="same"),@normalizeData); mse = 0; mseWorstCase = 0; cnt = 0; while hasdata(ds_Pred) testData = read(ds_Test_SNR); denoisedData = read(ds_Pred); % MSE performance of denoiser - testData{2} contains clean EEG, % testData{1} contains noisy EEG. mse = mse + sum((testData{2} - denoisedData{1}).^2)/numel(denoisedData{1}); % Worst-case MSE performance when no denoising is applied. % Convert data to single precession as denoisedData is single % precision. mseWorstCase = mseWorstCase + sum((single(testData{2}) - single(testData{1})).^2)/numel(testData{1}); cnt = cnt+1; end % Average MSE of denoised signals MSE_Denoised_rawNet(idx) = mse/cnt; % Worst-case average MSE MSE_No_Denoise(idx) = mseWorstCase/cnt; end
Постройте средние результаты MSE.
plot(SNRs,[MSE_No_Denoise,MSE_Denoised_rawNet],LineWidth=2) xlabel("SNR") ylabel("Average MSE") title("Denoising Performance") legend("Worst-case scenario (no-denoising)","Denoising with rawNet model")
Общий подход, чтобы улучшать производительность модели глубокого обучения должен использовать извлеченные функции вместо исходных необработанных данных сигнала. Функции обеспечивают представление входных данных, которое облегчает для сети изучать самые важные аспекты сигналов.
Выберите кратковременное преобразование Фурье (STFT) с продолжительностью окна 64 выборок и продолжительность перекрытия 63 выборок. Это преобразование эффективно создаст 33 комплексных функции с продолжительностью 449 выборок каждый. Сети LSTM не поддерживают комплексные входные параметры, таким образом, комплексные функции могут быть разделены на действительные и мнимые компоненты путем укладки действительной части функций сверх мнимой части функций, чтобы дать к 66 действительным функциям каждая длина 449 выборок.
winLength = 64; overlapLength = 63;
transformSTFT
функция помощника, перечисленная в конце этого примера, нормирует входной сигнал и затем вычисляет его STFT. Функция складывает действительные и мнимые компоненты, чтобы создать действительную выходную матрицу. Далее, если графический процессор доступен, функция перемещает данные в графический процессор, чтобы ускорить расчеты STFT и смягчить увеличенную сложность вычисления преобразований. Если вы не хотите использовать графический процессор, установите useGPUFlag
к false
.
useGPUFlag = true;
Вычислите и постройте STFT пары чистых и шумных сигналов EEG с помощью transformSTFT helper function.
data = preview(ds_Train); P = transformSTFT(data,winLength,overlapLength,useGPUFlag); figure subplot(1,2,1) h = imagesc(P{2}); h.Parent.CLim = [-40 57]; title('STFT of clean EEG signal') ylabel("Stacked real and imaginary features") subplot(1,2,2) h = imagesc(P{1}); h.Parent.CLim = [-40 57]; ylabel("Stacked real and imaginary features") title('STFT of noisy EEG signal')
Идея состоит в том, чтобы обучить сеть так, чтобы она могла произвести denoised STFT представления сигнала на основе входных параметров STFT, соответствующих сигналам с шумом. Обратите внимание на то, что целевым результатом является сигнал denoised, не его denoised STFT представление, таким образом, последний шаг должен быть добавлен, чтобы вычислить инверсию STFT (ISTFT), чтобы восстановить сигнал denoised, как изображено на блок-схеме ниже.
Функция помощника, transformISTFT
, перечисленный в конце этого примера берет denoised сеть STFT выход, преобразует сложенные действительные и мнимые функции назад, чтобы объединить функции и вычисляет обратный STFT. Как последний шаг функция нормирует получившийся сигнал. Если графический процессор доступен и useGPUF
задержка верна, функция выполняет все расчеты в графическом процессоре, чтобы уменьшать время вычислений.
Создайте обучаются, валидация, и тестируют хранилища данных, чтобы применить STFT использование transformSTFT
функция.
ds_Train_STFT = transform(ds_Train,@(d,wl,ol,fl)transformSTFT(d,winLength,overlapLength,useGPUFlag)); ds_Validate_STFT = transform(ds_Validate,@(d,wl,ol,fl)transformSTFT(d,winLength,overlapLength,useGPUFlag)); ds_Test_STFT = transform(ds_Test,@(d,wl,ol,fl)transformSTFT(d,winLength,overlapLength,useGPUFlag));
Обновите сетевую архитектуру с учетом 66 функций ввода и вывода и задайте новые данные о валидации в опциях обучения. Любой сетевой параметр или опция неизменны.
numFeatures = winLength + 2; layers = [ sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits) dropoutLayer(0.2) fullyConnectedLayer(numFeatures) regressionLayer]; options.ValidationData = ds_Validate_STFT;
Обучите сеть если trainingFlag
"Train networks"
.
if trainingFlag == "Train networks" stftNet = trainNetwork(ds_Train_STFT,layers,options); end
Используйте обучивший сеть для denoise сигналов EEG с помощью тестовых данных. Вычислите средние значения MSE путем сравнения denoised и уберите базовые сигналы EEG.
MSE_Denoised_stftNet = zeros(numel(SNRs),1); % Measure denoising performance for idx = 1:numel(SNRs) lblIdx = find(labels == "data_SNR_"+num2str(SNRs(idx))); % New datastores pointing to files with current SNR value ds_Test_SNR = subset(ds_Test_T,lblIdx); % Raw EEG signals to compute MSE ds_Test_STFT_SNR = subset(ds_Test_STFT,lblIdx); % STFT transforms % Denoise the data using the predict function of the trained model. pred = predict(stftNet,ds_Test_STFT_SNR); % Use an array datastore to loop over the 340 denoised signals for the % current SNR value. Transform the datastore to compute the inverse % STFT and recover the actual denoised signal. ds_Pred = transform(arrayDatastore(pred,OutputType="same"),@(P,wl,ol)transformISTFT(P,winLength,overlapLength)); mse = 0; cnt = 0; while hasdata(ds_Pred) testData = read(ds_Test_SNR); denoisedData = read(ds_Pred); % MSE performance of denoiser - testData{2} contains clean EEG mse = mse + sum((testData{2} - denoisedData).^2)/numel(denoisedData); cnt = cnt+1; end % Average MSE of denoised signals MSE_Denoised_stftNet(idx) = mse/cnt; end
Постройте средний MSE, полученный без шумоподавления, шумоподавления с сетью, обученной с необработанными входными сигналами, и шумоподавление с сетью, обученной с STFT, преобразовало сигналы. Вы видите, что сложение шага STFT улучшало производительность особенно в более низких значениях ОСШ.
figure plot(SNRs,[MSE_No_Denoise,MSE_Denoised_rawNet,MSE_Denoised_stftNet],LineWidth=2) xlabel("SNR") ylabel("Average MSE") title("Denoising Performance") legend("Worst-case scenario (no denoising)","Denoising with rawNet model","Denoising with stftNet model")
Постройте шумный и сигналы denoised для различного SNRs. getRandomEEG
функция помощника, перечисленная в конце этого примера, получает случайный сигнал EEG с заданным ОСШ от тестового набора данных.
SNR = -7; % dB данные = getRandomEEG (datasetFolder, ОСШ); noisyEEG = normalizeData (данные {1}); cleanEEG = normalizeData (данные {2}); stftInput = transformSTFT (noisyEEG, winLength, overlapLength, useGPUFlag); denoisedEEG = transformISTFT (предсказывают (stftNet, stftInput), winLength, overlapLength); график ([cleanEEG'. denoisedEEG'. noisyEEG '.], LineWidth=2) заголовок"EEG denoising (SNR = " + ОСШ + " dB)") легенда"Clean EEG", "Denoised EEG","Noisy EEG") ось tight
В этом примере вы изучили, как обучить глубокую сеть, чтобы выполнить регрессию для шумоподавления сигнала. Вы сравнили две модели, один обученный с необработанными чистыми и шумными сигналами EEG, другой обученный с функциями, извлеченными с помощью кратковременного преобразования Фурье. Вы узнали, что можно использовать комплексные функции путем укладки их действительных и мнимых компонентов и обработки их как независимых действительных функций. Использование последовательностей STFT обеспечивает большее повышение производительности в худшем SNRs, и оба подхода сходятся в эффективности, когда ОСШ улучшается.
[1] Хаомин Чжан, Минци Чжао, Чэнь Вэй, Данте Мантини, Зэруи Ли, Цюаньин Лю. "Набор данных сравнительного теста для решений для глубокого обучения шумоподавления EEG". https://arxiv.org/abs/2009.11662
normalizeData
- эта функция нормирует входные сигналы путем вычитания среднего значения и деления на стандартное отклонение.
function y = normalizeData(x) % This function is only intended to support examples in the Signal % Processing Toolbox. It may be changed or removed in a future release. if iscell(x) y = cell(1,numel(x)); y{1} = (x{1}-mean(x{1}))/std(x{1}); if numel(x) == 2 y{2} = (x{2}-mean(x{2}))/std(x{2}); end else y = (x - mean(x))/std(x); end end
transformSTFT
- эта функция нормирует сигналы во входе data
и вычисляет их кратковременное преобразование Фурье. Это преобразует комплексные результаты STFT в действительную матрицу путем укладки действительных и мнимых компонентов один сверху другого.
function P = transformSTFT(data,winLength,overlapLength,useGPUFlag) % This function is only intended to support examples in the Signal % Processing Toolbox. It may be changed or removed in a future release. if ~iscell(data) data = {data}; end P = cell(1,numel(data)); x = data{1}; if useGPUFlag x = gpuArray(x); end x = normalizeData(x); y = stft(x,Window=rectwin(winLength),OverlapLength=overlapLength,FrequencyRange="onesided"); P{1} = [real(y);imag(y)]; if numel(data) == 2 x = data{2}; if useGPUFlag x = gpuArray(x); end x = normalizeData(x); y = stft(x,Window=rectwin(winLength),OverlapLength=overlapLength,FrequencyRange="onesided"); P{2} = [real(y);imag(y)]; end end
transformISTFT
- эта функция берет матрицу со сложенными действительными и мнимыми элементами STFT и комбинирует их назад к комплексной матрице STFT. Функция затем вычисляет обратный STFT, преобразовывают, и нормирует получившиеся восстановленные сигналы.
function data = transformISTFT(P,winLength,overlapLength) % This function is only intended to support examples in the Signal % Processing Toolbox. It may be changed or removed in a future release. PP = P{1}; NumRows = size(PP,1); X = PP(1:NumRows/2,:)+1i*PP(1+NumRows/2:end,:); data = istft(X,Window=rectwin(winLength),OverlapLength=overlapLength,ConjugateSymmetric=true,FrequencyRange="onesided").'; data = normalizeData(data); end
createDataset
- эта функция объединения чистит сегменты сигнала EEG с сегментами EOG, чтобы создать обучение, валидацию и тестирующий наборы данных, чтобы обучить EEG denoiser нейронная сеть.
function createDataset(dataDir) % This function is only intended to support examples in the Signal % Processing Toolbox. It may be changed or removed in a future release. % Create training, validation, and testing datasets consisting of clean EEG % signals and noisy EEG signals contaminated by EOG segments. load(fullfile(dataDir,"EEG_all_epochs.mat"),"EEG_all_epochs"); load(fullfile(dataDir,"EOG_all_epochs.mat"),"EOG_all_epochs"); EEG_all_epochs = EEG_all_epochs(1:3400,:).'; EOG_all_epochs = EOG_all_epochs.'; Fs = 256; trainingPercentage = 80; validationPercentage = 10; N = size(EEG_all_epochs,2); % Create a training dataset consisting of mat files containing two signals % - a clean EEG signal, and an EEG signal contaminated by EOG artifacts. % Combine each of 2720 pairs of EEG and EOG segments ten times with random % SNRs in the range -7dB to 2dB to obtain 27200 training segments. EEG_training = EEG_all_epochs(:,1:N*trainingPercentage/100); EOG_training = EOG_all_epochs(:,1:N*trainingPercentage/100); M = size(EEG_training,2); cnt = 0; if ~exist(fullfile(dataDir,"train"),'dir') mkdir(fullfile(dataDir,"train")) end for idx = 1:M for kk = 1:10 cnt = cnt + 1; EEG = EEG_training(:,idx).'; EOG = EOG_training(:,idx).'; [noisyEEG,SNR] = createNoisySegment(EEG,EOG,[-7,2]); save(fullfile(dataDir,"train","data_" + num2str(cnt) + ".mat"),"EEG","EOG","noisyEEG","Fs","SNR"); end end % Create a validation dataset by combining 340 pairs of EEG and EOG % segments ten times with random SNRs in (-7:2) dB EEG_validation = EEG_all_epochs(:,1+N*trainingPercentage/100:N*trainingPercentage/100+N*validationPercentage/100); EOG_validation = EOG_all_epochs(:,1+N*trainingPercentage/100:N*trainingPercentage/100+N*validationPercentage/100); M = size(EEG_validation,2); cnt = 0; if ~exist(fullfile(dataDir,"validate"),'dir') mkdir(fullfile(dataDir,"validate")) end for idx = 1:M for kk = 1:10 cnt = cnt + 1; EEG = EEG_validation(:,idx).'; EOG = EOG_validation(:,idx).'; [noisyEEG,SNR] = createNoisySegment(EEG,EOG,[-7,2]); save(fullfile(dataDir,"validate","data_" + num2str(cnt) + ".mat"),"EEG","EOG","noisyEEG","Fs","SNR"); end end % Create a test dataset by combining 340 pairs of EEG and EOG segments ten % times with 10 SNR values [-7 -6 -5 -4 -3 -2 -1 0 1 2] dB. Store the % training sets in folders with names that identify the SNR value so that % it is easy to analyze performance by accessing files with a specific SNR. EEG_test = EEG_all_epochs(:,1+N*trainingPercentage/100+N*validationPercentage/100:end); EOG_test = EOG_all_epochs(:,1+N*trainingPercentage/100+N*validationPercentage/100:end); M = size(EEG_test,2); SNRVect = (-7:2); for kk = 1:numel(SNRVect) cnt = 0; if ~exist(fullfile(dataDir,"test","data_SNR_" + num2str(SNRVect(kk))),'dir') mkdir(fullfile(dataDir,"test","data_SNR_" + num2str(SNRVect(kk)))); end for idx = 1:M cnt = cnt + 1; EEG = EEG_test(:,idx).'; EOG = EOG_test(:,idx).'; [noisyEEG,SNR] = createNoisySegment(EEG,EOG,SNRVect(kk)); save(fullfile(dataDir,"test","data_SNR_" + num2str(SNR)+"/" + "data_"+num2str(cnt) + ".mat"),"EEG","EOG","noisyEEG","Fs","SNR"); end end end function [y,SNROut] = createNoisySegment(eeg,artifact,SNR) % Combine EEG and artifact signals with a specified SNR in dB. If SNR is a % two-element vector, its value is chosen randomly from a uniform % distribution over the interval [SNR(1) SNR(2)] if numel(SNR) == 2 SNR = SNR(1) + (SNR(2)-SNR(1)).*rand(1,1); end k = 10^(SNR/10); lambda = (1/k)*rms(eeg)/rms(artifact); y = eeg + lambda * artifact; SNROut = SNR; end
getRandomEEG -
this function reads the data from a
случайный тестовый файл EEG с заданным ОСШ.
function data = getRandomEEG(datasetFolder,SNR) sds = signalDatastore(fullfile(datasetFolder,"test","data_SNR_"+num2str(SNR)),SignalVariableNames=["noisyEEG","EEG"],IncludeSubfolders=true); n = numel(sds.Files); idx = randi(n,1); data = read(subset(sds,idx)); end
folders2labels
(Signal Processing Toolbox) | signalDatastore
(Signal Processing Toolbox) | trainNetwork