From 69c60ce67b53f759e638752b0225342d15cf5a7e Mon Sep 17 00:00:00 2001 From: "Eric.Lee2021" <305141918@qq.com> Date: Mon, 8 Mar 2021 04:20:43 +0800 Subject: [PATCH] update data iter --- data_iter/datasets.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/data_iter/datasets.py b/data_iter/datasets.py index 27ec9c1..f036814 100644 --- a/data_iter/datasets.py +++ b/data_iter/datasets.py @@ -97,7 +97,7 @@ def M_rotate_image(image , angle , cx , cy): return cv2.warpAffine(image , M , (nW , nH)) , M class LoadImagesAndLabels(Dataset): # for training/testing - def __init__(self, path, img_size=(224,224), flag_agu = False,fix_res = True,val_split = []): + def __init__(self, path, img_size=(224,224), flag_agu = False,fix_res = True,val_split = [],have_label_file = False): print('img_size (height,width) : ',img_size[0],img_size[1]) labels_ = [] files_ = [] @@ -117,6 +117,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing self.img_size = img_size self.flag_agu = flag_agu self.fix_res = fix_res + self.have_label_file = have_label_file def __len__(self): return len(self.files) @@ -129,22 +130,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing img = cv2.imread(img_path) # BGR #-------------------------------------------- - xml_ = img_path.replace(".jpg",".xml") + if self.have_label_file: + xml_ = img_path.replace(".jpg",".xml") - list_x = get_xml_msg(xml_)# 获取 xml 文件 的 object + list_x = get_xml_msg(xml_)# 获取 xml 文件 的 object - # 绘制 bbox - choose_idx = random.randint(0,int(len(list_x)-1)) - for j in range(len(list_x)): - if j ==choose_idx: - _,bbox_ = list_x[j] - x1,y1,x2,y2 = bbox_ - x1 = int(np.clip(x1,0,img.shape[1]-1)) - y1 = int(np.clip(y1,0,img.shape[0]-1)) - x2 = int(np.clip(x2,0,img.shape[1]-1)) - y2 = int(np.clip(y2,0,img.shape[0]-1)) - img = img[y1:y2,x1:x2,:] - break + # 绘制 bbox + choose_idx = random.randint(0,int(len(list_x)-1)) + for j in range(len(list_x)): + if j ==choose_idx: + _,bbox_ = list_x[j] + x1,y1,x2,y2 = bbox_ + x1 = int(np.clip(x1,0,img.shape[1]-1)) + y1 = int(np.clip(y1,0,img.shape[0]-1)) + x2 = int(np.clip(x2,0,img.shape[1]-1)) + y2 = int(np.clip(y2,0,img.shape[0]-1)) + img = img[y1:y2,x1:x2,:] + break @@ -158,9 +160,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing if self.flag_agu == True and random.random()>0.6: cx = int(img.shape[1]/2) cy = int(img.shape[0]/2) - angle = random.randint(-45,45) - offset_x = random.randint(-3,3) - offset_y = random.randint(-3,3) + # 手势(gesture)分类建议是全角度旋转, 对于 Stanford dogs 数据集适当角度旋转扰动,目的是为了符合真实样本旋转角度样本分布情况。 + angle = random.randint(-180,180) + range_limit_x = int(min(6,img.shape[1]/16)) + range_limit_y = int(min(6,img.shape[0]/16)) + offset_x = random.randint(-range_limit_x,range_limit_x) + offset_y = random.randint(-range_limit_y,range_limit_y) if not(angle==0 and offset_x==0 and offset_y==0): img,_ = M_rotate_image(img , angle , cx+offset_x , cy+offset_y) -- GitLab