提交 c7de34fb 编写于 作者: H hypox64

Update to Pytorch 1.0

上级 9f82c3a6
......@@ -18,11 +18,11 @@
DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类.<br>
* EEG频谱图<br>
这里展示5个睡眠阶段对应的频谱图,它们依次是wake, stage 1, stage 2, stage 3, REM<br>
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_wake.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N1.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N2.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N3.png)
这里展示5个睡眠阶段对应的频谱图,它们依次是Wake, Stage 1, Stage 2, Stage 3, REM<br>
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_Wake.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_Stage1.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_Stage2.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_Stage3.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_REM.png)
* 关于代码<br>
......@@ -51,4 +51,5 @@
* CinC Challenge 2018<br>
## 心路历程
* 2019/04/01 DFCNN的运算量也忒大了,提升还不明显,还容易过拟合......真是食之无味,弃之可惜...
\ No newline at end of file
* 2019/04/01 DFCNN的运算量也忒大了,提升还不明显,还容易过拟合......真是食之无味,弃之可惜...
* 2019/04/03 花了一天更新到pytorch 1.0, 然后尝试了一下缩小输入频谱图的尺寸从而减小运算量...
......@@ -5,7 +5,7 @@ import os
import time
import torch
import random
import DSP
import dsp
# import pyedflib
import mne
......@@ -45,7 +45,7 @@ def loaddata(dirpath,signal_name,BID = 'median',filter = True):
#load
signals = loadsignals(dirpath,signal_name)
if filter:
signals = DSP.BPF(signals,200,0.2,50,mod = 'fir')
signals = dsp.BPF(signals,200,0.2,50,mod = 'fir')
stages = loadstages(dirpath)
#resample
signals = reducesample(signals,2)
......
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate
def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return torch.cat([x, new_features], 1)
class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
self.add_module('denselayer%d' % (i + 1), layer)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
out = self.classifier(out)
return out
def _load_state_dict(model, model_url):
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_url)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet121'])
return model
def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet169'])
return model
def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet201'])
return model
def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet161'])
return model
import scipy.signal
import scipy.fftpack as fftpack
import numpy as np
b1 = scipy.signal.firwin(31, [0.5, 4], pass_zero=False,fs=100)
b2 = scipy.signal.firwin(31, [4,8], pass_zero=False,fs=100)
b3 = scipy.signal.firwin(31, [8,12], pass_zero=False,fs=100)
b4 = scipy.signal.firwin(31, [12,16], pass_zero=False,fs=100)
b5 = scipy.signal.firwin(31, [16,45], pass_zero=False,fs=100)
def getfir_b(fc1,fc2,fs):
if fc1==0.5 and fc2==4 and fs==100:
b=b1
elif fc1==4 and fc2==8 and fs==100:
b=b2
elif fc1==8 and fc2==12 and fs==100:
b=b3
elif fc1==12 and fc2==16 and fs==100:
b=b4
elif fc1==16 and fc2==45 and fs==100:
b=b5
else:
b=scipy.signal.firwin(51, [fc1, fc2], pass_zero=False,fs=fs)
return b
def BPF(signal,fs,fc1,fc2,mod = 'fir'):
if mod == 'fft':
length=len(signal)#get N
k1=int(fc1*length/fs)#get k1=Nw1/fs
k2=int(fc2*length/fs)#get k1=Nw1/fs
#FFT
signal_fft=fftpack.fft(signal)
#Frequency truncation
signal_fft[0:k1]=0+0j
signal_fft[k2:length-k2]=0+0j
signal_fft[length-k1:length]=0+0j
#IFFT
signal_ifft=fftpack.ifft(signal_fft)
result = signal_ifft.real
else:
b=getfir_b(fc1,fc2,fs)
result = scipy.signal.lfilter(b, 1, signal)
return result
def getfeature(signal,mod = 'fir',ch_num = 5):
result=[]
signal =signal - np.mean(signal)
eeg=signal
beta=BPF(eeg,100,16,45,mod) # β
theta=BPF(eeg,100,4,8,mod)
sigma=BPF(eeg,100,12,16,mod) #σ spindle
alpha=BPF(eeg,100,8,12,mod)
delta=BPF(eeg,100,0.5,4,mod)
result.append(beta)
result.append(theta)
result.append(sigma)
result.append(alpha)
result.append(delta)
if ch_num == 6:
fft = abs(fftpack.fft(eeg))
fft = fft - np.median(fft)
result.append(fft)
result=np.array(result)
result=result.reshape(ch_num*len(signal),)
return result
# def signal2spectrum(signal):
# spectrum =np.zeros((224,224))
# spectrum_y = np.zeros((224))
# for i in range(224):
# signal_window=signal[i*9:i*9+896]
# signal_window_fft=np.abs(np.fft.fft(signal_window))[0:448]
# spectrum_y[0:112]=signal_window_fft[0:112]
# spectrum_y[112:168]=signal_window_fft[112:224][::2]
# spectrum_y[168:224]=signal_window_fft[224:448][::4]
# spectrum[:,i] = spectrum_y
# # spectrum = np.log(spectrum+1)/11
# return spectrum
# def signal2spectrum(data):
# # window : ('tukey',0.5) hann
# zxx = scipy.signal.stft(data, fs=100, window='hann', nperseg=1024, noverlap=1024-12, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2]
# zxx =np.abs(zxx)[:512]
# spectrum=np.zeros((256,251))
# spectrum[0:128]=zxx[0:128]
# spectrum[128:192]=zxx[128:256][::2]
# spectrum[192:256]=zxx[256:512][::4]
# spectrum = np.log(spectrum+1)
# return spectrum
def signal2spectrum(data):
# window : ('tukey',0.5) hann
zxx = scipy.signal.stft(data, fs=100, window=('tukey',0.5), nperseg=1024, noverlap=1024-24, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2]
zxx =np.abs(zxx)[:512]
spectrum=np.zeros((256,126))
spectrum[0:128]=zxx[0:128]
spectrum[128:192]=zxx[128:256][::2]
spectrum[192:256]=zxx[256:512][::4]
spectrum = np.log(spectrum+1)
return spectrum
\ No newline at end of file
......@@ -66,3 +66,12 @@ confusion_mat:
[ 54 671 3371 399 468]
[ 29 379 1454 4759 151]
[ 40 50 1272 124 12881]]
resnet18_1d
avg_recall:0.7846 avg_acc:0.9203 error:0.1993
confusion_mat:
[[ 3263 576 14 0 0]
[ 1100 14074 1728 917 45]
[ 30 876 2958 821 307]
[ 13 396 652 5728 68]
[ 13 77 1585 312 12255]]
image/spectrum_REM.png

24.8 KB | W: | H:

image/spectrum_REM.png

21.2 KB | W: | H:

image/spectrum_REM.png
image/spectrum_REM.png
image/spectrum_REM.png
image/spectrum_REM.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -4,6 +4,7 @@ from torch import nn, optim
import torch.nn.functional as F
import torchvision
from collections import OrderedDict
import densenet
def CreatNet(name):
if name =='LSTM':
......@@ -30,16 +31,16 @@ def CreatNet(name):
elif 'densenet' in name:
if name =='densenet121':
net = torchvision.models.densenet121(pretrained=False)
net = densenet.densenet121(pretrained=False,num_classes=5)
elif name == 'densenet201':
net = torchvision.models.densenet201(pretrained=False)
net.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', nn.BatchNorm2d(64)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
net.classifier = nn.Linear(4096, 5)
net = densenet.densenet201(pretrained=False,num_classes=5)
# net.features = nn.Sequential(OrderedDict([
# ('conv0', nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)),
# ('norm0', nn.BatchNorm2d(64)),
# ('relu0', nn.ReLU(inplace=True)),
# ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
# ]))
# net.classifier = nn.Linear(64, 5)
return net
......
import data
import numpy as np
import time
import util
import os
import time
import data
import transformer
import dataloader
import models
import torch
......@@ -24,7 +23,7 @@ t1 = time.time()
signals,stages = dataloader.loaddataset(opt,opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,shuffle=True,BID='median')
stage_cnt_per = statistics.stage(stages)[1]
print('stage_cnt_per:',stage_cnt_per,'\nlength of dataset:',len(stages))
signals_train,stages_train,signals_eval,stages_eval, = data.batch_generator(signals,stages,opt.batchsize,shuffle = True)
signals_train,stages_train,signals_eval,stages_eval, = transformer.batch_generator(signals,stages,opt.batchsize,shuffle = True)
batch_length = len(signals_train)+len(signals_eval)
print('length of batch:',batch_length)
......@@ -60,15 +59,15 @@ criterion = nn.CrossEntropyLoss(weight)
def evalnet(net,signals,stages,epoch,plot_result={},mode = 'part'):
net.eval()
if mode =='part':
data.shuffledata(signals,stages)
transformer.shuffledata(signals,stages)
signals=signals[0:int(len(stages)/2)]
stages=stages[0:int(len(stages)/2)]
confusion_mat = np.zeros((5,5), dtype=int)
for i, (signal, stage) in enumerate(zip(signals,stages), 1):
signal=data.ToInputShape(signal,opt.model_name,test_flag =True)
signal,stage = data.ToTensor(signal,stage,no_cuda =opt.no_cuda)
signal=transformer.ToInputShape(signal,opt.model_name,test_flag =True)
signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda)
out = net(signal)
loss = criterion(out, stage)
pred = torch.max(out, 1)[1]
......@@ -102,8 +101,8 @@ for epoch in range(opt.epochs):
net.train()
for i, (signal, stage) in enumerate(zip(signals_train,stages_train), 1):
signal=data.ToInputShape(signal,opt.model_name,test_flag =False)
signal,stage = data.ToTensor(signal,stage,no_cuda =opt.no_cuda)
signal=transformer.ToInputShape(signal,opt.model_name,test_flag =False)
signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda)
out = net(signal)
loss = criterion(out, stage)
......
import numpy as np
import os
import torch
import random
import dsp
#python3 train.py --dataset_name sleep-edfx --model_name resnet18_1d --batchsize 16 --epochs 50 --lr 0.001 --select_sleep_time --sample_num 197
def trimdata(data,num):
return data[:num*int(len(data)/num)]
def shuffledata(data,target):
state = np.random.get_state()
np.random.shuffle(data)
np.random.set_state(state)
np.random.shuffle(target)
# return data,target
def batch_generator(data,target,batchsize,shuffle = True):
if shuffle:
shuffledata(data,target)
data = trimdata(data,batchsize)
target = trimdata(target,batchsize)
data = data.reshape(-1,batchsize,3000)
target = target.reshape(-1,batchsize)
return data[0:int(0.8*len(target))],target[0:int(0.8*len(target))],data[int(0.8*len(target)):],target[int(0.8*len(target)):]
def Normalize(data,maxmin,avg,sigma):
data = np.clip(data, -maxmin, maxmin)
return (data-avg)/sigma
def ToTensor(data,target,no_cuda = False):
data = torch.from_numpy(data).float()
target = torch.from_numpy(target).long()
if not no_cuda:
data = data.cuda()
target = target.cuda()
return data,target
def random_transform_1d(data,finesize,test_flag):
length = len(data)
if test_flag:
move = int((length-finesize)*0.5)
result = data[move:move+finesize]
else:
#random crop
move = int((length-finesize)*random.random())
result = data[move:move+finesize]
#random flip
if random.random()<0.5:
result = result[::-1]
#random amp
result = result*random.uniform(0.95,1.05)
return result
def random_transform_2d(img,finesize = (224,122),test_flag = True):
h,w = img.shape[:2]
if test_flag:
h_move = 2
w_move = int((w-finesize[1])*0.5)
result = img[h_move:h_move+finesize[0],w_move:w_move+finesize[1]]
else:
#random crop
h_move = int(5*random.random()) #do not loss low freq signal infos
w_move = int((w-finesize[1])*random.random())
result = img[h_move:h_move+finesize[0],w_move:w_move+finesize[1]]
#random flip
if random.random()<0.5:
result = result[:,::-1]
#random amp
result = result*random.uniform(0.95,1.05)+random.uniform(-0.02,0.02)
return result
# def random_transform_2d(img,finesize,test_flag):
# h,w = img.shape[:2]
# if test_flag:
# h_move = 2
# w_move = int((w-finesize)*0.5)
# result = img[h_move:h_move+finesize,w_move:w_move+finesize]
# else:
# #random crop
# h_move = int(5*random.random()) #do not loss low freq signal infos
# w_move = int((w-finesize)*random.random())
# result = img[h_move:h_move+finesize,w_move:w_move+finesize]
# #random flip
# if random.random()<0.5:
# result = result[:,::-1]
# #random amp
# result = result*random.uniform(0.95,1.05)+random.uniform(-0.02,0.02)
# return result
def ToInputShape(data,net_name,norm=True,test_flag = False):
data = data.astype(np.float32)
batchsize=data.shape[0]
if net_name=='LSTM':
result =[]
for i in range(0,batchsize):
randomdata=random_transform_1d(data[i],finesize = 2700,test_flag=test_flag)
result.append(dsp.getfeature(randomdata))
result = np.array(result).reshape(batchsize,2700*5)
elif net_name=='CNN' or net_name=='resnet18_1d':
result =[]
for i in range(0,batchsize):
randomdata=random_transform_1d(data[i],finesize = 2700,test_flag=test_flag)
# result.append(dsp.getfeature(randomdata,ch_num = 6))
result.append(randomdata)
result = np.array(result)
if norm:
result = Normalize(result,maxmin = 1000,avg=0,sigma=1000)
result = result.reshape(batchsize,1,2700)
elif net_name in ['dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']:
result =[]
for i in range(0,batchsize):
spectrum = dsp.signal2spectrum(data[i])
spectrum = random_transform_2d(spectrum,(224,122),test_flag=test_flag)
result.append(spectrum)
result = np.array(result)
if norm:
#sleep_def : std,mean,median = 0.4157 0.3688 0.2473
#challge 2018 : std,mean,median,max= 0.2972 0.3008 0.2006 2.0830
result=Normalize(result,3,0.3,1)
result = result.reshape(batchsize,1,224,122)
# print(result.shape)
return result
# datasetpath='/media/hypo/Hypo/training'
# dir = '/media/hypo/Hypo/training/tr03-0005'
def main():
dir = '/media/hypo/Hypo/physionet_org_train/tr03-0052'
t1=time.time()
stages=loadstages(dir)
for i in range(len(stages)):
if stages[i]!=5:
print(i+1)
break
print(stages.shape)
t2=time.time()
print(t2-t1)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册