提交 21105adc 编写于 作者: H hypox64

Support 2d network

上级 7ecfcb78
......@@ -4,8 +4,8 @@ import time
import numpy as np
import torch
from torch import nn, optim
# from multiprocessing import Process, Queue
import torch.multiprocessing as mp
from multiprocessing import Process, Queue
# import torch.multiprocessing as mp
import warnings
warnings.filterwarnings("ignore")
......@@ -42,7 +42,7 @@ class Core(object):
self.test_flag = True
if printflag:
util.writelog('network:\n'+str(self.net),self.opt,True)
#util.writelog('network:\n'+str(self.net),self.opt,True)
show_paramsnumber(self.net,self.opt)
if self.opt.pretrained != '':
......@@ -76,23 +76,24 @@ class Core(object):
signal = transformer.ToInputShape(signal,self.opt,test_flag =self.test_flag)
self.queue.put([signal,label])
# def process_pool_init(self,signals,labels,sequences):
# self.queue = mp.Queue(self.opt.load_process*2)
# part_len = len(sequences)//self.opt.load_process//self.opt.batchsize*self.opt.batchsize
# for i in range(self.opt.load_process):
# if i == (self.opt.load_process -1):
# p = mp.Process(target=self.preprocessing,args=(signals,labels,sequences[i*part_len:]))
# else:
# p = mp.Process(target=self.preprocessing,args=(signals,labels,sequences[i*part_len:(i+1)*part_len]))
# p.daemon = True
# p.start()
def process_pool_init(self,signals,labels,sequences):
self.queue = mp.Queue()
p = mp.Process(target=self.preprocessing,args=(signals,labels,sequences))
def start_process(self,signals,labels,sequences):
p = Process(target=self.preprocessing,args=(signals,labels,sequences))
p.daemon = True
p.start()
def process_pool_init(self,signals,labels,sequences):
self.queue = Queue(self.opt.load_thread*2)
process_batch_num = len(sequences)//self.opt.batchsize//self.opt.load_thread
if process_batch_num == 0:
print('\033[1;33m'+'Warning: too much load thread'+'\033[0m')
self.start_process(signals,labels,sequences)
else:
for i in range(self.opt.load_thread):
if i != self.opt.load_thread-1:
self.start_process(signals,labels,sequences[i*self.opt.load_thread*self.opt.batchsize:(i+1)*self.opt.load_thread*self.opt.batchsize])
else:
self.start_process(signals,labels,sequences[i*self.opt.load_thread*self.opt.batchsize:])
def forward(self,signal,label,features,confusion_mat):
if self.opt.model_name == 'autoencoder':
out,feature = self.net(signal)
......
......@@ -11,7 +11,7 @@ def creatnet(opt):
net = autoencoder.Autoencoder(opt.input_nc, opt.feature, opt.label,opt.finesize)
#lstm
elif name =='lstm':
net = lstm.lstm(opt.input_size,opt.time_step,input_nc=opt.input_nc,num_classes=opt.label)
net = lstm.lstm(opt.lstm_inputsize,opt.lstm_timestep,input_nc=opt.input_nc,num_classes=opt.label)
#cnn
elif name == 'cnn_1d':
net = cnn_1d.cnn(opt.input_nc,num_classes=opt.label)
......@@ -27,30 +27,30 @@ def creatnet(opt):
net = multi_scale_resnet_1d.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=opt.label)
elif name == 'micro_multi_scale_resnet_1d':
net = micro_multi_scale_resnet_1d.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=opt.label)
elif name == 'multi_scale_resnet':
net = multi_scale_resnet.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=opt.label)
#---------------------------------2d---------------------------------
elif name == 'dfcnn':
net = dfcnn.dfcnn(num_classes = opt.label)
elif name == 'multi_scale_resnet':
net = multi_scale_resnet.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=opt.label)
elif name in ['resnet101','resnet50','resnet18']:
if name =='resnet101':
net = resnet.resnet101(pretrained=False)
net = resnet.resnet101(pretrained=True)
net.fc = nn.Linear(2048, opt.label)
elif name =='resnet50':
net = resnet.resnet50(pretrained=False)
net = resnet.resnet50(pretrained=True)
net.fc = nn.Linear(2048, opt.label)
elif name =='resnet18':
net = resnet.resnet18(pretrained=False)
net = resnet.resnet18(pretrained=True)
net.fc = nn.Linear(512, opt.label)
net.conv1 = nn.Conv2d(opt.input_nc, 64, 7, 2, 3, bias=False)
elif 'densenet' in name:
if name =='densenet121':
net = densenet.densenet121(pretrained=False,num_classes=opt.label)
net = densenet.densenet121(pretrained=True,num_classes=opt.label)
elif name == 'densenet201':
net = densenet.densenet201(pretrained=False,num_classes=opt.label)
net = densenet.densenet201(pretrained=True,num_classes=opt.label)
elif name =='squeezenet':
net = squeezenet.squeezenet1_1(pretrained=False,num_classes=opt.label,inchannel = 1)
net = squeezenet.squeezenet1_1(pretrained=True,num_classes=opt.label,inchannel = 1)
return net
\ No newline at end of file
import numpy as np
def interp(y,length):
def interp(y, length):
xp = np.linspace(0, len(y)-1,num = len(y))
fp = y
x = np.linspace(0, len(y)-1,num = length)
return np.interp(x, xp, fp)
def pad(data,padding,mod='zero'):
if mod == 'zero':
def pad(data, padding, mode = 'zero'):
if mode == 'zero':
pad_data = np.zeros(padding, dtype = data.dtype)
return np.append(data, pad_data)
elif mod == 'repeat':
elif mode == 'repeat':
out_data = data.copy()
repeat_num = int(padding/len(data))
......@@ -20,28 +20,29 @@ def pad(data,padding,mod='zero'):
pad_data = data[:padding-repeat_num*len(data)]
return np.append(out_data, pad_data)
def normliaze(data,mod = 'norm',sigma = 0,dtype=np.float64,truncated = 1):
def normliaze(data, mode = 'norm', sigma = 0, dtype=np.float64, truncated = 2):
'''
mod: norm | std | maxmin | 5_95
mode: norm | std | maxmin | 5_95
dtype : np.float64,np.float16...
'''
data = data.astype(dtype)
if mod == 'norm':
result = (data-np.mean(data))/sigma
elif mod == 'std':
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
data_calculate = data.copy()
if mode == 'norm':
result = (data-np.mean(data_calculate))/sigma
elif mode == 'std':
mu = np.mean(data_calculate)
sigma = np.std(data_calculate)
result = (data - mu) / sigma
elif mod == 'maxmin':
result = (data-np.mean(data))/sigma
elif mod == '5_95':
data_sort = np.sort(data)
elif mode == 'maxmin':
result = (data-np.mean(data_calculate))/(max(np.max(data_calculate),np.abs(np.min(data_calculate))))
elif mode == '5_95':
data_sort = np.sort(data_calculate,axis=None)
th5 = data_sort[int(0.05*len(data_sort))]
th95 = data_sort[int(0.95*len(data_sort))]
baseline = (th5+th95)/2
sigma = (th95-th5)/2
if sigma == 0:
sigma =1
sigma = 1e-06
result = (data-baseline)/sigma
if truncated > 1:
......
......@@ -93,17 +93,16 @@ def balance_label(signals,labels):
#load all data in datasets
def loaddataset(opt,shuffle = False):
if opt.dataset_name == 'preload':
if opt.separated:
signals_train = np.load(opt.dataset_dir+'/signals_train.npy')
labels_train = np.load(opt.dataset_dir+'/labels_train.npy')
signals_eval = np.load(opt.dataset_dir+'/signals_eval.npy')
labels_eval = np.load(opt.dataset_dir+'/labels_eval.npy')
else:
signals = np.load(opt.dataset_dir+'/signals.npy')
labels = np.load(opt.dataset_dir+'/labels.npy')
if not opt.no_shuffle:
transformer.shuffledata(signals,labels)
if opt.separated:
signals_train = np.load(opt.dataset_dir+'/signals_train.npy')
labels_train = np.load(opt.dataset_dir+'/labels_train.npy')
signals_eval = np.load(opt.dataset_dir+'/signals_eval.npy')
labels_eval = np.load(opt.dataset_dir+'/labels_eval.npy')
else:
signals = np.load(opt.dataset_dir+'/signals.npy')
labels = np.load(opt.dataset_dir+'/labels.npy')
if not opt.no_shuffle:
transformer.shuffledata(signals,labels)
if opt.separated:
return signals_train,labels_train,signals_eval,labels_eval
......
......@@ -2,93 +2,92 @@ import scipy.signal
import scipy.fftpack as fftpack
import numpy as np
b1 = scipy.signal.firwin(31, [0.5, 4], pass_zero=False,fs=100)
b2 = scipy.signal.firwin(31, [4,8], pass_zero=False,fs=100)
b3 = scipy.signal.firwin(31, [8,12], pass_zero=False,fs=100)
b4 = scipy.signal.firwin(31, [12,16], pass_zero=False,fs=100)
b5 = scipy.signal.firwin(31, [16,45], pass_zero=False,fs=100)
def sin(f,fs,time):
x = np.linspace(0, 2*np.pi*f*time, fs*time)
return np.sin(x)
def getfir_b(fc1,fc2,fs):
if fc1==0.5 and fc2==4 and fs==100:
b=b1
elif fc1==4 and fc2==8 and fs==100:
b=b2
elif fc1==8 and fc2==12 and fs==100:
b=b3
elif fc1==12 and fc2==16 and fs==100:
b=b4
elif fc1==16 and fc2==45 and fs==100:
b=b5
else:
b=scipy.signal.firwin(51, [fc1, fc2], pass_zero=False,fs=fs)
return b
def downsample(signal,fs1=0,fs2=0,alpha=0,mod = 'just_down'):
if alpha ==0:
alpha = int(fs1/fs2)
if mod == 'just_down':
return signal[::alpha]
elif mod == 'avg':
result = np.zeros(int(len(signal)/alpha))
for i in range(int(len(signal)/alpha)):
result[i] = np.mean(signal[i*alpha:(i+1)*alpha])
return result
def medfilt(signal,x):
return scipy.signal.medfilt(signal,x)
def BPF(signal,fs,fc1,fc2,mod = 'fir'):
if mod == 'fft':
length=len(signal)#get N
k1=int(fc1*length/fs)#get k1=Nw1/fs
k2=int(fc2*length/fs)#get k1=Nw1/fs
#FFT
signal_fft=fftpack.fft(signal)
#Frequency truncation
signal_fft[0:k1]=0+0j
signal_fft[k2:length-k2]=0+0j
signal_fft[length-k1:length]=0+0j
#IFFT
signal_ifft=fftpack.ifft(signal_fft)
result = signal_ifft.real
else:
b=getfir_b(fc1,fc2,fs)
result = scipy.signal.lfilter(b, 1, signal)
def cleanoffset(signal):
return signal - np.mean(signal)
def bpf_fir(signal,fs,fc1,fc2,numtaps=101):
b=scipy.signal.firwin(numtaps, [fc1, fc2], pass_zero=False,fs=fs)
result = scipy.signal.lfilter(b, 1, signal)
return result
def getfeature(signal,mod = 'fft',ch_num = 5):
result=[]
signal =signal - np.mean(signal)
eeg=signal
def fft_filter(signal,fs,fc=[],type = 'bandpass'):
'''
signal: Signal
fs: Sampling frequency
fc: [fc1,fc2...] Cut-off frequency
type: bandpass | bandstop
'''
k = []
N=len(signal)#get N
beta=BPF(eeg,100,16,45,mod) # β
theta=BPF(eeg,100,4,8,mod)
sigma=BPF(eeg,100,12,16,mod) #σ spindle
alpha=BPF(eeg,100,8,12,mod)
delta=BPF(eeg,100,0.5,4,mod)
result.append(beta)
result.append(theta)
result.append(sigma)
result.append(alpha)
result.append(delta)
for i in range(len(fc)):
k.append(int(fc[i]*N/fs))
if ch_num == 6:
fft = abs(fftpack.fft(eeg))
fft = fft - np.median(fft)
result.append(fft)
#FFT
signal_fft=scipy.fftpack.fft(signal)
#Frequency truncation
result=np.array(result)
result=result.reshape(ch_num*len(signal),)
if type == 'bandpass':
a = np.zeros(N)
for i in range(int(len(fc)/2)):
a[k[2*i]:k[2*i+1]] = 1
a[N-k[2*i+1]:N-k[2*i]] = 1
elif type == 'bandstop':
a = np.ones(N)
for i in range(int(len(fc)/2)):
a[k[2*i]:k[2*i+1]] = 0
a[N-k[2*i+1]:N-k[2*i]] = 0
signal_fft = a*signal_fft
signal_ifft=scipy.fftpack.ifft(signal_fft)
result = signal_ifft.real
return result
# def signal2spectrum(data):
# # window : ('tukey',0.5) hann
def rms(signal):
signal = signal.astype('float64')
return np.mean((signal*signal))**0.5
# zxx = scipy.signal.stft(data, fs=100, window='hann', nperseg=1024, noverlap=1024-12, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2]
# zxx =np.abs(zxx)[:512]
# spectrum=np.zeros((256,251))
# spectrum[0:128]=zxx[0:128]
# spectrum[128:192]=zxx[128:256][::2]
# spectrum[192:256]=zxx[256:512][::4]
# spectrum = np.log(spectrum+1)
# return spectrum
def energy(signal,kernel_size,stride,padding = 0):
_signal = np.zeros(len(signal)+padding)
_signal[0:len(signal)] = signal
signal = _signal
out_len = int((len(signal)+1-kernel_size)/stride)
energy = np.zeros(out_len)
for i in range(out_len):
energy[i] = rms(signal[i*stride:i*stride+kernel_size])
return energy
def signal2spectrum(data):
def signal2spectrum(data,window_size,stride,log = True):
# window : ('tukey',0.5) hann
zxx = scipy.signal.stft(data, fs=100, window='hann', nperseg=1024, noverlap=1024-24, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2]
zxx =np.abs(zxx)[:512]
spectrum=np.zeros((256,126))
spectrum[0:128]=zxx[0:128]
spectrum[128:192]=zxx[128:256][::2]
spectrum[192:256]=zxx[256:512][::4]
spectrum = np.log(spectrum+1)
zxx = scipy.signal.stft(data, window='hann', nperseg=window_size,noverlap=window_size-stride)[2]
spectrum = np.abs(zxx)
if log:
spectrum = np.log1p(spectrum)
h = window_size//2+1
tmp = np.linspace(0, h-1,num=h,dtype=np.int64)
index = np.log1p(tmp)*(h/np.log1p(h))
spectrum_new = np.zeros_like(spectrum)
for i in range(h-1):
spectrum_new[int(index[i]):int(index[i+1])] = spectrum[i]
spectrum = spectrum_new
return spectrum
\ No newline at end of file
......@@ -2,7 +2,7 @@ import argparse
import os
import time
import numpy as np
from . import util
from . import util,dsp
class Options():
def __init__(self):
......@@ -10,7 +10,7 @@ class Options():
self.initialized = False
def initialize(self):
#base
# ------------------------Base------------------------
self.parser.add_argument('--gpu_id', type=int, default=0,help='choose which gpu want to use, 0 | 1 | 2 ...')
self.parser.add_argument('--no_cudnn', action='store_true', help='if specified, do not use cudnn')
self.parser.add_argument('--label', type=str, default='auto',help='number of labels')
......@@ -18,14 +18,35 @@ class Options():
self.parser.add_argument('--loadsize', type=str, default='auto', help='load data in this size')
self.parser.add_argument('--finesize', type=str, default='auto', help='crop your data into this size')
self.parser.add_argument('--label_name', type=str, default='auto',help='name of labels,example:"a,b,c,d,e,f"')
self.parser.add_argument('--model_name', type=str, default='micro_multi_scale_resnet_1d',help='Choose model lstm | multi_scale_resnet_1d | resnet18 | micro_multi_scale_resnet_1d...')
# ------------
# for lstm
self.parser.add_argument('--input_size', type=str, default='auto',help='input_size of LSTM')
self.parser.add_argument('--time_step', type=int, default=100,help='time_step of LSTM')
# for autoencoder
self.parser.add_argument('--normliaze', type=str, default='5_95', help='mode of normliaze, 5_95 | maxmin | None')
# ------------------------Dataset------------------------
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/simple_test',help='your dataset path')
self.parser.add_argument('--save_dir', type=str, default='./checkpoints/',help='save checkpoints')
self.parser.add_argument('--separated', action='store_true', help='if specified,for preload data, if input, load separated train and test datasets')
self.parser.add_argument('--no_shuffle', action='store_true', help='if specified, do not shuffle data when load(use to evaluate individual differences)')
self.parser.add_argument('--load_thread', type=int, default=8,help='how many threads when load data')
# ------------------------Network------------------------
"""Available Network
1d: lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d,
micro_multi_scale_resnet_1d,autoencoder
2d: dfcnn, multi_scale_resnet, resnet18, resnet50, resnet101,
densenet121, densenet201, squeezenet
"""
self.parser.add_argument('--model_name', type=str, default='micro_multi_scale_resnet_1d',help='Choose model lstm...')
self.parser.add_argument('--model_type', type=str, default='auto',help='1d | 2d')
# For lstm
self.parser.add_argument('--lstm_inputsize', type=str, default='auto',help='lstm_inputsize of LSTM')
self.parser.add_argument('--lstm_timestep', type=int, default=100,help='time_step of LSTM')
# For autoecoder
self.parser.add_argument('--feature', type=int, default=3, help='number of encoder features')
# ------------
# For 2d network(stft spectrum)
self.parser.add_argument('--stft_size', type=int, default=512, help='length of each fft segment')
self.parser.add_argument('--stft_stride', type=int, default=128, help='stride of each fft segment')
self.parser.add_argument('--stft_no_log', action='store_true', help='if specified, do not log1p spectrum')
# ------------------------Training Matters------------------------
self.parser.add_argument('--pretrained', type=str, default='',help='pretrained model path. If not specified, fo not use pretrained model')
self.parser.add_argument('--continue_train', action='store_true', help='if specified, continue train')
self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate')
......@@ -35,23 +56,10 @@ class Options():
self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network')
self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.if 0 or 1,no k-fold')
self.parser.add_argument('--mergelabel', type=str, default='None',
help='merge some labels to one label and give the result, example:"[[0,1,4],[2,3,5]]" , label(0,1,4) regard as 0,label(2,3,5) regard as 1')
help='merge some labels to one label and give the result, example:"[[0,1,4],[2,3,5]]" -> label(0,1,4) regard as 0,label(2,3,5) regard as 1')
self.parser.add_argument('--mergelabel_name', type=str, default='None',help='name of labels,example:"a,b,c,d,e,f"')
self.parser.add_argument('--plotfreq', type=int, default=100,help='frequency of plotting results')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/',
help='your dataset path')
self.parser.add_argument('--save_dir', type=str, default='./checkpoints/',help='save checkpoints')
self.parser.add_argument('--dataset_name', type=str, default='preload',
help='Choose dataset preload | sleep-edfx | cc2018 ,preload:your data->shape:(num,ch,length), sleep-edfx&cc2018:sleep stage')
self.parser.add_argument('--separated', action='store_true', help='if specified,for preload data, if input, load separated train and test datasets')
self.parser.add_argument('--no_shuffle', action='store_true', help='if specified,do not shuffle data when load(use to evaluate individual differences)')
#for EEG datasets
self.parser.add_argument('--BID', type=str, default='5_95_th',help='Balance individualized differences 5_95_th | median |None')
self.parser.add_argument('--select_sleep_time', action='store_true', help='if specified, for sleep-cassette only use sleep time to train')
self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1 | EEG Fpz-Cz |...')
self.parser.add_argument('--sample_num', type=int, default=20,help='the amount you want to load')
self.initialized = True
......@@ -63,7 +71,7 @@ class Options():
if self.opt.gpu_id != -1:
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.gpu_id)
if self.opt.label !='auto':
if self.opt.label != 'auto':
self.opt.label = int(self.opt.label)
if self.opt.input_nc !='auto':
self.opt.input_nc = int(self.opt.input_nc)
......@@ -71,16 +79,19 @@ class Options():
self.opt.loadsize = int(self.opt.loadsize)
if self.opt.finesize !='auto':
self.opt.finesize = int(self.opt.finesize)
if self.opt.input_size !='auto':
self.opt.input_size = int(self.opt.input_size)
if self.opt.dataset_name == 'sleep-edf':
self.opt.sample_num = 8
if self.opt.dataset_name not in ['sleep-edf','sleep-edfx','cc2018']:
self.opt.BID = 'not-supported'
self.opt.select_sleep_time = 'not-supported'
self.opt.signal_name = 'not-supported'
self.opt.sample_num = 'not-supported'
if self.opt.lstm_inputsize != 'auto':
self.opt.lstm_inputsize = int(self.opt.lstm_inputsize)
if self.opt.model_type == 'auto':
if self.opt.model_name in ['lstm', 'cnn_1d', 'resnet18_1d', 'resnet34_1d',
'multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder']:
self.opt.model_type = '1d'
elif self.opt.model_name in ['dfcnn', 'multi_scale_resnet', 'resnet18', 'resnet50',
'resnet101','densenet121', 'densenet201', 'squeezenet']:
self.opt.model_type = '2d'
else:
print('\033[1;31m'+'Error: do not support this network '+self.opt.model_name+'\033[0m')
exit(0)
if self.opt.k_fold == 0 :
self.opt.k_fold = 1
......@@ -121,8 +132,8 @@ def get_auto_options(opt,label_cnt_per,label_num,shape):
opt.loadsize = shape[2]
if opt.finesize =='auto':
opt.finesize = int(shape[2]*0.9)
if opt.input_size =='auto':
opt.input_size = opt.finesize//opt.time_step
if opt.lstm_inputsize =='auto':
opt.lstm_inputsize = opt.finesize//opt.lstm_timestep
# weight
opt.weight = np.ones(opt.label)
......@@ -137,14 +148,21 @@ def get_auto_options(opt,label_cnt_per,label_num,shape):
# label name
if opt.label_name == 'auto':
if opt.dataset_name in ['sleep-edf','sleep-edfx','cc2018']:
opt.label_name = ["N3", "N2", "N1", "REM","W"]
else:
names = []
for i in range(opt.label):
names.append(str(i))
opt.label_name = names
names = []
for i in range(opt.label):
names.append(str(i))
opt.label_name = names
elif not isinstance(opt.label_name,list):
opt.label_name = opt.label_name.replace(" ", "").split(",")
# check stft spectrum
if opt.model_type =='2d':
h, w = opt.stft_size//2+1, opt.loadsize//opt.stft_stride
print('Shape of stft spectrum h,w:',(h,w))
if h<64 or w<64:
print('\033[1;33m'+'Warning: spectrum is too small'+'\033[0m')
if h>512 or w>512:
print('\033[1;33m'+'Warning: spectrum is too large'+'\033[0m')
return opt
\ No newline at end of file
......@@ -247,6 +247,10 @@ def showscatter3d(data):
plt.show()
def draw_spectrum(spectrum,opt):
plt.imshow(spectrum)
plt.savefig(os.path.join(opt.save_dir,'spectrum_eg.png'))
plt.close('all')
def main():
......
......@@ -42,14 +42,6 @@ def batch_generator(data,target,sequence,shuffle = True):
return out_data,out_target
def Normalize(data,maxmin,avg,sigma,is_01=False):
data = np.clip(data, -maxmin, maxmin)
if is_01:
return (data-avg)/sigma/2+0.5 #(0,1)
else:
return (data-avg)/sigma #(-1,1)
def ToTensor(data,target=None,gpu_id=0):
if target is not None:
......@@ -105,33 +97,26 @@ def random_transform_2d(img,finesize = (224,122),test_flag = True):
def ToInputShape(data,opt,test_flag = False):
#data = data.astype(np.float32)
batchsize = data.shape[0]
if opt.model_name in['lstm','cnn_1d','resnet18_1d','resnet34_1d','multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder']:
if opt.model_type == '1d':
if opt.normliaze != 'None':
for i in range(opt.batchsize):
for j in range(opt.input_nc):
data[i][j] = arr.normliaze(data[i][j],mode = opt.normliaze)
result = random_transform_1d(data, opt.finesize, test_flag=test_flag)
# unsupported now
# elif opt.model_name=='lstm':
# result =[]
# for i in range(0,batchsize):
# randomdata=random_transform_1d(data[i],finesize = _finesize,test_flag=test_flag)
# result.append(dsp.getfeature(randomdata))
# result = np.array(result).reshape(batchsize,_finesize*5)
elif opt.model_type == '2d':
result = []
for i in range(opt.batchsize):
for j in range(opt.input_nc):
spectrum = dsp.signal2spectrum(data[i][j],opt.stft_size,opt.stft_stride, not opt.stft_no_log)
#spectrum = arr.normliaze(spectrum, mode = opt.normliaze)
spectrum = (spectrum-2)/5
# print(spectrum.shape)
#spectrum = random_transform_2d(spectrum,(224,122),test_flag=test_flag)
result.append(spectrum)
h,w = spectrum.shape
result = (np.array(result)).reshape(opt.batchsize,opt.input_nc,h,w)
# elif opt.model_name in ['squeezenet','multi_scale_resnet','dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']:
# result =[]
# data = (data-0.5)*2
# for i in range(0,batchsize):
# spectrum = dsp.signal2spectrum(data[i])
# spectrum = random_transform_2d(spectrum,(224,122),test_flag=test_flag)
# result.append(spectrum)
# result = np.array(result)
# #datasets th_95 avg mid
# # sleep_edfx 0.0458 0.0128 0.0053
# # CC2018 0.0507 0.0161 0.00828
# result = Normalize(result, maxmin=0.5, avg=0.0150, sigma=0.0500)
# result = result.reshape(batchsize,1,224,122)
return result.astype(np.float32)
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册