提交 f817cab2 编写于 作者: H hypox64

add k-fold

上级 72406b56
......@@ -105,7 +105,8 @@ def loaddata_sleep_edf(opt,filedir,filenum,signal_name,BID = 'median',filter = T
signals.append(eeg[events[i][0]:events[i][0]+3000])
stages=np.array(stages)
signals=np.array(signals)
signals = signals*13/np.median(np.abs(signals))
if BID == 'median':
signals = signals*13/np.median(np.abs(signals))
# #select sleep time
if opt.select_sleep_time:
......@@ -122,35 +123,6 @@ def loaddata_sleep_edf(opt,filedir,filenum,signal_name,BID = 'median',filter = T
cnt += 1
print('shape:',signals.shape,stages.shape)
'''
f_stage = pyedflib.EdfReader(os.path.join(filedir,f_stage_name))
annotations = f_stage.readAnnotations()
number_of_annotations = f_stage.annotations_in_file
end_duration = int(annotations[0][number_of_annotations-1])+int(annotations[1][number_of_annotations-1])
stages = np.zeros(end_duration//30, dtype=int)
# print(number_of_annotations)
for i in range(number_of_annotations):
stages[int(annotations[0][i])//30:(int(annotations[0][i])+int(annotations[1][i]))//30] = stage_str2int(annotations[2][i])
f_signal = pyedflib.EdfReader(os.path.join(filedir,f_signal_name))
signals = f_signal.readSignal(0)
signals=trimdata(signals,3000)
signals = signals.reshape(-1,3000)
stages = stages[0:signals.shape[0]]
# #select sleep time
# signals = signals[np.clip(int(annotations[1][0])//30-60,0,9999999):int(annotations[0][number_of_annotations-2])//30+60]
# stages = stages[np.clip(int(annotations[1][0])//30-60,0,9999999):int(annotations[0][number_of_annotations-2])//30+60]
#del UND
stages_copy = stages.copy()
cnt = 0
for i in range(len(stages_copy)):
if stages_copy[i] == 5 :
signals = np.delete(signals,i-cnt,axis =0)
stages = np.delete(stages,i-cnt,axis =0)
cnt += 1
'''
return signals.astype(np.int16),stages.astype(np.int16)
......@@ -168,7 +140,7 @@ def loaddataset(opt,filedir,dataset_name = 'CinC_Challenge_2018',signal_name = '
for i,filename in enumerate(filenames[:num],0):
try:
signal,stage = loaddata(os.path.join(filedir,filename),signal_name,BID)
signal,stage = loaddata(os.path.join(filedir,filename),signal_name,BID = None)
if i == 0:
signals =signal.copy()
stages =stage.copy()
......@@ -187,7 +159,7 @@ def loaddataset(opt,filedir,dataset_name = 'CinC_Challenge_2018',signal_name = '
cnt = 0
for filename in filenames:
if 'PSG' in filename:
signal,stage = loaddata_sleep_edf(opt,filedir,filename[2:6],signal_name = 'EEG Fpz-Cz')
signal,stage = loaddata_sleep_edf(opt,filedir,filename[2:6],signal_name = signal_name)
if cnt == 0:
signals =signal.copy()
stages =stage.copy()
......
......@@ -2,6 +2,9 @@ import argparse
import os
import numpy as np
import torch
#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 200 --model_name resnet18 --batchsize 32 --epochs 10 --fold_num 5 --pretrained
#python3 train_new.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 10 --model_name LSTM --batchsize 32 --network_save_freq 100 --epochs 10
#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 10 --model_name resnet18 --batchsize 32
#filedir = '/media/hypo/Hypo/physionet_org_train'
# filedir ='E:\physionet_org_train'
#python3 train.py --dataset_name sleep-edf --model_name resnet50 --batchsize 4 --epochs 50 --pretrained
......@@ -13,17 +16,19 @@ class Options():
def initialize(self):
self.parser.add_argument('--no_cuda', action='store_true', help='if input, do not use gpu')
self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn')
self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models')
self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate')
self.parser.add_argument('--fold_num', type=int, default=5,help='k-fold')
self.parser.add_argument('--batchsize', type=int, default=16,help='batchsize')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/',
help='your dataset path')
self.parser.add_argument('--dataset_name', type=str, default='sleep-edf',help='Choose dataset')
self.parser.add_argument('--dataset_name', type=str, default='sleep-edf',help='Choose dataset sleep-edf|sleep-edf|CinC_Challenge_2018|')
self.parser.add_argument('--select_sleep_time', action='store_true', help='if input, for sleep-cassette only use sleep time to train')
self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1|EEG Fpz-Cz')
self.parser.add_argument('--sample_num', type=int, default=20,help='the amount you want to load')
self.parser.add_argument('--model_name', type=str, default='resnet18',help='Choose model')
self.parser.add_argument('--epochs', type=int, default=20,help='end epoch')
self.parser.add_argument('--epochs', type=int, default=50,help='end epoch')
self.parser.add_argument('--weight_mod', type=str, default='avg_best',help='Choose weight mode: avg_best|normal')
self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network')
......
import numpy as np
import matplotlib.pyplot as plt
def writelog(log):
f = open('./log','a+')
f.write(log+'\n')
# print(log)
import util
def stage(stages):
#N3->0 N2->1 N1->2 REM->3 W->4
......@@ -12,6 +8,8 @@ def stage(stages):
for i in range(len(stages)):
stage_cnt[stages[i]] += 1
stage_cnt_per = stage_cnt/len(stages)
util.writelog('statistics of dataset [S3 S2 S1 R W]: '+str(stage_cnt))
print('statistics of dataset [S3 S2 S1 R W]:\n',stage_cnt,'\n',stage_cnt_per)
return stage_cnt,stage_cnt_per
def result(mat):
......@@ -20,7 +18,10 @@ def result(mat):
sub_recall = np.zeros(wide)
err = 0
for i in range(wide):
sub_recall[i]=mat[i,i]/np.sum(mat[i])
if np.sum(mat[i]) == 0 :
sub_recall[i] = 0
else:
sub_recall[i]=mat[i,i]/np.sum(mat[i])
err += mat[i,i]
sub_acc[i] = (np.sum(mat)-((np.sum(mat[i])+np.sum(mat[:,i]))-2*mat[i,i]))/np.sum(mat)
avg_recall = np.mean(sub_recall)
......@@ -28,6 +29,14 @@ def result(mat):
err = 1-err/np.sum(mat)
return avg_recall,avg_acc,err
def stagefrommat(mat):
wide=mat.shape[0]
stage_num = np.zeros(wide,dtype='int')
for i in range(wide):
stage_num[i]=np.sum(mat[i])
util.writelog('statistics of dataset [S3 S2 S1 R W]:\n'+str(stage_num),True)
def show(plot_result,epoch):
train_recall = np.array(plot_result['train'])
......
......@@ -17,59 +17,56 @@ warnings.filterwarnings("ignore")
opt = Options().getparse()
localtime = time.asctime(time.localtime(time.time()))
statistics.writelog('\n\n'+str(localtime)+'\n'+str(opt))
util.writelog('\n\n'+str(localtime)+'\n'+str(opt))
t1 = time.time()
signals,stages = dataloader.loaddataset(opt,opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,shuffle=True,BID='median')
stage_cnt_per = statistics.stage(stages)[1]
print('stage_cnt_per:',stage_cnt_per,'\nlength of dataset:',len(stages))
signals_train,stages_train,signals_eval,stages_eval, = transformer.batch_generator(signals,stages,opt.batchsize,shuffle = True)
batch_length = len(signals_train)+len(signals_eval)
signals,stages = dataloader.loaddataset(opt,opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,shuffle=True,BID=None)
stage_cnt,stage_cnt_per = statistics.stage(stages)
signals,stages = transformer.batch_generator(signals,stages,opt.batchsize,shuffle = True)
batch_length = len(stages)
print('length of batch:',batch_length)
show_freq = int(len(signals_train)/5)
train_sequences,test_sequences = transformer.k_fold_generator(batch_length,opt.fold_num)
show_freq = int(len(train_sequences[0])/5)
util.show_menory()
t2 = time.time()
print('load data cost time:',t2-t1)
print('load data cost time: %.2f'% (t2-t1),'s')
net=models.CreatNet(opt.model_name)
# print(net)
if opt.pretrained:
net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.model_name+'.pth'))
torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'.pth')
weight = np.array([1,1,1,1,1])
if opt.weight_mod == 'avg_best':
weight = np.log(1/stage_cnt_per)
weight[2] = weight[2]+1
print(weight)
weight = np.clip(weight,1,5)
print('Loss_weight:',weight)
weight = torch.from_numpy(weight).float()
# print(net)
if not opt.no_cuda:
net.cuda()
weight = weight.cuda()
# cudnn.benchmark = True
# print(weight)
# time.sleep(2000)
if not opt.no_cudnn:
cudnn.benchmark = True
optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
criterion = nn.CrossEntropyLoss(weight)
def evalnet(net,signals,stages,epoch,plot_result={},mode = 'part'):
net.eval()
def evalnet(net,signals,stages,sequences,epoch,plot_result={},mode = 'part'):
# net.eval()
if mode =='part':
transformer.shuffledata(signals,stages)
signals=signals[0:int(len(stages)/2)]
stages=stages[0:int(len(stages)/2)]
confusion_mat = np.zeros((5,5), dtype=int)
for i, (signal, stage) in enumerate(zip(signals,stages), 1):
for i, sequence in enumerate(sequences, 1):
signal=transformer.ToInputShape(signal,opt.model_name,test_flag =True)
signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda)
out = net(signal)
loss = criterion(out, stage)
signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =True)
signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda)
with torch.no_grad():
out = net(signal)
pred = torch.max(out, 1)[1]
pred=pred.data.cpu().numpy()
......@@ -83,56 +80,71 @@ def evalnet(net,signals,stages,epoch,plot_result={},mode = 'part'):
plot_result['test'].append(recall)
heatmap.draw(confusion_mat,name = 'test')
print('test avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'error:','%.4f' % error)
statistics.writelog('epoch:'+str(epoch)+' test avg_recall:'+str(round(recall,4))+' avg_acc:'+str(round(acc,4))+' error:'+str(round(error,4)))
if epoch%1==0:
statistics.writelog('confusion_mat:\n'+str(confusion_mat))
# torch.cuda.empty_cache()
return plot_result
#util.writelog('epoch:'+str(epoch)+' test avg_recall:'+str(round(recall,4))+' avg_acc:'+str(round(acc,4))+' error:'+str(round(error,4)))
return plot_result,confusion_mat
plot_result={}
plot_result['train']=[0]
plot_result['test']=[0]
print('begin to train ...')
for epoch in range(opt.epochs):
t1 = time.time()
confusion_mat = np.zeros((5,5), dtype=int)
print('epoch:',epoch+1)
net.train()
for i, (signal, stage) in enumerate(zip(signals_train,stages_train), 1):
signal=transformer.ToInputShape(signal,opt.model_name,test_flag =False)
signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda)
out = net(signal)
loss = criterion(out, stage)
pred = torch.max(out, 1)[1]
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred=pred.data.cpu().numpy()
stage=stage.data.cpu().numpy()
for x in range(len(pred)):
confusion_mat[stage[x]][pred[x]] += 1
if i%show_freq==0:
# torch.cuda.empty_cache()
plot_result['train'].append(statistics.result(confusion_mat)[0])
heatmap.draw(confusion_mat,name = 'train')
# plot_result=evalnet(net,signals_eval,stages_eval,plot_result,show_freq,mode = 'part')
statistics.show(plot_result,epoch+i/(batch_length*0.8))
confusion_mat[:]=0
# net.train()
# torch.cuda.empty_cache()
evalnet(net,signals_eval,stages_eval,epoch+1,plot_result,mode = 'all')
# scheduler.step()
if (epoch+1)%opt.network_save_freq == 0:
torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')
print('network saved.')
if not opt.no_cuda:
net.cuda()
t2=time.time()
print('cost time: %.2f' % (t2-t1))
\ No newline at end of file
final_confusion_mat = np.zeros((5,5), dtype=int)
for fold in range(opt.fold_num):
net.load_state_dict(torch.load('./checkpoints/'+opt.model_name+'.pth'))
if opt.pretrained:
net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.model_name+'.pth'))
if not opt.no_cuda:
net.cuda()
plot_result={'train':[0],'test':[0]}
confusion_mats = []
for epoch in range(opt.epochs):
t1 = time.time()
confusion_mat = np.zeros((5,5), dtype=int)
print('fold:',fold+1,'epoch:',epoch+1)
net.train()
for i, sequence in enumerate(train_sequences[fold], 1):
signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =False)
signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda)
out = net(signal)
loss = criterion(out, stage)
pred = torch.max(out, 1)[1]
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred=pred.data.cpu().numpy()
stage=stage.data.cpu().numpy()
for x in range(len(pred)):
confusion_mat[stage[x]][pred[x]] += 1
if i%show_freq==0:
plot_result['train'].append(statistics.result(confusion_mat)[0])
heatmap.draw(confusion_mat,name = 'train')
# plot_result=evalnet(net,signals_eval,stages_eval,plot_result,show_freq,mode = 'part')
statistics.show(plot_result,epoch+i/(batch_length*0.8))
confusion_mat[:]=0
plot_result,confusion_mat = evalnet(net,signals,stages,test_sequences[fold],epoch+1,plot_result,mode = 'all')
confusion_mats.append(confusion_mat)
# scheduler.step()
if (epoch+1)%opt.network_save_freq == 0:
torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')
print('network saved.')
if not opt.no_cuda:
net.cuda()
t2=time.time()
print('cost time: %.2f' % (t2-t1),'s')
pos = plot_result['test'].index(max(plot_result['test']))-1
final_confusion_mat = final_confusion_mat+confusion_mats[pos]
recall,acc,error = statistics.result(confusion_mats[pos])
print('\nfold:',fold+1,'finished',' avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'error:','%.4f' % error,'\n')
util.writelog('fold:'+str(fold+1)+' test avg_recall:'+str(round(recall,4))+' avg_acc:'+str(round(acc,4))+' error:'+str(round(error,4)))
util.writelog('confusion_mat:\n'+str(confusion_mat))
recall,acc,error = statistics.result(final_confusion_mat)
#print('all finished!\n',final_confusion_mat)
#print('avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'error:','%.4f' % error)
util.writelog('final:'+' test avg_recall:'+str(round(recall,4))+' avg_acc:'+str(round(acc,4))+' error:'+str(round(error,4)),True)
util.writelog('confusion_mat:\n'+str(confusion_mat),True)
statistics.stagefrommat(confusion_mat)
heatmap.draw(final_confusion_mat,name = 'final_test')
\ No newline at end of file
......@@ -15,17 +15,39 @@ def shuffledata(data,target):
np.random.shuffle(target)
# return data,target
def batch_generator(data,target,batchsize,shuffle = True):
if shuffle:
shuffledata(data,target)
data = trimdata(data,batchsize)
target = trimdata(target,batchsize)
data = data.reshape(-1,batchsize,3000)
target = target.reshape(-1,batchsize)
return data,target
def k_fold_generator(length,fold_num):
sequence = np.linspace(0,length-1,length,dtype='int')
train_length = int(length/fold_num*(fold_num-1))
test_length = int(length/fold_num)
train_sequence = np.zeros((fold_num,train_length), dtype = 'int')
test_sequence = np.zeros((fold_num,test_length), dtype = 'int')
for i in range(fold_num):
test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length]
train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length]
return train_sequence,test_sequence
'''
def batch_generator(data,target,batchsize,shuffle = True):
data = trimdata(data,batchsize)
target = trimdata(target,batchsize)
data = data.reshape(-1,batchsize,3000)
target = target.reshape(-1,batchsize)
return data[0:int(0.8*len(target))],target[0:int(0.8*len(target))],data[int(0.8*len(target)):],target[int(0.8*len(target)):]
signals_train,stages_train,signals_eval,stages_eval = data[0:int(0.8*len(target))],target[0:int(0.8*len(target))],data[int(0.8*len(target)):],target[int(0.8*len(target)):]
if shuffle:
shuffledata(signals_train,stages_train)
shuffledata(signals_eval,stages_eval)
return signals_train,stages_train,signals_eval,stages_eval
'''
def Normalize(data,maxmin,avg,sigma):
data = np.clip(data, -maxmin, maxmin)
return (data-avg)/sigma
......@@ -67,14 +89,14 @@ def random_transform_2d(img,finesize = (224,122),test_flag = True):
result = img[h_move:h_move+finesize[0],w_move:w_move+finesize[1]]
else:
#random crop
h_move = int(5*random.random()) #do not loss low freq signal infos
h_move = int(10*random.random()) #do not loss low freq signal infos
w_move = int((w-finesize[1])*random.random())
result = img[h_move:h_move+finesize[0],w_move:w_move+finesize[1]]
#random flip
if random.random()<0.5:
result = result[:,::-1]
#random amp
result = result*random.uniform(0.95,1.05)+random.uniform(-0.02,0.02)
result = result*random.uniform(0.9,1.1)+random.uniform(-0.05,0.05)
return result
......@@ -127,7 +149,7 @@ def ToInputShape(data,net_name,norm=True,test_flag = False):
if norm:
#sleep_def : std,mean,median = 0.4157 0.3688 0.2473
#challge 2018 : std,mean,median,max= 0.2972 0.3008 0.2006 2.0830
result=Normalize(result,3,0.3,1)
result=Normalize(result,2,0.3,1)
result = result.reshape(batchsize,1,224,122)
# print(result.shape)
......
......@@ -3,4 +3,10 @@ import memory_profiler
def show_menory():
usage=int(memory_profiler.memory_usage()[0])
print('menory usage:',usage,'MB')
return usage
\ No newline at end of file
return usage
def writelog(log,printflag = False):
f = open('./log','a+')
f.write(log+'\n')
if printflag:
print(log)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册