提交 5436233f 编写于 作者: H hypox64

Modify file structure

上级 0f3b076f
......@@ -40,7 +40,7 @@ cd candock
python3 train.py --label 50 --input_nc 1 --dataset_dir ./datasets/simple_test --save_dir ./checkpoints/simple_test --model_name micro_multi_scale_resnet_1d --gpu_id 0 --batchsize 64 --k_fold 5
# if you want to use cpu to train, please input --no_cuda
```
* More [options](./options.py).
* More [options](./util/options.py).
#### Use your own data to train
* step1: Generate signals.npy and labels.npy in the following format.
```python
......@@ -56,4 +56,5 @@ labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
```bash
python3 simple_test.py --label 50 --input_nc 1 --model_name micro_multi_scale_resnet_1d --gpu_id 0
# if you want to use cpu to test, please input --no_cuda
```
\ No newline at end of file
```
from torch import nn
from models import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet
from models import multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d
from . import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet, \
multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d
# from models import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet
# from models import multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d
def CreatNet(opt):
name = opt.model_name
......
......@@ -3,19 +3,16 @@ import numpy as np
import torch
import matplotlib.pyplot as plt
import util
import transformer
import dataloader
from options import Options
from creatnet import CreatNet
from util import util,transformer,dataloader,statistics,heatmap,options
from models import creatnet
'''
--------------------------------preload data--------------------------------
@hypox64
2020/04/03
'''
opt = Options().getparse()
net = CreatNet(opt)
opt = options.Options().getparse()
net = creatnet.CreatNet(opt)
#load data
signals = np.load('./datasets/simple_test/signals.npy')
......
......@@ -7,15 +7,10 @@ from torch import nn, optim
import warnings
warnings.filterwarnings("ignore")
import util
import transformer
import dataloader
import statistics
import heatmap
from creatnet import CreatNet
from options import Options
opt = Options().getparse()
from util import util,transformer,dataloader,statistics,heatmap,options
from models import creatnet
opt = options.Options().getparse()
torch.cuda.set_device(opt.gpu_id)
t1 = time.time()
......@@ -37,7 +32,7 @@ train_sequences,test_sequences = transformer.k_fold_generator(len(labels),opt.k_
t2 = time.time()
print('load data cost time: %.2f'% (t2-t1),'s')
net=CreatNet(opt)
net=creatnet.CreatNet(opt)
util.writelog('network:\n'+str(net),opt,True)
util.show_paramsnumber(net,opt)
......
......@@ -5,9 +5,10 @@ import random
import scipy.io as sio
import numpy as np
import dsp
import transformer
import statistics
from . import dsp,transformer,statistics
# import dsp
# import transformer
# import statistics
def trimdata(data,num):
......
import argparse
import os
import time
import util
from . import util
class Options():
def __init__(self):
......@@ -25,7 +25,7 @@ class Options():
self.parser.add_argument('--continue_train', action='store_true', help='if specified, continue train')
self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate')
self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize')
self.parser.add_argument('--weight_mod', type=str, default='normal',help='Choose weight mode: auto | normal')
self.parser.add_argument('--weight_mod', type=str, default='auto',help='Choose weight mode: auto | normal')
self.parser.add_argument('--epochs', type=int, default=20,help='end epoch')
self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network')
self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.if 0 or 1,no k-fold')
......
......@@ -2,7 +2,7 @@ import numpy as np
import matplotlib.pyplot as plt
import util
import os
import heatmap
from . import heatmap
def label_statistics(labels):
#for sleep label: N3->0 N2->1 N1->2 REM->3 W->4
......
......@@ -2,7 +2,8 @@ import os
import random
import numpy as np
import torch
import dsp
from . import dsp
# import dsp
def trimdata(data,num):
return data[:num*int(len(data)/num)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册