未验证 提交 09d35587 编写于 作者: J Jason 提交者: GitHub

Merge pull request #396 from Channingss/fix_nms

fix bug of multiclass_nms when attr:keep_top_k==-1
......@@ -72,6 +72,8 @@ def multiclass_nms(op, block):
dims=(),
vals=[float(attrs['nms_threshold'])]))
boxes_num = block.var(outputs['Out'][0]).shape[0]
top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k'])
node_keep_top_k = onnx.helper.make_node(
'Constant',
inputs=[],
......@@ -80,7 +82,7 @@ def multiclass_nms(op, block):
name=name_keep_top_k[0] + "@const",
data_type=onnx.TensorProto.INT64,
dims=(),
vals=[np.int64(attrs['keep_top_k'])]))
vals=[top_k_value]))
node_keep_top_k_2D = onnx.helper.make_node(
'Constant',
......@@ -90,7 +92,7 @@ def multiclass_nms(op, block):
name=name_keep_top_k_2D[0] + "@const",
data_type=onnx.TensorProto.INT64,
dims=[1, 1],
vals=[np.int64(attrs['keep_top_k'])]))
vals=[top_k_value]))
# the paddle data format is x1,y1,x2,y2
kwargs = {'center_point_box': 0}
......
......@@ -72,6 +72,8 @@ def multiclass_nms(op, block):
dims=(),
vals=[float(attrs['nms_threshold'])]))
boxes_num = block.var(outputs['Out'][0]).shape[0]
top_k_value = np.int64(boxes_num if attrs['keep_top_k'] == -1 else attrs['keep_top_k'])
node_keep_top_k = onnx.helper.make_node(
'Constant',
inputs=[],
......@@ -80,7 +82,7 @@ def multiclass_nms(op, block):
name=name_keep_top_k[0] + "@const",
data_type=onnx.TensorProto.INT64,
dims=(),
vals=[np.int64(attrs['keep_top_k'])]))
vals=[top_k_value]))
node_keep_top_k_2D = onnx.helper.make_node(
'Constant',
......@@ -90,7 +92,7 @@ def multiclass_nms(op, block):
name=name_keep_top_k_2D[0] + "@const",
data_type=onnx.TensorProto.INT64,
dims=[1, 1],
vals=[np.int64(attrs['keep_top_k'])]))
vals=[top_k_value]))
# the paddle data format is x1,y1,x2,y2
kwargs = {'center_point_box': 0}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册