提交 7ecfcb78 编写于 作者: H hypox64

Allow separated data

上级 f495e3a1
......@@ -5,6 +5,7 @@ import numpy as np
import random
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
import warnings
from util import util,transformer,dataloader,statistics,plot,options
......@@ -19,10 +20,10 @@ opt.k_fold = 0
opt.save_dir = './datasets/server/tmp'
util.makedirs(opt.save_dir)
'''load ori data'''
signals,labels = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)
# use separated mode
signals_train,labels_train,signals_eval,labels_eval = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train.shape)
'''def network'''
core = core.Core(opt)
core.network_init(printflag=True)
......@@ -34,54 +35,103 @@ os.system('unzip ./datasets/server/data.zip -d ./datasets/server/')
categorys = os.listdir('./datasets/server/data')
categorys.sort()
print('categorys:',categorys)
receive_category = len(categorys)
received_signals = []
received_labels = []
for i in range(receive_category):
samples = os.listdir(os.path.join('./datasets/server/data',categorys[i]))
category_num = len(categorys)
# received_signals_train = [];received_labels_train = []
# received_signals_eval = [];received_labels_eval = []
# sample_num = 1000
# eval_num = 1
# for i in range(category_num):
# samples = os.listdir(os.path.join('./datasets/server/data',categorys[i]))
# for j in range(len(samples)):
# txt = util.loadtxt(os.path.join('./datasets/server/data',categorys[i],samples[j]))
# #print(os.path.join('./datasets/server/data',categorys[i],sample))
# txt_split = txt.split()
# signal_ori = np.zeros(len(txt_split))
# for point in range(len(txt_split)):
# signal_ori[point] = float(txt_split[point])
# for x in range(sample_num//len(samples)):
# ran = random.randint(1000, len(signal_ori)-2000-1)
# this_signal = signal_ori[ran:ran+2000]
# this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
# # if i ==0:
# # plt.plot(this_signal)
# # plt.show()
# if j < (len(samples)-eval_num):
# received_signals_train.append(this_signal)
# received_labels_train.append(i)
# else:
# received_signals_eval.append(this_signal)
# received_labels_eval.append(i)
# received_signals_train = np.array(received_signals_train).reshape(-1,opt.input_nc,opt.loadsize)
# received_labels_train = np.array(received_labels_train).reshape(-1,1)
# received_signals_eval = np.array(received_signals_eval).reshape(-1,opt.input_nc,opt.loadsize)
# received_labels_eval = np.array(received_labels_eval).reshape(-1,1)
#print(received_signals_train.shape,received_signals_eval.shape)
received_signals = [];received_labels = []
for sample in samples:
txt = util.loadtxt(os.path.join('./datasets/server/data',categorys[i],sample))
sample_num = 1000
eval_num = 1
for i in range(category_num):
samples = os.listdir(os.path.join('./datasets/server/data',categorys[i]))
random.shuffle(samples)
for j in range(len(samples)):
txt = util.loadtxt(os.path.join('./datasets/server/data',categorys[i],samples[j]))
#print(os.path.join('./datasets/server/data',categorys[i],sample))
txt_split = txt.split()
signal_ori = np.zeros(len(txt_split))
for point in range(len(txt_split)):
signal_ori[point] = float(txt_split[point])
# #just cut
# for j in range(1,len(signal_ori)//opt.loadsize-1):
# this_signal = signal_ori[j*opt.loadsize:(j+1)*opt.loadsize]
# this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
# received_signals.append(this_signal)
# received_labels.append(i)
#random cut
for j in range(500//len(samples)-1):
for x in range(sample_num//len(samples)):
ran = random.randint(1000, len(signal_ori)-2000-1)
this_signal = signal_ori[ran:ran+2000]
this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
received_signals.append(this_signal)
received_labels.append(i)
received_signals = np.array(received_signals).reshape(-1,opt.input_nc,opt.loadsize)
received_labels = np.array(received_labels).reshape(-1,1)
received_signals_train,received_labels_train,received_signals_eval,received_labels_eval=\
dataloader.segment_dataset(received_signals, received_labels, 0.8,random=False)
print(received_signals_train.shape,received_signals_eval.shape)
# print(labels)
'''merge data'''
signals = signals[receive_category*500:]
labels = labels[receive_category*500:]
signals = np.concatenate((signals, received_signals))
labels = np.concatenate((labels, received_labels))
transformer.shuffledata(signals,labels)
signals_train,labels_train = dataloader.del_labels(signals_train,labels_train, np.linspace(0, category_num-1,category_num,dtype=np.int64))
signals_eval,labels_eval = dataloader.del_labels(signals_eval,labels_eval, np.linspace(0, category_num-1,category_num,dtype=np.int64))
signals_train = np.concatenate((signals_train, received_signals_train))
labels_train = np.concatenate((labels_train, received_labels_train))
signals_eval = np.concatenate((signals_eval, received_signals_eval))
labels_eval = np.concatenate((labels_eval, received_labels_eval))
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train.shape)
train_sequences= transformer.k_fold_generator(len(labels_train),opt.k_fold,opt.separated)
eval_sequences= transformer.k_fold_generator(len(labels_eval),opt.k_fold,opt.separated)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)
train_sequences,test_sequences = transformer.k_fold_generator(len(labels),opt.k_fold)
for epoch in range(opt.epochs):
t1 = time.time()
core.train(signals,labels,train_sequences[0])
core.eval(signals,labels,test_sequences[0])
if opt.separated:
#print(signals_train.shape,labels_train.shape)
core.train(signals_train,labels_train,train_sequences)
core.eval(signals_eval,labels_eval,eval_sequences)
else:
core.train(signals,labels,train_sequences[fold])
core.eval(signals,labels,eval_sequences[fold])
t2=time.time()
if epoch+1==1:
util.writelog('>>> per epoch cost time:'+str(round((t2-t1),2))+'s',opt,True)
plot.draw_heatmap(core.confusion_mats[-1],opt,name = 'final')
core.save_traced_net()
......@@ -24,11 +24,21 @@ signals = np.zeros((10,1,10),dtype='np.float64')
labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
* step2: input ```--dataset_dir your_dataset_dir``` when running code.
'''
signals,labels = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
util.writelog('label statistics: '+str(label_cnt),opt,True)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)
train_sequences,eval_sequences = transformer.k_fold_generator(len(labels),opt.k_fold)
#----------------------------Load Data----------------------------
if opt.separated:
signals_train,labels_train,signals_eval,labels_eval = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
util.writelog('label statistics: '+str(label_cnt),opt,True)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train.shape)
train_sequences= transformer.k_fold_generator(len(labels_train),opt.k_fold,opt.separated)
eval_sequences= transformer.k_fold_generator(len(labels_eval),opt.k_fold,opt.separated)
else:
signals,labels = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
util.writelog('label statistics: '+str(label_cnt),opt,True)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)
train_sequences,eval_sequences = transformer.k_fold_generator(len(labels),opt.k_fold)
t2 = time.time()
print('load data cost time: %.2f'% (t2-t1),'s')
......@@ -40,17 +50,19 @@ fold_final_confusion_mat = np.zeros((opt.label,opt.label), dtype=int)
for fold in range(opt.k_fold):
if opt.k_fold != 1:util.writelog('------------------------------ k-fold:'+str(fold+1)+' ------------------------------',opt,True)
core.network_init()
final_confusion_mat = np.zeros((opt.label,opt.label), dtype=int)
# confusion_mats = []
for epoch in range(opt.epochs):
for epoch in range(opt.epochs):
t1 = time.time()
core.train(signals,labels,train_sequences[fold])
core.eval(signals,labels,eval_sequences[fold])
# confusion_mats.append(confusion_mat_eval)
if opt.separated:
#print(signals_train.shape,labels_train.shape)
core.train(signals_train,labels_train,train_sequences)
core.eval(signals_eval,labels_eval,eval_sequences)
else:
core.train(signals,labels,train_sequences[fold])
core.eval(signals,labels,eval_sequences[fold])
core.save()
t2=time.time()
if epoch+1==1:
util.writelog('>>> per epoch cost time:'+str(round((t2-t1),2))+'s',opt,True)
......
......@@ -6,16 +6,55 @@ import scipy.io as sio
import numpy as np
from . import dsp,transformer,statistics
# import dsp
# import transformer
# import statistics
def trimdata(data,num):
return data[:num*int(len(data)/num)]
def del_labels(signals,labels,dels):
del_index = []
for i in range(len(labels)):
if labels[i] in dels:
del_index.append(i)
del_index = np.array(del_index)
signals = np.delete(signals,del_index, axis = 0)
labels = np.delete(labels,del_index,axis = 0)
return signals,labels
# def sortbylabel(signals,labels):
# signals
def segment_dataset(signals,labels,a=0.8,random=True):
length = len(labels)
if random:
transformer.shuffledata(signals, labels)
signals_train = signals[:int(a*length)]
labels_train = labels[:int(a*length)]
signals_eval = signals[int(a*length):]
labels_eval = labels[int(a*length):]
else:
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
#signals_train=[];labels_train=[];signals_eval=[];labels_eval=[]
# cnt_ori = 0
# signals_tmp=np.zeros_like(signals)
# labels_tmp=np.zeros_like(labels)
cnt = 0
for i in range(label_num):
if i ==0:
signals_train = signals[cnt:cnt+int(label_cnt[i]*0.8)]
labels_train = labels[cnt:cnt+int(label_cnt[i]*0.8)]
signals_eval = signals[cnt+int(label_cnt[i]*0.8):cnt+label_cnt[i]]
labels_eval = labels[cnt+int(label_cnt[i]*0.8):cnt+label_cnt[i]]
else:
signals_train = np.concatenate((signals_train, signals[cnt:cnt+int(label_cnt[i]*0.8)]))
labels_train = np.concatenate((labels_train, labels[cnt:cnt+int(label_cnt[i]*0.8)]))
signals_eval = np.concatenate((signals_eval, signals[cnt+int(label_cnt[i]*0.8):cnt+label_cnt[i]]))
labels_eval = np.concatenate((labels_eval, labels[cnt+int(label_cnt[i]*0.8):cnt+label_cnt[i]]))
cnt += label_cnt[i]
return signals_train,labels_train,signals_eval,labels_eval
def reducesample(data,mult):
return data[::mult]
def balance_label(signals,labels):
......@@ -51,192 +90,22 @@ def balance_label(signals,labels):
return new_signals,new_labels
# delete uesless label
def del_UND(signals,stages):
stages_copy = stages.copy()
cnt = 0
for i in range(len(stages_copy)):
if stages_copy[i] == 5 :
signals = np.delete(signals,i-cnt,axis =0)
stages = np.delete(stages,i-cnt,axis =0)
cnt += 1
return signals,stages
def connectdata(signal,stage,signals=[],stages=[]):
if signals == []:
signals =signal.copy()
stages =stage.copy()
else:
signals=np.concatenate((signals, signal), axis=0)
stages=np.concatenate((stages, stage), axis=0)
return signals,stages
#load one subject data form cc2018
def loaddata_cc2018(filedir,filename,signal_name,BID,filter = True):
dirpath = os.path.join(filedir,filename)
#load signal
hea_path = os.path.join(dirpath,os.path.basename(dirpath)+'.hea')
signal_path = os.path.join(dirpath,os.path.basename(dirpath)+'.mat')
signal_names = []
for i,line in enumerate(open(hea_path),0):
if i!=0:
line=line.strip()
signal_names.append(line.split()[8])
mat = sio.loadmat(signal_path)
signals = mat['val'][signal_names.index(signal_name)]
if filter:
signals = dsp.BPF(signals,200,0.2,50,mod = 'fir')
#load stage
stagepath = os.path.join(dirpath,os.path.basename(dirpath)+'-arousal.mat')
mat=h5py.File(stagepath,'r')
# N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4 UND->5
N3 = mat['data']['sleep_stages']['nonrem3'][0]
N2 = mat['data']['sleep_stages']['nonrem2'][0]
N1 = mat['data']['sleep_stages']['nonrem1'][0]
REM = mat['data']['sleep_stages']['rem'][0]
W = mat['data']['sleep_stages']['wake'][0]
UND = mat['data']['sleep_stages']['undefined'][0]
stages = N3*0 + N2*1 + N1*2 + REM*3 + W*4 + UND*5
#resample
signals = reducesample(signals,2)
stages = reducesample(stages,2)
#trim
signals = trimdata(signals,3000)
stages = trimdata(stages,3000)
#30s per label
signals = signals.reshape(-1,3000)
stages = stages[::3000]
#Balance individualized differences
signals = transformer.Balance_individualized_differences(signals, BID)
#del UND
signals,stages = del_UND(signals, stages)
return signals.astype(np.float16),stages.astype(np.int16)
#load one subject data form sleep-edfx
def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
filenum = filename[2:6]
filenames = os.listdir(filedir)
for filename in filenames:
if str(filenum) in filename and 'Hypnogram' in filename:
f_stage_name = filename
if str(filenum) in filename and 'PSG' in filename:
f_signal_name = filename
raw_data= mne.io.read_raw_edf(os.path.join(filedir,f_signal_name),preload=True)
raw_annot = mne.read_annotations(os.path.join(filedir,f_stage_name))
eeg = raw_data.pick_channels([signal_name]).to_data_frame().values.T
eeg = eeg.reshape(-1)
raw_data.set_annotations(raw_annot, emit_warning=False)
#N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4 other->UND->5
event_id = {'Sleep stage 4': 0,
'Sleep stage 3': 0,
'Sleep stage 2': 1,
'Sleep stage 1': 2,
'Sleep stage R': 3,
'Sleep stage W': 4,
'Sleep stage ?': 5,
'Movement time': 5}
events, _ = mne.events_from_annotations(
raw_data, event_id=event_id, chunk_duration=30.)
stages = [];signals =[]
for i in range(len(events)-1):
stages.append(events[i][2])
signals.append(eeg[events[i][0]:events[i][0]+3000])
stages=np.array(stages)
signals=np.array(signals)
# #select sleep time
if select_sleep_time:
if 'SC' in f_signal_name:
signals = signals[np.clip(int(raw_annot[0]['duration'])//30-60,0,9999999):int(raw_annot[-2]['onset'])//30+60]
stages = stages[np.clip(int(raw_annot[0]['duration'])//30-60,0,9999999):int(raw_annot[-2]['onset'])//30+60]
signals,stages = del_UND(signals, stages)
print('shape:',signals.shape,stages.shape)
signals = transformer.Balance_individualized_differences(signals, BID)
return signals.astype(np.float16),stages.astype(np.int16)
#load all data in datasets
def loaddataset(opt,shuffle = False):
filedir=opt.dataset_dir
dataset_name = opt.dataset_name
signal_name = opt.signal_name
num = opt.sample_num
BID = opt.BID
select_sleep_time = opt.select_sleep_time
print('load dataset, please wait...')
signals_train=[];labels_train=[];signals_test=[];labels_test=[]
if dataset_name == 'cc2018':
import h5py
filenames = os.listdir(filedir)
if not opt.no_shuffle:
random.shuffle(filenames)
else:
filenames.sort()
if num > len(filenames):
num = len(filenames)
print('num of dataset is:',num)
for cnt,filename in enumerate(filenames[:num],0):
signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID)
if cnt < round(num*0.8) :
signals_train,labels_train = connectdata(signal,stage,signals_train,labels_train)
else:
signals_test,labels_test = connectdata(signal,stage,signals_test,labels_test)
print('train subjects:',round(num*0.8),'test subjects:',round(num*0.2))
elif dataset_name == 'sleep-edfx':
import mne
if num > 197:
num = 197
filenames_sc_train = ['SC4001E0-PSG.edf', 'SC4002E0-PSG.edf', 'SC4011E0-PSG.edf', 'SC4012E0-PSG.edf', 'SC4021E0-PSG.edf', 'SC4022E0-PSG.edf', 'SC4031E0-PSG.edf', 'SC4032E0-PSG.edf', 'SC4041E0-PSG.edf', 'SC4042E0-PSG.edf', 'SC4051E0-PSG.edf', 'SC4052E0-PSG.edf', 'SC4061E0-PSG.edf', 'SC4062E0-PSG.edf', 'SC4071E0-PSG.edf', 'SC4072E0-PSG.edf', 'SC4081E0-PSG.edf', 'SC4082E0-PSG.edf', 'SC4091E0-PSG.edf', 'SC4092E0-PSG.edf', 'SC4101E0-PSG.edf', 'SC4102E0-PSG.edf', 'SC4111E0-PSG.edf', 'SC4112E0-PSG.edf', 'SC4121E0-PSG.edf', 'SC4122E0-PSG.edf', 'SC4131E0-PSG.edf', 'SC4141E0-PSG.edf', 'SC4142E0-PSG.edf', 'SC4151E0-PSG.edf', 'SC4152E0-PSG.edf', 'SC4161E0-PSG.edf', 'SC4162E0-PSG.edf', 'SC4171E0-PSG.edf', 'SC4172E0-PSG.edf', 'SC4181E0-PSG.edf', 'SC4182E0-PSG.edf', 'SC4191E0-PSG.edf', 'SC4192E0-PSG.edf', 'SC4201E0-PSG.edf', 'SC4202E0-PSG.edf', 'SC4211E0-PSG.edf', 'SC4212E0-PSG.edf', 'SC4221E0-PSG.edf', 'SC4222E0-PSG.edf', 'SC4231E0-PSG.edf', 'SC4232E0-PSG.edf', 'SC4241E0-PSG.edf', 'SC4242E0-PSG.edf', 'SC4251E0-PSG.edf', 'SC4252E0-PSG.edf', 'SC4261F0-PSG.edf', 'SC4262F0-PSG.edf', 'SC4271F0-PSG.edf', 'SC4272F0-PSG.edf', 'SC4281G0-PSG.edf', 'SC4282G0-PSG.edf', 'SC4291G0-PSG.edf', 'SC4292G0-PSG.edf', 'SC4301E0-PSG.edf', 'SC4302E0-PSG.edf', 'SC4311E0-PSG.edf', 'SC4312E0-PSG.edf', 'SC4321E0-PSG.edf', 'SC4322E0-PSG.edf', 'SC4331F0-PSG.edf', 'SC4332F0-PSG.edf', 'SC4341F0-PSG.edf', 'SC4342F0-PSG.edf', 'SC4351F0-PSG.edf', 'SC4352F0-PSG.edf', 'SC4362F0-PSG.edf', 'SC4371F0-PSG.edf', 'SC4372F0-PSG.edf', 'SC4381F0-PSG.edf', 'SC4382F0-PSG.edf', 'SC4401E0-PSG.edf', 'SC4402E0-PSG.edf', 'SC4411E0-PSG.edf', 'SC4412E0-PSG.edf', 'SC4421E0-PSG.edf', 'SC4422E0-PSG.edf', 'SC4431E0-PSG.edf', 'SC4432E0-PSG.edf', 'SC4441E0-PSG.edf', 'SC4442E0-PSG.edf', 'SC4451F0-PSG.edf', 'SC4452F0-PSG.edf', 'SC4461F0-PSG.edf', 'SC4462F0-PSG.edf', 'SC4471F0-PSG.edf', 'SC4472F0-PSG.edf', 'SC4481F0-PSG.edf', 'SC4482F0-PSG.edf', 'SC4491G0-PSG.edf', 'SC4492G0-PSG.edf', 'SC4501E0-PSG.edf', 'SC4502E0-PSG.edf', 'SC4511E0-PSG.edf', 'SC4512E0-PSG.edf', 'SC4522E0-PSG.edf', 'SC4531E0-PSG.edf', 'SC4532E0-PSG.edf', 'SC4541F0-PSG.edf', 'SC4542F0-PSG.edf', 'SC4551F0-PSG.edf', 'SC4552F0-PSG.edf', 'SC4561F0-PSG.edf', 'SC4562F0-PSG.edf', 'SC4571F0-PSG.edf', 'SC4572F0-PSG.edf', 'SC4581G0-PSG.edf', 'SC4582G0-PSG.edf', 'SC4591G0-PSG.edf', 'SC4592G0-PSG.edf', 'SC4601E0-PSG.edf', 'SC4602E0-PSG.edf', 'SC4611E0-PSG.edf', 'SC4612E0-PSG.edf', 'SC4621E0-PSG.edf', 'SC4622E0-PSG.edf', 'SC4631E0-PSG.edf', 'SC4632E0-PSG.edf']
filenames_sc_test = ['SC4641E0-PSG.edf', 'SC4642E0-PSG.edf', 'SC4651E0-PSG.edf', 'SC4652E0-PSG.edf', 'SC4661E0-PSG.edf', 'SC4662E0-PSG.edf', 'SC4671G0-PSG.edf', 'SC4672G0-PSG.edf', 'SC4701E0-PSG.edf', 'SC4702E0-PSG.edf', 'SC4711E0-PSG.edf', 'SC4712E0-PSG.edf', 'SC4721E0-PSG.edf', 'SC4722E0-PSG.edf', 'SC4731E0-PSG.edf', 'SC4732E0-PSG.edf', 'SC4741E0-PSG.edf', 'SC4742E0-PSG.edf', 'SC4751E0-PSG.edf', 'SC4752E0-PSG.edf', 'SC4761E0-PSG.edf', 'SC4762E0-PSG.edf', 'SC4771G0-PSG.edf', 'SC4772G0-PSG.edf', 'SC4801G0-PSG.edf', 'SC4802G0-PSG.edf', 'SC4811G0-PSG.edf', 'SC4812G0-PSG.edf', 'SC4821G0-PSG.edf', 'SC4822G0-PSG.edf']
filenames_st_train = ['ST7011J0-PSG.edf', 'ST7012J0-PSG.edf', 'ST7021J0-PSG.edf', 'ST7022J0-PSG.edf', 'ST7041J0-PSG.edf', 'ST7042J0-PSG.edf', 'ST7051J0-PSG.edf', 'ST7052J0-PSG.edf', 'ST7061J0-PSG.edf', 'ST7062J0-PSG.edf', 'ST7071J0-PSG.edf', 'ST7072J0-PSG.edf', 'ST7081J0-PSG.edf', 'ST7082J0-PSG.edf', 'ST7091J0-PSG.edf', 'ST7092J0-PSG.edf', 'ST7101J0-PSG.edf', 'ST7102J0-PSG.edf', 'ST7111J0-PSG.edf', 'ST7112J0-PSG.edf', 'ST7121J0-PSG.edf', 'ST7122J0-PSG.edf', 'ST7131J0-PSG.edf', 'ST7132J0-PSG.edf', 'ST7141J0-PSG.edf', 'ST7142J0-PSG.edf', 'ST7151J0-PSG.edf', 'ST7152J0-PSG.edf', 'ST7161J0-PSG.edf', 'ST7162J0-PSG.edf', 'ST7171J0-PSG.edf', 'ST7172J0-PSG.edf', 'ST7181J0-PSG.edf', 'ST7182J0-PSG.edf', 'ST7191J0-PSG.edf', 'ST7192J0-PSG.edf']
filenames_st_test = ['ST7201J0-PSG.edf', 'ST7202J0-PSG.edf', 'ST7211J0-PSG.edf', 'ST7212J0-PSG.edf', 'ST7221J0-PSG.edf', 'ST7222J0-PSG.edf', 'ST7241J0-PSG.edf', 'ST7242J0-PSG.edf']
for filename in filenames_sc_train[:round(num*153/197*0.8)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_train,labels_train = connectdata(signal,stage,signals_train,labels_train)
for filename in filenames_st_train[:round(num*44/197*0.8)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_train,labels_train = connectdata(signal,stage,signals_train,labels_train)
for filename in filenames_sc_test[:round(num*153/197*0.2)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_test,labels_test = connectdata(signal,stage,signals_test,labels_test)
for filename in filenames_st_test[:round(num*44/197*0.2)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_test,labels_test = connectdata(signal,stage,signals_test,labels_test)
print('---------Each subject has two sample---------',
'\nTrain samples_SC/ST:',round(num*153/197*0.8),round(num*44/197*0.8),
'\nTest samples_SC/ST:',round(num*153/197*0.2),round(num*44/197*0.2))
elif dataset_name == 'preload':
if opt.dataset_name == 'preload':
if opt.separated:
signals_train = np.load(filedir+'/signals_train.npy')
labels_train = np.load(filedir+'/labels_train.npy')
signals_test = np.load(filedir+'/signals_test.npy')
labels_test = np.load(filedir+'/labels_test.npy')
signals_train = np.load(opt.dataset_dir+'/signals_train.npy')
labels_train = np.load(opt.dataset_dir+'/labels_train.npy')
signals_eval = np.load(opt.dataset_dir+'/signals_eval.npy')
labels_eval = np.load(opt.dataset_dir+'/labels_eval.npy')
else:
signals = np.load(filedir+'/signals.npy')
labels = np.load(filedir+'/labels.npy')
signals = np.load(opt.dataset_dir+'/signals.npy')
labels = np.load(opt.dataset_dir+'/labels.npy')
if not opt.no_shuffle:
transformer.shuffledata(signals,labels)
if opt.separated:
return signals_train,labels_train,signals_test,labels_test
return signals_train,labels_train,signals_eval,labels_eval
else:
return signals,labels
\ No newline at end of file
......@@ -85,6 +85,9 @@ class Options():
if self.opt.k_fold == 0 :
self.opt.k_fold = 1
if self.opt.separated:
self.opt.k_fold = 1
self.opt.mergelabel = eval(self.opt.mergelabel)
if self.opt.mergelabel_name != 'None':
self.opt.mergelabel_name = self.opt.mergelabel_name.replace(" ", "").split(",")
......
......@@ -6,9 +6,6 @@ from . import dsp
from . import array_operation as arr
# import dsp
def trimdata(data,num):
return data[:num*int(len(data)/num)]
def shuffledata(data,target):
state = np.random.get_state()
np.random.shuffle(data)
......@@ -16,20 +13,24 @@ def shuffledata(data,target):
np.random.shuffle(target)
# return data,target
def k_fold_generator(length,fold_num):
if fold_num == 0 or fold_num == 1:
train_sequence = np.linspace(0,int(length*0.8)-1,int(length*0.8),dtype='int')[None]
test_sequence = np.linspace(int(length*0.8),length-1,int(length*0.2),dtype='int')[None]
def k_fold_generator(length,fold_num,separated=False):
if separated:
sequence = np.linspace(0, length-1,num = length,dtype='int')
return sequence
else:
sequence = np.linspace(0,length-1,length,dtype='int')
train_length = int(length/fold_num*(fold_num-1))
test_length = int(length/fold_num)
train_sequence = np.zeros((fold_num,train_length), dtype = 'int')
test_sequence = np.zeros((fold_num,test_length), dtype = 'int')
for i in range(fold_num):
test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length]
train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length]
return train_sequence,test_sequence
if fold_num == 0 or fold_num == 1:
train_sequence = np.linspace(0,int(length*0.8)-1,int(length*0.8),dtype='int')[None]
test_sequence = np.linspace(int(length*0.8),length-1,int(length*0.2),dtype='int')[None]
else:
sequence = np.linspace(0,length-1,length,dtype='int')
train_length = int(length/fold_num*(fold_num-1))
test_length = int(length/fold_num)
train_sequence = np.zeros((fold_num,train_length), dtype = 'int')
test_sequence = np.zeros((fold_num,test_length), dtype = 'int')
for i in range(fold_num):
test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length]
train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length]
return train_sequence,test_sequence
def batch_generator(data,target,sequence,shuffle = True):
batchsize = len(sequence)
......@@ -48,22 +49,7 @@ def Normalize(data,maxmin,avg,sigma,is_01=False):
else:
return (data-avg)/sigma #(-1,1)
def Balance_individualized_differences(signals,BID):
if BID == 'median':
signals = (signals*8/(np.median(abs(signals))))
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
elif BID == '5_95_th':
tmp = np.sort(signals.reshape(-1))
th_5 = -tmp[int(0.05*len(tmp))]
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5,is_01=True)
else:
#dataser 5_95_th(-1,1) median
#CC2018 24.75 7.438
#sleep edfx 37.4 9.71
#sleep edfx sleeptime 39.03 10.125
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
return signals
def ToTensor(data,target=None,gpu_id=0):
if target is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册