Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Hypo
candock
提交
085dde45
C
candock
项目概览
Hypo
/
candock
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
candock
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
085dde45
编写于
7月 25, 2020
作者:
H
HypoX64
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add mlp
上级
fdbfced6
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
72 addition
and
14 deletion
+72
-14
.gitignore
.gitignore
+1
-0
models/core.py
models/core.py
+7
-5
models/creatnet.py
models/creatnet.py
+4
-1
models/net_1d/mlp.py
models/net_1d/mlp.py
+22
-0
tools/server.py
tools/server.py
+1
-1
util/dsp.py
util/dsp.py
+27
-4
util/options.py
util/options.py
+2
-2
util/util.py
util/util.py
+8
-1
未找到文件。
.gitignore
浏览文件 @
085dde45
...
...
@@ -138,6 +138,7 @@ checkpoints/
/train_backup.py
/tools/client_data
/tools/server_data
/trainscript.py
*.pth
*.edf
*log*
\ No newline at end of file
models/core.py
浏览文件 @
085dde45
...
...
@@ -42,7 +42,7 @@ class Core(object):
self
.
test_flag
=
True
if
printflag
:
util
.
writelog
(
'network:
\n
'
+
str
(
self
.
net
),
self
.
opt
,
True
)
#
util.writelog('network:\n'+str(self.net),self.opt,True)
show_paramsnumber
(
self
.
net
,
self
.
opt
)
if
self
.
opt
.
pretrained
!=
''
:
...
...
@@ -85,7 +85,8 @@ class Core(object):
self
.
queue
=
Queue
(
self
.
opt
.
load_thread
*
2
)
process_batch_num
=
len
(
sequences
)
//
self
.
opt
.
batchsize
//
self
.
opt
.
load_thread
if
process_batch_num
==
0
:
print
(
'
\033
[1;33m'
+
'Warning: too much load thread'
+
'
\033
[0m'
)
if
self
.
epoch
==
1
:
print
(
'
\033
[1;33m'
+
'Warning: too much load thread'
+
'
\033
[0m'
)
self
.
start_process
(
signals
,
labels
,
sequences
)
else
:
for
i
in
range
(
self
.
opt
.
load_thread
):
...
...
@@ -130,8 +131,8 @@ class Core(object):
loss
.
backward
()
self
.
optimizer
.
step
()
self
.
plot_result
[
'train'
].
append
(
epoch_loss
/
i
)
plot
.
draw_loss
(
self
.
plot_result
,
self
.
epoch
+
i
/
(
sequences
.
shape
[
0
]
/
self
.
opt
.
batchsize
),
self
.
opt
)
self
.
plot_result
[
'train'
].
append
(
epoch_loss
/
(
i
+
1
)
)
plot
.
draw_loss
(
self
.
plot_result
,
self
.
epoch
+
(
i
+
1
)
/
(
sequences
.
shape
[
0
]
/
self
.
opt
.
batchsize
),
self
.
opt
)
# if self.opt.model_name != 'autoencoder':
# plot.draw_heatmap(confusion_mat,self.opt,name = 'current_train')
...
...
@@ -142,6 +143,7 @@ class Core(object):
epoch_loss
=
0
confusion_mat
=
np
.
zeros
((
self
.
opt
.
label
,
self
.
opt
.
label
),
dtype
=
int
)
np
.
random
.
shuffle
(
sequences
)
self
.
process_pool_init
(
signals
,
labels
,
sequences
)
for
i
in
range
(
len
(
sequences
)
//
self
.
opt
.
batchsize
):
signal
,
label
=
self
.
queue
.
get
()
...
...
@@ -160,7 +162,7 @@ class Core(object):
print
(
'epoch:'
+
str
(
self
.
epoch
),
' macro-prec,reca,F1,err,kappa: '
+
str
(
statistics
.
report
(
confusion_mat
)))
self
.
plot_result
[
'F1'
].
append
(
statistics
.
report
(
confusion_mat
)[
2
])
self
.
plot_result
[
'eval'
].
append
(
epoch_loss
/
i
)
self
.
plot_result
[
'eval'
].
append
(
epoch_loss
/
(
i
+
1
)
)
self
.
epoch
+=
1
self
.
confusion_mats
.
append
(
confusion_mat
)
...
...
models/creatnet.py
浏览文件 @
085dde45
from
torch
import
nn
from
.net_1d
import
cnn_1d
,
lstm
,
resnet_1d
,
multi_scale_resnet_1d
,
micro_multi_scale_resnet_1d
,
autoencoder
from
.net_1d
import
cnn_1d
,
lstm
,
resnet_1d
,
multi_scale_resnet_1d
,
micro_multi_scale_resnet_1d
,
autoencoder
,
mlp
from
.net_2d
import
densenet
,
dfcnn
,
mobilenet
,
resnet
,
squeezenet
,
multi_scale_resnet
...
...
@@ -9,6 +9,9 @@ def creatnet(opt):
#encoder
if
name
==
'autoencoder'
:
net
=
autoencoder
.
Autoencoder
(
opt
.
input_nc
,
opt
.
feature
,
opt
.
label
,
opt
.
finesize
)
#mlp
if
name
==
'mlp'
:
net
=
mlp
.
mlp
(
opt
.
input_nc
,
opt
.
label
,
opt
.
finesize
)
#lstm
elif
name
==
'lstm'
:
net
=
lstm
.
lstm
(
opt
.
lstm_inputsize
,
opt
.
lstm_timestep
,
input_nc
=
opt
.
input_nc
,
num_classes
=
opt
.
label
)
...
...
models/net_1d/mlp.py
0 → 100644
浏览文件 @
085dde45
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
class
mlp
(
nn
.
Module
):
def
__init__
(
self
,
input_nc
,
num_classes
,
datasize
):
super
(
mlp
,
self
).
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
Linear
(
datasize
*
input_nc
,
128
),
nn
.
Tanh
(),
nn
.
Linear
(
128
,
64
),
nn
.
Tanh
(),
nn
.
Linear
(
64
,
64
),
nn
.
Tanh
(),
nn
.
Linear
(
64
,
num_classes
),
)
def
forward
(
self
,
x
):
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
net
(
x
)
return
x
\ No newline at end of file
tools/server.py
浏览文件 @
085dde45
...
...
@@ -132,4 +132,4 @@ def handlepost():
return
{
'return'
:
'error'
}
app
.
run
(
"0.0.0.0"
,
port
=
4000
,
debug
=
Tru
e
)
app
.
run
(
"0.0.0.0"
,
port
=
4000
,
debug
=
Fals
e
)
util/dsp.py
浏览文件 @
085dde45
import
scipy.signal
import
scipy.fftpack
as
fftpack
import
numpy
as
np
import
pywt
def
sin
(
f
,
fs
,
time
):
x
=
np
.
linspace
(
0
,
2
*
np
.
pi
*
f
*
time
,
fs
*
time
)
...
...
@@ -23,10 +24,32 @@ def medfilt(signal,x):
def
cleanoffset
(
signal
):
return
signal
-
np
.
mean
(
signal
)
def
bpf_fir
(
signal
,
fs
,
fc1
,
fc2
,
numtaps
=
101
):
b
=
scipy
.
signal
.
firwin
(
numtaps
,
[
fc1
,
fc2
],
pass_zero
=
False
,
fs
=
fs
)
result
=
scipy
.
signal
.
lfilter
(
b
,
1
,
signal
)
return
result
def
showfreq
(
signal
,
fs
,
fc
=
0
):
"""
return f,fft
"""
if
fc
==
0
:
kc
=
int
(
len
(
signal
)
/
2
)
else
:
kc
=
int
(
len
(
signal
)
/
fs
*
fc
)
signal_fft
=
np
.
abs
(
scipy
.
fftpack
.
fft
(
signal
))
f
=
np
.
linspace
(
0
,
fs
/
2
,
num
=
int
(
len
(
signal_fft
)
/
2
))
return
f
[:
kc
],
signal_fft
[
0
:
int
(
len
(
signal_fft
)
/
2
)][:
kc
]
def
bpf
(
signal
,
fs
,
fc1
,
fc2
,
numtaps
=
3
,
mode
=
'iir'
):
if
mode
==
'iir'
:
b
,
a
=
scipy
.
signal
.
iirfilter
(
numtaps
,
[
fc1
,
fc2
],
fs
=
fs
)
elif
mode
==
'fir'
:
b
=
scipy
.
signal
.
firwin
(
numtaps
,
[
fc1
,
fc2
],
pass_zero
=
False
,
fs
=
fs
)
a
=
1
return
scipy
.
signal
.
lfilter
(
b
,
a
,
signal
)
def
wave_filter
(
signal
,
wave
,
level
,
usedcoeffs
):
coeffs
=
pywt
.
wavedec
(
signal
,
wave
,
level
=
level
)
for
i
in
range
(
len
(
usedcoeffs
)):
if
usedcoeffs
[
i
]
==
0
:
coeffs
[
i
]
=
np
.
zeros_like
(
coeffs
[
i
])
return
pywt
.
waverec
(
coeffs
,
wave
,
mode
=
'symmetric'
,
axis
=-
1
)
def
fft_filter
(
signal
,
fs
,
fc
=
[],
type
=
'bandpass'
):
'''
...
...
util/options.py
浏览文件 @
085dde45
...
...
@@ -48,7 +48,7 @@ class Options():
# ------------------------Network------------------------
"""Available Network
1d: lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d,
micro_multi_scale_resnet_1d,autoencoder
micro_multi_scale_resnet_1d,autoencoder
,mlp
2d: mobilenet, dfcnn, multi_scale_resnet, resnet18, resnet50, resnet101,
densenet121, densenet201, squeezenet
"""
...
...
@@ -100,7 +100,7 @@ class Options():
if
self
.
opt
.
model_type
==
'auto'
:
if
self
.
opt
.
model_name
in
[
'lstm'
,
'cnn_1d'
,
'resnet18_1d'
,
'resnet34_1d'
,
'multi_scale_resnet_1d'
,
'micro_multi_scale_resnet_1d'
,
'autoencoder'
]:
'multi_scale_resnet_1d'
,
'micro_multi_scale_resnet_1d'
,
'autoencoder'
,
'mlp'
]:
self
.
opt
.
model_type
=
'1d'
elif
self
.
opt
.
model_name
in
[
'dfcnn'
,
'multi_scale_resnet'
,
'resnet18'
,
'resnet50'
,
'resnet101'
,
'densenet121'
,
'densenet201'
,
'squeezenet'
,
'mobilenet'
]:
...
...
util/util.py
浏览文件 @
085dde45
import
os
import
string
import
random
import
shutil
def
randomstr
(
num
):
return
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
num
))
...
...
@@ -38,4 +39,10 @@ def loadfile(path):
def
savefile
(
file
,
path
):
wf
=
open
(
path
,
'wb'
)
wf
.
write
(
file
)
wf
.
close
()
\ No newline at end of file
wf
.
close
()
def
copyfile
(
src
,
dst
):
try
:
shutil
.
copyfile
(
src
,
dst
)
except
Exception
as
e
:
print
(
e
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录