提交 bb5ad100 编写于 作者: H HypoX64

Data Augmentation

上级 085dde45
......@@ -139,6 +139,7 @@ checkpoints/
/tools/client_data
/tools/server_data
/trainscript.py
*.out
*.pth
*.edf
*log*
\ No newline at end of file
import os
import time
import numpy as np
import torch
from torch import nn, optim
from multiprocessing import Process, Queue
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# import torch.multiprocessing as mp
import warnings
warnings.filterwarnings("ignore")
import sys
sys.path.append("..")
from util import util,transformer,dataloader,statistics,plot,options
from models.net_1d.gan import Generator,Discriminator,GANloss,weights_init_normal
from models.core import show_paramsnumber
def gan(opt,signals,labels):
print('Augment dataset using gan...')
if opt.gpu_id != -1:
os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_id)
if not opt.no_cudnn:
torch.backends.cudnn.benchmark = True
signals_train = signals[:opt.fold_index[0]]
labels_train = labels[:opt.fold_index[0]]
signals_eval = signals[opt.fold_index[0]:]
labels_eval = labels[opt.fold_index[0]:]
signals_train = signals_train[labels_train.argsort()]
labels_train = labels_train[labels_train.argsort()]
out_signals = signals_train.copy()
out_labels = labels_train.copy()
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
opt = options.get_auto_options(opt, signals_train, labels_train)
generator = Generator(opt.loadsize,opt.input_nc,opt.gan_latent_dim)
discriminator = Discriminator(opt.loadsize,opt.input_nc)
show_paramsnumber(generator, opt)
show_paramsnumber(discriminator, opt)
ganloss = GANloss(opt.gpu_id,opt.batchsize)
if opt.gpu_id != -1:
generator.cuda()
discriminator.cuda()
ganloss.cuda()
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.gan_lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.gan_lr, betas=(0.5, 0.999))
index_cnt = 0
for which_label in range(len(label_cnt)):
if which_label in opt.gan_labels:
sub_signals = signals_train[index_cnt:index_cnt+label_cnt[which_label]]
sub_labels = labels_train[index_cnt:index_cnt+label_cnt[which_label]]
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
generator.train()
discriminator.train()
for epoch in range(opt.gan_epochs):
epoch_g_loss = 0
epoch_d_loss = 0
iter_pre_epoch = len(sub_labels)//opt.batchsize
transformer.shuffledata(sub_signals, sub_labels)
t1 = time.time()
for i in range(iter_pre_epoch):
real_signal = sub_signals[i*opt.batchsize:(i+1)*opt.batchsize].reshape(opt.batchsize,opt.input_nc,opt.loadsize)
real_signal = transformer.ToTensor(real_signal,gpu_id=opt.gpu_id)
# Train Generator
optimizer_G.zero_grad()
z = transformer.ToTensor(np.random.normal(0, 1, (opt.batchsize, opt.gan_latent_dim)),gpu_id = opt.gpu_id)
gen_signal = generator(z)
g_loss = ganloss(discriminator(gen_signal),True)
epoch_g_loss += g_loss.item()
g_loss.backward()
optimizer_G.step()
# Train Discriminator
optimizer_D.zero_grad()
d_real = ganloss(discriminator(real_signal), True)
d_fake = ganloss(discriminator(gen_signal.detach()), False)
d_loss = (d_real + d_fake) / 2
epoch_d_loss += d_loss.item()
d_loss.backward()
optimizer_D.step()
t2 = time.time()
print(
"[Label %d] [Epoch %d/%d] [D loss: %.4f] [G loss: %.4f] [time: %.2f]"
% (sub_labels[0], epoch+1, opt.gan_epochs, epoch_g_loss/iter_pre_epoch, epoch_d_loss/iter_pre_epoch, t2-t1)
)
plot.draw_gan_result(real_signal.data.cpu().numpy(), gen_signal.data.cpu().numpy(),opt)
generator.eval()
for i in range(int(len(sub_labels)*(opt.gan_augment_times-1))//opt.batchsize):
z = transformer.ToTensor(np.random.normal(0, 1, (opt.batchsize, opt.gan_latent_dim)),gpu_id = opt.gpu_id)
gen_signal = generator(z)
out_signals = np.concatenate((out_signals, gen_signal.data.cpu().numpy()))
#print(np.ones((opt.batchsize),dtype=np.int64)*which_label)
out_labels = np.concatenate((out_labels,np.ones((opt.batchsize),dtype=np.int64)*which_label))
index_cnt += label_cnt[which_label]
opt.fold_index = [len(out_labels)]
out_signals = np.concatenate((out_signals, signals_eval))
out_labels = np.concatenate((out_labels, labels_eval))
# return signals,labels
return out_signals,out_labels
def base(opt,signals,labels):
pass
def augment(opt,signals,labels):
pass
if __name__ == '__main__':
opt = options.Options().getparse()
signals,labels = dataloader.loaddataset(opt)
out_signals,out_labels = gan(opt,signals,labels,2)
print(out_signals.shape,out_labels.shape)
\ No newline at end of file
......@@ -26,7 +26,7 @@ class Core(object):
self.epoch = 1
if self.opt.gpu_id != -1:
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.gpu_id)
#torch.cuda.set_device(self.opt.gpu_id)
# torch.cuda.set_device(self.opt.gpu_id)
if not self.opt.no_cudnn:
torch.backends.cudnn.benchmark = True
......@@ -71,8 +71,12 @@ class Core(object):
self.net.cuda()
def preprocessing(self,signals, labels, sequences):
for i in range(len(sequences)//self.opt.batchsize):
signal,label = transformer.batch_generator(signals, labels, sequences[i*self.opt.batchsize:(i+1)*self.opt.batchsize])
_times = np.ceil(len(sequences)/self.opt.batchsize).astype(np.int)
for i in range(_times):
if i != _times-1:
signal,label = transformer.batch_generator(signals, labels, sequences[i*self.opt.batchsize:(i+1)*self.opt.batchsize])
else:
signal,label = transformer.batch_generator(signals, labels, sequences[i*self.opt.batchsize:])
signal = transformer.ToInputShape(signal,self.opt,test_flag =self.test_flag)
self.queue.put([signal,label])
......@@ -121,13 +125,15 @@ class Core(object):
np.random.shuffle(sequences)
self.process_pool_init(signals, labels, sequences)
for i in range(len(sequences)//self.opt.batchsize):
for i in range(np.ceil(len(sequences)/self.opt.batchsize).astype(np.int)):
self.optimizer.zero_grad()
signal,label = self.queue.get()
signal,label = transformer.ToTensor(signal,label,gpu_id =self.opt.gpu_id)
output,loss,features,confusion_mat = self.forward(signal, label, features, confusion_mat)
epoch_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
......@@ -145,7 +151,7 @@ class Core(object):
np.random.shuffle(sequences)
self.process_pool_init(signals, labels, sequences)
for i in range(len(sequences)//self.opt.batchsize):
for i in range(np.ceil(len(sequences)/self.opt.batchsize).astype(np.int)):
signal,label = self.queue.get()
signal,label = transformer.ToTensor(signal,label,gpu_id =self.opt.gpu_id)
with torch.no_grad():
......@@ -153,9 +159,10 @@ class Core(object):
epoch_loss += loss.item()
if self.opt.model_name == 'autoencoder':
plot.draw_autoencoder_result(signal.data.cpu().numpy(), output.data.cpu().numpy(),self.opt)
print('epoch:'+str(self.epoch),' loss: '+str(round(epoch_loss/i,5)))
plot.draw_scatter(features, self.opt)
if self.epoch%10 == 0:
plot.draw_autoencoder_result(signal.data.cpu().numpy(), output.data.cpu().numpy(),self.opt)
print('epoch:'+str(self.epoch),' loss: '+str(round(epoch_loss/i,5)))
plot.draw_scatter(features, self.opt)
else:
recall,acc,sp,err,k = statistics.report(confusion_mat)
#plot.draw_heatmap(confusion_mat,self.opt,name = 'current_eval')
......
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
class Generator(nn.Module):
def __init__(self,signal_size,output_nc,latent_dim):
super(Generator, self).__init__()
self.init_size = signal_size // 4
self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size))
self.conv_blocks = nn.Sequential(
nn.BatchNorm1d(128),
nn.Upsample(scale_factor=2),
nn.Conv1d(128, 128, 31, stride=1, padding=15),
nn.BatchNorm1d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv1d(128, 64, 31, stride=1, padding=15),
nn.BatchNorm1d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, output_nc, 31, stride=1, padding=15),
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size)
signal = self.conv_blocks(out)
return signal
class Discriminator(nn.Module):
def __init__(self,signal_size,input_nc):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv1d(in_filters, out_filters, 31, 2, 15), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm1d(out_filters, 0.8))
return block
self.model = nn.Sequential(
*discriminator_block(input_nc, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = signal_size // 2 ** 4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size, 1), nn.Sigmoid())
def forward(self, signal):
out = self.model(signal)
# print(out.size())
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
class GANloss(nn.Module):
def __init__(self,gpu_id,batchsize):
super(GANloss,self).__init__()
self.Tensor = torch.cuda.FloatTensor if gpu_id != -1 else torch.FloatTensor
self.valid = Variable(self.Tensor(batchsize, 1).fill_(1.0), requires_grad=False)
self.fake = Variable(self.Tensor(batchsize, 1).fill_(0.0), requires_grad=False)
self.loss_function = torch.nn.BCELoss()
def forward(self,tensor,target_is_real):
if target_is_real:
loss = self.loss_function(tensor,self.valid)
else:
loss = self.loss_function(tensor,self.fake)
return loss
\ No newline at end of file
......@@ -12,7 +12,8 @@ class lstm_block(nn.Module):
input_size=input_size,
hidden_size=Hidden_size,
num_layers=Num_layers,
batch_first=True,
batch_first=True,
# bidirectional = True
)
def forward(self, x):
......@@ -30,7 +31,7 @@ class lstm(nn.Module):
self.point = input_size*time_step
for i in range(input_nc):
exec('self.lstm'+str(i) + '=lstm_block(input_size, time_step)')
exec('self.lstm'+str(i) + '=lstm_block(input_size, time_step, '+str(Hidden_size)+')')
self.fc = nn.Linear(Hidden_size*input_nc, num_classes)
def forward(self, x):
......@@ -39,5 +40,6 @@ class lstm(nn.Module):
for i in range(self.input_nc):
y.append(eval('self.lstm'+str(i)+'(x[:,i,:])'))
x = torch.cat(tuple(y), 1)
# print(x.size())
x = self.fc(x)
return x
\ No newline at end of file
......@@ -25,7 +25,7 @@ signals,labels = dataloader.loaddataset(opt)
ori_signals_train,ori_labels_train,ori_signals_eval,ori_labels_eval = \
signals[:opt.fold_index[0]].copy(),labels[:opt.fold_index[0]].copy(),signals[opt.fold_index[0]:].copy(),labels[opt.fold_index[0]:].copy()
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
opt = options.get_auto_options(opt, label_cnt_per, label_num, ori_signals_train)
opt = options.get_auto_options(opt, ori_signals_train, ori_labels_train)
# -----------------------------def network-----------------------------
core = core.Core(opt)
......@@ -42,7 +42,7 @@ def train(opt):
received_signals = [];received_labels = []
sample_num = 1000
sample_num = 5000
for i in range(category_num):
samples = os.listdir(os.path.join(opt.rec_tmp,categorys[i]))
random.shuffle(samples)
......@@ -55,7 +55,7 @@ def train(opt):
signal_ori[point] = float(txt_split[point])
for x in range(sample_num//len(samples)):
ran = random.randint(1000, len(signal_ori)-2000-1)
ran = random.randint(0, len(signal_ori)-2000-1)
this_signal = signal_ori[ran:ran+2000]
this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
......@@ -63,25 +63,28 @@ def train(opt):
received_labels.append(i)
received_signals = np.array(received_signals).reshape(-1,opt.input_nc,opt.loadsize)
received_labels = np.array(received_labels).reshape(-1,1)
received_labels = np.array(received_labels).reshape(-1)
received_signals_train,received_labels_train,received_signals_eval,received_labels_eval=\
dataloader.segment_dataset(received_signals, received_labels, 0.8,random=False)
print(received_signals_train.shape,received_signals_eval.shape)
dataloader.segment_traineval_dataset(received_signals, received_labels, 0.8,random=False)
#print(received_signals_train.shape,received_signals_eval.shape)
'''merge data'''
signals_train,labels_train = dataloader.del_labels(ori_signals_train,ori_labels_train, np.linspace(0, category_num-1,category_num,dtype=np.int64))
signals_eval,labels_eval = dataloader.del_labels(ori_signals_eval,ori_labels_eval, np.linspace(0, category_num-1,category_num,dtype=np.int64))
signals_train = np.concatenate((signals_train, received_signals_train))
#print(labels_train.shape, received_labels_train.shape)
labels_train = np.concatenate((labels_train, received_labels_train))
signals_eval = np.concatenate((signals_eval, received_signals_eval))
labels_eval = np.concatenate((labels_eval, received_labels_eval))
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train)
opt = options.get_auto_options(opt, signals_train, labels_train)
train_sequences = np.linspace(0, len(labels_train)-1,len(labels_train),dtype=np.int64)
eval_sequences = np.linspace(0, len(labels_eval)-1,len(labels_eval),dtype=np.int64)
print('train.shape:',signals_train.shape,'eval.shape:',signals_eval.shape)
print('train_label_cnt:',label_cnt,'eval_label_cnt:',statistics.label_statistics(labels_eval))
for epoch in range(opt.epochs):
t1 = time.time()
......@@ -132,4 +135,4 @@ def handlepost():
return {'return':'error'}
app.run("0.0.0.0", port= 4000, debug=False)
app.run("0.0.0.0", port= 4000, debug=True)
......@@ -8,6 +8,7 @@ import warnings
warnings.filterwarnings("ignore")
from util import util,transformer,dataloader,statistics,plot,options
from data import augmenter
from models import core
opt = options.Options().getparse()
......@@ -26,9 +27,11 @@ labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
#----------------------------Load Data----------------------------
t1 = time.time()
signals,labels = dataloader.loaddataset(opt)
if opt.gan:
signals,labels = augmenter.gan(opt,signals,labels)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
util.writelog('label statistics: '+str(label_cnt),opt,True)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals)
opt = options.get_auto_options(opt, signals, labels)
train_sequences,eval_sequences = transformer.k_fold_generator(len(labels),opt.k_fold,opt.fold_index)
t2 = time.time()
print('Cost time: %.2f'% (t2-t1),'s')
......
......@@ -6,20 +6,23 @@ def interp(y, length):
x = np.linspace(0, len(y)-1,num = length)
return np.interp(x, xp, fp)
def pad(data, padding, mode = 'zero'):
if mode == 'zero':
def pad(data,padding,mod='zero'):
if mod == 'zero':
pad_data = np.zeros(padding, dtype = data.dtype)
return np.append(data, pad_data)
elif mode == 'repeat':
elif mod == 'repeat':
out_data = data.copy()
repeat_num = int(padding/len(data))
for i in range(repeat_num):
out_data = np.append(out_data, data)
pad_data = data[:padding-repeat_num*len(data)]
return np.append(out_data, pad_data)
elif mod == 'reflect':
pad_data = data[::-1][:padding]
return np.append(data, pad_data)
def normliaze(data, mode = 'norm', sigma = 0, dtype=np.float32, truncated = 2):
'''
mode: norm | std | maxmin | 5_95
......
......@@ -20,7 +20,7 @@ def del_labels(signals,labels,dels):
return signals,labels
def segment_dataset(signals,labels,a=0.8,random=True):
def segment_traineval_dataset(signals,labels,a=0.8,random=True):
length = len(labels)
if random:
transformer.shuffledata(signals, labels)
......@@ -50,9 +50,6 @@ def segment_dataset(signals,labels,a=0.8,random=True):
cnt += label_cnt[i]
return signals_train,labels_train,signals_eval,labels_eval
def balance_label(signals,labels):
label_sta,_,label_num = statistics.label_statistics(labels)
......@@ -86,17 +83,44 @@ def balance_label(signals,labels):
cnt +=1
return new_signals,new_labels
#load all data in datasets
def loaddataset(opt):
print('Loading dataset...')
signals = np.load(os.path.join(opt.dataset_dir,'signals.npy'))
labels = np.load(os.path.join(opt.dataset_dir,'labels.npy'))
num,ch,size = signals.shape
# normliaze
if opt.normliaze != 'None':
for i in range(signals.shape[0]):
for j in range(signals.shape[1]):
for i in range(num):
for j in range(ch):
signals[i][j] = arr.normliaze(signals[i][j], mode = opt.normliaze, truncated=5)
# filter
if opt.filter != 'None':
for i in range(num):
for j in range(ch):
if opt.filter == 'fft':
signals[i][j] = dsp.fft_filter(signals[i][j], opt.filter_fs, opt.filter_fc,type = opt.filter_mod)
elif opt.filter == 'iir':
signals[i][j] = dsp.bpf(signals[i][j], opt.filter_fs, opt.filter_fc[0], opt.filter_fc[1], numtaps=3, mode='iir')
elif opt.filter == 'fir':
signals[i][j] = dsp.bpf(signals[i][j], opt.filter_fs, opt.filter_fc[0], opt.filter_fc[1], numtaps=101, mode='fir')
# wave filter
if opt.wave != 'None':
for i in range(num):
for j in range(ch):
signals[i][j] = dsp.wave_filter(signals[i][j],opt.wave,opt.wave_level,opt.wave_usedcoeffs)
# use fft to improve frequency domain information
if opt.augment_fft:
new_signals = np.zeros((num,ch*2,size), dtype=np.float32)
new_signals[:,:ch,:] = signals
for i in range(num):
for j in range(ch):
new_signals[i,ch+j,:] = dsp.fft(signals[i,j,:],half=False)
signals = new_signals
if opt.fold_index == 'auto':
transformer.shuffledata(signals,labels)
......
......@@ -2,6 +2,7 @@ import scipy.signal
import scipy.fftpack as fftpack
import numpy as np
import pywt
from . import array_operation as arr
def sin(f,fs,time):
x = np.linspace(0, 2*np.pi*f*time, fs*time)
......@@ -24,7 +25,7 @@ def medfilt(signal,x):
def cleanoffset(signal):
return signal - np.mean(signal)
def showfreq(signal,fs,fc=0):
def showfreq(signal,fs,fc=0,db=False):
"""
return f,fft
"""
......@@ -34,7 +35,23 @@ def showfreq(signal,fs,fc=0):
kc = int(len(signal)/fs*fc)
signal_fft = np.abs(scipy.fftpack.fft(signal))
f = np.linspace(0,fs/2,num=int(len(signal_fft)/2))
return f[:kc],signal_fft[0:int(len(signal_fft)/2)][:kc]
out_f = f[:kc]
out_fft = signal_fft[0:int(len(signal_fft)/2)][:kc]
if db:
out_fft = 20*np.log10(out_fft/np.max(out_fft))
out_fft = out_fft-np.max(out_fft)
np.clip(out_fft,-100,0)
return out_f,out_fft
def fft(signal,half = True,db=True,normliaze=True):
signal_fft = np.abs(scipy.fftpack.fft(signal))
if half:
signal_fft = signal_fft[:len(signal_fft)//2]
if db:
signal_fft = 20*np.log10(signal_fft)
if normliaze:
signal_fft = arr.normliaze(signal_fft,mode = '5_95',truncated = 4)
return signal_fft
def bpf(signal, fs, fc1, fc2, numtaps=3, mode='iir'):
if mode == 'iir':
......@@ -45,7 +62,13 @@ def bpf(signal, fs, fc1, fc2, numtaps=3, mode='iir'):
return scipy.signal.lfilter(b, a, signal)
def wave_filter(signal,wave,level,usedcoeffs):
'''
wave : wavelet name string, wavelet(eg. dbN symN haar gaus mexh)
level : decomposition level
usedcoeffs : coeff used for reconstruction eg. when level = 6 usedcoeffs=[1,1,0,0,0,0,0] : reconstruct signal with cA6, cD6
'''
coeffs = pywt.wavedec(signal, wave, level=level)
#[cAn, cDn, cDn-1, …, cD2, cD1]
for i in range(len(usedcoeffs)):
if usedcoeffs[i] == 0:
coeffs[i] = np.zeros_like(coeffs[i])
......
......@@ -2,7 +2,7 @@ import argparse
import os
import time
import numpy as np
from . import util,dsp,plot
from . import util,dsp,plot,statistics
class Options():
def __init__(self):
......@@ -19,12 +19,38 @@ class Options():
self.parser.add_argument('--finesize', type=str, default='auto', help='crop your data into this size')
self.parser.add_argument('--label_name', type=str, default='auto',help='name of labels,example:"a,b,c,d,e,f"')
# ------------------------Dataset------------------------
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/simple_test',help='your dataset path')
self.parser.add_argument('--save_dir', type=str, default='./checkpoints/',help='save checkpoints')
self.parser.add_argument('--load_thread', type=int, default=8,help='how many threads when load data')
# ------------------------Preprocessing------------------------
self.parser.add_argument('--normliaze', type=str, default='5_95', help='mode of normliaze, 5_95 | maxmin | None')
self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.If 0 or 1, no k-fold and cut 0.8 to train and other to eval')
# filter
self.parser.add_argument('--filter', type=str, default='None', help='type of filter, fft | fir | iir |None')
self.parser.add_argument('--filter_mod', type=str, default='bandpass', help='mode of fft_filter, bandpass | bandstop')
self.parser.add_argument('--filter_fs', type=int, default=1000, help='fs of filter')
self.parser.add_argument('--filter_fc', type=str, default='[]', help='fc of filter, eg. [0.1,10]')
# filter by wavelet
self.parser.add_argument('--wave', type=str, default='None', help='wavelet name string, wavelet(eg. dbN symN haar gaus mexh) | None')
self.parser.add_argument('--wave_level', type=int, default=5, help='decomposition level')
self.parser.add_argument('--wave_usedcoeffs', type=str, default='[]', help='Coeff used for reconstruction, \
eg. when level = 6 usedcoeffs=[1,1,0,0,0,0,0] : reconstruct signal with cA6, cD6')
self.parser.add_argument('--wave_channel', action='store_true', help='if specified, input reconstruct each coeff as a channel.')
# ------------------------Data Augmentation------------------------
# base
self.parser.add_argument('--augment', type=str, default='all', help='all | scale,filp,amp,noise | scale,filp ....')
# fft channel --> use fft to improve frequency domain information.
self.parser.add_argument('--augment_fft', action='store_true', help='if specified, use fft to improve frequency domain informationa')
# for gan,it only support when fold_index = 1 or 0 now
# only support when k_fold =0 or 1
self.parser.add_argument('--gan', action='store_true', help='if specified, using gan to augmente dataset')
self.parser.add_argument('--gan_lr', type=float, default=0.0002,help='learning rate')
self.parser.add_argument('--gan_augment_times', type=float, default=2,help='how many times that will be augmented by dcgan')
self.parser.add_argument('--gan_latent_dim', type=int, default=100,help='dimensionality of the latent space')
self.parser.add_argument('--gan_labels', type=str, default='[]',help='which label that will be augmented by dcgan, eg: [0,1,2,3]')
self.parser.add_argument('--gan_epochs', type=int, default=100,help='number of epochs of gan training')
# ------------------------Dataset------------------------
"""--fold_index
5-fold:
Cut dataset into sub-set using index , and then run k-fold with sub-set
......@@ -32,7 +58,7 @@ class Options():
If input: [2,4,6,7]
when len(dataset) == 10
sub-set: dataset[0:2],dataset[2:4],dataset[4:6],dataset[6:7],dataset[7:]
---------------------------------------------------------------
-------
No-fold:
If input 'auto', it will shuffle dataset and then cut 80% dataset to train and other to eval
If input: [5]
......@@ -41,6 +67,10 @@ class Options():
"""
self.parser.add_argument('--fold_index', type=str, default='auto',
help='where to fold, eg. when 5-fold and input: [2,4,6,7] -> sub-set: dataset[0:2],dataset[2:4],dataset[4:6],dataset[6:7],dataset[7:]')
self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.If 0 or 1, no k-fold and cut 0.8 to train and other to eval')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/simple_test',help='your dataset path')
self.parser.add_argument('--save_dir', type=str, default='./checkpoints/',help='save checkpoints')
self.parser.add_argument('--load_thread', type=int, default=8,help='how many threads when load data')
self.parser.add_argument('--mergelabel', type=str, default='None',
help='merge some labels to one label and give the result, example:"[[0,1,4],[2,3,5]]" -> label(0,1,4) regard as 0,label(2,3,5) regard as 1')
self.parser.add_argument('--mergelabel_name', type=str, default='None',help='name of labels,example:"a,b,c,d,e,f"')
......@@ -118,6 +148,15 @@ class Options():
if os.path.isfile(os.path.join(self.opt.dataset_dir,'index.npy')):
self.opt.fold_index = (np.load(os.path.join(self.opt.dataset_dir,'index.npy'))).tolist()
if self.opt.augment == 'all':
self.opt.augment = ["scale","filp","amp","noise"]
else:
self.opt.augment = str2list(self.opt.augment)
self.opt.filter_fc = eval(self.opt.filter_fc)
self.opt.wave_usedcoeffs = eval(self.opt.wave_usedcoeffs)
self.opt.gan_labels = eval(self.opt.gan_labels)
self.opt.mergelabel = eval(self.opt.mergelabel)
if self.opt.mergelabel_name != 'None':
self.opt.mergelabel_name = self.opt.mergelabel_name.replace(" ", "").split(",")
......@@ -141,9 +180,24 @@ class Options():
return self.opt
def get_auto_options(opt,label_cnt_per,label_num,signals):
def str2list(string,out_type = 'string'):
out_list = []
string = string.replace(' ','').replace('[','').replace(']','')
strings = string.split(',')
for string in strings:
if out_type == 'string':
out_list.append(string)
elif out_type == 'int':
out_list.append(int(string))
elif out_type == 'float':
out_list.append(float(string))
return out_list
def get_auto_options(opt,signals,labels):
shape = signals.shape
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
if opt.label =='auto':
opt.label = label_num
if opt.input_nc =='auto':
......@@ -175,7 +229,6 @@ def get_auto_options(opt,label_cnt_per,label_num,signals):
elif not isinstance(opt.label_name,list):
opt.label_name = opt.label_name.replace(" ", "").split(",")
# check stft spectrum
if opt.model_type =='2d':
spectrums = []
......
......@@ -248,6 +248,27 @@ def draw_autoencoder_result(true_signal,pred_signal,opt):
plt.savefig(os.path.join(opt.save_dir,'autoencoder_result.png'))
plt.close('all')
def draw_gan_result(real_signal,gan_signal,opt):
if real_signal.shape[0]>4:
fig = plt.figure(figsize=(18,4))
for i in range(4):
plt.subplot(2,4,i+1)
plt.plot(real_signal[i][0])
plt.title('real')
for i in range(4):
plt.subplot(2,4,4+i+1)
plt.plot(gan_signal[i][0])
plt.title('gan')
else:
plt.subplot(211)
plt.plot(real_signal[0][0])
plt.title('real')
plt.subplot(212)
plt.plot(gan_signal[0][0])
plt.title('gan')
plt.savefig(os.path.join(opt.save_dir,'gan_result.png'))
plt.close('all')
def showscatter3d(data):
label_cnt,_,label_num = label_statistics(data[:,3])
......
......@@ -4,7 +4,6 @@ import numpy as np
import torch
from . import dsp
from . import array_operation as arr
# import dsp
def shuffledata(data,target):
state = np.random.get_state()
......@@ -64,24 +63,40 @@ def ToTensor(data,target=None,gpu_id=0):
data = data.cuda()
return data
def random_transform_1d(data,finesize,test_flag):
batch_size,ch,length = data.shape
def random_transform_1d(data,opt,test_flag):
batchsize,ch,length = data.shape
if test_flag:
move = int((length-finesize)*0.5)
result = data[:,:,move:move+finesize]
move = int((length-opt.finesize)*0.5)
result = data[:,:,move:move+opt.finesize]
else:
#random scale
if 'scale' in opt.augment:
length = np.random.randint(opt.finesize, length*1.1, dtype=np.int64)
result = np.zeros((batchsize,ch,length))
for i in range(batchsize):
for j in range(ch):
result[i][j] = arr.interp(data[i][j], length)
data = result
#random crop
move = int((length-finesize)*random.random())
result = data[:,:,move:move+finesize]
move = int((length-opt.finesize)*random.random())
result = data[:,:,move:move+opt.finesize]
#random flip
if random.random()<0.5:
result = result[:,:,::-1]
if 'flip' in opt.augment:
if random.random()<0.5:
result = result[:,:,::-1]
#random amp
result = result*random.uniform(0.9,1.1)
if 'amp' in opt.augment:
result = result*random.uniform(0.9,1.1)
#add noise
# noise = np.random.rand(ch,finesize)
# result = result + (noise-0.5)*0.01
if 'noise' in opt.augment:
noise = np.random.rand(ch,opt.finesize)
result = result + (noise-0.5)*0.01
return result
def random_transform_2d(img,finesize = (224,244),test_flag = True):
......@@ -104,18 +119,19 @@ def random_transform_2d(img,finesize = (224,244),test_flag = True):
def ToInputShape(data,opt,test_flag = False):
#data = data.astype(np.float32)
_batchsize,_ch,_size = data.shape
if opt.model_type == '1d':
result = random_transform_1d(data, opt.finesize, test_flag=test_flag)
result = random_transform_1d(data, opt, test_flag = test_flag)
elif opt.model_type == '2d':
result = []
h,w = opt.stft_shape
for i in range(opt.batchsize):
for i in range(_batchsize):
for j in range(opt.input_nc):
spectrum = dsp.signal2spectrum(data[i][j],opt.stft_size,opt.stft_stride, opt.stft_n_downsample, not opt.stft_no_log)
spectrum = random_transform_2d(spectrum,(h,int(w*0.9)),test_flag=test_flag)
result.append(spectrum)
result = (np.array(result)).reshape(opt.batchsize,opt.input_nc,h,int(w*0.9))
result = (np.array(result)).reshape(_batchsize,opt.input_nc,h,int(w*0.9))
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册