diff --git a/README.md b/README.md index 7826fabf120a423568fbb11e377f0ad8019ada12..7bdae109e3ddd40c8545419454587a35970ff052 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,6 @@ | Network | Label average recall | Label average accuracy | error rate | | :------------- | :------------------- | ---------------------- | ---------- | - | lstm | 0.7257 | 0.9266 | 0.1836 | + | lstm | 0.8342 | 0.9611 | 0.0974 | | resnet18_1d | 0.8434 | 0.9627 | 0.093 | | DFCNN+resnet18 | 0.8567 | 0.9663 | 0.0842 | diff --git a/data.py b/data.py index 3da21d857f3264fe49353f604bf68fb79bedd2be..6b8f15b8b45c43998b7aa42a1403a101de6555a5 100644 --- a/data.py +++ b/data.py @@ -49,8 +49,8 @@ def random_transform_1d(data,finesize,test_flag): result = data[move:move+finesize] #random flip - if random.random()<0.5: - result = result[::-1] + # if random.random()<0.5: + # result = result[::-1] #random amp result = result*random.uniform(0.95,1.05) @@ -106,7 +106,7 @@ def ToInputShape(data,net_name,norm=True,test_flag = False): if norm: #sleep_def : std,mean,median = 0.4157 0.3688 0.2473 #challge 2018 : std,mean,median,max= 0.2972 0.3008 0.2006 2.0830 - result=Normalize(result,2,0.3,1) + result=Normalize(result,3,0.3,1) result = result.reshape(batchsize,1,224,224) # print(result.shape) diff --git a/dataloader.py b/dataloader.py index 844eb469460f0e21809d126f4ae74ede28fe252e..d11c30200da0a57e4e65f80449aec64816ddb2d1 100644 --- a/dataloader.py +++ b/dataloader.py @@ -101,7 +101,7 @@ def loaddata_sleep_edf(filedir,filenum,signal_name,BID = 'median',filter = True) signals = signals[events[0][0]:events[-1][0]] events = np.array(events) signals = signals.reshape(-1,3000) - # signals = signals*13/np.median(np.abs(signals)) + signals = signals*13/np.median(np.abs(signals)) stages = events[:,2] stages = stages[:len(signals)] @@ -170,11 +170,13 @@ def loaddataset(filedir,dataset_name = 'CinC_Challenge_2018',signal_name = 'C4-M except Exception as e: print(filename,e) elif dataset_name in ['sleep-edfx','sleep-edf']: - cnt = 0 + if num > 197: + num = 197 if dataset_name == 'sleep-edf': filenames = ['SC4002E0-PSG.edf','SC4012E0-PSG.edf','SC4102E0-PSG.edf','SC4112E0-PSG.edf', 'ST7022J0-PSG.edf','ST7052J0-PSG.edf','ST7121J0-PSG.edf','ST7132J0-PSG.edf'] - + + cnt = 0 for filename in filenames: if 'PSG' in filename: signal,stage = loaddata_sleep_edf(filedir,filename[2:6],signal_name = 'EEG Fpz-Cz') diff --git a/image/LSTM_heatmap_sleep-edf.png b/image/LSTM_heatmap_sleep-edf.png new file mode 100644 index 0000000000000000000000000000000000000000..27c9edab372fab1d94b410c0cc1eb2c3d2509bfe Binary files /dev/null and b/image/LSTM_heatmap_sleep-edf.png differ diff --git a/image/LSTM_running_recall_sleep-edf.png b/image/LSTM_running_recall_sleep-edf.png new file mode 100644 index 0000000000000000000000000000000000000000..d0d20a5580a1b43dbd8697b275a5558147230114 Binary files /dev/null and b/image/LSTM_running_recall_sleep-edf.png differ diff --git a/image/confusion_mat b/image/confusion_mat index 89594b5c5d187d08806050f9cfff92664d4b8203..cafb090b974c3d8da83a4a3baa3b3603ed70752f 100644 --- a/image/confusion_mat +++ b/image/confusion_mat @@ -8,8 +8,20 @@ True S3 REM WAKE -DFresnet18_heatmap_sleep-edf -epoch:10 test avg_recall:0.8567 avg_acc:0.9663 error:0.0842 +-------------------sleep-edf------------------- + +LSTM +avg_recall:0.8342 avg_acc:0.9611 error:0.0974 +confusion_mat: +[[ 232 34 0 1 2] + [ 41 609 17 25 2] + [ 1 9 77 26 3] + [ 1 28 32 268 1] + [ 3 1 51 18 1558]] + + +DFresnet18 +avg_recall:0.8567 avg_acc:0.9663 error:0.0842 confusion_mat: [[ 238 9 2 1 0] [ 59 587 29 41 3] @@ -18,8 +30,7 @@ confusion_mat: [ 1 1 27 7 1621]] -resnet18_1d_heatmap_sleep-edf -epoch:10 +resnet18_1d_sleep-edf avg_recall:0.8434 avg_acc:0.9627 error:0.0930 confusion_mat: [[ 225 37 1 0 1] @@ -27,3 +38,5 @@ confusion_mat: [ 0 4 85 28 1] [ 0 19 43 261 3] [ 1 3 44 9 1533]] + +-------------------sleep-edfx------------------- diff --git a/models.py b/models.py index 10149f358e1f0ce0e6a7a08011fe36f77cc2682b..e3dac8a7e158b72ee7edfd2421f92933df14cfdd 100644 --- a/models.py +++ b/models.py @@ -15,12 +15,15 @@ def CreatNet(name): elif name in ['resnet101','resnet50','resnet18']: if name =='resnet101': net = torchvision.models.resnet101(pretrained=False) + net.fc = nn.Linear(2048, 5) elif name =='resnet50': net = torchvision.models.resnet50(pretrained=False) + net.fc = nn.Linear(2048, 5) elif name =='resnet18': net = torchvision.models.resnet18(pretrained=False) + net.fc = nn.Linear(512, 5) net.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False) - net.fc = nn.Linear(512, 5) + return net elif 'densenet' in name: @@ -56,7 +59,7 @@ class LSTM(nn.Module): self.out = nn.Linear(Hidden_size, CLASS_NUM) def forward(self, x): - x=self.bn(x) + # x=self.bn(x) x=x.view(-1, self.TIME_STEP, self.INPUT_SIZE) r_out, (h_n, h_c) = self.lstm(x, None) # None represents zero initial hidden state x=r_out[:, -1, :] diff --git a/options.py b/options.py index 92ec6e6d2d37639a0c3c07a29fe17f917a925f9a..f208dc44a605f6809bea6d1d4589c6aa762c689e 100644 --- a/options.py +++ b/options.py @@ -12,10 +12,11 @@ class Options(): self.initialized = False def initialize(self): - self.parser.add_argument('--no_cuda', action='store_true', help='if true, do not use gpu') + self.parser.add_argument('--no_cuda', action='store_true', help='if input, do not use gpu') + self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models') self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate') self.parser.add_argument('--batchsize', type=int, default=16,help='batchsize') - self.parser.add_argument('--dataset_dir', type=str, default='./sleep-edfx/sleep-cassette', + self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/', help='your dataset path') self.parser.add_argument('--dataset_name', type=str, default='sleep-edf',help='Choose dataset') self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1|EEG Fpz-Cz') @@ -23,6 +24,8 @@ class Options(): self.parser.add_argument('--model_name', type=str, default='resnet18',help='Choose model') self.parser.add_argument('--epochs', type=int, default=20,help='end epoch') self.parser.add_argument('--weight_mod', type=str, default='avg_best',help='Choose weight mode: avg_best|normal') + self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network') + self.initialized = True @@ -30,6 +33,7 @@ class Options(): if not self.initialized: self.initialize() self.opt = self.parser.parse_args() + if self.opt.dataset_name == 'sleep-edf': self.opt.sample_num = 8 @@ -40,9 +44,9 @@ class Options(): if self.opt.dataset_name == 'CinC_Challenge_2018': weight = np.log(1/np.array([0.15,0.3,0.08,0.13,0.18])) elif self.opt.dataset_name == 'sleep-edfx': - weight = np.log(1/np.array([0.08,0.30,0.05,0.15,0.35])) + weight = np.log(1/np.array([0.04,0.20,0.04,0.08,0.63])) elif self.opt.dataset_name == 'sleep-edf': - weight = np.log(1/np.array([0.08,0.23,0.02,0.10,0.53])) + weight = np.log(1/np.array([0.08,0.23,0.01,0.10,0.53])) self.opt.weight = weight diff --git a/train.py b/train.py index 7c55c4ccb2d48b34dd3233c6ca654b5b45f43c01..ccc4ae9d03a206ca80806472fc7765f20d922bec 100644 --- a/train.py +++ b/train.py @@ -18,7 +18,7 @@ warnings.filterwarnings("ignore") opt = Options().getparse() localtime = time.asctime(time.localtime(time.time())) -statistics.writelog('\n'+str(localtime)+'\n'+str(opt)) +statistics.writelog('\n\n'+str(localtime)+'\n'+str(opt)) t1 = time.time() signals,stages = dataloader.loaddataset(opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,shuffle=True,BID='median') @@ -34,16 +34,20 @@ t2 = time.time() print('load data cost time:',t2-t1) net=models.CreatNet(opt.model_name) +# print(net) +if opt.pretrained: + net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.model_name+'.pth')) + weight = torch.from_numpy(opt.weight).float() # print(net) if not opt.no_cuda: net.cuda() weight = weight.cuda() cudnn.benchmark = True - +# print(weight) # time.sleep(2000) optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr) -scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) criterion = nn.CrossEntropyLoss(weight) @@ -76,7 +80,7 @@ def evalnet(net,signals,stages,epoch,plot_result={},mode = 'part'): heatmap.draw(confusion_mat,name = 'test') print('test avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'error:','%.4f' % error) statistics.writelog('epoch:'+str(epoch)+' test avg_recall:'+str(round(recall,4))+' avg_acc:'+str(round(acc,4))+' error:'+str(round(error,4))) - if epoch%5==0: + if epoch%1==0: statistics.writelog('confusion_mat:\n'+str(confusion_mat)) # torch.cuda.empty_cache() return plot_result @@ -120,5 +124,11 @@ for epoch in range(opt.epochs): evalnet(net,signals_eval,stages_eval,epoch+1,plot_result,mode = 'all') scheduler.step() + if (epoch+1)%opt.network_save_freq == 0: + torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth') + print('network saved.') + if not opt.no_cuda: + net.cuda() + t2=time.time() print('cost time: %.2f' % (t2-t1)) \ No newline at end of file