提交 2531fdb8 编写于 作者: C channingss

fix bug for output_nodes

上级 edc2430a
......@@ -162,6 +162,17 @@ class ONNXGraph(Graph):
if ipt_data not in inner_nodes:
self.place_holder_nodes.append(ipt_data)
def get_output_nodes(self):
"""
generate output_nodes node of ONNX model
"""
inner_nodes = self.get_inner_nodes()
output_nodes = [value.name for value in self.model.output]
for opt_data in output_nodes:
if opt_data not in inner_nodes:
self.output_nodes.append(opt_data)
print(opt_data)
def is_place_holder_nodes(self, layer):
"""
return layer is or not place_holder node
......
......@@ -140,6 +140,7 @@ class ONNXOpMapper(OpMapper):
model.graph.output.MergeFrom(outputs)
onnx.save(model, os.path.join(self.tmp_data_dir,
'onnx_model_infer.onnx'))
os.system('onnx_infer --save_dir=' + self.tmp_data_dir)
return
......@@ -336,7 +337,8 @@ class ONNXOpMapper(OpMapper):
node = parameter
dtype = node.dtype
shape = node.out_shapes[0]
if len(node.weight.shape) == 0:
shape = [1]
self.weights[node.layer_name] = node.weight
attr = {
'dtype': string(dtype),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册