提交 5436233f 编写于 作者: H hypox64

Modify file structure

上级 0f3b076f
...@@ -40,7 +40,7 @@ cd candock ...@@ -40,7 +40,7 @@ cd candock
python3 train.py --label 50 --input_nc 1 --dataset_dir ./datasets/simple_test --save_dir ./checkpoints/simple_test --model_name micro_multi_scale_resnet_1d --gpu_id 0 --batchsize 64 --k_fold 5 python3 train.py --label 50 --input_nc 1 --dataset_dir ./datasets/simple_test --save_dir ./checkpoints/simple_test --model_name micro_multi_scale_resnet_1d --gpu_id 0 --batchsize 64 --k_fold 5
# if you want to use cpu to train, please input --no_cuda # if you want to use cpu to train, please input --no_cuda
``` ```
* More [options](./options.py). * More [options](./util/options.py).
#### Use your own data to train #### Use your own data to train
* step1: Generate signals.npy and labels.npy in the following format. * step1: Generate signals.npy and labels.npy in the following format.
```python ```python
...@@ -56,4 +56,5 @@ labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1 ...@@ -56,4 +56,5 @@ labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
```bash ```bash
python3 simple_test.py --label 50 --input_nc 1 --model_name micro_multi_scale_resnet_1d --gpu_id 0 python3 simple_test.py --label 50 --input_nc 1 --model_name micro_multi_scale_resnet_1d --gpu_id 0
# if you want to use cpu to test, please input --no_cuda # if you want to use cpu to test, please input --no_cuda
``` ```
\ No newline at end of file
from torch import nn from torch import nn
from models import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet from . import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet, \
from models import multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d
# from models import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet
# from models import multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d
def CreatNet(opt): def CreatNet(opt):
name = opt.model_name name = opt.model_name
......
...@@ -3,19 +3,16 @@ import numpy as np ...@@ -3,19 +3,16 @@ import numpy as np
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import util from util import util,transformer,dataloader,statistics,heatmap,options
import transformer from models import creatnet
import dataloader
from options import Options
from creatnet import CreatNet
''' '''
--------------------------------preload data-------------------------------- --------------------------------preload data--------------------------------
@hypox64 @hypox64
2020/04/03 2020/04/03
''' '''
opt = Options().getparse() opt = options.Options().getparse()
net = CreatNet(opt) net = creatnet.CreatNet(opt)
#load data #load data
signals = np.load('./datasets/simple_test/signals.npy') signals = np.load('./datasets/simple_test/signals.npy')
......
...@@ -7,15 +7,10 @@ from torch import nn, optim ...@@ -7,15 +7,10 @@ from torch import nn, optim
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import util from util import util,transformer,dataloader,statistics,heatmap,options
import transformer from models import creatnet
import dataloader
import statistics opt = options.Options().getparse()
import heatmap
from creatnet import CreatNet
from options import Options
opt = Options().getparse()
torch.cuda.set_device(opt.gpu_id) torch.cuda.set_device(opt.gpu_id)
t1 = time.time() t1 = time.time()
...@@ -37,7 +32,7 @@ train_sequences,test_sequences = transformer.k_fold_generator(len(labels),opt.k_ ...@@ -37,7 +32,7 @@ train_sequences,test_sequences = transformer.k_fold_generator(len(labels),opt.k_
t2 = time.time() t2 = time.time()
print('load data cost time: %.2f'% (t2-t1),'s') print('load data cost time: %.2f'% (t2-t1),'s')
net=CreatNet(opt) net=creatnet.CreatNet(opt)
util.writelog('network:\n'+str(net),opt,True) util.writelog('network:\n'+str(net),opt,True)
util.show_paramsnumber(net,opt) util.show_paramsnumber(net,opt)
......
...@@ -5,9 +5,10 @@ import random ...@@ -5,9 +5,10 @@ import random
import scipy.io as sio import scipy.io as sio
import numpy as np import numpy as np
import dsp from . import dsp,transformer,statistics
import transformer # import dsp
import statistics # import transformer
# import statistics
def trimdata(data,num): def trimdata(data,num):
......
import argparse import argparse
import os import os
import time import time
import util from . import util
class Options(): class Options():
def __init__(self): def __init__(self):
...@@ -25,7 +25,7 @@ class Options(): ...@@ -25,7 +25,7 @@ class Options():
self.parser.add_argument('--continue_train', action='store_true', help='if specified, continue train') self.parser.add_argument('--continue_train', action='store_true', help='if specified, continue train')
self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate') self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate')
self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize') self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize')
self.parser.add_argument('--weight_mod', type=str, default='normal',help='Choose weight mode: auto | normal') self.parser.add_argument('--weight_mod', type=str, default='auto',help='Choose weight mode: auto | normal')
self.parser.add_argument('--epochs', type=int, default=20,help='end epoch') self.parser.add_argument('--epochs', type=int, default=20,help='end epoch')
self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network') self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network')
self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.if 0 or 1,no k-fold') self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.if 0 or 1,no k-fold')
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import util import util
import os import os
import heatmap from . import heatmap
def label_statistics(labels): def label_statistics(labels):
#for sleep label: N3->0 N2->1 N1->2 REM->3 W->4 #for sleep label: N3->0 N2->1 N1->2 REM->3 W->4
......
...@@ -2,7 +2,8 @@ import os ...@@ -2,7 +2,8 @@ import os
import random import random
import numpy as np import numpy as np
import torch import torch
import dsp from . import dsp
# import dsp
def trimdata(data,num): def trimdata(data,num):
return data[:num*int(len(data)/num)] return data[:num*int(len(data)/num)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册