diff --git a/nets/ssd_training.py b/nets/ssd_training.py index b41fa8f0cf90dbbf56fc2bf4f525fad33323b8ce..e49b82cf705720f5441663b711326dd1028bcfde 100644 --- a/nets/ssd_training.py +++ b/nets/ssd_training.py @@ -115,7 +115,7 @@ class Generator(object): self.image_size = image_size self.num_classes = num_classes-1 - def get_random_data(self, annotation_line, input_shape, jitter=.1, hue=.1, sat=1.1, val=1.1): + def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5): '''r实时数据增强的随机预处理''' line = annotation_line.split() image = Image.open(line[0]) @@ -125,7 +125,7 @@ class Generator(object): # resize image new_ar = w/h * rand(1-jitter,1+jitter)/rand(1-jitter,1+jitter) - scale = rand(.25, 2) + scale = rand(.5, 1.5) if new_ar < 1: nh = int(scale*h) nw = int(nh*new_ar)