提交 9c04cece 编写于 作者: H hypox64

add simple_test.py

上级 44c1ecd5
# candock # candock
这是一个用于记录毕业设计的日志仓库,其目的是尝试多种不同的深度神经网络结构(如LSTM,RESNET,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.我们相信这些代码同时可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究.<br> 这是一个用于记录毕业设计的日志仓库,其目的是尝试多种不同的深度神经网络结构(如LSTM,RESNET,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.我们相信这些代码同时可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究.<br>
## 如何运行
如果你需要运行这些代码(训练自己的模型或者使用预训练模型进行测试)请进入以下页面<br>
[How to run codes](https://github.com/HypoX64/candock/blob/master/how_to_run.md)<br>
## 数据集 ## 数据集
使用了三个睡眠数据集进行测试,分别是: [[CinC Challenge 2018]](https://physionet.org/physiobank/database/challenge/2018/#files) [[sleep-edf]](https://www.physionet.org/physiobank/database/sleep-edf/) [[sleep-edfx]](https://www.physionet.org/physiobank/database/sleep-edfx/) <br> 使用了三个睡眠数据集进行测试,分别是: [[CinC Challenge 2018]](https://physionet.org/physiobank/database/challenge/2018/#files) [[sleep-edf]](https://www.physionet.org/physiobank/database/sleep-edf/) [[sleep-edfx]](https://www.physionet.org/physiobank/database/sleep-edfx/) <br>
对于CinC Challenge 2018数据集,使用其C4-M1通道<br>对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道<br> 对于CinC Challenge 2018数据集,使用其C4-M1通道<br>对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道<br>
...@@ -25,6 +28,11 @@ ...@@ -25,6 +28,11 @@
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_Stage3.png) ![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_Stage3.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_REM.png)<br> ![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_REM.png)<br>
* multi_scale_resnet_1d 网络结构<br>
该网络参考[geekfeiw / Multi-Scale-1D-ResNet](https://github.com/geekfeiw/Multi-Scale-1D-ResNet)<br>
修改后的网络结构如图:<br>
![image](https://github.com/HypoX64/candock/blob/master/image/multi_scale_resnet_1d_network.png)<br>
* 关于交叉验证<br> * 关于交叉验证<br>
为了便于与其他文献中的方法便于比较,使用了两种交叉验证方法<br> 为了便于与其他文献中的方法便于比较,使用了两种交叉验证方法<br>
1.对于同一数据集,采用5倍K-fold交叉验证<br> 1.对于同一数据集,采用5倍K-fold交叉验证<br>
...@@ -48,9 +56,10 @@ ...@@ -48,9 +56,10 @@
| Network | Label average recall | Label average accuracy | error rate | | Network | Label average recall | Label average accuracy | error rate |
| :----------------------- | :------------------- | ---------------------- | ---------- | | :----------------------- | :------------------- | ---------------------- | ---------- |
| lstm | | | | | lstm | | | |
| resnet18_1d | | | | | resnet18_1d | 0.8263 | 0.9601 | 0.0997 |
| DFCNN+resnet18 | 0.8261 | 0.9594 | 0.1016 | | DFCNN+resnet18 | 0.8261 | 0.9594 | 0.1016 |
| DFCNN+multi_scale_resnet | 0.8196 | 0.9631 | 0.0922 | | DFCNN+multi_scale_resnet | 0.8196 | 0.9631 | 0.0922 |
| multi_scale_resnet_1d | 0.8400 | 0.9595 | 0.1013 |
* sleep-edfx(only sleep time)<br> * sleep-edfx(only sleep time)<br>
...@@ -78,3 +87,4 @@ ...@@ -78,3 +87,4 @@
* 2019/04/04 需要增加k-fold+受试者交叉验证才够严谨... * 2019/04/04 需要增加k-fold+受试者交叉验证才够严谨...
* 2019/04/05 清明节…看文献,还是按照大部分人的做法来做吧,使用5倍K-fold和数据集间的交叉验证,这样方便与其他人的方法做横向比较. 不行,这里要吐槽一下,别人做k-fold完全是因为数据集太小了…这上百Gb的数据做K-fold…真的是多此一举,结果根本不会有什么差别…完全是浪费计算资源… * 2019/04/05 清明节…看文献,还是按照大部分人的做法来做吧,使用5倍K-fold和数据集间的交叉验证,这样方便与其他人的方法做横向比较. 不行,这里要吐槽一下,别人做k-fold完全是因为数据集太小了…这上百Gb的数据做K-fold…真的是多此一举,结果根本不会有什么差别…完全是浪费计算资源…
* 2019/04/09 回老家了,啊!!!!我的毕业论文啊。。。。写不完了! * 2019/04/09 回老家了,啊!!!!我的毕业论文啊。。。。写不完了!
* 2019/04/13 回学校撸论文了,回去的几天莫名奇妙多了两个star...额,这也行
\ No newline at end of file
...@@ -6,6 +6,7 @@ import time ...@@ -6,6 +6,7 @@ import time
import torch import torch
import random import random
import dsp import dsp
import transformer
# import pyedflib # import pyedflib
import mne import mne
...@@ -41,7 +42,7 @@ def trimdata(data,num): ...@@ -41,7 +42,7 @@ def trimdata(data,num):
def reducesample(data,mult): def reducesample(data,mult):
return data[::mult] return data[::mult]
def loaddata(dirpath,signal_name,BID = 'median',filter = True): def loaddata(dirpath,signal_name,BID,filter = True):
#load #load
signals = loadsignals(dirpath,signal_name) signals = loadsignals(dirpath,signal_name)
if filter: if filter:
...@@ -52,9 +53,11 @@ def loaddata(dirpath,signal_name,BID = 'median',filter = True): ...@@ -52,9 +53,11 @@ def loaddata(dirpath,signal_name,BID = 'median',filter = True):
stages = reducesample(stages,2) stages = reducesample(stages,2)
#Balance individualized differences #Balance individualized differences
if BID == 'median': if BID == 'median':
signals = (signals*8/(np.median(abs(signals)))).astype(np.int16) signals = (signals*10/(np.median(abs(signals))))
elif BID == 'std': elif BID == '5_95_th':
signals = (signals*55/(np.std(signals))).astype(np.int16) tmp = np.sort(signals.reshape(-1))
th_5 = tmp[int(0.05*len(tmp))]
signals=transformer.Normalize(signals,1000,0,th_5)
#trim #trim
signals = trimdata(signals,3000) signals = trimdata(signals,3000)
stages = trimdata(stages,3000) stages = trimdata(stages,3000)
...@@ -70,9 +73,9 @@ def loaddata(dirpath,signal_name,BID = 'median',filter = True): ...@@ -70,9 +73,9 @@ def loaddata(dirpath,signal_name,BID = 'median',filter = True):
stages = np.delete(stages,i-cnt,axis =0) stages = np.delete(stages,i-cnt,axis =0)
cnt += 1 cnt += 1
# print(stages.shape,signals.shape) # print(stages.shape,signals.shape)
return signals,stages return signals.astype(np.float16),stages.astype(np.int16)
def loaddata_sleep_edf(opt,filedir,filenum,signal_name,BID = 'median',filter = True): def loaddata_sleep_edf(opt,filedir,filenum,signal_name,BID):
filenames = os.listdir(filedir) filenames = os.listdir(filedir)
for filename in filenames: for filename in filenames:
if str(filenum) in filename and 'Hypnogram' in filename: if str(filenum) in filename and 'Hypnogram' in filename:
...@@ -105,8 +108,6 @@ def loaddata_sleep_edf(opt,filedir,filenum,signal_name,BID = 'median',filter = T ...@@ -105,8 +108,6 @@ def loaddata_sleep_edf(opt,filedir,filenum,signal_name,BID = 'median',filter = T
signals.append(eeg[events[i][0]:events[i][0]+3000]) signals.append(eeg[events[i][0]:events[i][0]+3000])
stages=np.array(stages) stages=np.array(stages)
signals=np.array(signals) signals=np.array(signals)
if BID == 'median':
signals = signals*13/np.median(np.abs(signals))
# #select sleep time # #select sleep time
if opt.select_sleep_time: if opt.select_sleep_time:
...@@ -123,10 +124,17 @@ def loaddata_sleep_edf(opt,filedir,filenum,signal_name,BID = 'median',filter = T ...@@ -123,10 +124,17 @@ def loaddata_sleep_edf(opt,filedir,filenum,signal_name,BID = 'median',filter = T
cnt += 1 cnt += 1
print('shape:',signals.shape,stages.shape) print('shape:',signals.shape,stages.shape)
return signals.astype(np.int16),stages.astype(np.int16) if BID == 'median':
signals = signals*10/np.median(np.abs(signals))
elif BID == '5_95_th':
tmp = np.sort(signals.reshape(-1))
th_5 = tmp[int(0.05*len(tmp))]
signals=transformer.Normalize(signals,1000,0,th_5)
return signals.astype(np.float16),stages.astype(np.int16)
def loaddataset(opt,filedir,dataset_name = 'CinC_Challenge_2018',signal_name = 'C4-M1',num = 100 ,BID = 'median',shuffle = True): def loaddataset(opt,filedir,dataset_name = 'CinC_Challenge_2018',signal_name = 'C4-M1',num = 100 ,BID = 'median' ,shuffle = True):
print('load dataset, please wait...') print('load dataset, please wait...')
filenames = os.listdir(filedir) filenames = os.listdir(filedir)
...@@ -140,7 +148,7 @@ def loaddataset(opt,filedir,dataset_name = 'CinC_Challenge_2018',signal_name = ' ...@@ -140,7 +148,7 @@ def loaddataset(opt,filedir,dataset_name = 'CinC_Challenge_2018',signal_name = '
for i,filename in enumerate(filenames[:num],0): for i,filename in enumerate(filenames[:num],0):
try: try:
signal,stage = loaddata(os.path.join(filedir,filename),signal_name,BID = None) signal,stage = loaddata(os.path.join(filedir,filename),signal_name,BID = BID)
if i == 0: if i == 0:
signals =signal.copy() signals =signal.copy()
stages =stage.copy() stages =stage.copy()
...@@ -159,7 +167,7 @@ def loaddataset(opt,filedir,dataset_name = 'CinC_Challenge_2018',signal_name = ' ...@@ -159,7 +167,7 @@ def loaddataset(opt,filedir,dataset_name = 'CinC_Challenge_2018',signal_name = '
cnt = 0 cnt = 0
for filename in filenames: for filename in filenames:
if 'PSG' in filename: if 'PSG' in filename:
signal,stage = loaddata_sleep_edf(opt,filedir,filename[2:6],signal_name = signal_name) signal,stage = loaddata_sleep_edf(opt,filedir,filename[2:6],signal_name,BID)
if cnt == 0: if cnt == 0:
signals =signal.copy() signals =signal.copy()
stages =stage.copy() stages =stage.copy()
......
## Prerequisites
- Linux, Windows,mac
- CPU or NVIDIA GPU + CUDA CuDNN
- Python 3.5+
- Pytroch 1.0+
## Dependencies
This code depends on torchvision, numpy, scipy, h5py, matplotlib, mne = 18.0, opencv-python, requests, hashlib, memory_profiler, available via pip install.<br>
For example:<br>
```bash
pip3 install matplotlib
```
But for mne, you may run:<br>
```bash
pip3 install -U https://api.github.com/repos/mne-tools/mne-python/zipball/master
```
## Getting Started
### Clone this repo:
```bash
git clone https://github.com/HypoX64/candock
cd candock
```
### Train
* download datasets
```bash
python3 download_dataset.py
```
* choose your options and run
```bash
python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edf --signal_name 'EEG Fpz-Cz' --sample_num 8 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 50 --lr 0.0005 --select_sleep_time
```
* Notes<br>
If want to use cpu to train, please use --no_cuda
### Simple Test
```bash
python3 simple_test.py --pretrained --no_cuda
```
\ No newline at end of file
...@@ -2,9 +2,9 @@ Confusion matrix: ...@@ -2,9 +2,9 @@ Confusion matrix:
Pred Pred
S1 S2 S3 REM WAKE S3 S2 S1 REM WAKE
S2 S2
True S3 True S1
REM REM
WAKE WAKE
...@@ -12,6 +12,17 @@ True S3 ...@@ -12,6 +12,17 @@ True S3
5-Fold Cross-Validation Results 5-Fold Cross-Validation Results
-------------------sleep-edf------------------- -------------------sleep-edf-------------------
resnet18_1d
final: test avg_recall:0.8263 avg_acc:0.9601 error:0.0997
confusion_mat:
[[1133 148 7 3 0]
[ 296 2968 120 195 3]
[ 3 60 368 149 18]
[ 3 99 135 1342 18]
[ 9 7 190 37 7729]]
statistics of dataset [S3 S2 S1 R W]:
[1291 3582 598 1597 7972]
DFCNN+resnet18 DFCNN+resnet18
final: test avg_recall:0.8261 avg_acc:0.9594 error:0.1016 final: test avg_recall:0.8261 avg_acc:0.9594 error:0.1016
confusion_mat: confusion_mat:
...@@ -23,6 +34,16 @@ confusion_mat: ...@@ -23,6 +34,16 @@ confusion_mat:
statistics of dataset [S3 S2 S1 R W]: statistics of dataset [S3 S2 S1 R W]:
[ 268 725 106 314 1611] [ 268 725 106 314 1611]
multi_scale_resnet_1d
confusion_mat:
final: test avg_recall:0.84 avg_acc:0.9595 error:0.1013
[[1159 114 10 3 1]
[ 335 2961 153 143 5]
[ 4 45 436 102 13]
[ 0 100 241 1244 9]
[ 2 2 214 27 7717]]
statistics of dataset [S3 S2 S1 R W]:
[1287 3597 600 1594 7962]
-------------------sleep-edfx(only sleep time)------------------- -------------------sleep-edfx(only sleep time)-------------------
......
...@@ -32,12 +32,14 @@ class Route(nn.Module): ...@@ -32,12 +32,14 @@ class Route(nn.Module):
self.block1 = ResidualBlock(64, 64, kernel_size, stride=1) self.block1 = ResidualBlock(64, 64, kernel_size, stride=1)
self.block2 = ResidualBlock(64, 128, kernel_size) self.block2 = ResidualBlock(64, 128, kernel_size)
self.block3 = ResidualBlock(128, 256, kernel_size) self.block3 = ResidualBlock(128, 256, kernel_size)
self.block4 = ResidualBlock(256, 512, kernel_size)
self.avgpool = nn.AdaptiveAvgPool1d(1) self.avgpool = nn.AdaptiveAvgPool1d(1)
def forward(self, x): def forward(self, x):
x = self.block1(x) x = self.block1(x)
x = self.block2(x) x = self.block2(x)
x = self.block3(x) x = self.block3(x)
x = self.block4(x)
x = self.avgpool(x) x = self.avgpool(x)
return x return x
...@@ -45,7 +47,7 @@ class Multi_Scale_ResNet(nn.Module): ...@@ -45,7 +47,7 @@ class Multi_Scale_ResNet(nn.Module):
def __init__(self, inchannel, num_classes): def __init__(self, inchannel, num_classes):
super(Multi_Scale_ResNet, self).__init__() super(Multi_Scale_ResNet, self).__init__()
self.pre_conv = nn.Sequential( self.pre_conv = nn.Sequential(
nn.Conv1d(inchannel, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.Conv1d(inchannel, 64, kernel_size=15, stride=2, padding=7, bias=False),
nn.BatchNorm1d(64), nn.BatchNorm1d(64),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=3, stride=2, padding=1) nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
...@@ -53,7 +55,7 @@ class Multi_Scale_ResNet(nn.Module): ...@@ -53,7 +55,7 @@ class Multi_Scale_ResNet(nn.Module):
self.Route1 = Route(3) self.Route1 = Route(3)
self.Route2 = Route(5) self.Route2 = Route(5)
self.Route3 = Route(7) self.Route3 = Route(7)
self.fc = nn.Linear(256*3, num_classes) self.fc = nn.Linear(512*3, num_classes)
def forward(self, x): def forward(self, x):
x = self.pre_conv(x) x = self.pre_conv(x)
......
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
import os import os
import numpy as np import numpy as np
import torch import torch
#python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edf --signal_name 'EEG Fpz-Cz' --sample_num 8 --model_name multi_scale_resnet_1d --batchsize 32 --network_save_freq 100 --epochs 40 --lr 0.0005
#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 200 --model_name resnet18 --batchsize 32 --epochs 10 --fold_num 5 --pretrained #python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 200 --model_name resnet18 --batchsize 32 --epochs 10 --fold_num 5 --pretrained
#python3 train_new.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 10 --model_name LSTM --batchsize 32 --network_save_freq 100 --epochs 10 #python3 train_new.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 10 --model_name LSTM --batchsize 32 --network_save_freq 100 --epochs 10
#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 10 --model_name resnet18 --batchsize 32 #python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 10 --model_name resnet18 --batchsize 32
...@@ -19,6 +20,8 @@ class Options(): ...@@ -19,6 +20,8 @@ class Options():
self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn') self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn')
self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models') self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models')
self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate') self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate')
self.parser.add_argument('--Cross_Validation', type=str, default='k_fold',help='k-fold | subject')
self.parser.add_argument('--BID', type=str, default='None',help='Balance individualized differences 5_95_th | median |None')
self.parser.add_argument('--fold_num', type=int, default=5,help='k-fold') self.parser.add_argument('--fold_num', type=int, default=5,help='k-fold')
self.parser.add_argument('--batchsize', type=int, default=16,help='batchsize') self.parser.add_argument('--batchsize', type=int, default=16,help='batchsize')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/', self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/',
...@@ -27,7 +30,7 @@ class Options(): ...@@ -27,7 +30,7 @@ class Options():
self.parser.add_argument('--select_sleep_time', action='store_true', help='if input, for sleep-cassette only use sleep time to train') self.parser.add_argument('--select_sleep_time', action='store_true', help='if input, for sleep-cassette only use sleep time to train')
self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1|EEG Fpz-Cz') self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1|EEG Fpz-Cz')
self.parser.add_argument('--sample_num', type=int, default=20,help='the amount you want to load') self.parser.add_argument('--sample_num', type=int, default=20,help='the amount you want to load')
self.parser.add_argument('--model_name', type=str, default='resnet18',help='Choose model') self.parser.add_argument('--model_name', type=str, default='lstm',help='Choose model')
self.parser.add_argument('--epochs', type=int, default=50,help='end epoch') self.parser.add_argument('--epochs', type=int, default=50,help='end epoch')
self.parser.add_argument('--weight_mod', type=str, default='avg_best',help='Choose weight mode: avg_best|normal') self.parser.add_argument('--weight_mod', type=str, default='avg_best',help='Choose weight mode: avg_best|normal')
self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network') self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network')
...@@ -43,6 +46,10 @@ class Options(): ...@@ -43,6 +46,10 @@ class Options():
if self.opt.dataset_name == 'sleep-edf': if self.opt.dataset_name == 'sleep-edf':
self.opt.sample_num = 8 self.opt.sample_num = 8
if self.opt.no_cuda:
self.opt.no_cudnn = True
if self.opt.fold_num == 0:
self.opt.fold_num = 1
# if self.opt.weight_mod == 'normal': # if self.opt.weight_mod == 'normal':
......
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import util
import transformer
import dataloader
from options import Options
from creatnet import CreatNet
'''
@hypox64
19/04/13
'''
opt = Options().getparse()
net=CreatNet(opt.model_name)
if not opt.no_cuda:
net.cuda()
if not opt.no_cudnn:
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
if opt.pretrained:
net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.model_name+'.pth'))
# N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4
stage_map={0:'stage3',1:'stage2',2:'stage3',3:'REM',4:'Wake'}
def runmodel(eegdata):
eegdata = eegdata.reshape(1,-1)
eegdata = transformer.ToInputShape(eegdata,opt.model_name,test_flag =True)
eegdata = transformer.ToTensor(eegdata,no_cuda =opt.no_cuda)
with torch.no_grad():
out = net(eegdata)
pred = torch.max(out, 1)[1]
pred_stage=pred.data.cpu().numpy()
return pred_stage[0]
'''
you can change your input data here.
but the data needs meet the following conditions:
1.record for 1 epoch(30s)
2.fs = 100Hz
3.uv
'''
eegdata = np.load('./datasets/simple_test_data.npy')
print('the shape of eegdata:',eegdata.shape)
stage = runmodel(eegdata)
print('the sleep stage of this signal is:',stage_map[stage])
plt.plot(eegdata)
plt.show()
...@@ -8,26 +8,46 @@ def stage(stages): ...@@ -8,26 +8,46 @@ def stage(stages):
for i in range(len(stages)): for i in range(len(stages)):
stage_cnt[stages[i]] += 1 stage_cnt[stages[i]] += 1
stage_cnt_per = stage_cnt/len(stages) stage_cnt_per = stage_cnt/len(stages)
util.writelog('statistics of dataset [S3 S2 S1 R W]: '+str(stage_cnt)) util.writelog('statistics of dataset [S3 S2 S1 R W]: '+str(stage_cnt),True)
print('statistics of dataset [S3 S2 S1 R W]:\n',stage_cnt,'\n',stage_cnt_per)
return stage_cnt,stage_cnt_per return stage_cnt,stage_cnt_per
def result(mat): def Kappa(mat):
mat=mat/10000 # avoid overflow
mat_length=np.sum(mat)
wide=mat.shape[0]
po=0.0;pe=0.0
for i in range(wide):
po=po+mat[i][i]
pe=pe+np.sum(mat[:,i])*np.sum(mat[i,:])
po=po/mat_length
pe=pe/(mat_length*mat_length)
k=(po-pe)/(1-pe)
return k
def result(mat,print_sub=False):
wide=mat.shape[0] wide=mat.shape[0]
sub_acc = np.zeros(wide) sub_acc = np.zeros(wide)
sub_recall = np.zeros(wide) sub_recall = np.zeros(wide)
sub_sp = np.zeros(wide)
err = 0 err = 0
for i in range(wide): for i in range(wide):
if np.sum(mat[i]) == 0 : TP = mat[i,i]
sub_recall[i] = 0 FN = np.sum(mat[i])- mat[i,i]
else: TN = (np.sum(mat)-np.sum(mat[i])-np.sum(mat[:,i])+mat[i,i])
sub_recall[i]=mat[i,i]/np.sum(mat[i]) FP = np.sum(mat[:,i]) - mat[i,i]
err += mat[i,i] err += mat[i,i]
sub_acc[i] = (np.sum(mat)-((np.sum(mat[i])+np.sum(mat[:,i]))-2*mat[i,i]))/np.sum(mat) sub_acc[i]=(TP+TN)/(TP+FN+TN+FP)
sub_recall[i]=(TP)/np.clip((TP+FN), 1e-5, 1e10)
sub_sp[i] = TN/np.clip((TN+FP), 1e-5, 1e10)
if print_sub == True:
print('sub_recall:',sub_recall,'\nsub_acc:',sub_acc,'\nsub_sp:',sub_sp)
avg_recall = np.mean(sub_recall) avg_recall = np.mean(sub_recall)
avg_acc = np.mean(sub_acc) avg_acc = np.mean(sub_acc)
avg_sp = np.mean(sub_sp)
err = 1-err/np.sum(mat) err = 1-err/np.sum(mat)
return avg_recall,avg_acc,err k = Kappa(mat)
return round(avg_recall,4),round(avg_acc,4),round(avg_sp,4),round(err,4),round(k, 4)
def stagefrommat(mat): def stagefrommat(mat):
wide=mat.shape[0] wide=mat.shape[0]
...@@ -36,34 +56,35 @@ def stagefrommat(mat): ...@@ -36,34 +56,35 @@ def stagefrommat(mat):
stage_num[i]=np.sum(mat[i]) stage_num[i]=np.sum(mat[i])
util.writelog('statistics of dataset [S3 S2 S1 R W]:\n'+str(stage_num),True) util.writelog('statistics of dataset [S3 S2 S1 R W]:\n'+str(stage_num),True)
def show(plot_result,epoch): def show(plot_result,epoch):
train_recall = np.array(plot_result['train']) train = np.array(plot_result['train'])
test_recall = np.array(plot_result['test']) test = np.array(plot_result['test'])
plt.figure('running recall') plt.figure('running recall')
plt.clf() plt.clf()
train_recall_x = np.linspace(0,epoch,len(train_recall)) train_x = np.linspace(0,epoch,len(train))
test_recall_x = np.linspace(0,int(epoch),len(test_recall)) test_x = np.linspace(0,int(epoch),len(test))
plt.xlabel('Epoch') plt.xlabel('Epoch')
plt.ylabel('%') plt.ylabel('%')
plt.ylim((0,1)) plt.ylim((0,100))
if epoch <10: if epoch <10:
plt.xlim((0,10)) plt.xlim((0,10))
else: else:
plt.xlim((0,epoch)) plt.xlim((0,epoch))
plt.plot(train_recall_x,train_recall,label='train',linewidth = 2.0,color = 'red') plt.plot(train_x,train*100,label='train',linewidth = 2.0,color = 'red')
plt.plot(test_recall_x,test_recall,label='test', linewidth = 2.0,color = 'blue') plt.plot(test_x,test*100,label='test', linewidth = 2.0,color = 'blue')
plt.legend(loc=4) plt.legend(loc=1)
plt.savefig('./running_recall.png') plt.title('Running err.',fontsize='large')
plt.savefig('./running_err.png')
# plt.draw() # plt.draw()
# plt.pause(0.01) # plt.pause(0.01)
def main(): def main():
plot_result={'train': [0.2303303787268332, 0.2119345588626961, 0.20542007990053074, 0.20353191245282734, 0.2032570804016917, 0.20269640625503033, 0.2020943574651975, 0.2108357726067258, 0.21750990713964172, 0.23142651474994708, 0.2318236991596459, 0.22924187151697578, 0.22830716248841004, 0.2331831179181414, 0.23604422314519158, 0.23734486777406488, 0.23929925551037354, 0.2451802483014293, 0.24753448439761755, 0.24964581836870603, 0.2506097959967858, 0.2497704229822455], 'test': [0.28670433145009416, 0.29533625933982305, 0.2927783086111587, 0.28665535025585603, 0.2884532914652956]} mat=[[37980,1322,852,2,327],[3922,8784,3545,0,2193],[1756,5136,99564,1091,991],[18,1,7932,4063,14],[1361,1680,465,0,23931]]
show(plot_result,10) mat = np.array(mat)
avg_recall,avg_acc,err = result(mat)
print(avg_recall,avg_acc,err)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
import numpy as np
import time
import util
import os import os
import time import time
import transformer
import dataloader import numpy as np
# import models
from creatnet import CreatNet
import torch import torch
from torch import nn, optim from torch import nn, optim
import statistics
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import heatmap
from options import Options
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import util
import transformer
import dataloader
import statistics
import heatmap
from creatnet import CreatNet
from options import Options
opt = Options().getparse() opt = Options().getparse()
localtime = time.asctime(time.localtime(time.time())) localtime = time.asctime(time.localtime(time.time()))
util.writelog('\n\n'+str(localtime)+'\n'+str(opt)) util.writelog('\n\n'+str(localtime)+'\n'+str(opt))
t1 = time.time() t1 = time.time()
signals,stages = dataloader.loaddataset(opt,opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,shuffle=True,BID=None)
signals,stages = dataloader.loaddataset(opt,opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,shuffle=False,BID=opt.BID)
stage_cnt,stage_cnt_per = statistics.stage(stages) stage_cnt,stage_cnt_per = statistics.stage(stages)
signals,stages = transformer.batch_generator(signals,stages,opt.batchsize,shuffle = True)
if opt.Cross_Validation =='k_fold':
signals,stages = transformer.batch_generator(signals,stages,opt.batchsize,shuffle = True)
train_sequences,test_sequences = transformer.k_fold_generator(len(stages),opt.fold_num)
elif opt.Cross_Validation =='subject':
util.writelog('train statistics:',True)
stage_cnt,stage_cnt_per = statistics.stage(stages[:int(0.8*len(stages))])
util.writelog('test statistics:',True)
stage_cnt,stage_cnt_per = statistics.stage(stages[int(0.8*len(stages)):])
signals,stages = transformer.batch_generator_subject(signals,stages,opt.batchsize,shuffle = False)
train_sequences,test_sequences = transformer.k_fold_generator(len(stages),1)
batch_length = len(stages) batch_length = len(stages)
print('length of batch:',batch_length) print('length of batch:',batch_length)
train_sequences,test_sequences = transformer.k_fold_generator(batch_length,opt.fold_num)
show_freq = int(len(train_sequences[0])/5) show_freq = int(len(train_sequences[0])/5)
util.show_menory() util.show_menory()
t2 = time.time() t2 = time.time()
...@@ -64,7 +74,7 @@ def evalnet(net,signals,stages,sequences,epoch,plot_result={},mode = 'part'): ...@@ -64,7 +74,7 @@ def evalnet(net,signals,stages,sequences,epoch,plot_result={},mode = 'part'):
confusion_mat = np.zeros((5,5), dtype=int) confusion_mat = np.zeros((5,5), dtype=int)
for i, sequence in enumerate(sequences, 1): for i, sequence in enumerate(sequences, 1):
signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =True) signal=transformer.ToInputShape(signals[sequence],opt.model_name,opt.BID,test_flag =True)
signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda) signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda)
with torch.no_grad(): with torch.no_grad():
out = net(signal) out = net(signal)
...@@ -77,11 +87,10 @@ def evalnet(net,signals,stages,sequences,epoch,plot_result={},mode = 'part'): ...@@ -77,11 +87,10 @@ def evalnet(net,signals,stages,sequences,epoch,plot_result={},mode = 'part'):
if mode =='part': if mode =='part':
plot_result['test'].append(statistics.result(confusion_mat)[0]) plot_result['test'].append(statistics.result(confusion_mat)[0])
else: else:
recall,acc,error = statistics.result(confusion_mat) recall,acc,sp,err,k = statistics.result(confusion_mat)
plot_result['test'].append(recall) plot_result['test'].append(err)
heatmap.draw(confusion_mat,name = 'test') heatmap.draw(confusion_mat,name = 'test')
print('test avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'error:','%.4f' % error) print('avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'avg_sp:','%.4f' % sp,'error:','%.4f' % err,'Kappa:','%.4f' % k)
#util.writelog('epoch:'+str(epoch)+' test avg_recall:'+str(round(recall,4))+' avg_acc:'+str(round(acc,4))+' error:'+str(round(error,4)))
return plot_result,confusion_mat return plot_result,confusion_mat
print('begin to train ...') print('begin to train ...')
...@@ -92,7 +101,7 @@ for fold in range(opt.fold_num): ...@@ -92,7 +101,7 @@ for fold in range(opt.fold_num):
net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.model_name+'.pth')) net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.model_name+'.pth'))
if not opt.no_cuda: if not opt.no_cuda:
net.cuda() net.cuda()
plot_result={'train':[0],'test':[0]} plot_result={'train':[1.],'test':[1.]}
confusion_mats = [] confusion_mats = []
for epoch in range(opt.epochs): for epoch in range(opt.epochs):
...@@ -102,7 +111,7 @@ for fold in range(opt.fold_num): ...@@ -102,7 +111,7 @@ for fold in range(opt.fold_num):
net.train() net.train()
for i, sequence in enumerate(train_sequences[fold], 1): for i, sequence in enumerate(train_sequences[fold], 1):
signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =False) signal=transformer.ToInputShape(signals[sequence],opt.model_name,opt.BID,test_flag =False)
signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda) signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda)
out = net(signal) out = net(signal)
...@@ -117,7 +126,7 @@ for fold in range(opt.fold_num): ...@@ -117,7 +126,7 @@ for fold in range(opt.fold_num):
for x in range(len(pred)): for x in range(len(pred)):
confusion_mat[stage[x]][pred[x]] += 1 confusion_mat[stage[x]][pred[x]] += 1
if i%show_freq==0: if i%show_freq==0:
plot_result['train'].append(statistics.result(confusion_mat)[0]) plot_result['train'].append(statistics.result(confusion_mat)[3])
heatmap.draw(confusion_mat,name = 'train') heatmap.draw(confusion_mat,name = 'train')
# plot_result=evalnet(net,signals_eval,stages_eval,plot_result,show_freq,mode = 'part') # plot_result=evalnet(net,signals_eval,stages_eval,plot_result,show_freq,mode = 'part')
statistics.show(plot_result,epoch+i/(batch_length*0.8)) statistics.show(plot_result,epoch+i/(batch_length*0.8))
...@@ -134,18 +143,15 @@ for fold in range(opt.fold_num): ...@@ -134,18 +143,15 @@ for fold in range(opt.fold_num):
net.cuda() net.cuda()
t2=time.time() t2=time.time()
if epoch+1==1:
print('cost time: %.2f' % (t2-t1),'s') print('cost time: %.2f' % (t2-t1),'s')
pos = plot_result['test'].index(max(plot_result['test']))-1 pos = plot_result['test'].index(min(plot_result['test']))-1
final_confusion_mat = final_confusion_mat+confusion_mats[pos] final_confusion_mat = final_confusion_mat+confusion_mats[pos]
recall,acc,error = statistics.result(confusion_mats[pos]) util.writelog('fold:'+str(fold+1)+' recall,acc,sp,err,k: '+str(statistics.result(confusion_mats[pos])),True)
print('\nfold:',fold+1,'finished',' avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'error:','%.4f' % error,'\n') print('------------------')
util.writelog('fold:'+str(fold+1)+' test avg_recall:'+str(round(recall,4))+' avg_acc:'+str(round(acc,4))+' error:'+str(round(error,4)))
util.writelog('confusion_mat:\n'+str(confusion_mat)) util.writelog('confusion_mat:\n'+str(confusion_mat))
recall,acc,error = statistics.result(final_confusion_mat) util.writelog('final: '+'recall,acc,sp,err,k: '+str(statistics.result(final_confusion_mat)),True)
#print('all finished!\n',final_confusion_mat) util.writelog('confusion_mat:\n'+str(final_confusion_mat),True)
#print('avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'error:','%.4f' % error) statistics.stagefrommat(final_confusion_mat)
util.writelog('final:'+' test avg_recall:'+str(round(recall,4))+' avg_acc:'+str(round(acc,4))+' error:'+str(round(error,4)),True)
util.writelog('confusion_mat:\n'+str(confusion_mat),True)
statistics.stagefrommat(confusion_mat)
heatmap.draw(final_confusion_mat,name = 'final_test') heatmap.draw(final_confusion_mat,name = 'final_test')
\ No newline at end of file
...@@ -15,6 +15,17 @@ def shuffledata(data,target): ...@@ -15,6 +15,17 @@ def shuffledata(data,target):
np.random.shuffle(target) np.random.shuffle(target)
# return data,target # return data,target
def batch_generator_subject(data,target,batchsize,shuffle = True):
data_test = data[int(0.8*len(target)):]
data_train = data[0:int(0.8*len(target))]
target_test = target[int(0.8*len(target)):]
target_train = target[0:int(0.8*len(target))]
data_test,target_test = batch_generator(data_test, target_test, batchsize)
data_train,target_train = batch_generator(data_train, target_train, batchsize)
data = np.concatenate((data_train, data_test), axis=0)
target = np.concatenate((target_train, target_test), axis=0)
return data,target
def batch_generator(data,target,batchsize,shuffle = True): def batch_generator(data,target,batchsize,shuffle = True):
if shuffle: if shuffle:
shuffledata(data,target) shuffledata(data,target)
...@@ -26,6 +37,10 @@ def batch_generator(data,target,batchsize,shuffle = True): ...@@ -26,6 +37,10 @@ def batch_generator(data,target,batchsize,shuffle = True):
def k_fold_generator(length,fold_num): def k_fold_generator(length,fold_num):
sequence = np.linspace(0,length-1,length,dtype='int') sequence = np.linspace(0,length-1,length,dtype='int')
if fold_num == 1:
train_sequence = sequence[0:int(0.8*length)].reshape(1,-1)
test_sequence = sequence[int(0.8*length):].reshape(1,-1)
else:
train_length = int(length/fold_num*(fold_num-1)) train_length = int(length/fold_num*(fold_num-1))
test_length = int(length/fold_num) test_length = int(length/fold_num)
train_sequence = np.zeros((fold_num,train_length), dtype = 'int') train_sequence = np.zeros((fold_num,train_length), dtype = 'int')
...@@ -33,6 +48,7 @@ def k_fold_generator(length,fold_num): ...@@ -33,6 +48,7 @@ def k_fold_generator(length,fold_num):
for i in range(fold_num): for i in range(fold_num):
test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length] test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length]
train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length] train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length]
return train_sequence,test_sequence return train_sequence,test_sequence
...@@ -52,14 +68,19 @@ def Normalize(data,maxmin,avg,sigma): ...@@ -52,14 +68,19 @@ def Normalize(data,maxmin,avg,sigma):
data = np.clip(data, -maxmin, maxmin) data = np.clip(data, -maxmin, maxmin)
return (data-avg)/sigma return (data-avg)/sigma
def ToTensor(data,target,no_cuda = False): def ToTensor(data,target=None,no_cuda = False):
if target is not None:
data = torch.from_numpy(data).float() data = torch.from_numpy(data).float()
target = torch.from_numpy(target).long() target = torch.from_numpy(target).long()
if not no_cuda: if not no_cuda:
data = data.cuda() data = data.cuda()
target = target.cuda() target = target.cuda()
return data,target return data,target
else:
data = torch.from_numpy(data).float()
if not no_cuda:
data = data.cuda()
return data
def random_transform_1d(data,finesize,test_flag): def random_transform_1d(data,finesize,test_flag):
length = len(data) length = len(data)
...@@ -76,7 +97,7 @@ def random_transform_1d(data,finesize,test_flag): ...@@ -76,7 +97,7 @@ def random_transform_1d(data,finesize,test_flag):
result = result[::-1] result = result[::-1]
#random amp #random amp
result = result*random.uniform(0.95,1.05) result = result*random.uniform(0.8,1.2)
return result return result
...@@ -100,26 +121,8 @@ def random_transform_2d(img,finesize = (224,122),test_flag = True): ...@@ -100,26 +121,8 @@ def random_transform_2d(img,finesize = (224,122),test_flag = True):
return result return result
# def random_transform_2d(img,finesize,test_flag): def ToInputShape(data,net_name,BID = 'None',norm = True,test_flag = False):
# h,w = img.shape[:2]
# if test_flag:
# h_move = 2
# w_move = int((w-finesize)*0.5)
# result = img[h_move:h_move+finesize,w_move:w_move+finesize]
# else:
# #random crop
# h_move = int(5*random.random()) #do not loss low freq signal infos
# w_move = int((w-finesize)*random.random())
# result = img[h_move:h_move+finesize,w_move:w_move+finesize]
# #random flip
# if random.random()<0.5:
# result = result[:,::-1]
# #random amp
# result = result*random.uniform(0.95,1.05)+random.uniform(-0.02,0.02)
# return result
def ToInputShape(data,net_name,norm=True,test_flag = False):
data = data.astype(np.float32) data = data.astype(np.float32)
batchsize=data.shape[0] batchsize=data.shape[0]
if net_name=='lstm': if net_name=='lstm':
...@@ -127,16 +130,17 @@ def ToInputShape(data,net_name,norm=True,test_flag = False): ...@@ -127,16 +130,17 @@ def ToInputShape(data,net_name,norm=True,test_flag = False):
for i in range(0,batchsize): for i in range(0,batchsize):
randomdata=random_transform_1d(data[i],finesize = 2700,test_flag=test_flag) randomdata=random_transform_1d(data[i],finesize = 2700,test_flag=test_flag)
result.append(dsp.getfeature(randomdata)) result.append(dsp.getfeature(randomdata))
if norm and BID != '5_95_th':
result = Normalize(result,maxmin = 1000,avg=0,sigma=50)
result = np.array(result).reshape(batchsize,2700*5) result = np.array(result).reshape(batchsize,2700*5)
elif net_name in['cnn_1d','resnet18_1d','multi_scale_resnet_1d']: elif net_name in['cnn_1d','resnet18_1d','multi_scale_resnet_1d']:
result =[] result =[]
for i in range(0,batchsize): for i in range(0,batchsize):
randomdata=random_transform_1d(data[i],finesize = 2700,test_flag=test_flag) randomdata=random_transform_1d(data[i],finesize = 2700,test_flag=test_flag)
# result.append(dsp.getfeature(randomdata,ch_num = 6))
result.append(randomdata) result.append(randomdata)
result = np.array(result) result = np.array(result)
if norm: if norm and BID != '5_95_th':
result = Normalize(result,maxmin = 1000,avg=0,sigma=1000) result = Normalize(result,maxmin = 1000,avg=0,sigma=50)
result = result.reshape(batchsize,1,2700) result = result.reshape(batchsize,1,2700)
elif net_name in ['squeezenet','multi_scale_resnet','dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']: elif net_name in ['squeezenet','multi_scale_resnet','dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册