options.py 10.0 KB
Newer Older
HypoX64's avatar
HypoX64 已提交
1 2
import argparse
import os
3
import time
H
hypox64 已提交
4
import numpy as np
H
hypox64 已提交
5
from . import util,dsp,plot
H
hypox64 已提交
6

HypoX64's avatar
HypoX64 已提交
7 8 9 10 11 12
class Options():
    def __init__(self):
        self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        self.initialized = False

    def initialize(self):
H
hypox64 已提交
13
        # ------------------------Base------------------------
14
        self.parser.add_argument('--gpu_id', type=int, default=0,help='choose which gpu want to use, 0 | 1 | 2 ...')        
H
hypox64 已提交
15
        self.parser.add_argument('--no_cudnn', action='store_true', help='if specified, do not use cudnn')
H
hypox64 已提交
16 17
        self.parser.add_argument('--label', type=str, default='auto',help='number of labels')
        self.parser.add_argument('--input_nc', type=str, default='auto', help='of input channels')
H
hypox64 已提交
18
        self.parser.add_argument('--loadsize', type=str, default='auto', help='load data in this size')
H
hypox64 已提交
19
        self.parser.add_argument('--finesize', type=str, default='auto', help='crop your data into this size')
H
hypox64 已提交
20
        self.parser.add_argument('--label_name', type=str, default='auto',help='name of labels,example:"a,b,c,d,e,f"')
H
hypox64 已提交
21 22 23 24
        
        # ------------------------Dataset------------------------
        self.parser.add_argument('--dataset_dir', type=str, default='./datasets/simple_test',help='your dataset path')
        self.parser.add_argument('--save_dir', type=str, default='./checkpoints/',help='save checkpoints')
H
hypox64 已提交
25 26
        self.parser.add_argument('--load_thread', type=int, default=8,help='how many threads when load data')  
        self.parser.add_argument('--normliaze', type=str, default='5_95', help='mode of normliaze, 5_95 | maxmin | None')      
H
hypox64 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
        self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.If 0 or 1, no k-fold and cut 0.8 to train and other to eval')
        """--fold_index
        5-fold:
        Cut dataset into sub-set using index , and then run k-fold with sub-set
        If input 'auto', it will shuffle dataset and then cut dataset equally
        If input: [2,4,6,7]
        when len(dataset) == 10
        sub-set: dataset[0:2],dataset[2:4],dataset[4:6],dataset[6:7],dataset[7:]
        ---------------------------------------------------------------
        No-fold:
        If input 'auto', it will shuffle dataset and then cut 80% dataset to train and other to eval
        If input: [5]
        when len(dataset) == 10
        train-set : dataset[0:5]  eval-set : dataset[5:]
        """
        self.parser.add_argument('--fold_index', type=str, default='auto',
            help='where to fold, eg. when 5-fold and input: [2,4,6,7] -> sub-set: dataset[0:2],dataset[2:4],dataset[4:6],dataset[6:7],dataset[7:]')
        self.parser.add_argument('--mergelabel', type=str, default='None',
            help='merge some labels to one label and give the result, example:"[[0,1,4],[2,3,5]]" -> label(0,1,4) regard as 0,label(2,3,5) regard as 1')
        self.parser.add_argument('--mergelabel_name', type=str, default='None',help='name of labels,example:"a,b,c,d,e,f"')
        
H
hypox64 已提交
48 49 50
        # ------------------------Network------------------------
        """Available Network
        1d: lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d,
H
Add mlp  
HypoX64 已提交
51
            micro_multi_scale_resnet_1d,autoencoder,mlp
H
hypox64 已提交
52
        2d: mobilenet, dfcnn, multi_scale_resnet, resnet18, resnet50, resnet101,
H
hypox64 已提交
53 54 55 56 57 58 59 60
            densenet121, densenet201, squeezenet
        """
        self.parser.add_argument('--model_name', type=str, default='micro_multi_scale_resnet_1d',help='Choose model  lstm...')
        self.parser.add_argument('--model_type', type=str, default='auto',help='1d | 2d')
        # For lstm 
        self.parser.add_argument('--lstm_inputsize', type=str, default='auto',help='lstm_inputsize of LSTM')
        self.parser.add_argument('--lstm_timestep', type=int, default=100,help='time_step of LSTM')
        # For autoecoder
H
hypox64 已提交
61
        self.parser.add_argument('--feature', type=int, default=3, help='number of encoder features')
H
hypox64 已提交
62
        # For 2d network(stft spectrum)
H
hypox64 已提交
63
        # Please cheek ./save_dir/spectrum_eg.jpg to change the following parameters
H
hypox64 已提交
64 65
        self.parser.add_argument('--stft_size', type=int, default=512, help='length of each fft segment')
        self.parser.add_argument('--stft_stride', type=int, default=128, help='stride of each fft segment')
H
hypox64 已提交
66
        self.parser.add_argument('--stft_n_downsample', type=int, default=1, help='downsample befor stft')
H
hypox64 已提交
67
        self.parser.add_argument('--stft_no_log', action='store_true', help='if specified, do not log1p spectrum')
H
hypox64 已提交
68 69
        self.parser.add_argument('--stft_shape', type=str, default='auto', help='shape of stft. It depend on \
            stft_size,stft_stride,stft_n_downsample. Do not input this parameter.')
H
hypox64 已提交
70 71

        # ------------------------Training Matters------------------------
H
hypox64 已提交
72
        self.parser.add_argument('--pretrained', type=str, default='',help='pretrained model path. If not specified, fo not use pretrained model')
H
hypox64 已提交
73
        self.parser.add_argument('--continue_train', action='store_true', help='if specified, continue train')
H
hypox64 已提交
74
        self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate') 
H
hypox64 已提交
75
        self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize')
H
hypox64 已提交
76
        self.parser.add_argument('--weight_mod', type=str, default='auto',help='Choose weight mode: auto | normal')
H
hypox64 已提交
77
        self.parser.add_argument('--epochs', type=int, default=20,help='end epoch')
HypoX64's avatar
HypoX64 已提交
78
        self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network')
H
hypox64 已提交
79

HypoX64's avatar
HypoX64 已提交
80 81 82 83 84 85
        self.initialized = True

    def getparse(self):
        if not self.initialized:
            self.initialize()
        self.opt = self.parser.parse_args()
HypoX64's avatar
HypoX64 已提交
86

H
hypox64 已提交
87 88 89
        if self.opt.gpu_id != -1:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.gpu_id)

H
hypox64 已提交
90
        if self.opt.label != 'auto':
H
hypox64 已提交
91 92 93
            self.opt.label = int(self.opt.label)
        if self.opt.input_nc !='auto':
            self.opt.input_nc = int(self.opt.input_nc)
H
hypox64 已提交
94 95
        if self.opt.loadsize !='auto':
            self.opt.loadsize = int(self.opt.loadsize)
H
hypox64 已提交
96 97
        if self.opt.finesize !='auto':
            self.opt.finesize = int(self.opt.finesize)
H
hypox64 已提交
98 99 100 101 102
        if self.opt.lstm_inputsize != 'auto':
            self.opt.lstm_inputsize = int(self.opt.lstm_inputsize)

        if self.opt.model_type == 'auto':
            if self.opt.model_name in ['lstm', 'cnn_1d', 'resnet18_1d', 'resnet34_1d', 
H
Add mlp  
HypoX64 已提交
103
                'multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder','mlp']:
H
hypox64 已提交
104 105
                self.opt.model_type = '1d'
            elif self.opt.model_name in ['dfcnn', 'multi_scale_resnet', 'resnet18', 'resnet50',
H
hypox64 已提交
106
                'resnet101','densenet121', 'densenet201', 'squeezenet', 'mobilenet']:
H
hypox64 已提交
107 108 109 110
                self.opt.model_type = '2d'
            else:
                print('\033[1;31m'+'Error: do not support this network '+self.opt.model_name+'\033[0m')
                exit(0)
111

H
hypox64 已提交
112 113 114
        if self.opt.k_fold == 0 :
            self.opt.k_fold = 1

H
hypox64 已提交
115 116 117 118 119
        if self.opt.fold_index != 'auto':
            self.opt.fold_index = eval(self.opt.fold_index)

        if os.path.isfile(os.path.join(self.opt.dataset_dir,'index.npy')):
            self.opt.fold_index = (np.load(os.path.join(self.opt.dataset_dir,'index.npy'))).tolist()
H
hypox64 已提交
120

121
        self.opt.mergelabel = eval(self.opt.mergelabel)
H
hypox64 已提交
122 123 124
        if self.opt.mergelabel_name != 'None':
            self.opt.mergelabel_name = self.opt.mergelabel_name.replace(" ", "").split(",")

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        """Print and save options
        It will print both current options and default values(if different).
        It will save options into a text file / [checkpoints_dir] / opt.txt
        """
        message = ''
        message += '----------------- Options ---------------\n'
        for k, v in sorted(vars(self.opt).items()):
            comment = ''
            default = self.parser.get_default(k)
            if v != default:
                comment = '\t[default: %s]' % str(default)
            message += '{:>20}: {:<30}{}\n'.format(str(k), str(v), comment)
        message += '----------------- End -------------------'
        localtime = time.asctime(time.localtime(time.time()))
        util.makedirs(self.opt.save_dir)
        util.writelog(str(localtime)+'\n'+message, self.opt,True)

        return self.opt
H
hypox64 已提交
143

H
hypox64 已提交
144
def get_auto_options(opt,label_cnt_per,label_num,signals):
H
hypox64 已提交
145
    
H
hypox64 已提交
146
    shape = signals.shape
H
hypox64 已提交
147 148 149 150
    if opt.label =='auto':
        opt.label = label_num
    if opt.input_nc =='auto':
        opt.input_nc = shape[1]
H
hypox64 已提交
151 152
    if opt.loadsize =='auto':
        opt.loadsize = shape[2]
H
hypox64 已提交
153 154
    if opt.finesize =='auto':
        opt.finesize = int(shape[2]*0.9)
H
hypox64 已提交
155 156
    if opt.lstm_inputsize =='auto':
        opt.lstm_inputsize = opt.finesize//opt.lstm_timestep
H
hypox64 已提交
157 158 159 160 161 162 163

    # weight
    opt.weight = np.ones(opt.label)
    if opt.weight_mod == 'auto':
        opt.weight = 1/label_cnt_per
        opt.weight = opt.weight/np.min(opt.weight)
    util.writelog('Loss_weight:'+str(opt.weight),opt,True)
H
hypox64 已提交
164
    import torch
H
hypox64 已提交
165 166 167 168 169 170
    opt.weight = torch.from_numpy(opt.weight).float()
    if opt.gpu_id != -1:      
        opt.weight = opt.weight.cuda()

    # label name
    if opt.label_name == 'auto':
H
hypox64 已提交
171 172 173 174
        names = []
        for i in range(opt.label):
            names.append(str(i))
        opt.label_name = names
H
hypox64 已提交
175
    elif not isinstance(opt.label_name,list):
H
hypox64 已提交
176
        opt.label_name = opt.label_name.replace(" ", "").split(",")
H
hypox64 已提交
177 178 179 180


    # check stft spectrum
    if opt.model_type =='2d':
H
hypox64 已提交
181 182 183 184 185 186 187 188 189 190
        spectrums = []
        data = signals[np.random.randint(0,shape[0]-1)]
        for i in range(shape[1]):
            spectrums.append(dsp.signal2spectrum(data[i],opt.stft_size, opt.stft_stride, opt.stft_n_downsample, not opt.stft_no_log))
        plot.draw_spectrums(spectrums,opt)
        opt.stft_shape = spectrums[0].shape
        h,w = opt.stft_shape
        print('Shape of stft spectrum h,w:',opt.stft_shape)
        print('\033[1;37m'+'Please cheek ./save_dir/spectrum_eg.jpg to change parameters'+'\033[0m')
        
H
hypox64 已提交
191 192 193 194 195
        if h<64 or w<64:
            print('\033[1;33m'+'Warning: spectrum is too small'+'\033[0m') 
        if h>512 or w>512:
            print('\033[1;33m'+'Warning: spectrum is too large'+'\033[0m')

H
hypox64 已提交
196
    return opt