未验证 提交 fc2d90b7 编写于 作者: B Bubbliiiing 提交者: GitHub

Update ssd.py

上级 55669332
......@@ -47,6 +47,7 @@ class SSD(object):
self.__dict__.update(self._defaults)
self.class_names = self._get_class()
self.generate()
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
......@@ -56,21 +57,23 @@ class SSD(object):
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
#---------------------------------------------------#
# 获得所有的分类
# 载入模型
#---------------------------------------------------#
def generate(self):
# 计算总的种类
#-------------------------------#
# 计算总的类的数量
#-------------------------------#
self.num_classes = len(self.class_names) + 1
# 载入模型
#-------------------------------#
# 载入模型与权值
#-------------------------------#
model = ssd.get_ssd("test", self.num_classes, self.confidence, self.nms_iou)
# 载入权重
print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = model.eval()
if self.cuda:
......@@ -145,7 +148,9 @@ class SSD(object):
top_bboxes = np.array(top_bboxes)
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0], -1),np.expand_dims(top_bboxes[:,1], -1),np.expand_dims(top_bboxes[:,2], -1),np.expand_dims(top_bboxes[:,3], -1)
# 去掉灰条
#-----------------------------------------------------------#
# 去掉灰条部分
#-----------------------------------------------------------#
boxes = ssd_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax, np.array([self.input_shape[0],self.input_shape[1]]), image_shape)
font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册