提交 7a0c7ef4 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mge/module): add module for extern-c-opr

GitOrigin-RevId: a2d9fa067a5db245a3f1546268888a747dab01f1
上级 09d2b7c3
...@@ -8,9 +8,10 @@ ...@@ -8,9 +8,10 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse
import megengine._internal as mgb
import numpy as np import numpy as np
import yaml import yaml
from megengine import jit
from megengine.module.external import ExternOprSubgraph
# "1,3,224,224" -> (1,3,224,224) # "1,3,224,224" -> (1,3,224,224)
...@@ -89,26 +90,19 @@ def main(): ...@@ -89,26 +90,19 @@ def main():
+ raw_param + raw_param
) )
# cn not ensured net = ExternOprSubgraph(wk_raw_content, "mace", osizes)
cn = mgb.comp_node("xpux") net.eval()
cg = mgb.comp_graph()
inp = [
mgb.make_shared(
comp_node=cn,
comp_graph=cg,
shape=isizes[i],
name=input_names[i],
dtype=np.float32,
)
for i in range(len(isizes))
]
oup = mgb.opr.extern_c_opr_placeholder( @jit.trace(symbolic=True)
inp, osizes, dump_name="mace", dump_data=wk_raw_content, def inference(inputs):
) return net(inputs)
inputs = [
np.random.random(isizes[i]).astype(np.float32) for i in range(len(isizes))
]
mgb.serialize_comp_graph_to_file(args.output, oup) inference.trace(inputs)
inference.dump(args.output)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册