diff --git a/README.md b/README.md index 90c2f94a71597696baaf1efc60058b48ec25f16c..e2027b2e27bad7512c9cb29d720f15d2ae609733 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ | Network | Parameters | Top1.err. | Avg. Acc. | Avg. Re. | Need to extract feature | | --------------------------- | ---------- | --------- | --------- | -------- | ----------------------- | -| lstm | 1.25M | 26.32% | 93.56% | 68.57% | Yes | -| micro_multi_scale_resnet_1d | 2.11M | 25.33% | 93.87% | 72.61% | No | -| resnet18_1d | 3.85M | 24.21% | 94.07% | 72.87% | No | -| multi_scale_resnet_1d | 8.42M | 24.01% | 94.06% | 72.37% | No | \ No newline at end of file +| lstm | 1.25M | 26.32% | 89.47% | 68.57% | Yes | +| micro_multi_scale_resnet_1d | 2.11M | 25.33% | 89.87% | 72.61% | No | +| resnet18_1d | 3.85M | 24.21% | 90.31% | 72.87% | No | +| multi_scale_resnet_1d | 8.42M | 24.01% | 90.40% | 72.37% | No | \ No newline at end of file diff --git a/how_to_run.md b/how_to_run.md index bf11b6094a255f3508c9a8a281230814f57aa207..9eb2035c650e68226e85a0578c02e7d0134a725d 100644 --- a/how_to_run.md +++ b/how_to_run.md @@ -1,8 +1,8 @@ ## Prerequisites - Linux, Windows,mac - CPU or NVIDIA GPU + CUDA CuDNN -- Python 3.5+ -- Pytroch 1.0 +- Python 3 +- Pytroch 1.0+ ## Dependencies This code depends on torchvision, numpy, scipy, h5py, matplotlib, mne , requests, hashlib, available via pip install.
@@ -29,14 +29,14 @@ python3 download_dataset.py ``` * Input your options and run ```bash -python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject +python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 20 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.001 --BID 5_95_th --select_sleep_time ``` * Notes
If want to use cpu to train or test, please input --no_cuda ### Simple Test -* Download pretrained model & simple test data [[Google Drive]](https://drive.google.com/open?id=1pup2_tZFGQQwB-hoXRjpMxiD4Vmpn0Lf) [[百度云,dh88]](https://pan.baidu.com/s/1dGobTMVa_4u2HLky6bmbsA) +* Download pretrained model & simple test data [[Google Drive]](https://drive.google.com/open?id=1NTtLmT02jqlc81lhtzQ7GlPK8epuHfU5) [[百度云,y4ks]](https://pan.baidu.com/s/1WKWZL91SekrSlhOoEC1bQA) * Input your options and run ```bash -python3 simple_test.py +python3 simple_test.py --model_name lstm ``` \ No newline at end of file diff --git a/options.py b/options.py index f34dd46ee4c617f6520d58b3bf9da0275533529b..6d81e9be8a6aabfa130d429d3e281db0b92e8583 100644 --- a/options.py +++ b/options.py @@ -15,7 +15,7 @@ class Options(): self.parser.add_argument('--no_cuda', action='store_true', help='if input, do not use gpu') self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn') self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models') - self.parser.add_argument('--lr', type=float, default=0.0005,help='learning rate') + self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate') self.parser.add_argument('--BID', type=str, default='5_95_th',help='Balance individualized differences 5_95_th | median |None') self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize') self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/', diff --git a/simple_test.py b/simple_test.py index 77489c3949c72fdac0718cc3f7a9846e5909993c..82ca2b82ef53f4eb90ef5a8528461913787b6b51 100644 --- a/simple_test.py +++ b/simple_test.py @@ -17,7 +17,7 @@ https://drive.google.com/open?id=1pup2_tZFGQQwB-hoXRjpMxiD4Vmpn0Lf ''' opt = Options().getparse() #choose and creat model -opt.model_name = 'micro_multi_scale_resnet_1d' +opt.model_name = 'lstm' net=CreatNet(opt.model_name) if not opt.no_cuda: diff --git a/train.py b/train.py index 07d890eb7daf8bddc75a0ae470c806397d6e1053..0f4a3b5b2bc926105563e72c3a7255323b4464ec 100644 --- a/train.py +++ b/train.py @@ -47,14 +47,14 @@ t2 = time.time() print('load data cost time: %.2f'% (t2-t1),'s') net=CreatNet(opt.model_name) -torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'.pth') util.show_paramsnumber(net) weight = np.array([1,1,1,1,1]) if opt.weight_mod == 'avg_best': weight = np.log(1/stage_cnt_per) weight[2] = weight[2]+1 - weight = np.clip(weight,1,3) + weight = weight/np.median(weight) + weight = np.clip(weight, 0.8, 2) print('Loss_weight:',weight) weight = torch.from_numpy(weight).float() # print(net)