提交 7ecfcb78 编写于 作者: H hypox64

Allow separated data

上级 f495e3a1
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import random import random
import torch import torch
from torch import nn, optim from torch import nn, optim
import matplotlib.pyplot as plt
import warnings import warnings
from util import util,transformer,dataloader,statistics,plot,options from util import util,transformer,dataloader,statistics,plot,options
...@@ -19,10 +20,10 @@ opt.k_fold = 0 ...@@ -19,10 +20,10 @@ opt.k_fold = 0
opt.save_dir = './datasets/server/tmp' opt.save_dir = './datasets/server/tmp'
util.makedirs(opt.save_dir) util.makedirs(opt.save_dir)
'''load ori data''' '''load ori data'''
signals,labels = dataloader.loaddataset(opt) # use separated mode
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels) signals_train,labels_train,signals_eval,labels_eval = dataloader.loaddataset(opt)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape) label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train.shape)
'''def network''' '''def network'''
core = core.Core(opt) core = core.Core(opt)
core.network_init(printflag=True) core.network_init(printflag=True)
...@@ -34,54 +35,103 @@ os.system('unzip ./datasets/server/data.zip -d ./datasets/server/') ...@@ -34,54 +35,103 @@ os.system('unzip ./datasets/server/data.zip -d ./datasets/server/')
categorys = os.listdir('./datasets/server/data') categorys = os.listdir('./datasets/server/data')
categorys.sort() categorys.sort()
print('categorys:',categorys) print('categorys:',categorys)
receive_category = len(categorys) category_num = len(categorys)
received_signals = [] # received_signals_train = [];received_labels_train = []
received_labels = [] # received_signals_eval = [];received_labels_eval = []
for i in range(receive_category):
samples = os.listdir(os.path.join('./datasets/server/data',categorys[i])) # sample_num = 1000
# eval_num = 1
# for i in range(category_num):
# samples = os.listdir(os.path.join('./datasets/server/data',categorys[i]))
# for j in range(len(samples)):
# txt = util.loadtxt(os.path.join('./datasets/server/data',categorys[i],samples[j]))
# #print(os.path.join('./datasets/server/data',categorys[i],sample))
# txt_split = txt.split()
# signal_ori = np.zeros(len(txt_split))
# for point in range(len(txt_split)):
# signal_ori[point] = float(txt_split[point])
# for x in range(sample_num//len(samples)):
# ran = random.randint(1000, len(signal_ori)-2000-1)
# this_signal = signal_ori[ran:ran+2000]
# this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
# # if i ==0:
# # plt.plot(this_signal)
# # plt.show()
# if j < (len(samples)-eval_num):
# received_signals_train.append(this_signal)
# received_labels_train.append(i)
# else:
# received_signals_eval.append(this_signal)
# received_labels_eval.append(i)
# received_signals_train = np.array(received_signals_train).reshape(-1,opt.input_nc,opt.loadsize)
# received_labels_train = np.array(received_labels_train).reshape(-1,1)
# received_signals_eval = np.array(received_signals_eval).reshape(-1,opt.input_nc,opt.loadsize)
# received_labels_eval = np.array(received_labels_eval).reshape(-1,1)
#print(received_signals_train.shape,received_signals_eval.shape)
received_signals = [];received_labels = []
for sample in samples: sample_num = 1000
txt = util.loadtxt(os.path.join('./datasets/server/data',categorys[i],sample)) eval_num = 1
for i in range(category_num):
samples = os.listdir(os.path.join('./datasets/server/data',categorys[i]))
random.shuffle(samples)
for j in range(len(samples)):
txt = util.loadtxt(os.path.join('./datasets/server/data',categorys[i],samples[j]))
#print(os.path.join('./datasets/server/data',categorys[i],sample)) #print(os.path.join('./datasets/server/data',categorys[i],sample))
txt_split = txt.split() txt_split = txt.split()
signal_ori = np.zeros(len(txt_split)) signal_ori = np.zeros(len(txt_split))
for point in range(len(txt_split)): for point in range(len(txt_split)):
signal_ori[point] = float(txt_split[point]) signal_ori[point] = float(txt_split[point])
# #just cut
# for j in range(1,len(signal_ori)//opt.loadsize-1): for x in range(sample_num//len(samples)):
# this_signal = signal_ori[j*opt.loadsize:(j+1)*opt.loadsize]
# this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
# received_signals.append(this_signal)
# received_labels.append(i)
#random cut
for j in range(500//len(samples)-1):
ran = random.randint(1000, len(signal_ori)-2000-1) ran = random.randint(1000, len(signal_ori)-2000-1)
this_signal = signal_ori[ran:ran+2000] this_signal = signal_ori[ran:ran+2000]
this_signal = arr.normliaze(this_signal,'5_95',truncated=4) this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
received_signals.append(this_signal) received_signals.append(this_signal)
received_labels.append(i) received_labels.append(i)
received_signals = np.array(received_signals).reshape(-1,opt.input_nc,opt.loadsize) received_signals = np.array(received_signals).reshape(-1,opt.input_nc,opt.loadsize)
received_labels = np.array(received_labels).reshape(-1,1) received_labels = np.array(received_labels).reshape(-1,1)
received_signals_train,received_labels_train,received_signals_eval,received_labels_eval=\
dataloader.segment_dataset(received_signals, received_labels, 0.8,random=False)
print(received_signals_train.shape,received_signals_eval.shape)
# print(labels) # print(labels)
'''merge data''' '''merge data'''
signals = signals[receive_category*500:] signals_train,labels_train = dataloader.del_labels(signals_train,labels_train, np.linspace(0, category_num-1,category_num,dtype=np.int64))
labels = labels[receive_category*500:] signals_eval,labels_eval = dataloader.del_labels(signals_eval,labels_eval, np.linspace(0, category_num-1,category_num,dtype=np.int64))
signals = np.concatenate((signals, received_signals))
labels = np.concatenate((labels, received_labels))
transformer.shuffledata(signals,labels) signals_train = np.concatenate((signals_train, received_signals_train))
labels_train = np.concatenate((labels_train, received_labels_train))
signals_eval = np.concatenate((signals_eval, received_signals_eval))
labels_eval = np.concatenate((labels_eval, received_labels_eval))
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train.shape)
train_sequences= transformer.k_fold_generator(len(labels_train),opt.k_fold,opt.separated)
eval_sequences= transformer.k_fold_generator(len(labels_eval),opt.k_fold,opt.separated)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)
train_sequences,test_sequences = transformer.k_fold_generator(len(labels),opt.k_fold)
for epoch in range(opt.epochs): for epoch in range(opt.epochs):
t1 = time.time() t1 = time.time()
core.train(signals,labels,train_sequences[0]) if opt.separated:
core.eval(signals,labels,test_sequences[0]) #print(signals_train.shape,labels_train.shape)
core.train(signals_train,labels_train,train_sequences)
core.eval(signals_eval,labels_eval,eval_sequences)
else:
core.train(signals,labels,train_sequences[fold])
core.eval(signals,labels,eval_sequences[fold])
t2=time.time() t2=time.time()
if epoch+1==1: if epoch+1==1:
util.writelog('>>> per epoch cost time:'+str(round((t2-t1),2))+'s',opt,True) util.writelog('>>> per epoch cost time:'+str(round((t2-t1),2))+'s',opt,True)
plot.draw_heatmap(core.confusion_mats[-1],opt,name = 'final')
core.save_traced_net() core.save_traced_net()
...@@ -24,11 +24,21 @@ signals = np.zeros((10,1,10),dtype='np.float64') ...@@ -24,11 +24,21 @@ signals = np.zeros((10,1,10),dtype='np.float64')
labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1 labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
* step2: input ```--dataset_dir your_dataset_dir``` when running code. * step2: input ```--dataset_dir your_dataset_dir``` when running code.
''' '''
signals,labels = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels) #----------------------------Load Data----------------------------
util.writelog('label statistics: '+str(label_cnt),opt,True) if opt.separated:
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape) signals_train,labels_train,signals_eval,labels_eval = dataloader.loaddataset(opt)
train_sequences,eval_sequences = transformer.k_fold_generator(len(labels),opt.k_fold) label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
util.writelog('label statistics: '+str(label_cnt),opt,True)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train.shape)
train_sequences= transformer.k_fold_generator(len(labels_train),opt.k_fold,opt.separated)
eval_sequences= transformer.k_fold_generator(len(labels_eval),opt.k_fold,opt.separated)
else:
signals,labels = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
util.writelog('label statistics: '+str(label_cnt),opt,True)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)
train_sequences,eval_sequences = transformer.k_fold_generator(len(labels),opt.k_fold)
t2 = time.time() t2 = time.time()
print('load data cost time: %.2f'% (t2-t1),'s') print('load data cost time: %.2f'% (t2-t1),'s')
...@@ -40,17 +50,19 @@ fold_final_confusion_mat = np.zeros((opt.label,opt.label), dtype=int) ...@@ -40,17 +50,19 @@ fold_final_confusion_mat = np.zeros((opt.label,opt.label), dtype=int)
for fold in range(opt.k_fold): for fold in range(opt.k_fold):
if opt.k_fold != 1:util.writelog('------------------------------ k-fold:'+str(fold+1)+' ------------------------------',opt,True) if opt.k_fold != 1:util.writelog('------------------------------ k-fold:'+str(fold+1)+' ------------------------------',opt,True)
core.network_init() core.network_init()
final_confusion_mat = np.zeros((opt.label,opt.label), dtype=int) final_confusion_mat = np.zeros((opt.label,opt.label), dtype=int)
# confusion_mats = [] for epoch in range(opt.epochs):
for epoch in range(opt.epochs):
t1 = time.time() t1 = time.time()
core.train(signals,labels,train_sequences[fold]) if opt.separated:
core.eval(signals,labels,eval_sequences[fold]) #print(signals_train.shape,labels_train.shape)
# confusion_mats.append(confusion_mat_eval) core.train(signals_train,labels_train,train_sequences)
core.eval(signals_eval,labels_eval,eval_sequences)
else:
core.train(signals,labels,train_sequences[fold])
core.eval(signals,labels,eval_sequences[fold])
core.save() core.save()
t2=time.time() t2=time.time()
if epoch+1==1: if epoch+1==1:
util.writelog('>>> per epoch cost time:'+str(round((t2-t1),2))+'s',opt,True) util.writelog('>>> per epoch cost time:'+str(round((t2-t1),2))+'s',opt,True)
......
...@@ -6,16 +6,55 @@ import scipy.io as sio ...@@ -6,16 +6,55 @@ import scipy.io as sio
import numpy as np import numpy as np
from . import dsp,transformer,statistics from . import dsp,transformer,statistics
# import dsp
# import transformer
# import statistics
def trimdata(data,num): def del_labels(signals,labels,dels):
return data[:num*int(len(data)/num)] del_index = []
for i in range(len(labels)):
if labels[i] in dels:
del_index.append(i)
del_index = np.array(del_index)
signals = np.delete(signals,del_index, axis = 0)
labels = np.delete(labels,del_index,axis = 0)
return signals,labels
# def sortbylabel(signals,labels):
# signals
def segment_dataset(signals,labels,a=0.8,random=True):
length = len(labels)
if random:
transformer.shuffledata(signals, labels)
signals_train = signals[:int(a*length)]
labels_train = labels[:int(a*length)]
signals_eval = signals[int(a*length):]
labels_eval = labels[int(a*length):]
else:
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
#signals_train=[];labels_train=[];signals_eval=[];labels_eval=[]
# cnt_ori = 0
# signals_tmp=np.zeros_like(signals)
# labels_tmp=np.zeros_like(labels)
cnt = 0
for i in range(label_num):
if i ==0:
signals_train = signals[cnt:cnt+int(label_cnt[i]*0.8)]
labels_train = labels[cnt:cnt+int(label_cnt[i]*0.8)]
signals_eval = signals[cnt+int(label_cnt[i]*0.8):cnt+label_cnt[i]]
labels_eval = labels[cnt+int(label_cnt[i]*0.8):cnt+label_cnt[i]]
else:
signals_train = np.concatenate((signals_train, signals[cnt:cnt+int(label_cnt[i]*0.8)]))
labels_train = np.concatenate((labels_train, labels[cnt:cnt+int(label_cnt[i]*0.8)]))
signals_eval = np.concatenate((signals_eval, signals[cnt+int(label_cnt[i]*0.8):cnt+label_cnt[i]]))
labels_eval = np.concatenate((labels_eval, labels[cnt+int(label_cnt[i]*0.8):cnt+label_cnt[i]]))
cnt += label_cnt[i]
return signals_train,labels_train,signals_eval,labels_eval
def reducesample(data,mult):
return data[::mult]
def balance_label(signals,labels): def balance_label(signals,labels):
...@@ -51,192 +90,22 @@ def balance_label(signals,labels): ...@@ -51,192 +90,22 @@ def balance_label(signals,labels):
return new_signals,new_labels return new_signals,new_labels
# delete uesless label
def del_UND(signals,stages):
stages_copy = stages.copy()
cnt = 0
for i in range(len(stages_copy)):
if stages_copy[i] == 5 :
signals = np.delete(signals,i-cnt,axis =0)
stages = np.delete(stages,i-cnt,axis =0)
cnt += 1
return signals,stages
def connectdata(signal,stage,signals=[],stages=[]):
if signals == []:
signals =signal.copy()
stages =stage.copy()
else:
signals=np.concatenate((signals, signal), axis=0)
stages=np.concatenate((stages, stage), axis=0)
return signals,stages
#load one subject data form cc2018
def loaddata_cc2018(filedir,filename,signal_name,BID,filter = True):
dirpath = os.path.join(filedir,filename)
#load signal
hea_path = os.path.join(dirpath,os.path.basename(dirpath)+'.hea')
signal_path = os.path.join(dirpath,os.path.basename(dirpath)+'.mat')
signal_names = []
for i,line in enumerate(open(hea_path),0):
if i!=0:
line=line.strip()
signal_names.append(line.split()[8])
mat = sio.loadmat(signal_path)
signals = mat['val'][signal_names.index(signal_name)]
if filter:
signals = dsp.BPF(signals,200,0.2,50,mod = 'fir')
#load stage
stagepath = os.path.join(dirpath,os.path.basename(dirpath)+'-arousal.mat')
mat=h5py.File(stagepath,'r')
# N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4 UND->5
N3 = mat['data']['sleep_stages']['nonrem3'][0]
N2 = mat['data']['sleep_stages']['nonrem2'][0]
N1 = mat['data']['sleep_stages']['nonrem1'][0]
REM = mat['data']['sleep_stages']['rem'][0]
W = mat['data']['sleep_stages']['wake'][0]
UND = mat['data']['sleep_stages']['undefined'][0]
stages = N3*0 + N2*1 + N1*2 + REM*3 + W*4 + UND*5
#resample
signals = reducesample(signals,2)
stages = reducesample(stages,2)
#trim
signals = trimdata(signals,3000)
stages = trimdata(stages,3000)
#30s per label
signals = signals.reshape(-1,3000)
stages = stages[::3000]
#Balance individualized differences
signals = transformer.Balance_individualized_differences(signals, BID)
#del UND
signals,stages = del_UND(signals, stages)
return signals.astype(np.float16),stages.astype(np.int16)
#load one subject data form sleep-edfx
def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
filenum = filename[2:6]
filenames = os.listdir(filedir)
for filename in filenames:
if str(filenum) in filename and 'Hypnogram' in filename:
f_stage_name = filename
if str(filenum) in filename and 'PSG' in filename:
f_signal_name = filename
raw_data= mne.io.read_raw_edf(os.path.join(filedir,f_signal_name),preload=True)
raw_annot = mne.read_annotations(os.path.join(filedir,f_stage_name))
eeg = raw_data.pick_channels([signal_name]).to_data_frame().values.T
eeg = eeg.reshape(-1)
raw_data.set_annotations(raw_annot, emit_warning=False)
#N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4 other->UND->5
event_id = {'Sleep stage 4': 0,
'Sleep stage 3': 0,
'Sleep stage 2': 1,
'Sleep stage 1': 2,
'Sleep stage R': 3,
'Sleep stage W': 4,
'Sleep stage ?': 5,
'Movement time': 5}
events, _ = mne.events_from_annotations(
raw_data, event_id=event_id, chunk_duration=30.)
stages = [];signals =[]
for i in range(len(events)-1):
stages.append(events[i][2])
signals.append(eeg[events[i][0]:events[i][0]+3000])
stages=np.array(stages)
signals=np.array(signals)
# #select sleep time
if select_sleep_time:
if 'SC' in f_signal_name:
signals = signals[np.clip(int(raw_annot[0]['duration'])//30-60,0,9999999):int(raw_annot[-2]['onset'])//30+60]
stages = stages[np.clip(int(raw_annot[0]['duration'])//30-60,0,9999999):int(raw_annot[-2]['onset'])//30+60]
signals,stages = del_UND(signals, stages)
print('shape:',signals.shape,stages.shape)
signals = transformer.Balance_individualized_differences(signals, BID)
return signals.astype(np.float16),stages.astype(np.int16)
#load all data in datasets #load all data in datasets
def loaddataset(opt,shuffle = False): def loaddataset(opt,shuffle = False):
filedir=opt.dataset_dir
dataset_name = opt.dataset_name
signal_name = opt.signal_name
num = opt.sample_num
BID = opt.BID
select_sleep_time = opt.select_sleep_time
print('load dataset, please wait...')
signals_train=[];labels_train=[];signals_test=[];labels_test=[]
if dataset_name == 'cc2018':
import h5py
filenames = os.listdir(filedir)
if not opt.no_shuffle:
random.shuffle(filenames)
else:
filenames.sort()
if num > len(filenames):
num = len(filenames)
print('num of dataset is:',num)
for cnt,filename in enumerate(filenames[:num],0):
signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID)
if cnt < round(num*0.8) :
signals_train,labels_train = connectdata(signal,stage,signals_train,labels_train)
else:
signals_test,labels_test = connectdata(signal,stage,signals_test,labels_test)
print('train subjects:',round(num*0.8),'test subjects:',round(num*0.2))
elif dataset_name == 'sleep-edfx':
import mne
if num > 197:
num = 197
filenames_sc_train = ['SC4001E0-PSG.edf', 'SC4002E0-PSG.edf', 'SC4011E0-PSG.edf', 'SC4012E0-PSG.edf', 'SC4021E0-PSG.edf', 'SC4022E0-PSG.edf', 'SC4031E0-PSG.edf', 'SC4032E0-PSG.edf', 'SC4041E0-PSG.edf', 'SC4042E0-PSG.edf', 'SC4051E0-PSG.edf', 'SC4052E0-PSG.edf', 'SC4061E0-PSG.edf', 'SC4062E0-PSG.edf', 'SC4071E0-PSG.edf', 'SC4072E0-PSG.edf', 'SC4081E0-PSG.edf', 'SC4082E0-PSG.edf', 'SC4091E0-PSG.edf', 'SC4092E0-PSG.edf', 'SC4101E0-PSG.edf', 'SC4102E0-PSG.edf', 'SC4111E0-PSG.edf', 'SC4112E0-PSG.edf', 'SC4121E0-PSG.edf', 'SC4122E0-PSG.edf', 'SC4131E0-PSG.edf', 'SC4141E0-PSG.edf', 'SC4142E0-PSG.edf', 'SC4151E0-PSG.edf', 'SC4152E0-PSG.edf', 'SC4161E0-PSG.edf', 'SC4162E0-PSG.edf', 'SC4171E0-PSG.edf', 'SC4172E0-PSG.edf', 'SC4181E0-PSG.edf', 'SC4182E0-PSG.edf', 'SC4191E0-PSG.edf', 'SC4192E0-PSG.edf', 'SC4201E0-PSG.edf', 'SC4202E0-PSG.edf', 'SC4211E0-PSG.edf', 'SC4212E0-PSG.edf', 'SC4221E0-PSG.edf', 'SC4222E0-PSG.edf', 'SC4231E0-PSG.edf', 'SC4232E0-PSG.edf', 'SC4241E0-PSG.edf', 'SC4242E0-PSG.edf', 'SC4251E0-PSG.edf', 'SC4252E0-PSG.edf', 'SC4261F0-PSG.edf', 'SC4262F0-PSG.edf', 'SC4271F0-PSG.edf', 'SC4272F0-PSG.edf', 'SC4281G0-PSG.edf', 'SC4282G0-PSG.edf', 'SC4291G0-PSG.edf', 'SC4292G0-PSG.edf', 'SC4301E0-PSG.edf', 'SC4302E0-PSG.edf', 'SC4311E0-PSG.edf', 'SC4312E0-PSG.edf', 'SC4321E0-PSG.edf', 'SC4322E0-PSG.edf', 'SC4331F0-PSG.edf', 'SC4332F0-PSG.edf', 'SC4341F0-PSG.edf', 'SC4342F0-PSG.edf', 'SC4351F0-PSG.edf', 'SC4352F0-PSG.edf', 'SC4362F0-PSG.edf', 'SC4371F0-PSG.edf', 'SC4372F0-PSG.edf', 'SC4381F0-PSG.edf', 'SC4382F0-PSG.edf', 'SC4401E0-PSG.edf', 'SC4402E0-PSG.edf', 'SC4411E0-PSG.edf', 'SC4412E0-PSG.edf', 'SC4421E0-PSG.edf', 'SC4422E0-PSG.edf', 'SC4431E0-PSG.edf', 'SC4432E0-PSG.edf', 'SC4441E0-PSG.edf', 'SC4442E0-PSG.edf', 'SC4451F0-PSG.edf', 'SC4452F0-PSG.edf', 'SC4461F0-PSG.edf', 'SC4462F0-PSG.edf', 'SC4471F0-PSG.edf', 'SC4472F0-PSG.edf', 'SC4481F0-PSG.edf', 'SC4482F0-PSG.edf', 'SC4491G0-PSG.edf', 'SC4492G0-PSG.edf', 'SC4501E0-PSG.edf', 'SC4502E0-PSG.edf', 'SC4511E0-PSG.edf', 'SC4512E0-PSG.edf', 'SC4522E0-PSG.edf', 'SC4531E0-PSG.edf', 'SC4532E0-PSG.edf', 'SC4541F0-PSG.edf', 'SC4542F0-PSG.edf', 'SC4551F0-PSG.edf', 'SC4552F0-PSG.edf', 'SC4561F0-PSG.edf', 'SC4562F0-PSG.edf', 'SC4571F0-PSG.edf', 'SC4572F0-PSG.edf', 'SC4581G0-PSG.edf', 'SC4582G0-PSG.edf', 'SC4591G0-PSG.edf', 'SC4592G0-PSG.edf', 'SC4601E0-PSG.edf', 'SC4602E0-PSG.edf', 'SC4611E0-PSG.edf', 'SC4612E0-PSG.edf', 'SC4621E0-PSG.edf', 'SC4622E0-PSG.edf', 'SC4631E0-PSG.edf', 'SC4632E0-PSG.edf']
filenames_sc_test = ['SC4641E0-PSG.edf', 'SC4642E0-PSG.edf', 'SC4651E0-PSG.edf', 'SC4652E0-PSG.edf', 'SC4661E0-PSG.edf', 'SC4662E0-PSG.edf', 'SC4671G0-PSG.edf', 'SC4672G0-PSG.edf', 'SC4701E0-PSG.edf', 'SC4702E0-PSG.edf', 'SC4711E0-PSG.edf', 'SC4712E0-PSG.edf', 'SC4721E0-PSG.edf', 'SC4722E0-PSG.edf', 'SC4731E0-PSG.edf', 'SC4732E0-PSG.edf', 'SC4741E0-PSG.edf', 'SC4742E0-PSG.edf', 'SC4751E0-PSG.edf', 'SC4752E0-PSG.edf', 'SC4761E0-PSG.edf', 'SC4762E0-PSG.edf', 'SC4771G0-PSG.edf', 'SC4772G0-PSG.edf', 'SC4801G0-PSG.edf', 'SC4802G0-PSG.edf', 'SC4811G0-PSG.edf', 'SC4812G0-PSG.edf', 'SC4821G0-PSG.edf', 'SC4822G0-PSG.edf']
filenames_st_train = ['ST7011J0-PSG.edf', 'ST7012J0-PSG.edf', 'ST7021J0-PSG.edf', 'ST7022J0-PSG.edf', 'ST7041J0-PSG.edf', 'ST7042J0-PSG.edf', 'ST7051J0-PSG.edf', 'ST7052J0-PSG.edf', 'ST7061J0-PSG.edf', 'ST7062J0-PSG.edf', 'ST7071J0-PSG.edf', 'ST7072J0-PSG.edf', 'ST7081J0-PSG.edf', 'ST7082J0-PSG.edf', 'ST7091J0-PSG.edf', 'ST7092J0-PSG.edf', 'ST7101J0-PSG.edf', 'ST7102J0-PSG.edf', 'ST7111J0-PSG.edf', 'ST7112J0-PSG.edf', 'ST7121J0-PSG.edf', 'ST7122J0-PSG.edf', 'ST7131J0-PSG.edf', 'ST7132J0-PSG.edf', 'ST7141J0-PSG.edf', 'ST7142J0-PSG.edf', 'ST7151J0-PSG.edf', 'ST7152J0-PSG.edf', 'ST7161J0-PSG.edf', 'ST7162J0-PSG.edf', 'ST7171J0-PSG.edf', 'ST7172J0-PSG.edf', 'ST7181J0-PSG.edf', 'ST7182J0-PSG.edf', 'ST7191J0-PSG.edf', 'ST7192J0-PSG.edf']
filenames_st_test = ['ST7201J0-PSG.edf', 'ST7202J0-PSG.edf', 'ST7211J0-PSG.edf', 'ST7212J0-PSG.edf', 'ST7221J0-PSG.edf', 'ST7222J0-PSG.edf', 'ST7241J0-PSG.edf', 'ST7242J0-PSG.edf']
for filename in filenames_sc_train[:round(num*153/197*0.8)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_train,labels_train = connectdata(signal,stage,signals_train,labels_train)
for filename in filenames_st_train[:round(num*44/197*0.8)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_train,labels_train = connectdata(signal,stage,signals_train,labels_train)
for filename in filenames_sc_test[:round(num*153/197*0.2)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_test,labels_test = connectdata(signal,stage,signals_test,labels_test)
for filename in filenames_st_test[:round(num*44/197*0.2)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_test,labels_test = connectdata(signal,stage,signals_test,labels_test)
print('---------Each subject has two sample---------',
'\nTrain samples_SC/ST:',round(num*153/197*0.8),round(num*44/197*0.8),
'\nTest samples_SC/ST:',round(num*153/197*0.2),round(num*44/197*0.2))
elif dataset_name == 'preload': if opt.dataset_name == 'preload':
if opt.separated: if opt.separated:
signals_train = np.load(filedir+'/signals_train.npy') signals_train = np.load(opt.dataset_dir+'/signals_train.npy')
labels_train = np.load(filedir+'/labels_train.npy') labels_train = np.load(opt.dataset_dir+'/labels_train.npy')
signals_test = np.load(filedir+'/signals_test.npy') signals_eval = np.load(opt.dataset_dir+'/signals_eval.npy')
labels_test = np.load(filedir+'/labels_test.npy') labels_eval = np.load(opt.dataset_dir+'/labels_eval.npy')
else: else:
signals = np.load(filedir+'/signals.npy') signals = np.load(opt.dataset_dir+'/signals.npy')
labels = np.load(filedir+'/labels.npy') labels = np.load(opt.dataset_dir+'/labels.npy')
if not opt.no_shuffle: if not opt.no_shuffle:
transformer.shuffledata(signals,labels) transformer.shuffledata(signals,labels)
if opt.separated: if opt.separated:
return signals_train,labels_train,signals_test,labels_test return signals_train,labels_train,signals_eval,labels_eval
else: else:
return signals,labels return signals,labels
\ No newline at end of file
...@@ -85,6 +85,9 @@ class Options(): ...@@ -85,6 +85,9 @@ class Options():
if self.opt.k_fold == 0 : if self.opt.k_fold == 0 :
self.opt.k_fold = 1 self.opt.k_fold = 1
if self.opt.separated:
self.opt.k_fold = 1
self.opt.mergelabel = eval(self.opt.mergelabel) self.opt.mergelabel = eval(self.opt.mergelabel)
if self.opt.mergelabel_name != 'None': if self.opt.mergelabel_name != 'None':
self.opt.mergelabel_name = self.opt.mergelabel_name.replace(" ", "").split(",") self.opt.mergelabel_name = self.opt.mergelabel_name.replace(" ", "").split(",")
......
...@@ -6,9 +6,6 @@ from . import dsp ...@@ -6,9 +6,6 @@ from . import dsp
from . import array_operation as arr from . import array_operation as arr
# import dsp # import dsp
def trimdata(data,num):
return data[:num*int(len(data)/num)]
def shuffledata(data,target): def shuffledata(data,target):
state = np.random.get_state() state = np.random.get_state()
np.random.shuffle(data) np.random.shuffle(data)
...@@ -16,20 +13,24 @@ def shuffledata(data,target): ...@@ -16,20 +13,24 @@ def shuffledata(data,target):
np.random.shuffle(target) np.random.shuffle(target)
# return data,target # return data,target
def k_fold_generator(length,fold_num): def k_fold_generator(length,fold_num,separated=False):
if fold_num == 0 or fold_num == 1: if separated:
train_sequence = np.linspace(0,int(length*0.8)-1,int(length*0.8),dtype='int')[None] sequence = np.linspace(0, length-1,num = length,dtype='int')
test_sequence = np.linspace(int(length*0.8),length-1,int(length*0.2),dtype='int')[None] return sequence
else: else:
sequence = np.linspace(0,length-1,length,dtype='int') if fold_num == 0 or fold_num == 1:
train_length = int(length/fold_num*(fold_num-1)) train_sequence = np.linspace(0,int(length*0.8)-1,int(length*0.8),dtype='int')[None]
test_length = int(length/fold_num) test_sequence = np.linspace(int(length*0.8),length-1,int(length*0.2),dtype='int')[None]
train_sequence = np.zeros((fold_num,train_length), dtype = 'int') else:
test_sequence = np.zeros((fold_num,test_length), dtype = 'int') sequence = np.linspace(0,length-1,length,dtype='int')
for i in range(fold_num): train_length = int(length/fold_num*(fold_num-1))
test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length] test_length = int(length/fold_num)
train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length] train_sequence = np.zeros((fold_num,train_length), dtype = 'int')
return train_sequence,test_sequence test_sequence = np.zeros((fold_num,test_length), dtype = 'int')
for i in range(fold_num):
test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length]
train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length]
return train_sequence,test_sequence
def batch_generator(data,target,sequence,shuffle = True): def batch_generator(data,target,sequence,shuffle = True):
batchsize = len(sequence) batchsize = len(sequence)
...@@ -48,22 +49,7 @@ def Normalize(data,maxmin,avg,sigma,is_01=False): ...@@ -48,22 +49,7 @@ def Normalize(data,maxmin,avg,sigma,is_01=False):
else: else:
return (data-avg)/sigma #(-1,1) return (data-avg)/sigma #(-1,1)
def Balance_individualized_differences(signals,BID):
if BID == 'median':
signals = (signals*8/(np.median(abs(signals))))
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
elif BID == '5_95_th':
tmp = np.sort(signals.reshape(-1))
th_5 = -tmp[int(0.05*len(tmp))]
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5,is_01=True)
else:
#dataser 5_95_th(-1,1) median
#CC2018 24.75 7.438
#sleep edfx 37.4 9.71
#sleep edfx sleeptime 39.03 10.125
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
return signals
def ToTensor(data,target=None,gpu_id=0): def ToTensor(data,target=None,gpu_id=0):
if target is not None: if target is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册