提交 3f51c5ce 编写于 作者: C Channingss

fix bug of nearest opset11

上级 0a40bebb
......@@ -29,6 +29,7 @@ class OpSet11(OpSet10):
super(OpSet11, self).__init__()
def relu6(self, op, block):
print('relu6')
min_name = self.get_name(op.type, 'min')
max_name = self.get_name(op.type, 'max')
min_node = self.make_constant_node(min_name, onnx_pb.TensorProto.FLOAT,
......@@ -42,6 +43,7 @@ class OpSet11(OpSet10):
return [min_node, max_node, node]
def pad2d(self, op, block):
print('pad2d')
x_shape = block.var(op.input('X')[0]).shape
paddings = op.attr('paddings')
onnx_pads = []
......@@ -69,6 +71,7 @@ class OpSet11(OpSet10):
return [pads_node, constant_value_node, node]
def clip(self, op, block):
print('clip')
min_name = self.get_name(op.type, 'min')
max_name = self.get_name(op.type, 'max')
min_node = self.make_constant_node(min_name, onnx_pb.TensorProto.FLOAT,
......@@ -82,6 +85,7 @@ class OpSet11(OpSet10):
return [min_node, max_node, node]
def bilinear_interp(self, op, block):
print('bilinear')
input_names = op.input_names
coordinate_transformation_mode = ''
align_corners = op.attr('align_corners')
......@@ -196,11 +200,12 @@ class OpSet11(OpSet10):
if align_corners:
coordinate_transformation_mode = 'align_corners'
else:
coordinate_transformation_mode = 'asymmetric'
coordinate_transformation_mode = 'half_pixel'
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name, onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
print('0000')
node = helper.make_node(
'Resize',
inputs=[op.input('X')[0], roi_name, op.input('OutSize')[0]],
......@@ -208,6 +213,7 @@ class OpSet11(OpSet10):
mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode)
elif 'Scale' in input_names and len(op.input('Scale')) > 0:
print('1111')
node = helper.make_node(
'Resize',
inputs=[op.input('X')[0], roi_name, op.input('Scale')[0]],
......@@ -215,6 +221,7 @@ class OpSet11(OpSet10):
mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode)
else:
print('2222')
out_shape = [op.attr('out_h'), op.attr('out_w')]
scale = op.attr('scale')
if out_shape.count(-1) > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册