diff --git a/README.md b/README.md index 945587e43b3fb1271a3b49c34972f4f5f04553ea..7d30cc9360207019cd245941c6749191f732bcc0 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ cd candock python3 train.py --label 50 --input_nc 1 --dataset_dir ./datasets/simple_test --save_dir ./checkpoints/simple_test --model_name micro_multi_scale_resnet_1d --gpu_id 0 --batchsize 64 --k_fold 5 # if you want to use cpu to train, please input --no_cuda ``` -* More [options](./options.py). +* More [options](./util/options.py). #### Use your own data to train * step1: Generate signals.npy and labels.npy in the following format. ```python @@ -56,4 +56,5 @@ labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1 ```bash python3 simple_test.py --label 50 --input_nc 1 --model_name micro_multi_scale_resnet_1d --gpu_id 0 # if you want to use cpu to test, please input --no_cuda -``` \ No newline at end of file +``` + diff --git a/confusion_mat b/docs/confusion_mat similarity index 100% rename from confusion_mat rename to docs/confusion_mat diff --git a/creatnet.py b/models/creatnet.py similarity index 85% rename from creatnet.py rename to models/creatnet.py index f0ba6f890fe9e65088154cca9913308363e0b10a..eb16fbf8ec7a468044fb0e6931534a6324ffb830 100644 --- a/creatnet.py +++ b/models/creatnet.py @@ -1,6 +1,8 @@ from torch import nn -from models import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet -from models import multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d +from . import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet, \ +multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d +# from models import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet +# from models import multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d def CreatNet(opt): name = opt.model_name diff --git a/simple_test.py b/simple_test.py index efe48793bdb140a7c7b2c7cae09eae037c49f46b..de0cf2746b09942eb2745a80e64d310b90e21123 100644 --- a/simple_test.py +++ b/simple_test.py @@ -3,19 +3,16 @@ import numpy as np import torch import matplotlib.pyplot as plt -import util -import transformer -import dataloader -from options import Options -from creatnet import CreatNet +from util import util,transformer,dataloader,statistics,heatmap,options +from models import creatnet ''' --------------------------------preload data-------------------------------- @hypox64 2020/04/03 ''' -opt = Options().getparse() -net = CreatNet(opt) +opt = options.Options().getparse() +net = creatnet.CreatNet(opt) #load data signals = np.load('./datasets/simple_test/signals.npy') diff --git a/train.py b/train.py index 5df13170e375213421790583ff7eed2062d9ad30..be7bd12fb03948c619e10829f5e5bda8e2a11d56 100644 --- a/train.py +++ b/train.py @@ -7,15 +7,10 @@ from torch import nn, optim import warnings warnings.filterwarnings("ignore") -import util -import transformer -import dataloader -import statistics -import heatmap -from creatnet import CreatNet -from options import Options - -opt = Options().getparse() +from util import util,transformer,dataloader,statistics,heatmap,options +from models import creatnet + +opt = options.Options().getparse() torch.cuda.set_device(opt.gpu_id) t1 = time.time() @@ -37,7 +32,7 @@ train_sequences,test_sequences = transformer.k_fold_generator(len(labels),opt.k_ t2 = time.time() print('load data cost time: %.2f'% (t2-t1),'s') -net=CreatNet(opt) +net=creatnet.CreatNet(opt) util.writelog('network:\n'+str(net),opt,True) util.show_paramsnumber(net,opt) diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dataloader.py b/util/dataloader.py similarity index 99% rename from dataloader.py rename to util/dataloader.py index 6906e436db4e87150386cd16ee767953e41fc38c..6866f1611af224000d3c0c159319ccb6e7216e77 100644 --- a/dataloader.py +++ b/util/dataloader.py @@ -5,9 +5,10 @@ import random import scipy.io as sio import numpy as np -import dsp -import transformer -import statistics +from . import dsp,transformer,statistics +# import dsp +# import transformer +# import statistics def trimdata(data,num): diff --git a/dsp.py b/util/dsp.py similarity index 100% rename from dsp.py rename to util/dsp.py diff --git a/heatmap.py b/util/heatmap.py similarity index 100% rename from heatmap.py rename to util/heatmap.py diff --git a/options.py b/util/options.py similarity index 98% rename from options.py rename to util/options.py index 9905cf6b68c10cb774322d047d53b683adfb17de..44559ba33f1e775d65b1cfc992692c304ab0622a 100644 --- a/options.py +++ b/util/options.py @@ -1,7 +1,7 @@ import argparse import os import time -import util +from . import util class Options(): def __init__(self): @@ -25,7 +25,7 @@ class Options(): self.parser.add_argument('--continue_train', action='store_true', help='if specified, continue train') self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate') self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize') - self.parser.add_argument('--weight_mod', type=str, default='normal',help='Choose weight mode: auto | normal') + self.parser.add_argument('--weight_mod', type=str, default='auto',help='Choose weight mode: auto | normal') self.parser.add_argument('--epochs', type=int, default=20,help='end epoch') self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network') self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.if 0 or 1,no k-fold') diff --git a/statistics.py b/util/statistics.py similarity index 99% rename from statistics.py rename to util/statistics.py index ebabbd1cc62719a8b568ed7993f8ce7be04c5b50..11c0c2b87ac518af72030739c241fb327def3680 100644 --- a/statistics.py +++ b/util/statistics.py @@ -2,7 +2,7 @@ import numpy as np import matplotlib.pyplot as plt import util import os -import heatmap +from . import heatmap def label_statistics(labels): #for sleep label: N3->0 N2->1 N1->2 REM->3 W->4 diff --git a/transformer.py b/util/transformer.py similarity index 99% rename from transformer.py rename to util/transformer.py index 18b860214a41c36eb30e2191f7ab15119b4901ac..7098ad579f857d7849a4ebcb7c885ce66be87178 100644 --- a/transformer.py +++ b/util/transformer.py @@ -2,7 +2,8 @@ import os import random import numpy as np import torch -import dsp +from . import dsp +# import dsp def trimdata(data,num): return data[:num*int(len(data)/num)] diff --git a/util.py b/util/util.py similarity index 100% rename from util.py rename to util/util.py