提交 d09afc3c 编写于 作者: H hypox64

Updata README.md and loss weight

上级 7ac8cf6f
......@@ -58,7 +58,7 @@
| Network | Parameters | Top1.err. | Avg. Acc. | Avg. Re. | Need to extract feature |
| --------------------------- | ---------- | --------- | --------- | -------- | ----------------------- |
| lstm | 1.25M | 26.32% | 93.56% | 68.57% | Yes |
| micro_multi_scale_resnet_1d | 2.11M | 25.33% | 93.87% | 72.61% | No |
| resnet18_1d | 3.85M | 24.21% | 94.07% | 72.87% | No |
| multi_scale_resnet_1d | 8.42M | 24.01% | 94.06% | 72.37% | No |
\ No newline at end of file
| lstm | 1.25M | 26.32% | 89.47% | 68.57% | Yes |
| micro_multi_scale_resnet_1d | 2.11M | 25.33% | 89.87% | 72.61% | No |
| resnet18_1d | 3.85M | 24.21% | 90.31% | 72.87% | No |
| multi_scale_resnet_1d | 8.42M | 24.01% | 90.40% | 72.37% | No |
\ No newline at end of file
## Prerequisites
- Linux, Windows,mac
- CPU or NVIDIA GPU + CUDA CuDNN
- Python 3.5+
- Pytroch 1.0
- Python 3
- Pytroch 1.0+
## Dependencies
This code depends on torchvision, numpy, scipy, h5py, matplotlib, mne , requests, hashlib, available via pip install.<br>
......@@ -29,14 +29,14 @@ python3 download_dataset.py
```
* Input your options and run
```bash
python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject
python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 20 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.001 --BID 5_95_th --select_sleep_time
```
* Notes<br>
If want to use cpu to train or test, please input --no_cuda
### Simple Test
* Download pretrained model & simple test data [[Google Drive]](https://drive.google.com/open?id=1pup2_tZFGQQwB-hoXRjpMxiD4Vmpn0Lf) [[百度云,dh88]](https://pan.baidu.com/s/1dGobTMVa_4u2HLky6bmbsA)
* Download pretrained model & simple test data [[Google Drive]](https://drive.google.com/open?id=1NTtLmT02jqlc81lhtzQ7GlPK8epuHfU5) [[百度云,y4ks]](https://pan.baidu.com/s/1WKWZL91SekrSlhOoEC1bQA)
* Input your options and run
```bash
python3 simple_test.py
python3 simple_test.py --model_name lstm
```
\ No newline at end of file
......@@ -15,7 +15,7 @@ class Options():
self.parser.add_argument('--no_cuda', action='store_true', help='if input, do not use gpu')
self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn')
self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models')
self.parser.add_argument('--lr', type=float, default=0.0005,help='learning rate')
self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate')
self.parser.add_argument('--BID', type=str, default='5_95_th',help='Balance individualized differences 5_95_th | median |None')
self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/',
......
......@@ -17,7 +17,7 @@ https://drive.google.com/open?id=1pup2_tZFGQQwB-hoXRjpMxiD4Vmpn0Lf
'''
opt = Options().getparse()
#choose and creat model
opt.model_name = 'micro_multi_scale_resnet_1d'
opt.model_name = 'lstm'
net=CreatNet(opt.model_name)
if not opt.no_cuda:
......
......@@ -47,14 +47,14 @@ t2 = time.time()
print('load data cost time: %.2f'% (t2-t1),'s')
net=CreatNet(opt.model_name)
torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'.pth')
util.show_paramsnumber(net)
weight = np.array([1,1,1,1,1])
if opt.weight_mod == 'avg_best':
weight = np.log(1/stage_cnt_per)
weight[2] = weight[2]+1
weight = np.clip(weight,1,3)
weight = weight/np.median(weight)
weight = np.clip(weight, 0.8, 2)
print('Loss_weight:',weight)
weight = torch.from_numpy(weight).float()
# print(net)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册