提交 085dde45 编写于 作者: H HypoX64

Add mlp

上级 fdbfced6
......@@ -138,6 +138,7 @@ checkpoints/
/train_backup.py
/tools/client_data
/tools/server_data
/trainscript.py
*.pth
*.edf
*log*
\ No newline at end of file
......@@ -42,7 +42,7 @@ class Core(object):
self.test_flag = True
if printflag:
util.writelog('network:\n'+str(self.net),self.opt,True)
#util.writelog('network:\n'+str(self.net),self.opt,True)
show_paramsnumber(self.net,self.opt)
if self.opt.pretrained != '':
......@@ -85,7 +85,8 @@ class Core(object):
self.queue = Queue(self.opt.load_thread*2)
process_batch_num = len(sequences)//self.opt.batchsize//self.opt.load_thread
if process_batch_num == 0:
print('\033[1;33m'+'Warning: too much load thread'+'\033[0m')
if self.epoch == 1:
print('\033[1;33m'+'Warning: too much load thread'+'\033[0m')
self.start_process(signals,labels,sequences)
else:
for i in range(self.opt.load_thread):
......@@ -130,8 +131,8 @@ class Core(object):
loss.backward()
self.optimizer.step()
self.plot_result['train'].append(epoch_loss/i)
plot.draw_loss(self.plot_result,self.epoch+i/(sequences.shape[0]/self.opt.batchsize),self.opt)
self.plot_result['train'].append(epoch_loss/(i+1))
plot.draw_loss(self.plot_result,self.epoch+(i+1)/(sequences.shape[0]/self.opt.batchsize),self.opt)
# if self.opt.model_name != 'autoencoder':
# plot.draw_heatmap(confusion_mat,self.opt,name = 'current_train')
......@@ -142,6 +143,7 @@ class Core(object):
epoch_loss = 0
confusion_mat = np.zeros((self.opt.label,self.opt.label), dtype=int)
np.random.shuffle(sequences)
self.process_pool_init(signals, labels, sequences)
for i in range(len(sequences)//self.opt.batchsize):
signal,label = self.queue.get()
......@@ -160,7 +162,7 @@ class Core(object):
print('epoch:'+str(self.epoch),' macro-prec,reca,F1,err,kappa: '+str(statistics.report(confusion_mat)))
self.plot_result['F1'].append(statistics.report(confusion_mat)[2])
self.plot_result['eval'].append(epoch_loss/i)
self.plot_result['eval'].append(epoch_loss/(i+1))
self.epoch +=1
self.confusion_mats.append(confusion_mat)
......
from torch import nn
from .net_1d import cnn_1d,lstm,resnet_1d,multi_scale_resnet_1d,micro_multi_scale_resnet_1d,autoencoder
from .net_1d import cnn_1d,lstm,resnet_1d,multi_scale_resnet_1d,micro_multi_scale_resnet_1d,autoencoder,mlp
from .net_2d import densenet,dfcnn,mobilenet,resnet,squeezenet,multi_scale_resnet
......@@ -9,6 +9,9 @@ def creatnet(opt):
#encoder
if name =='autoencoder':
net = autoencoder.Autoencoder(opt.input_nc, opt.feature, opt.label,opt.finesize)
#mlp
if name =='mlp':
net = mlp.mlp(opt.input_nc, opt.label, opt.finesize)
#lstm
elif name =='lstm':
net = lstm.lstm(opt.lstm_inputsize,opt.lstm_timestep,input_nc=opt.input_nc,num_classes=opt.label)
......
import torch
from torch import nn
import torch.nn.functional as F
class mlp(nn.Module):
def __init__(self, input_nc,num_classes,datasize):
super(mlp, self).__init__()
self.net = nn.Sequential(
nn.Linear(datasize*input_nc, 128),
nn.Tanh(),
nn.Linear(128, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, num_classes),
)
def forward(self, x):
x = x.view(x.size(0),-1)
x = self.net(x)
return x
\ No newline at end of file
......@@ -132,4 +132,4 @@ def handlepost():
return {'return':'error'}
app.run("0.0.0.0", port= 4000, debug=True)
app.run("0.0.0.0", port= 4000, debug=False)
import scipy.signal
import scipy.fftpack as fftpack
import numpy as np
import pywt
def sin(f,fs,time):
x = np.linspace(0, 2*np.pi*f*time, fs*time)
......@@ -23,10 +24,32 @@ def medfilt(signal,x):
def cleanoffset(signal):
return signal - np.mean(signal)
def bpf_fir(signal,fs,fc1,fc2,numtaps=101):
b=scipy.signal.firwin(numtaps, [fc1, fc2], pass_zero=False,fs=fs)
result = scipy.signal.lfilter(b, 1, signal)
return result
def showfreq(signal,fs,fc=0):
"""
return f,fft
"""
if fc==0:
kc = int(len(signal)/2)
else:
kc = int(len(signal)/fs*fc)
signal_fft = np.abs(scipy.fftpack.fft(signal))
f = np.linspace(0,fs/2,num=int(len(signal_fft)/2))
return f[:kc],signal_fft[0:int(len(signal_fft)/2)][:kc]
def bpf(signal, fs, fc1, fc2, numtaps=3, mode='iir'):
if mode == 'iir':
b,a = scipy.signal.iirfilter(numtaps, [fc1,fc2], fs=fs)
elif mode == 'fir':
b = scipy.signal.firwin(numtaps, [fc1, fc2], pass_zero=False,fs=fs)
a = 1
return scipy.signal.lfilter(b, a, signal)
def wave_filter(signal,wave,level,usedcoeffs):
coeffs = pywt.wavedec(signal, wave, level=level)
for i in range(len(usedcoeffs)):
if usedcoeffs[i] == 0:
coeffs[i] = np.zeros_like(coeffs[i])
return pywt.waverec(coeffs, wave, mode='symmetric', axis=-1)
def fft_filter(signal,fs,fc=[],type = 'bandpass'):
'''
......
......@@ -48,7 +48,7 @@ class Options():
# ------------------------Network------------------------
"""Available Network
1d: lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d,
micro_multi_scale_resnet_1d,autoencoder
micro_multi_scale_resnet_1d,autoencoder,mlp
2d: mobilenet, dfcnn, multi_scale_resnet, resnet18, resnet50, resnet101,
densenet121, densenet201, squeezenet
"""
......@@ -100,7 +100,7 @@ class Options():
if self.opt.model_type == 'auto':
if self.opt.model_name in ['lstm', 'cnn_1d', 'resnet18_1d', 'resnet34_1d',
'multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder']:
'multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder','mlp']:
self.opt.model_type = '1d'
elif self.opt.model_name in ['dfcnn', 'multi_scale_resnet', 'resnet18', 'resnet50',
'resnet101','densenet121', 'densenet201', 'squeezenet', 'mobilenet']:
......
import os
import string
import random
import shutil
def randomstr(num):
return ''.join(random.sample(string.ascii_letters + string.digits, num))
......@@ -38,4 +39,10 @@ def loadfile(path):
def savefile(file,path):
wf = open(path,'wb')
wf.write(file)
wf.close()
\ No newline at end of file
wf.close()
def copyfile(src,dst):
try:
shutil.copyfile(src, dst)
except Exception as e:
print(e)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册