From fc2d90b7b5e99b5cead6a7c5d2435c01c51db04f Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Sun, 17 Jan 2021 20:34:54 +0800 Subject: [PATCH] Update ssd.py --- ssd.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/ssd.py b/ssd.py index 6386349..70ef294 100644 --- a/ssd.py +++ b/ssd.py @@ -47,6 +47,7 @@ class SSD(object): self.__dict__.update(self._defaults) self.class_names = self._get_class() self.generate() + #---------------------------------------------------# # 获得所有的分类 #---------------------------------------------------# @@ -56,21 +57,23 @@ class SSD(object): class_names = f.readlines() class_names = [c.strip() for c in class_names] return class_names + #---------------------------------------------------# - # 获得所有的分类 + # 载入模型 #---------------------------------------------------# def generate(self): - # 计算总的种类 + #-------------------------------# + # 计算总的类的数量 + #-------------------------------# self.num_classes = len(self.class_names) + 1 - # 载入模型 + #-------------------------------# + # 载入模型与权值 + #-------------------------------# model = ssd.get_ssd("test", self.num_classes, self.confidence, self.nms_iou) - - # 载入权重 print('Loading weights into state dict...') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict(torch.load(self.model_path, map_location=device)) - self.net = model.eval() if self.cuda: @@ -145,7 +148,9 @@ class SSD(object): top_bboxes = np.array(top_bboxes) top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0], -1),np.expand_dims(top_bboxes[:,1], -1),np.expand_dims(top_bboxes[:,2], -1),np.expand_dims(top_bboxes[:,3], -1) - # 去掉灰条 + #-----------------------------------------------------------# + # 去掉灰条部分 + #-----------------------------------------------------------# boxes = ssd_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax, np.array([self.input_shape[0],self.input_shape[1]]), image_shape) font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32')) -- GitLab