提交 fbbe82c2 编写于 作者: H hypox64

Support mobilenet. Fix some BUG.

上级 21105adc
......@@ -138,5 +138,4 @@ dmypy.json
/train_backup.py
*.pth
*.edf
*log*
*.png
\ No newline at end of file
*log*
\ No newline at end of file
<div align="center">
<img src="./imgs/compare.png " alt="image" style="zoom:100%;" />
</div>
# candock
[这原本是一个用于记录毕业设计的日志仓库](<https://github.com/HypoX64/candock/tree/Graduation_Project>),其目的是尝试多种不同的深度神经网络结构(如LSTM,ResNet,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.<br>目前,项目重点将转变为如何建立一个通用的一维时序信号分析,分类框架.<br>它将包含多种网络结构,并提供数据预处理,读取,训练,评估,测试等功能.<br>
一些训练时的输出样例: [heatmap](./image/heatmap_eg.png) [running_err](./image/running_err_eg.png) [log.txt](./docs/log_eg.txt)
| English | [中文版](./README_CN.md) |<br>
A time series signal analysis and classification framework.<br>
It contain multiple network and provide data preprocessing, reading, training, evaluation, testing and other functions.<br>
Some output examples: [heatmap](./image/heatmap_eg.png) [running_err](./image/running_err_eg.png) [log.txt](./docs/log_eg.txt)<br>Supported network:<br>
## 注意
为了适应新的项目,代码已被大幅更改,不能确保仍然能正常运行如sleep-edfx等睡眠数据集,如果仍然需要运行,请参照下文按照输入格式标准自行加载数据,如果有时间我会修复这个问题。
当然,如果需要加载睡眠数据集也可以直接使用[老的版本](https://github.com/HypoX64/candock/tree/f24cc44933f494d2235b3bf965a04cde5e6a1ae9)<br>
感谢[@swalltail99](https://github.com/swalltail99)指出的错误,为了适应sleep-edfx数据集的读取,使用这个版本的代码时,请安装mne==0.18.0<br>
>1d
>
>>lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d, micro_multi_scale_resnet_1d
>2d(stft spectrum)
>
>>mobilenet, resnet18, resnet50, resnet101, densenet121, densenet201, squeezenet, dfcnn, multi_scale_resnet,
## A example: Use EEG to classify sleep stage
[sleep-edfx](https://github.com/HypoX64/candock/tree/f24cc44933f494d2235b3bf965a04cde5e6a1ae9)<br>
Thank [@swalltail99](https://github.com/swalltail99)for the bug. In other to load sleep-edfx dataset,please install mne==0.18.0<br>
```bash
pip install mne==0.18.0
```
## Getting Started
### Prerequisites
- Linux, Windows,mac
......@@ -18,7 +29,7 @@ pip install mne==0.18.0
- Python 3
- Pytroch 1.0+
### Dependencies
This code depends on torchvision, numpy, scipy , matplotlib,available via pip install.<br>
This code depends on torchvision, numpy, scipy , matplotlib, available via pip install.<br>
For example:<br>
```bash
......@@ -41,7 +52,13 @@ python3 train.py --label 50 --input_nc 1 --dataset_dir ./datasets/simple_test --
# if you want to use cpu to train, please input --gpu_id -1
```
* More [options](./util/options.py).
#### Use your own data to train
### Test
```bash
python3 simple_test.py --label 50 --input_nc 1 --model_name micro_multi_scale_resnet_1d --gpu_id 0
# if you want to use cpu to test, please input --gpu_id -1
```
## Training with your own dataset
* step1: Generate signals.npy and labels.npy in the following format.
```python
#1.type:numpydata signals:np.float64 labels:np.int64
......@@ -51,10 +68,5 @@ python3 train.py --label 50 --input_nc 1 --dataset_dir ./datasets/simple_test --
signals = np.zeros((10,1,10),dtype='np.float64')
labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
```
* step2: input ```--dataset_dir your_dataset_dir``` when running code.
### Test
```bash
python3 simple_test.py --label 50 --input_nc 1 --model_name micro_multi_scale_resnet_1d --gpu_id 0
# if you want to use cpu to test, please input --gpu_id -1
```
* step2: input ```--dataset_dir "your_dataset_dir"``` when running code.
<div align="center">
<img src="./imgs/compare.png " alt="image" style="zoom:100%;" />
</div>
# candock
| [English](./README.md) | 中文版 |<br>
一个通用的一维时序信号分析,分类框架.<br>
它将包含多种网络结构,并提供数据预处理,读取,训练,评估,测试等功能.<br>
一些训练时的输出样例: [heatmap](./image/heatmap_eg.png) [running_err](./image/running_err_eg.png) [log.txt](./docs/log_eg.txt)<br>
目前支持的网络结构:<br>
>1d
>
>>lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d, micro_multi_scale_resnet_1d
>2d(stft spectrum)
>
>>mobilenet, resnet18, resnet50, resnet101, densenet121, densenet201, squeezenet, dfcnn, multi_scale_resnet,
## 关于EEG睡眠分期数据的实例
为了适应新的项目,代码已被大幅更改,不能正常运行如sleep-edfx等睡眠数据集,如果仍然需要运行,请参照下文按照输入格式标准自行加载数据,如果有时间我会修复这个问题。
当然,如果需要加载睡眠数据集也可以直接使用[老的版本](https://github.com/HypoX64/candock/tree/f24cc44933f494d2235b3bf965a04cde5e6a1ae9)<br>
感谢[@swalltail99](https://github.com/swalltail99)指出的错误,为了适应sleep-edfx数据集的读取,使用这个版本的代码时,请安装mne==0.18.0<br>
```bash
pip install mne==0.18.0
```
## 入门
### 前提要求
- Linux, Windows,mac
- CPU or NVIDIA GPU + CUDA CuDNN
- Python 3
- Pytroch 1.0+
### 依赖
This code depends on torchvision, numpy, scipy , matplotlib, available via pip install.<br>
For example:<br>
```bash
pip3 install matplotlib
```
### 克隆仓库:
```bash
git clone https://github.com/HypoX64/candock
cd candock
```
### 下载数据集以及预训练模型
[[Google Drive]](https://drive.google.com/open?id=1NTtLmT02jqlc81lhtzQ7GlPK8epuHfU5) [[百度云,y4ks]](https://pan.baidu.com/s/1WKWZL91SekrSlhOoEC1bQA)
* 数据集包括 signals.npy(shape:18207, 1, 2000) 以及 labels.npy(shape:18207) 可以使用"np.load()"加载
* 样本量:18207, 通道数:1, 每个样本的长度:2000, 总类别数:50
* Top1 err: 2.09%
### 训练
```bash
python3 train.py --label 50 --input_nc 1 --dataset_dir ./datasets/simple_test --save_dir ./checkpoints/simple_test --model_name micro_multi_scale_resnet_1d --gpu_id 0 --batchsize 64 --k_fold 5
# 如果需要使用cpu进行训练, 请输入 --gpu_id -1
```
* 更多可选参数 [options](./util/options.py).
### 测试
```bash
python3 simple_test.py --label 50 --input_nc 1 --model_name micro_multi_scale_resnet_1d --gpu_id 0
# 如果需要使用cpu进行训练, 请输入 --gpu_id -1
```
## 使用自己的数据进行训练
* step1: 按照如下格式生成 signals.npy 以及 labels.npy.
```python
#1.type:numpydata signals:np.float64 labels:np.int64
#2.shape signals:[num,ch,length] labels:[num]
#num:samples_num, ch :channel_num, num:length of each sample
#for example:
signals = np.zeros((10,1,10),dtype='np.float64')
labels = np.array([0,0,0,0,0,1,1,1,1,1]) #0->class0 1->class1
```
* step2: 输入 ```--dataset_dir "your_dataset_dir"``` 当运行代码的时候.
......@@ -37,7 +37,7 @@ class Core(object):
self.criterion_class = nn.CrossEntropyLoss(self.opt.weight)
self.criterion_auto = nn.MSELoss()
self.epoch = 1
self.plot_result = {'train':[],'eval':[]}
self.plot_result = {'train':[],'eval':[],'F1':[]}
self.confusion_mats = []
self.test_flag = True
......@@ -145,13 +145,14 @@ class Core(object):
signal,label = self.queue.get()
signal,label = transformer.ToTensor(signal,label,gpu_id =self.opt.gpu_id)
with torch.no_grad():
loss,features,confusion_mat=self.forward(signal, label, features, confusion_mat)
loss,features,confusion_mat = self.forward(signal, label, features, confusion_mat)
epoch_loss += loss.item()
if self.opt.model_name != 'autoencoder':
recall,acc,sp,err,k = statistics.report(confusion_mat)
#plot.draw_heatmap(confusion_mat,self.opt,name = 'current_eval')
print('epoch:'+str(self.epoch),' macro-prec,reca,F1,err,kappa: '+str(statistics.report(confusion_mat)))
self.plot_result['F1'].append(statistics.report(confusion_mat)[2])
else:
plot.draw_autoencoder_result(signal.data.cpu().numpy(), out.data.cpu().numpy(),self.opt)
print('epoch:'+str(self.epoch),' loss: '+str(round(epoch_loss/i,5)))
......
......@@ -30,9 +30,10 @@ def creatnet(opt):
#---------------------------------2d---------------------------------
elif name == 'dfcnn':
net = dfcnn.dfcnn(num_classes = opt.label)
net = dfcnn.dfcnn(num_classes = opt.label, input_nc = opt.input_nc)
elif name == 'multi_scale_resnet':
net = multi_scale_resnet.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=opt.label)
net = multi_scale_resnet.Multi_Scale_ResNet(input_nc = opt.input_nc, num_classes=opt.label)
elif name in ['resnet101','resnet50','resnet18']:
if name =='resnet101':
net = resnet.resnet101(pretrained=True)
......@@ -47,10 +48,14 @@ def creatnet(opt):
elif 'densenet' in name:
if name =='densenet121':
net = densenet.densenet121(pretrained=True,num_classes=opt.label)
net = densenet.densenet121(pretrained=False,num_classes = opt.label)
elif name == 'densenet201':
net = densenet.densenet201(pretrained=True,num_classes=opt.label)
elif name =='squeezenet':
net = squeezenet.squeezenet1_1(pretrained=True,num_classes=opt.label,inchannel = 1)
net = densenet.densenet201(pretrained=False,num_classes = opt.label)
net.features.conv0 = nn.Conv2d(opt.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False)
elif name == 'squeezenet':
net = squeezenet.squeezenet1_1(pretrained=False,num_classes = opt.label,inchannel = opt.input_nc)
elif name == 'mobilenet':
net = mobilenet.mobilenet_v2(pretrained=False, num_classes = opt.label, input_nc = opt.input_nc)
return net
\ No newline at end of file
......@@ -75,7 +75,7 @@ class DenseNet(nn.Module):
# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
......
......@@ -4,10 +4,10 @@ import torch.nn.functional as F
class dfcnn(nn.Module):
def __init__(self, num_classes):
def __init__(self, num_classes, input_nc):
super(dfcnn, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 1, bias=False),
nn.Conv2d(input_nc, 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace = True),
nn.Conv2d(32, 32, 3, 1, 1, bias=False),
......
......@@ -42,7 +42,7 @@ class InvertedResidual(nn.Module):
class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1.0):
def __init__(self, num_classes=1000, input_nc=3, width_mult=1.0):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
......@@ -61,7 +61,7 @@ class MobileNetV2(nn.Module):
# building first layer
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * max(1.0, width_mult))
features = [ConvBNReLU(3, input_channel, stride=2)]
features = [ConvBNReLU(input_nc, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = int(c * width_mult)
......
......@@ -2,18 +2,18 @@ import torch
from torch import nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, inchannel, outchannel,kernel_size,stride=2):
def __init__(self, input_nc, outchannel,kernel_size,stride=2):
super(ResidualBlock, self).__init__()
self.stride = stride
self.conv = nn.Sequential(
nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=int((kernel_size-1)/2), bias=False),
nn.Conv2d(input_nc, outchannel, kernel_size=kernel_size, stride=stride, padding=int((kernel_size-1)/2), bias=False),
nn.BatchNorm2d(outchannel),
nn.ReLU(inplace=True),
nn.Conv2d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=int((kernel_size-1)/2), bias=False),
nn.BatchNorm2d(outchannel)
)
self.shortcut = nn.Sequential(
nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=2, bias=False),
nn.Conv2d(input_nc, outchannel, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(outchannel)
)
......@@ -42,10 +42,10 @@ class Route(nn.Module):
return x
class Multi_Scale_ResNet(nn.Module):
def __init__(self, inchannel, num_classes):
def __init__(self, input_nc, num_classes):
super(Multi_Scale_ResNet, self).__init__()
self.pre_conv = nn.Sequential(
nn.Conv2d(inchannel, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.Conv2d(input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
......
......@@ -30,22 +30,22 @@ if opt.separated:
signals_train,labels_train,signals_eval,labels_eval = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels_train)
util.writelog('label statistics: '+str(label_cnt),opt,True)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train.shape)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals_train)
train_sequences= transformer.k_fold_generator(len(labels_train),opt.k_fold,opt.separated)
eval_sequences= transformer.k_fold_generator(len(labels_eval),opt.k_fold,opt.separated)
else:
signals,labels = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
util.writelog('label statistics: '+str(label_cnt),opt,True)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals)
train_sequences,eval_sequences = transformer.k_fold_generator(len(labels),opt.k_fold)
t2 = time.time()
print('load data cost time: %.2f'% (t2-t1),'s')
print('Cost time: %.2f'% (t2-t1),'s')
core = core.Core(opt)
core.network_init(printflag=True)
print('begin to train ...')
print('Begin to train ...')
fold_final_confusion_mat = np.zeros((opt.label,opt.label), dtype=int)
for fold in range(opt.k_fold):
if opt.k_fold != 1:util.writelog('------------------------------ k-fold:'+str(fold+1)+' ------------------------------',opt,True)
......@@ -69,7 +69,7 @@ for fold in range(opt.k_fold):
#save result
if opt.model_name != 'autoencoder':
pos = core.plot_result['eval'].index(min(core.plot_result['eval']))-1
pos = core.plot_result['F1'].index(max(core.plot_result['F1']))
final_confusion_mat = core.confusion_mats[pos]
if opt.k_fold==1:
statistics.statistics(final_confusion_mat, opt, 'final', 'final_eval')
......
......@@ -20,7 +20,7 @@ def pad(data, padding, mode = 'zero'):
pad_data = data[:padding-repeat_num*len(data)]
return np.append(out_data, pad_data)
def normliaze(data, mode = 'norm', sigma = 0, dtype=np.float64, truncated = 2):
def normliaze(data, mode = 'norm', sigma = 0, dtype=np.float32, truncated = 2):
'''
mode: norm | std | maxmin | 5_95
dtype : np.float64,np.float16...
......
......@@ -6,6 +6,7 @@ import scipy.io as sio
import numpy as np
from . import dsp,transformer,statistics
from . import array_operation as arr
def del_labels(signals,labels,dels):
......@@ -92,15 +93,27 @@ def balance_label(signals,labels):
#load all data in datasets
def loaddataset(opt,shuffle = False):
print('Loading dataset...')
if opt.separated:
signals_train = np.load(opt.dataset_dir+'/signals_train.npy')
labels_train = np.load(opt.dataset_dir+'/labels_train.npy')
signals_eval = np.load(opt.dataset_dir+'/signals_eval.npy')
labels_eval = np.load(opt.dataset_dir+'/labels_eval.npy')
if opt.normliaze != 'None':
for i in range(signals_train.shape[0]):
for j in range(signals_train.shape[1]):
signals_train[i][j] = arr.normliaze(signals_train[i][j], mode = opt.normliaze, truncated=5)
for i in range(signals_eval.shape[0]):
for j in range(signals_eval.shape[1]):
signals_eval[i][j] = arr.normliaze(signals_eval[i][j], mode = opt.normliaze, truncated=5)
else:
signals = np.load(opt.dataset_dir+'/signals.npy')
labels = np.load(opt.dataset_dir+'/labels.npy')
if opt.normliaze != 'None':
for i in range(signals.shape[0]):
for j in range(signals.shape[1]):
signals[i][j] = arr.normliaze(signals[i][j], mode = opt.normliaze, truncated=5)
if not opt.no_shuffle:
transformer.shuffledata(signals,labels)
......
......@@ -7,7 +7,7 @@ def sin(f,fs,time):
return np.sin(x)
def downsample(signal,fs1=0,fs2=0,alpha=0,mod = 'just_down'):
if alpha ==0:
if alpha == 0:
alpha = int(fs1/fs2)
if mod == 'just_down':
return signal[::alpha]
......@@ -74,8 +74,10 @@ def energy(signal,kernel_size,stride,padding = 0):
energy[i] = rms(signal[i*stride:i*stride+kernel_size])
return energy
def signal2spectrum(data,window_size,stride,log = True):
def signal2spectrum(data,window_size, stride, n_downsample=1, log = True, log_alpha = 0.1):
# window : ('tukey',0.5) hann
if n_downsample != 1:
data = downsample(data,alpha=n_downsample)
zxx = scipy.signal.stft(data, window='hann', nperseg=window_size,noverlap=window_size-stride)[2]
spectrum = np.abs(zxx)
......@@ -83,11 +85,25 @@ def signal2spectrum(data,window_size,stride,log = True):
if log:
spectrum = np.log1p(spectrum)
h = window_size//2+1
tmp = np.linspace(0, h-1,num=h,dtype=np.int64)
index = np.log1p(tmp)*(h/np.log1p(h))
x = np.linspace(h*log_alpha, h-1,num=h+1,dtype=np.int64)
index = (np.log1p(x)-np.log1p(h*log_alpha))/(np.log1p(h)-np.log1p(h*log_alpha))*h
spectrum_new = np.zeros_like(spectrum)
for i in range(h-1):
for i in range(h):
spectrum_new[int(index[i]):int(index[i+1])] = spectrum[i]
spectrum = spectrum_new
spectrum = (spectrum-0.05)/0.25
# spectrum = np.log1p(spectrum)
# h = window_size//2+1
# tmp = np.linspace(0, h-1,num=h,dtype=np.int64)
# index = np.log2(tmp+1)*(h/np.log2(h+1))
# spectrum_new = np.zeros_like(spectrum)
# for i in range(h-1):
# spectrum_new[int(index[i]):int(index[i+1])] = spectrum[i]
# spectrum = spectrum_new
# spectrum = (spectrum-0.05)/0.25
else:
spectrum = (spectrum-0.02)/0.05
return spectrum
\ No newline at end of file
......@@ -2,7 +2,7 @@ import argparse
import os
import time
import numpy as np
from . import util,dsp
from . import util,dsp,plot
class Options():
def __init__(self):
......@@ -18,20 +18,20 @@ class Options():
self.parser.add_argument('--loadsize', type=str, default='auto', help='load data in this size')
self.parser.add_argument('--finesize', type=str, default='auto', help='crop your data into this size')
self.parser.add_argument('--label_name', type=str, default='auto',help='name of labels,example:"a,b,c,d,e,f"')
self.parser.add_argument('--normliaze', type=str, default='5_95', help='mode of normliaze, 5_95 | maxmin | None')
# ------------------------Dataset------------------------
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/simple_test',help='your dataset path')
self.parser.add_argument('--save_dir', type=str, default='./checkpoints/',help='save checkpoints')
self.parser.add_argument('--separated', action='store_true', help='if specified,for preload data, if input, load separated train and test datasets')
self.parser.add_argument('--no_shuffle', action='store_true', help='if specified, do not shuffle data when load(use to evaluate individual differences)')
self.parser.add_argument('--load_thread', type=int, default=8,help='how many threads when load data')
self.parser.add_argument('--load_thread', type=int, default=8,help='how many threads when load data')
self.parser.add_argument('--normliaze', type=str, default='5_95', help='mode of normliaze, 5_95 | maxmin | None')
# ------------------------Network------------------------
"""Available Network
1d: lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d,
micro_multi_scale_resnet_1d,autoencoder
2d: dfcnn, multi_scale_resnet, resnet18, resnet50, resnet101,
2d: mobilenet, dfcnn, multi_scale_resnet, resnet18, resnet50, resnet101,
densenet121, densenet201, squeezenet
"""
self.parser.add_argument('--model_name', type=str, default='micro_multi_scale_resnet_1d',help='Choose model lstm...')
......@@ -42,9 +42,13 @@ class Options():
# For autoecoder
self.parser.add_argument('--feature', type=int, default=3, help='number of encoder features')
# For 2d network(stft spectrum)
# Please cheek ./save_dir/spectrum_eg.jpg to change the following parameters
self.parser.add_argument('--stft_size', type=int, default=512, help='length of each fft segment')
self.parser.add_argument('--stft_stride', type=int, default=128, help='stride of each fft segment')
self.parser.add_argument('--stft_n_downsample', type=int, default=1, help='downsample befor stft')
self.parser.add_argument('--stft_no_log', action='store_true', help='if specified, do not log1p spectrum')
self.parser.add_argument('--stft_shape', type=str, default='auto', help='shape of stft. It depend on \
stft_size,stft_stride,stft_n_downsample. Do not input this parameter.')
# ------------------------Training Matters------------------------
self.parser.add_argument('--pretrained', type=str, default='',help='pretrained model path. If not specified, fo not use pretrained model')
......@@ -58,8 +62,6 @@ class Options():
self.parser.add_argument('--mergelabel', type=str, default='None',
help='merge some labels to one label and give the result, example:"[[0,1,4],[2,3,5]]" -> label(0,1,4) regard as 0,label(2,3,5) regard as 1')
self.parser.add_argument('--mergelabel_name', type=str, default='None',help='name of labels,example:"a,b,c,d,e,f"')
self.parser.add_argument('--plotfreq', type=int, default=100,help='frequency of plotting results')
self.initialized = True
......@@ -87,7 +89,7 @@ class Options():
'multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder']:
self.opt.model_type = '1d'
elif self.opt.model_name in ['dfcnn', 'multi_scale_resnet', 'resnet18', 'resnet50',
'resnet101','densenet121', 'densenet201', 'squeezenet']:
'resnet101','densenet121', 'densenet201', 'squeezenet', 'mobilenet']:
self.opt.model_type = '2d'
else:
print('\033[1;31m'+'Error: do not support this network '+self.opt.model_name+'\033[0m')
......@@ -122,8 +124,9 @@ class Options():
return self.opt
def get_auto_options(opt,label_cnt_per,label_num,shape):
def get_auto_options(opt,label_cnt_per,label_num,signals):
shape = signals.shape
if opt.label =='auto':
opt.label = label_num
if opt.input_nc =='auto':
......@@ -158,8 +161,16 @@ def get_auto_options(opt,label_cnt_per,label_num,shape):
# check stft spectrum
if opt.model_type =='2d':
h, w = opt.stft_size//2+1, opt.loadsize//opt.stft_stride
print('Shape of stft spectrum h,w:',(h,w))
spectrums = []
data = signals[np.random.randint(0,shape[0]-1)]
for i in range(shape[1]):
spectrums.append(dsp.signal2spectrum(data[i],opt.stft_size, opt.stft_stride, opt.stft_n_downsample, not opt.stft_no_log))
plot.draw_spectrums(spectrums,opt)
opt.stft_shape = spectrums[0].shape
h,w = opt.stft_shape
print('Shape of stft spectrum h,w:',opt.stft_shape)
print('\033[1;37m'+'Please cheek ./save_dir/spectrum_eg.jpg to change parameters'+'\033[0m')
if h<64 or w<64:
print('\033[1;33m'+'Warning: spectrum is too small'+'\033[0m')
if h>512 or w>512:
......
......@@ -247,9 +247,15 @@ def showscatter3d(data):
plt.show()
def draw_spectrum(spectrum,opt):
plt.imshow(spectrum)
plt.savefig(os.path.join(opt.save_dir,'spectrum_eg.png'))
def draw_spectrums(spectrums,opt):
if len(spectrums) > 1:
plt.subplots(figsize=(6.4*2,4.8*2))
for i in range(len(spectrums)):
plt.subplot(len(spectrums)//2+1,2,i+1)
plt.imshow(spectrums[i])
else:
plt.imshow(spectrums[0])
plt.savefig(os.path.join(opt.save_dir,'spectrum_eg.jpg'))
plt.close('all')
......
......@@ -77,15 +77,15 @@ def random_transform_1d(data,finesize,test_flag):
# result = result + (noise-0.5)*0.01
return result
def random_transform_2d(img,finesize = (224,122),test_flag = True):
def random_transform_2d(img,finesize = (224,244),test_flag = True):
h,w = img.shape[:2]
if test_flag:
h_move = 2
h_move = int((h-finesize[0])*0.5)
w_move = int((w-finesize[1])*0.5)
result = img[h_move:h_move+finesize[0],w_move:w_move+finesize[1]]
else:
#random crop
h_move = int(10*random.random()) #do not loss low freq signal infos
h_move = int((h-finesize[0])*random.random())
w_move = int((w-finesize[1])*random.random())
result = img[h_move:h_move+finesize[0],w_move:w_move+finesize[1]]
#random flip
......@@ -99,24 +99,16 @@ def ToInputShape(data,opt,test_flag = False):
#data = data.astype(np.float32)
if opt.model_type == '1d':
if opt.normliaze != 'None':
for i in range(opt.batchsize):
for j in range(opt.input_nc):
data[i][j] = arr.normliaze(data[i][j],mode = opt.normliaze)
result = random_transform_1d(data, opt.finesize, test_flag=test_flag)
elif opt.model_type == '2d':
result = []
h,w = opt.stft_shape
for i in range(opt.batchsize):
for j in range(opt.input_nc):
spectrum = dsp.signal2spectrum(data[i][j],opt.stft_size,opt.stft_stride, not opt.stft_no_log)
#spectrum = arr.normliaze(spectrum, mode = opt.normliaze)
spectrum = (spectrum-2)/5
# print(spectrum.shape)
#spectrum = random_transform_2d(spectrum,(224,122),test_flag=test_flag)
spectrum = dsp.signal2spectrum(data[i][j],opt.stft_size,opt.stft_stride, opt.stft_n_downsample, not opt.stft_no_log)
spectrum = random_transform_2d(spectrum,(h,int(w*0.9)),test_flag=test_flag)
result.append(spectrum)
h,w = spectrum.shape
result = (np.array(result)).reshape(opt.batchsize,opt.input_nc,h,w)
result = (np.array(result)).reshape(opt.batchsize,opt.input_nc,h,int(w*0.9))
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册