Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Hypo
candock
提交
7ecfcb78
C
candock
项目概览
Hypo
/
candock
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
candock
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
7ecfcb78
编写于
5月 29, 2020
作者:
H
hypox64
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow separated data
上级
f495e3a1
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
177 addition
and
257 deletion
+177
-257
server.py
server.py
+79
-29
train.py
train.py
+24
-12
util/dataloader.py
util/dataloader.py
+54
-185
util/options.py
util/options.py
+3
-0
util/transformer.py
util/transformer.py
+17
-31
未找到文件。
server.py
浏览文件 @
7ecfcb78
...
...
@@ -5,6 +5,7 @@ import numpy as np
import
random
import
torch
from
torch
import
nn
,
optim
import
matplotlib.pyplot
as
plt
import
warnings
from
util
import
util
,
transformer
,
dataloader
,
statistics
,
plot
,
options
...
...
@@ -19,10 +20,10 @@ opt.k_fold = 0
opt
.
save_dir
=
'./datasets/server/tmp'
util
.
makedirs
(
opt
.
save_dir
)
'''load ori data'''
signals
,
labels
=
dataloader
.
loaddataset
(
opt
)
label_cnt
,
label_cnt_per
,
label_num
=
statistics
.
label_statistics
(
labels
)
opt
=
options
.
get_auto_options
(
opt
,
label_cnt_per
,
label_num
,
signals
.
shape
)
# use separated mode
signals_train
,
labels_train
,
signals_eval
,
labels_eval
=
dataloader
.
loaddataset
(
opt
)
label_cnt
,
label_cnt_per
,
label_num
=
statistics
.
label_statistics
(
labels_train
)
opt
=
options
.
get_auto_options
(
opt
,
label_cnt_per
,
label_num
,
signals_train
.
shape
)
'''def network'''
core
=
core
.
Core
(
opt
)
core
.
network_init
(
printflag
=
True
)
...
...
@@ -34,54 +35,103 @@ os.system('unzip ./datasets/server/data.zip -d ./datasets/server/')
categorys
=
os
.
listdir
(
'./datasets/server/data'
)
categorys
.
sort
()
print
(
'categorys:'
,
categorys
)
receive_category
=
len
(
categorys
)
received_signals
=
[]
received_labels
=
[]
for
i
in
range
(
receive_category
):
samples
=
os
.
listdir
(
os
.
path
.
join
(
'./datasets/server/data'
,
categorys
[
i
]))
category_num
=
len
(
categorys
)
# received_signals_train = [];received_labels_train = []
# received_signals_eval = [];received_labels_eval = []
# sample_num = 1000
# eval_num = 1
# for i in range(category_num):
# samples = os.listdir(os.path.join('./datasets/server/data',categorys[i]))
# for j in range(len(samples)):
# txt = util.loadtxt(os.path.join('./datasets/server/data',categorys[i],samples[j]))
# #print(os.path.join('./datasets/server/data',categorys[i],sample))
# txt_split = txt.split()
# signal_ori = np.zeros(len(txt_split))
# for point in range(len(txt_split)):
# signal_ori[point] = float(txt_split[point])
# for x in range(sample_num//len(samples)):
# ran = random.randint(1000, len(signal_ori)-2000-1)
# this_signal = signal_ori[ran:ran+2000]
# this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
# # if i ==0:
# # plt.plot(this_signal)
# # plt.show()
# if j < (len(samples)-eval_num):
# received_signals_train.append(this_signal)
# received_labels_train.append(i)
# else:
# received_signals_eval.append(this_signal)
# received_labels_eval.append(i)
# received_signals_train = np.array(received_signals_train).reshape(-1,opt.input_nc,opt.loadsize)
# received_labels_train = np.array(received_labels_train).reshape(-1,1)
# received_signals_eval = np.array(received_signals_eval).reshape(-1,opt.input_nc,opt.loadsize)
# received_labels_eval = np.array(received_labels_eval).reshape(-1,1)
#print(received_signals_train.shape,received_signals_eval.shape)
received_signals
=
[];
received_labels
=
[]
for
sample
in
samples
:
txt
=
util
.
loadtxt
(
os
.
path
.
join
(
'./datasets/server/data'
,
categorys
[
i
],
sample
))
sample_num
=
1000
eval_num
=
1
for
i
in
range
(
category_num
):
samples
=
os
.
listdir
(
os
.
path
.
join
(
'./datasets/server/data'
,
categorys
[
i
]))
random
.
shuffle
(
samples
)
for
j
in
range
(
len
(
samples
)):
txt
=
util
.
loadtxt
(
os
.
path
.
join
(
'./datasets/server/data'
,
categorys
[
i
],
samples
[
j
]))
#print(os.path.join('./datasets/server/data',categorys[i],sample))
txt_split
=
txt
.
split
()
signal_ori
=
np
.
zeros
(
len
(
txt_split
))
for
point
in
range
(
len
(
txt_split
)):
signal_ori
[
point
]
=
float
(
txt_split
[
point
])
# #just cut
# for j in range(1,len(signal_ori)//opt.loadsize-1):
# this_signal = signal_ori[j*opt.loadsize:(j+1)*opt.loadsize]
# this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
# received_signals.append(this_signal)
# received_labels.append(i)
#random cut
for
j
in
range
(
500
//
len
(
samples
)
-
1
):
for
x
in
range
(
sample_num
//
len
(
samples
)):
ran
=
random
.
randint
(
1000
,
len
(
signal_ori
)
-
2000
-
1
)
this_signal
=
signal_ori
[
ran
:
ran
+
2000
]
this_signal
=
arr
.
normliaze
(
this_signal
,
'5_95'
,
truncated
=
4
)
received_signals
.
append
(
this_signal
)
received_labels
.
append
(
i
)
received_signals
=
np
.
array
(
received_signals
).
reshape
(
-
1
,
opt
.
input_nc
,
opt
.
loadsize
)
received_labels
=
np
.
array
(
received_labels
).
reshape
(
-
1
,
1
)
received_signals_train
,
received_labels_train
,
received_signals_eval
,
received_labels_eval
=
\
dataloader
.
segment_dataset
(
received_signals
,
received_labels
,
0.8
,
random
=
False
)
print
(
received_signals_train
.
shape
,
received_signals_eval
.
shape
)
# print(labels)
'''merge data'''
signals
=
signals
[
receive_category
*
500
:]
labels
=
labels
[
receive_category
*
500
:]
signals
=
np
.
concatenate
((
signals
,
received_signals
))
labels
=
np
.
concatenate
((
labels
,
received_labels
))
transformer
.
shuffledata
(
signals
,
labels
)
signals_train
,
labels_train
=
dataloader
.
del_labels
(
signals_train
,
labels_train
,
np
.
linspace
(
0
,
category_num
-
1
,
category_num
,
dtype
=
np
.
int64
))
signals_eval
,
labels_eval
=
dataloader
.
del_labels
(
signals_eval
,
labels_eval
,
np
.
linspace
(
0
,
category_num
-
1
,
category_num
,
dtype
=
np
.
int64
))
signals_train
=
np
.
concatenate
((
signals_train
,
received_signals_train
))
labels_train
=
np
.
concatenate
((
labels_train
,
received_labels_train
))
signals_eval
=
np
.
concatenate
((
signals_eval
,
received_signals_eval
))
labels_eval
=
np
.
concatenate
((
labels_eval
,
received_labels_eval
))
label_cnt
,
label_cnt_per
,
label_num
=
statistics
.
label_statistics
(
labels_train
)
opt
=
options
.
get_auto_options
(
opt
,
label_cnt_per
,
label_num
,
signals_train
.
shape
)
train_sequences
=
transformer
.
k_fold_generator
(
len
(
labels_train
),
opt
.
k_fold
,
opt
.
separated
)
eval_sequences
=
transformer
.
k_fold_generator
(
len
(
labels_eval
),
opt
.
k_fold
,
opt
.
separated
)
label_cnt
,
label_cnt_per
,
label_num
=
statistics
.
label_statistics
(
labels
)
opt
=
options
.
get_auto_options
(
opt
,
label_cnt_per
,
label_num
,
signals
.
shape
)
train_sequences
,
test_sequences
=
transformer
.
k_fold_generator
(
len
(
labels
),
opt
.
k_fold
)
for
epoch
in
range
(
opt
.
epochs
):
t1
=
time
.
time
()
core
.
train
(
signals
,
labels
,
train_sequences
[
0
])
core
.
eval
(
signals
,
labels
,
test_sequences
[
0
])
if
opt
.
separated
:
#print(signals_train.shape,labels_train.shape)
core
.
train
(
signals_train
,
labels_train
,
train_sequences
)
core
.
eval
(
signals_eval
,
labels_eval
,
eval_sequences
)
else
:
core
.
train
(
signals
,
labels
,
train_sequences
[
fold
])
core
.
eval
(
signals
,
labels
,
eval_sequences
[
fold
])
t2
=
time
.
time
()
if
epoch
+
1
==
1
:
util
.
writelog
(
'>>> per epoch cost time:'
+
str
(
round
((
t2
-
t1
),
2
))
+
's'
,
opt
,
True
)
plot
.
draw_heatmap
(
core
.
confusion_mats
[
-
1
],
opt
,
name
=
'final'
)
core
.
save_traced_net
()
train.py
浏览文件 @
7ecfcb78
...
...
@@ -24,11 +24,21 @@ signals = np.zeros((10,1,10),dtype='np.float64')
labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
* step2: input ```--dataset_dir your_dataset_dir``` when running code.
'''
signals
,
labels
=
dataloader
.
loaddataset
(
opt
)
label_cnt
,
label_cnt_per
,
label_num
=
statistics
.
label_statistics
(
labels
)
util
.
writelog
(
'label statistics: '
+
str
(
label_cnt
),
opt
,
True
)
opt
=
options
.
get_auto_options
(
opt
,
label_cnt_per
,
label_num
,
signals
.
shape
)
train_sequences
,
eval_sequences
=
transformer
.
k_fold_generator
(
len
(
labels
),
opt
.
k_fold
)
#----------------------------Load Data----------------------------
if
opt
.
separated
:
signals_train
,
labels_train
,
signals_eval
,
labels_eval
=
dataloader
.
loaddataset
(
opt
)
label_cnt
,
label_cnt_per
,
label_num
=
statistics
.
label_statistics
(
labels_train
)
util
.
writelog
(
'label statistics: '
+
str
(
label_cnt
),
opt
,
True
)
opt
=
options
.
get_auto_options
(
opt
,
label_cnt_per
,
label_num
,
signals_train
.
shape
)
train_sequences
=
transformer
.
k_fold_generator
(
len
(
labels_train
),
opt
.
k_fold
,
opt
.
separated
)
eval_sequences
=
transformer
.
k_fold_generator
(
len
(
labels_eval
),
opt
.
k_fold
,
opt
.
separated
)
else
:
signals
,
labels
=
dataloader
.
loaddataset
(
opt
)
label_cnt
,
label_cnt_per
,
label_num
=
statistics
.
label_statistics
(
labels
)
util
.
writelog
(
'label statistics: '
+
str
(
label_cnt
),
opt
,
True
)
opt
=
options
.
get_auto_options
(
opt
,
label_cnt_per
,
label_num
,
signals
.
shape
)
train_sequences
,
eval_sequences
=
transformer
.
k_fold_generator
(
len
(
labels
),
opt
.
k_fold
)
t2
=
time
.
time
()
print
(
'load data cost time: %.2f'
%
(
t2
-
t1
),
's'
)
...
...
@@ -40,17 +50,19 @@ fold_final_confusion_mat = np.zeros((opt.label,opt.label), dtype=int)
for
fold
in
range
(
opt
.
k_fold
):
if
opt
.
k_fold
!=
1
:
util
.
writelog
(
'------------------------------ k-fold:'
+
str
(
fold
+
1
)
+
' ------------------------------'
,
opt
,
True
)
core
.
network_init
()
final_confusion_mat
=
np
.
zeros
((
opt
.
label
,
opt
.
label
),
dtype
=
int
)
# confusion_mats = []
for
epoch
in
range
(
opt
.
epochs
):
for
epoch
in
range
(
opt
.
epochs
):
t1
=
time
.
time
()
core
.
train
(
signals
,
labels
,
train_sequences
[
fold
])
core
.
eval
(
signals
,
labels
,
eval_sequences
[
fold
])
# confusion_mats.append(confusion_mat_eval)
if
opt
.
separated
:
#print(signals_train.shape,labels_train.shape)
core
.
train
(
signals_train
,
labels_train
,
train_sequences
)
core
.
eval
(
signals_eval
,
labels_eval
,
eval_sequences
)
else
:
core
.
train
(
signals
,
labels
,
train_sequences
[
fold
])
core
.
eval
(
signals
,
labels
,
eval_sequences
[
fold
])
core
.
save
()
t2
=
time
.
time
()
if
epoch
+
1
==
1
:
util
.
writelog
(
'>>> per epoch cost time:'
+
str
(
round
((
t2
-
t1
),
2
))
+
's'
,
opt
,
True
)
...
...
util/dataloader.py
浏览文件 @
7ecfcb78
...
...
@@ -6,16 +6,55 @@ import scipy.io as sio
import
numpy
as
np
from
.
import
dsp
,
transformer
,
statistics
# import dsp
# import transformer
# import statistics
def
trimdata
(
data
,
num
):
return
data
[:
num
*
int
(
len
(
data
)
/
num
)]
def
del_labels
(
signals
,
labels
,
dels
):
del_index
=
[]
for
i
in
range
(
len
(
labels
)):
if
labels
[
i
]
in
dels
:
del_index
.
append
(
i
)
del_index
=
np
.
array
(
del_index
)
signals
=
np
.
delete
(
signals
,
del_index
,
axis
=
0
)
labels
=
np
.
delete
(
labels
,
del_index
,
axis
=
0
)
return
signals
,
labels
# def sortbylabel(signals,labels):
# signals
def
segment_dataset
(
signals
,
labels
,
a
=
0.8
,
random
=
True
):
length
=
len
(
labels
)
if
random
:
transformer
.
shuffledata
(
signals
,
labels
)
signals_train
=
signals
[:
int
(
a
*
length
)]
labels_train
=
labels
[:
int
(
a
*
length
)]
signals_eval
=
signals
[
int
(
a
*
length
):]
labels_eval
=
labels
[
int
(
a
*
length
):]
else
:
label_cnt
,
label_cnt_per
,
label_num
=
statistics
.
label_statistics
(
labels
)
#signals_train=[];labels_train=[];signals_eval=[];labels_eval=[]
# cnt_ori = 0
# signals_tmp=np.zeros_like(signals)
# labels_tmp=np.zeros_like(labels)
cnt
=
0
for
i
in
range
(
label_num
):
if
i
==
0
:
signals_train
=
signals
[
cnt
:
cnt
+
int
(
label_cnt
[
i
]
*
0.8
)]
labels_train
=
labels
[
cnt
:
cnt
+
int
(
label_cnt
[
i
]
*
0.8
)]
signals_eval
=
signals
[
cnt
+
int
(
label_cnt
[
i
]
*
0.8
):
cnt
+
label_cnt
[
i
]]
labels_eval
=
labels
[
cnt
+
int
(
label_cnt
[
i
]
*
0.8
):
cnt
+
label_cnt
[
i
]]
else
:
signals_train
=
np
.
concatenate
((
signals_train
,
signals
[
cnt
:
cnt
+
int
(
label_cnt
[
i
]
*
0.8
)]))
labels_train
=
np
.
concatenate
((
labels_train
,
labels
[
cnt
:
cnt
+
int
(
label_cnt
[
i
]
*
0.8
)]))
signals_eval
=
np
.
concatenate
((
signals_eval
,
signals
[
cnt
+
int
(
label_cnt
[
i
]
*
0.8
):
cnt
+
label_cnt
[
i
]]))
labels_eval
=
np
.
concatenate
((
labels_eval
,
labels
[
cnt
+
int
(
label_cnt
[
i
]
*
0.8
):
cnt
+
label_cnt
[
i
]]))
cnt
+=
label_cnt
[
i
]
return
signals_train
,
labels_train
,
signals_eval
,
labels_eval
def
reducesample
(
data
,
mult
):
return
data
[::
mult
]
def
balance_label
(
signals
,
labels
):
...
...
@@ -51,192 +90,22 @@ def balance_label(signals,labels):
return
new_signals
,
new_labels
# delete uesless label
def
del_UND
(
signals
,
stages
):
stages_copy
=
stages
.
copy
()
cnt
=
0
for
i
in
range
(
len
(
stages_copy
)):
if
stages_copy
[
i
]
==
5
:
signals
=
np
.
delete
(
signals
,
i
-
cnt
,
axis
=
0
)
stages
=
np
.
delete
(
stages
,
i
-
cnt
,
axis
=
0
)
cnt
+=
1
return
signals
,
stages
def
connectdata
(
signal
,
stage
,
signals
=
[],
stages
=
[]):
if
signals
==
[]:
signals
=
signal
.
copy
()
stages
=
stage
.
copy
()
else
:
signals
=
np
.
concatenate
((
signals
,
signal
),
axis
=
0
)
stages
=
np
.
concatenate
((
stages
,
stage
),
axis
=
0
)
return
signals
,
stages
#load one subject data form cc2018
def
loaddata_cc2018
(
filedir
,
filename
,
signal_name
,
BID
,
filter
=
True
):
dirpath
=
os
.
path
.
join
(
filedir
,
filename
)
#load signal
hea_path
=
os
.
path
.
join
(
dirpath
,
os
.
path
.
basename
(
dirpath
)
+
'.hea'
)
signal_path
=
os
.
path
.
join
(
dirpath
,
os
.
path
.
basename
(
dirpath
)
+
'.mat'
)
signal_names
=
[]
for
i
,
line
in
enumerate
(
open
(
hea_path
),
0
):
if
i
!=
0
:
line
=
line
.
strip
()
signal_names
.
append
(
line
.
split
()[
8
])
mat
=
sio
.
loadmat
(
signal_path
)
signals
=
mat
[
'val'
][
signal_names
.
index
(
signal_name
)]
if
filter
:
signals
=
dsp
.
BPF
(
signals
,
200
,
0.2
,
50
,
mod
=
'fir'
)
#load stage
stagepath
=
os
.
path
.
join
(
dirpath
,
os
.
path
.
basename
(
dirpath
)
+
'-arousal.mat'
)
mat
=
h5py
.
File
(
stagepath
,
'r'
)
# N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4 UND->5
N3
=
mat
[
'data'
][
'sleep_stages'
][
'nonrem3'
][
0
]
N2
=
mat
[
'data'
][
'sleep_stages'
][
'nonrem2'
][
0
]
N1
=
mat
[
'data'
][
'sleep_stages'
][
'nonrem1'
][
0
]
REM
=
mat
[
'data'
][
'sleep_stages'
][
'rem'
][
0
]
W
=
mat
[
'data'
][
'sleep_stages'
][
'wake'
][
0
]
UND
=
mat
[
'data'
][
'sleep_stages'
][
'undefined'
][
0
]
stages
=
N3
*
0
+
N2
*
1
+
N1
*
2
+
REM
*
3
+
W
*
4
+
UND
*
5
#resample
signals
=
reducesample
(
signals
,
2
)
stages
=
reducesample
(
stages
,
2
)
#trim
signals
=
trimdata
(
signals
,
3000
)
stages
=
trimdata
(
stages
,
3000
)
#30s per label
signals
=
signals
.
reshape
(
-
1
,
3000
)
stages
=
stages
[::
3000
]
#Balance individualized differences
signals
=
transformer
.
Balance_individualized_differences
(
signals
,
BID
)
#del UND
signals
,
stages
=
del_UND
(
signals
,
stages
)
return
signals
.
astype
(
np
.
float16
),
stages
.
astype
(
np
.
int16
)
#load one subject data form sleep-edfx
def
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
):
filenum
=
filename
[
2
:
6
]
filenames
=
os
.
listdir
(
filedir
)
for
filename
in
filenames
:
if
str
(
filenum
)
in
filename
and
'Hypnogram'
in
filename
:
f_stage_name
=
filename
if
str
(
filenum
)
in
filename
and
'PSG'
in
filename
:
f_signal_name
=
filename
raw_data
=
mne
.
io
.
read_raw_edf
(
os
.
path
.
join
(
filedir
,
f_signal_name
),
preload
=
True
)
raw_annot
=
mne
.
read_annotations
(
os
.
path
.
join
(
filedir
,
f_stage_name
))
eeg
=
raw_data
.
pick_channels
([
signal_name
]).
to_data_frame
().
values
.
T
eeg
=
eeg
.
reshape
(
-
1
)
raw_data
.
set_annotations
(
raw_annot
,
emit_warning
=
False
)
#N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4 other->UND->5
event_id
=
{
'Sleep stage 4'
:
0
,
'Sleep stage 3'
:
0
,
'Sleep stage 2'
:
1
,
'Sleep stage 1'
:
2
,
'Sleep stage R'
:
3
,
'Sleep stage W'
:
4
,
'Sleep stage ?'
:
5
,
'Movement time'
:
5
}
events
,
_
=
mne
.
events_from_annotations
(
raw_data
,
event_id
=
event_id
,
chunk_duration
=
30.
)
stages
=
[];
signals
=
[]
for
i
in
range
(
len
(
events
)
-
1
):
stages
.
append
(
events
[
i
][
2
])
signals
.
append
(
eeg
[
events
[
i
][
0
]:
events
[
i
][
0
]
+
3000
])
stages
=
np
.
array
(
stages
)
signals
=
np
.
array
(
signals
)
# #select sleep time
if
select_sleep_time
:
if
'SC'
in
f_signal_name
:
signals
=
signals
[
np
.
clip
(
int
(
raw_annot
[
0
][
'duration'
])
//
30
-
60
,
0
,
9999999
):
int
(
raw_annot
[
-
2
][
'onset'
])
//
30
+
60
]
stages
=
stages
[
np
.
clip
(
int
(
raw_annot
[
0
][
'duration'
])
//
30
-
60
,
0
,
9999999
):
int
(
raw_annot
[
-
2
][
'onset'
])
//
30
+
60
]
signals
,
stages
=
del_UND
(
signals
,
stages
)
print
(
'shape:'
,
signals
.
shape
,
stages
.
shape
)
signals
=
transformer
.
Balance_individualized_differences
(
signals
,
BID
)
return
signals
.
astype
(
np
.
float16
),
stages
.
astype
(
np
.
int16
)
#load all data in datasets
def
loaddataset
(
opt
,
shuffle
=
False
):
filedir
=
opt
.
dataset_dir
dataset_name
=
opt
.
dataset_name
signal_name
=
opt
.
signal_name
num
=
opt
.
sample_num
BID
=
opt
.
BID
select_sleep_time
=
opt
.
select_sleep_time
print
(
'load dataset, please wait...'
)
signals_train
=
[];
labels_train
=
[];
signals_test
=
[];
labels_test
=
[]
if
dataset_name
==
'cc2018'
:
import
h5py
filenames
=
os
.
listdir
(
filedir
)
if
not
opt
.
no_shuffle
:
random
.
shuffle
(
filenames
)
else
:
filenames
.
sort
()
if
num
>
len
(
filenames
):
num
=
len
(
filenames
)
print
(
'num of dataset is:'
,
num
)
for
cnt
,
filename
in
enumerate
(
filenames
[:
num
],
0
):
signal
,
stage
=
loaddata_cc2018
(
filedir
,
filename
,
signal_name
,
BID
=
BID
)
if
cnt
<
round
(
num
*
0.8
)
:
signals_train
,
labels_train
=
connectdata
(
signal
,
stage
,
signals_train
,
labels_train
)
else
:
signals_test
,
labels_test
=
connectdata
(
signal
,
stage
,
signals_test
,
labels_test
)
print
(
'train subjects:'
,
round
(
num
*
0.8
),
'test subjects:'
,
round
(
num
*
0.2
))
elif
dataset_name
==
'sleep-edfx'
:
import
mne
if
num
>
197
:
num
=
197
filenames_sc_train
=
[
'SC4001E0-PSG.edf'
,
'SC4002E0-PSG.edf'
,
'SC4011E0-PSG.edf'
,
'SC4012E0-PSG.edf'
,
'SC4021E0-PSG.edf'
,
'SC4022E0-PSG.edf'
,
'SC4031E0-PSG.edf'
,
'SC4032E0-PSG.edf'
,
'SC4041E0-PSG.edf'
,
'SC4042E0-PSG.edf'
,
'SC4051E0-PSG.edf'
,
'SC4052E0-PSG.edf'
,
'SC4061E0-PSG.edf'
,
'SC4062E0-PSG.edf'
,
'SC4071E0-PSG.edf'
,
'SC4072E0-PSG.edf'
,
'SC4081E0-PSG.edf'
,
'SC4082E0-PSG.edf'
,
'SC4091E0-PSG.edf'
,
'SC4092E0-PSG.edf'
,
'SC4101E0-PSG.edf'
,
'SC4102E0-PSG.edf'
,
'SC4111E0-PSG.edf'
,
'SC4112E0-PSG.edf'
,
'SC4121E0-PSG.edf'
,
'SC4122E0-PSG.edf'
,
'SC4131E0-PSG.edf'
,
'SC4141E0-PSG.edf'
,
'SC4142E0-PSG.edf'
,
'SC4151E0-PSG.edf'
,
'SC4152E0-PSG.edf'
,
'SC4161E0-PSG.edf'
,
'SC4162E0-PSG.edf'
,
'SC4171E0-PSG.edf'
,
'SC4172E0-PSG.edf'
,
'SC4181E0-PSG.edf'
,
'SC4182E0-PSG.edf'
,
'SC4191E0-PSG.edf'
,
'SC4192E0-PSG.edf'
,
'SC4201E0-PSG.edf'
,
'SC4202E0-PSG.edf'
,
'SC4211E0-PSG.edf'
,
'SC4212E0-PSG.edf'
,
'SC4221E0-PSG.edf'
,
'SC4222E0-PSG.edf'
,
'SC4231E0-PSG.edf'
,
'SC4232E0-PSG.edf'
,
'SC4241E0-PSG.edf'
,
'SC4242E0-PSG.edf'
,
'SC4251E0-PSG.edf'
,
'SC4252E0-PSG.edf'
,
'SC4261F0-PSG.edf'
,
'SC4262F0-PSG.edf'
,
'SC4271F0-PSG.edf'
,
'SC4272F0-PSG.edf'
,
'SC4281G0-PSG.edf'
,
'SC4282G0-PSG.edf'
,
'SC4291G0-PSG.edf'
,
'SC4292G0-PSG.edf'
,
'SC4301E0-PSG.edf'
,
'SC4302E0-PSG.edf'
,
'SC4311E0-PSG.edf'
,
'SC4312E0-PSG.edf'
,
'SC4321E0-PSG.edf'
,
'SC4322E0-PSG.edf'
,
'SC4331F0-PSG.edf'
,
'SC4332F0-PSG.edf'
,
'SC4341F0-PSG.edf'
,
'SC4342F0-PSG.edf'
,
'SC4351F0-PSG.edf'
,
'SC4352F0-PSG.edf'
,
'SC4362F0-PSG.edf'
,
'SC4371F0-PSG.edf'
,
'SC4372F0-PSG.edf'
,
'SC4381F0-PSG.edf'
,
'SC4382F0-PSG.edf'
,
'SC4401E0-PSG.edf'
,
'SC4402E0-PSG.edf'
,
'SC4411E0-PSG.edf'
,
'SC4412E0-PSG.edf'
,
'SC4421E0-PSG.edf'
,
'SC4422E0-PSG.edf'
,
'SC4431E0-PSG.edf'
,
'SC4432E0-PSG.edf'
,
'SC4441E0-PSG.edf'
,
'SC4442E0-PSG.edf'
,
'SC4451F0-PSG.edf'
,
'SC4452F0-PSG.edf'
,
'SC4461F0-PSG.edf'
,
'SC4462F0-PSG.edf'
,
'SC4471F0-PSG.edf'
,
'SC4472F0-PSG.edf'
,
'SC4481F0-PSG.edf'
,
'SC4482F0-PSG.edf'
,
'SC4491G0-PSG.edf'
,
'SC4492G0-PSG.edf'
,
'SC4501E0-PSG.edf'
,
'SC4502E0-PSG.edf'
,
'SC4511E0-PSG.edf'
,
'SC4512E0-PSG.edf'
,
'SC4522E0-PSG.edf'
,
'SC4531E0-PSG.edf'
,
'SC4532E0-PSG.edf'
,
'SC4541F0-PSG.edf'
,
'SC4542F0-PSG.edf'
,
'SC4551F0-PSG.edf'
,
'SC4552F0-PSG.edf'
,
'SC4561F0-PSG.edf'
,
'SC4562F0-PSG.edf'
,
'SC4571F0-PSG.edf'
,
'SC4572F0-PSG.edf'
,
'SC4581G0-PSG.edf'
,
'SC4582G0-PSG.edf'
,
'SC4591G0-PSG.edf'
,
'SC4592G0-PSG.edf'
,
'SC4601E0-PSG.edf'
,
'SC4602E0-PSG.edf'
,
'SC4611E0-PSG.edf'
,
'SC4612E0-PSG.edf'
,
'SC4621E0-PSG.edf'
,
'SC4622E0-PSG.edf'
,
'SC4631E0-PSG.edf'
,
'SC4632E0-PSG.edf'
]
filenames_sc_test
=
[
'SC4641E0-PSG.edf'
,
'SC4642E0-PSG.edf'
,
'SC4651E0-PSG.edf'
,
'SC4652E0-PSG.edf'
,
'SC4661E0-PSG.edf'
,
'SC4662E0-PSG.edf'
,
'SC4671G0-PSG.edf'
,
'SC4672G0-PSG.edf'
,
'SC4701E0-PSG.edf'
,
'SC4702E0-PSG.edf'
,
'SC4711E0-PSG.edf'
,
'SC4712E0-PSG.edf'
,
'SC4721E0-PSG.edf'
,
'SC4722E0-PSG.edf'
,
'SC4731E0-PSG.edf'
,
'SC4732E0-PSG.edf'
,
'SC4741E0-PSG.edf'
,
'SC4742E0-PSG.edf'
,
'SC4751E0-PSG.edf'
,
'SC4752E0-PSG.edf'
,
'SC4761E0-PSG.edf'
,
'SC4762E0-PSG.edf'
,
'SC4771G0-PSG.edf'
,
'SC4772G0-PSG.edf'
,
'SC4801G0-PSG.edf'
,
'SC4802G0-PSG.edf'
,
'SC4811G0-PSG.edf'
,
'SC4812G0-PSG.edf'
,
'SC4821G0-PSG.edf'
,
'SC4822G0-PSG.edf'
]
filenames_st_train
=
[
'ST7011J0-PSG.edf'
,
'ST7012J0-PSG.edf'
,
'ST7021J0-PSG.edf'
,
'ST7022J0-PSG.edf'
,
'ST7041J0-PSG.edf'
,
'ST7042J0-PSG.edf'
,
'ST7051J0-PSG.edf'
,
'ST7052J0-PSG.edf'
,
'ST7061J0-PSG.edf'
,
'ST7062J0-PSG.edf'
,
'ST7071J0-PSG.edf'
,
'ST7072J0-PSG.edf'
,
'ST7081J0-PSG.edf'
,
'ST7082J0-PSG.edf'
,
'ST7091J0-PSG.edf'
,
'ST7092J0-PSG.edf'
,
'ST7101J0-PSG.edf'
,
'ST7102J0-PSG.edf'
,
'ST7111J0-PSG.edf'
,
'ST7112J0-PSG.edf'
,
'ST7121J0-PSG.edf'
,
'ST7122J0-PSG.edf'
,
'ST7131J0-PSG.edf'
,
'ST7132J0-PSG.edf'
,
'ST7141J0-PSG.edf'
,
'ST7142J0-PSG.edf'
,
'ST7151J0-PSG.edf'
,
'ST7152J0-PSG.edf'
,
'ST7161J0-PSG.edf'
,
'ST7162J0-PSG.edf'
,
'ST7171J0-PSG.edf'
,
'ST7172J0-PSG.edf'
,
'ST7181J0-PSG.edf'
,
'ST7182J0-PSG.edf'
,
'ST7191J0-PSG.edf'
,
'ST7192J0-PSG.edf'
]
filenames_st_test
=
[
'ST7201J0-PSG.edf'
,
'ST7202J0-PSG.edf'
,
'ST7211J0-PSG.edf'
,
'ST7212J0-PSG.edf'
,
'ST7221J0-PSG.edf'
,
'ST7222J0-PSG.edf'
,
'ST7241J0-PSG.edf'
,
'ST7242J0-PSG.edf'
]
for
filename
in
filenames_sc_train
[:
round
(
num
*
153
/
197
*
0.8
)]:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals_train
,
labels_train
=
connectdata
(
signal
,
stage
,
signals_train
,
labels_train
)
for
filename
in
filenames_st_train
[:
round
(
num
*
44
/
197
*
0.8
)]:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals_train
,
labels_train
=
connectdata
(
signal
,
stage
,
signals_train
,
labels_train
)
for
filename
in
filenames_sc_test
[:
round
(
num
*
153
/
197
*
0.2
)]:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals_test
,
labels_test
=
connectdata
(
signal
,
stage
,
signals_test
,
labels_test
)
for
filename
in
filenames_st_test
[:
round
(
num
*
44
/
197
*
0.2
)]:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals_test
,
labels_test
=
connectdata
(
signal
,
stage
,
signals_test
,
labels_test
)
print
(
'---------Each subject has two sample---------'
,
'
\n
Train samples_SC/ST:'
,
round
(
num
*
153
/
197
*
0.8
),
round
(
num
*
44
/
197
*
0.8
),
'
\n
Test samples_SC/ST:'
,
round
(
num
*
153
/
197
*
0.2
),
round
(
num
*
44
/
197
*
0.2
))
elif
dataset_name
==
'preload'
:
if
opt
.
dataset_name
==
'preload'
:
if
opt
.
separated
:
signals_train
=
np
.
load
(
file
dir
+
'/signals_train.npy'
)
labels_train
=
np
.
load
(
file
dir
+
'/labels_train.npy'
)
signals_
test
=
np
.
load
(
filedir
+
'/signals_test
.npy'
)
labels_
test
=
np
.
load
(
filedir
+
'/labels_test
.npy'
)
signals_train
=
np
.
load
(
opt
.
dataset_
dir
+
'/signals_train.npy'
)
labels_train
=
np
.
load
(
opt
.
dataset_
dir
+
'/labels_train.npy'
)
signals_
eval
=
np
.
load
(
opt
.
dataset_dir
+
'/signals_eval
.npy'
)
labels_
eval
=
np
.
load
(
opt
.
dataset_dir
+
'/labels_eval
.npy'
)
else
:
signals
=
np
.
load
(
file
dir
+
'/signals.npy'
)
labels
=
np
.
load
(
file
dir
+
'/labels.npy'
)
signals
=
np
.
load
(
opt
.
dataset_
dir
+
'/signals.npy'
)
labels
=
np
.
load
(
opt
.
dataset_
dir
+
'/labels.npy'
)
if
not
opt
.
no_shuffle
:
transformer
.
shuffledata
(
signals
,
labels
)
if
opt
.
separated
:
return
signals_train
,
labels_train
,
signals_
test
,
labels_test
return
signals_train
,
labels_train
,
signals_
eval
,
labels_eval
else
:
return
signals
,
labels
\ No newline at end of file
util/options.py
浏览文件 @
7ecfcb78
...
...
@@ -85,6 +85,9 @@ class Options():
if
self
.
opt
.
k_fold
==
0
:
self
.
opt
.
k_fold
=
1
if
self
.
opt
.
separated
:
self
.
opt
.
k_fold
=
1
self
.
opt
.
mergelabel
=
eval
(
self
.
opt
.
mergelabel
)
if
self
.
opt
.
mergelabel_name
!=
'None'
:
self
.
opt
.
mergelabel_name
=
self
.
opt
.
mergelabel_name
.
replace
(
" "
,
""
).
split
(
","
)
...
...
util/transformer.py
浏览文件 @
7ecfcb78
...
...
@@ -6,9 +6,6 @@ from . import dsp
from
.
import
array_operation
as
arr
# import dsp
def
trimdata
(
data
,
num
):
return
data
[:
num
*
int
(
len
(
data
)
/
num
)]
def
shuffledata
(
data
,
target
):
state
=
np
.
random
.
get_state
()
np
.
random
.
shuffle
(
data
)
...
...
@@ -16,20 +13,24 @@ def shuffledata(data,target):
np
.
random
.
shuffle
(
target
)
# return data,target
def
k_fold_generator
(
length
,
fold_num
):
if
fold_num
==
0
or
fold_num
==
1
:
train_sequence
=
np
.
linspace
(
0
,
int
(
length
*
0.8
)
-
1
,
int
(
length
*
0.8
),
dtype
=
'int'
)[
None
]
test_sequence
=
np
.
linspace
(
int
(
length
*
0.8
),
length
-
1
,
int
(
length
*
0.2
),
dtype
=
'int'
)[
None
]
def
k_fold_generator
(
length
,
fold_num
,
separated
=
False
):
if
separated
:
sequence
=
np
.
linspace
(
0
,
length
-
1
,
num
=
length
,
dtype
=
'int'
)
return
sequence
else
:
sequence
=
np
.
linspace
(
0
,
length
-
1
,
length
,
dtype
=
'int'
)
train_length
=
int
(
length
/
fold_num
*
(
fold_num
-
1
))
test_length
=
int
(
length
/
fold_num
)
train_sequence
=
np
.
zeros
((
fold_num
,
train_length
),
dtype
=
'int'
)
test_sequence
=
np
.
zeros
((
fold_num
,
test_length
),
dtype
=
'int'
)
for
i
in
range
(
fold_num
):
test_sequence
[
i
]
=
(
sequence
[
test_length
*
i
:
test_length
*
(
i
+
1
)])[:
test_length
]
train_sequence
[
i
]
=
np
.
concatenate
((
sequence
[
0
:
test_length
*
i
],
sequence
[
test_length
*
(
i
+
1
):]),
axis
=
0
)[:
train_length
]
return
train_sequence
,
test_sequence
if
fold_num
==
0
or
fold_num
==
1
:
train_sequence
=
np
.
linspace
(
0
,
int
(
length
*
0.8
)
-
1
,
int
(
length
*
0.8
),
dtype
=
'int'
)[
None
]
test_sequence
=
np
.
linspace
(
int
(
length
*
0.8
),
length
-
1
,
int
(
length
*
0.2
),
dtype
=
'int'
)[
None
]
else
:
sequence
=
np
.
linspace
(
0
,
length
-
1
,
length
,
dtype
=
'int'
)
train_length
=
int
(
length
/
fold_num
*
(
fold_num
-
1
))
test_length
=
int
(
length
/
fold_num
)
train_sequence
=
np
.
zeros
((
fold_num
,
train_length
),
dtype
=
'int'
)
test_sequence
=
np
.
zeros
((
fold_num
,
test_length
),
dtype
=
'int'
)
for
i
in
range
(
fold_num
):
test_sequence
[
i
]
=
(
sequence
[
test_length
*
i
:
test_length
*
(
i
+
1
)])[:
test_length
]
train_sequence
[
i
]
=
np
.
concatenate
((
sequence
[
0
:
test_length
*
i
],
sequence
[
test_length
*
(
i
+
1
):]),
axis
=
0
)[:
train_length
]
return
train_sequence
,
test_sequence
def
batch_generator
(
data
,
target
,
sequence
,
shuffle
=
True
):
batchsize
=
len
(
sequence
)
...
...
@@ -48,22 +49,7 @@ def Normalize(data,maxmin,avg,sigma,is_01=False):
else
:
return
(
data
-
avg
)
/
sigma
#(-1,1)
def
Balance_individualized_differences
(
signals
,
BID
):
if
BID
==
'median'
:
signals
=
(
signals
*
8
/
(
np
.
median
(
abs
(
signals
))))
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
30
,
is_01
=
True
)
elif
BID
==
'5_95_th'
:
tmp
=
np
.
sort
(
signals
.
reshape
(
-
1
))
th_5
=
-
tmp
[
int
(
0.05
*
len
(
tmp
))]
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
th_5
,
is_01
=
True
)
else
:
#dataser 5_95_th(-1,1) median
#CC2018 24.75 7.438
#sleep edfx 37.4 9.71
#sleep edfx sleeptime 39.03 10.125
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
30
,
is_01
=
True
)
return
signals
def
ToTensor
(
data
,
target
=
None
,
gpu_id
=
0
):
if
target
is
not
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录