提交 ba9492a4 编写于 作者: B BBuf

fix export onnx script bug

上级 2489e500
""" """
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -22,8 +19,6 @@ from typing import Callable, Text ...@@ -22,8 +19,6 @@ from typing import Callable, Text
import numpy as np import numpy as np
import oneflow as flow import oneflow as flow
import oneflow_onnx
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
import oneflow.typing as tp import oneflow.typing as tp
import onnx import onnx
import onnxruntime as ort import onnxruntime as ort
...@@ -31,6 +26,7 @@ import onnxruntime as ort ...@@ -31,6 +26,7 @@ import onnxruntime as ort
from resnet_model import resnet50 from resnet_model import resnet50
import config as configs import config as configs
from imagenet1000_clsidx_to_labels import clsidx_2_labels from imagenet1000_clsidx_to_labels import clsidx_2_labels
from oneflow_onnx.oneflow2onnx.util import export_onnx_model
parser = configs.get_parser() parser = configs.get_parser()
args = parser.parse_args() args = parser.parse_args()
...@@ -94,9 +90,15 @@ def oneflow_to_onnx( ...@@ -94,9 +90,15 @@ def oneflow_to_onnx(
assert os.path.exists(flow_weights_path) and os.path.isdir(onnx_model_dir) assert os.path.exists(flow_weights_path) and os.path.isdir(onnx_model_dir)
onnx_model_path = os.path.join( onnx_model_path = os.path.join(
onnx_model_dir, os.path.basename(flow_weights_path) + ".onnx" onnx_model_dir, "model.onnx"
)
export_onnx_model(
job_func,
flow_weight_dir=flow_weights_path,
onnx_model_path=onnx_model_dir,
opset=11,
external_data=external_data,
) )
convert_to_onnx_and_check(job_func, flow_weight_dir=flow_weights_path, onnx_model_path=onnx_model_path, opset=11, external_data=external_data)
print("Convert to onnx success! >> ", onnx_model_path) print("Convert to onnx success! >> ", onnx_model_path)
return onnx.load_model(onnx_model_path) return onnx.load_model(onnx_model_path)
...@@ -116,8 +118,6 @@ if __name__ == "__main__": ...@@ -116,8 +118,6 @@ if __name__ == "__main__":
# set up your model path # set up your model path
flow_weights_path = "resnet_v15_of_best_model_val_top1_77318" flow_weights_path = "resnet_v15_of_best_model_val_top1_77318"
onnx_model_dir = "onnx/model" onnx_model_dir = "onnx/model"
flow.train.CheckPoint().init()
flow.load_variables(flow.checkpoint.get(flow_weights_path)) flow.load_variables(flow.checkpoint.get(flow_weights_path))
...@@ -130,4 +130,4 @@ if __name__ == "__main__": ...@@ -130,4 +130,4 @@ if __name__ == "__main__":
are_equal, onnx_res = check_equality(InferenceNet, onnx_model, image_path) are_equal, onnx_res = check_equality(InferenceNet, onnx_model, image_path)
clsidx_onnx = onnx_res.argmax() clsidx_onnx = onnx_res.argmax()
print("Are the results equal? {}".format("Yes" if are_equal else "No")) print("Are the results equal? {}".format("Yes" if are_equal else "No"))
print("Class: {}; score: {}".format(clsidx_2_labels[clsidx_onnx], onnx_res.max())) print("Class: {}; score: {}".format(clsidx_2_labels[clsidx_onnx], onnx_res.max()))
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册