Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Hypo
candock
提交
5436233f
C
candock
项目概览
Hypo
/
candock
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
candock
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
5436233f
编写于
5月 06, 2020
作者:
H
hypox64
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify file structure
上级
0f3b076f
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
25 addition
and
28 deletion
+25
-28
README.md
README.md
+3
-2
docs/confusion_mat
docs/confusion_mat
+0
-0
models/creatnet.py
models/creatnet.py
+4
-2
simple_test.py
simple_test.py
+4
-7
train.py
train.py
+5
-10
util/__init__.py
util/__init__.py
+0
-0
util/dataloader.py
util/dataloader.py
+4
-3
util/dsp.py
util/dsp.py
+0
-0
util/heatmap.py
util/heatmap.py
+0
-0
util/options.py
util/options.py
+2
-2
util/statistics.py
util/statistics.py
+1
-1
util/transformer.py
util/transformer.py
+2
-1
util/util.py
util/util.py
+0
-0
未找到文件。
README.md
浏览文件 @
5436233f
...
...
@@ -40,7 +40,7 @@ cd candock
python3 train.py
--label
50
--input_nc
1
--dataset_dir
./datasets/simple_test
--save_dir
./checkpoints/simple_test
--model_name
micro_multi_scale_resnet_1d
--gpu_id
0
--batchsize
64
--k_fold
5
# if you want to use cpu to train, please input --no_cuda
```
*
More
[
options
](
./options.py
)
.
*
More
[
options
](
./
util/
options.py
)
.
#### Use your own data to train
*
step1: Generate signals.npy and labels.npy in the following format.
```
python
...
...
@@ -56,4 +56,5 @@ labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
```
bash
python3 simple_test.py
--label
50
--input_nc
1
--model_name
micro_multi_scale_resnet_1d
--gpu_id
0
# if you want to use cpu to test, please input --no_cuda
```
\ No newline at end of file
```
confusion_mat
→
docs/
confusion_mat
浏览文件 @
5436233f
文件已移动
creatnet.py
→
models/
creatnet.py
浏览文件 @
5436233f
from
torch
import
nn
from
models
import
cnn_1d
,
densenet
,
dfcnn
,
lstm
,
mobilenet
,
resnet
,
resnet_1d
,
squeezenet
from
models
import
multi_scale_resnet
,
multi_scale_resnet_1d
,
micro_multi_scale_resnet_1d
from
.
import
cnn_1d
,
densenet
,
dfcnn
,
lstm
,
mobilenet
,
resnet
,
resnet_1d
,
squeezenet
,
\
multi_scale_resnet
,
multi_scale_resnet_1d
,
micro_multi_scale_resnet_1d
# from models import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet
# from models import multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d
def
CreatNet
(
opt
):
name
=
opt
.
model_name
...
...
simple_test.py
浏览文件 @
5436233f
...
...
@@ -3,19 +3,16 @@ import numpy as np
import
torch
import
matplotlib.pyplot
as
plt
import
util
import
transformer
import
dataloader
from
options
import
Options
from
creatnet
import
CreatNet
from
util
import
util
,
transformer
,
dataloader
,
statistics
,
heatmap
,
options
from
models
import
creatnet
'''
--------------------------------preload data--------------------------------
@hypox64
2020/04/03
'''
opt
=
Options
().
getparse
()
net
=
CreatNet
(
opt
)
opt
=
options
.
Options
().
getparse
()
net
=
creatnet
.
CreatNet
(
opt
)
#load data
signals
=
np
.
load
(
'./datasets/simple_test/signals.npy'
)
...
...
train.py
浏览文件 @
5436233f
...
...
@@ -7,15 +7,10 @@ from torch import nn, optim
import
warnings
warnings
.
filterwarnings
(
"ignore"
)
import
util
import
transformer
import
dataloader
import
statistics
import
heatmap
from
creatnet
import
CreatNet
from
options
import
Options
opt
=
Options
().
getparse
()
from
util
import
util
,
transformer
,
dataloader
,
statistics
,
heatmap
,
options
from
models
import
creatnet
opt
=
options
.
Options
().
getparse
()
torch
.
cuda
.
set_device
(
opt
.
gpu_id
)
t1
=
time
.
time
()
...
...
@@ -37,7 +32,7 @@ train_sequences,test_sequences = transformer.k_fold_generator(len(labels),opt.k_
t2
=
time
.
time
()
print
(
'load data cost time: %.2f'
%
(
t2
-
t1
),
's'
)
net
=
CreatNet
(
opt
)
net
=
creatnet
.
CreatNet
(
opt
)
util
.
writelog
(
'network:
\n
'
+
str
(
net
),
opt
,
True
)
util
.
show_paramsnumber
(
net
,
opt
)
...
...
util/__init__.py
0 → 100644
浏览文件 @
5436233f
dataloader.py
→
util/
dataloader.py
浏览文件 @
5436233f
...
...
@@ -5,9 +5,10 @@ import random
import
scipy.io
as
sio
import
numpy
as
np
import
dsp
import
transformer
import
statistics
from
.
import
dsp
,
transformer
,
statistics
# import dsp
# import transformer
# import statistics
def
trimdata
(
data
,
num
):
...
...
dsp.py
→
util/
dsp.py
浏览文件 @
5436233f
文件已移动
heatmap.py
→
util/
heatmap.py
浏览文件 @
5436233f
文件已移动
options.py
→
util/
options.py
浏览文件 @
5436233f
import
argparse
import
os
import
time
import
util
from
.
import
util
class
Options
():
def
__init__
(
self
):
...
...
@@ -25,7 +25,7 @@ class Options():
self
.
parser
.
add_argument
(
'--continue_train'
,
action
=
'store_true'
,
help
=
'if specified, continue train'
)
self
.
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.001
,
help
=
'learning rate'
)
self
.
parser
.
add_argument
(
'--batchsize'
,
type
=
int
,
default
=
64
,
help
=
'batchsize'
)
self
.
parser
.
add_argument
(
'--weight_mod'
,
type
=
str
,
default
=
'
normal
'
,
help
=
'Choose weight mode: auto | normal'
)
self
.
parser
.
add_argument
(
'--weight_mod'
,
type
=
str
,
default
=
'
auto
'
,
help
=
'Choose weight mode: auto | normal'
)
self
.
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
20
,
help
=
'end epoch'
)
self
.
parser
.
add_argument
(
'--network_save_freq'
,
type
=
int
,
default
=
5
,
help
=
'the freq to save network'
)
self
.
parser
.
add_argument
(
'--k_fold'
,
type
=
int
,
default
=
0
,
help
=
'fold_num of k-fold.if 0 or 1,no k-fold'
)
...
...
statistics.py
→
util/
statistics.py
浏览文件 @
5436233f
...
...
@@ -2,7 +2,7 @@ import numpy as np
import
matplotlib.pyplot
as
plt
import
util
import
os
import
heatmap
from
.
import
heatmap
def
label_statistics
(
labels
):
#for sleep label: N3->0 N2->1 N1->2 REM->3 W->4
...
...
transformer.py
→
util/
transformer.py
浏览文件 @
5436233f
...
...
@@ -2,7 +2,8 @@ import os
import
random
import
numpy
as
np
import
torch
import
dsp
from
.
import
dsp
# import dsp
def
trimdata
(
data
,
num
):
return
data
[:
num
*
int
(
len
(
data
)
/
num
)]
...
...
util.py
→
util
/util
.py
浏览文件 @
5436233f
文件已移动
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录