transformer.py 5.5 KB
Newer Older
H
hypox64 已提交
1 2
import os
import random
H
hypox64 已提交
3 4
import numpy as np
import torch
H
hypox64 已提交
5
from . import dsp
H
hypox64 已提交
6
from . import array_operation as arr
H
hypox64 已提交
7
# import dsp
H
hypox64 已提交
8 9 10 11 12 13 14 15 16 17 18

def trimdata(data,num):
    return data[:num*int(len(data)/num)]

def shuffledata(data,target):
    state = np.random.get_state()
    np.random.shuffle(data)
    np.random.set_state(state)
    np.random.shuffle(target)
    # return data,target

H
hypox64 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32
def k_fold_generator(length,fold_num):
    if fold_num == 0 or fold_num == 1:
        train_sequence = np.linspace(0,int(length*0.8)-1,int(length*0.8),dtype='int')[None]
        test_sequence = np.linspace(int(length*0.8),length-1,int(length*0.2),dtype='int')[None]
    else:
        sequence = np.linspace(0,length-1,length,dtype='int')
        train_length = int(length/fold_num*(fold_num-1))
        test_length = int(length/fold_num)
        train_sequence = np.zeros((fold_num,train_length), dtype = 'int')
        test_sequence = np.zeros((fold_num,test_length), dtype = 'int')
        for i in range(fold_num):
            test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length]
            train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length]
    return train_sequence,test_sequence
H
hypox64 已提交
33

H
hypox64 已提交
34 35 36 37 38 39 40
def batch_generator(data,target,sequence,shuffle = True):
    batchsize = len(sequence)
    out_data = np.zeros((batchsize,data.shape[1],data.shape[2]), data.dtype)
    out_target = np.zeros((batchsize), target.dtype)
    for i in range(batchsize):
        out_data[i] = data[sequence[i]]
        out_target[i] = target[sequence[i]]
H
hypox64 已提交
41 42

    return out_data,out_target
H
hypox64 已提交
43

H
hypox64 已提交
44
def Normalize(data,maxmin,avg,sigma,is_01=False):
H
hypox64 已提交
45
    data = np.clip(data, -maxmin, maxmin)
H
hypox64 已提交
46 47 48 49
    if is_01:
        return (data-avg)/sigma/2+0.5 #(0,1)
    else:
        return (data-avg)/sigma #(-1,1)
H
hypox64 已提交
50

H
hypox64 已提交
51 52 53 54
def Balance_individualized_differences(signals,BID):

    if BID == 'median':
        signals = (signals*8/(np.median(abs(signals))))
H
hypox64 已提交
55
        signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
H
hypox64 已提交
56 57 58
    elif BID == '5_95_th':
        tmp = np.sort(signals.reshape(-1))
        th_5 = -tmp[int(0.05*len(tmp))]
H
hypox64 已提交
59
        signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5,is_01=True)
H
hypox64 已提交
60
    else:
H
hypox64 已提交
61
        #dataser 5_95_th(-1,1)  median
H
hypox64 已提交
62 63 64
        #CC2018  24.75   7.438
        #sleep edfx  37.4   9.71
        #sleep edfx sleeptime  39.03   10.125
H
hypox64 已提交
65
        signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
H
hypox64 已提交
66 67
    return signals

H
hypox64 已提交
68
def ToTensor(data,target=None,gpu_id=0):
H
hypox64 已提交
69 70 71
    if target is not None:
        data = torch.from_numpy(data).float()
        target = torch.from_numpy(target).long()
H
hypox64 已提交
72
        if gpu_id != -1:
H
hypox64 已提交
73 74 75 76 77
            data = data.cuda()
            target = target.cuda()
        return data,target
    else:
        data = torch.from_numpy(data).float()
H
hypox64 已提交
78
        if gpu_id != -1:
H
hypox64 已提交
79 80
            data = data.cuda()
        return data
H
hypox64 已提交
81 82

def random_transform_1d(data,finesize,test_flag):
H
hypox64 已提交
83
    batch_size,ch,length = data.shape
H
hypox64 已提交
84

H
hypox64 已提交
85 86
    if test_flag:
        move = int((length-finesize)*0.5)
H
hypox64 已提交
87
        result = data[:,:,move:move+finesize]
H
hypox64 已提交
88 89 90
    else:
        #random crop    
        move = int((length-finesize)*random.random())
H
hypox64 已提交
91
        result = data[:,:,move:move+finesize]
H
hypox64 已提交
92 93
        #random flip
        if random.random()<0.5:
H
hypox64 已提交
94
            result = result[:,:,::-1]
H
hypox64 已提交
95
        #random amp
H
hypox64 已提交
96 97
        result = result*random.uniform(0.9,1.1)
        #add noise
98 99
        # noise = np.random.rand(ch,finesize)
        # result = result + (noise-0.5)*0.01
H
hypox64 已提交
100 101 102 103 104 105 106 107 108 109
    return result

def random_transform_2d(img,finesize = (224,122),test_flag = True):
    h,w = img.shape[:2]
    if test_flag:
        h_move = 2
        w_move = int((w-finesize[1])*0.5)
        result = img[h_move:h_move+finesize[0],w_move:w_move+finesize[1]]
    else:
        #random crop
H
hypox64 已提交
110
        h_move = int(10*random.random()) #do not loss low freq signal infos
H
hypox64 已提交
111 112 113 114 115 116
        w_move = int((w-finesize[1])*random.random())
        result = img[h_move:h_move+finesize[0],w_move:w_move+finesize[1]]
        #random flip
        if random.random()<0.5:
            result = result[:,::-1]
        #random amp
H
hypox64 已提交
117
        result = result*random.uniform(0.9,1.1)+random.uniform(-0.05,0.05)
H
hypox64 已提交
118 119
    return result

H
hypox64 已提交
120
def ToInputShape(data,opt,test_flag = False):
H
hypox64 已提交
121
    #data = data.astype(np.float32)
122
    batchsize = data.shape[0]
H
hypox64 已提交
123

H
hypox64 已提交
124
    if opt.model_name in['lstm','cnn_1d','resnet18_1d','resnet34_1d','multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder']:
H
hypox64 已提交
125 126
        result = random_transform_1d(data, opt.finesize, test_flag=test_flag)

H
hypox64 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    # unsupported now
    # elif opt.model_name=='lstm':
    #     result =[]
    #     for i in range(0,batchsize):
    #         randomdata=random_transform_1d(data[i],finesize = _finesize,test_flag=test_flag)
    #         result.append(dsp.getfeature(randomdata))
    #     result = np.array(result).reshape(batchsize,_finesize*5)



    # elif opt.model_name in ['squeezenet','multi_scale_resnet','dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']:
    #     result =[]
    #     data = (data-0.5)*2
    #     for i in range(0,batchsize):
    #         spectrum = dsp.signal2spectrum(data[i])
    #         spectrum = random_transform_2d(spectrum,(224,122),test_flag=test_flag)
    #         result.append(spectrum)
    #     result = np.array(result)
    #     #datasets    th_95    avg       mid
    #     # sleep_edfx 0.0458   0.0128    0.0053
    #     # CC2018     0.0507   0.0161    0.00828
    #     result = Normalize(result, maxmin=0.5, avg=0.0150, sigma=0.0500)
    #     result = result.reshape(batchsize,1,224,122)
H
hypox64 已提交
150

H
hypox64 已提交
151
    return result.astype(np.float32)