diff --git a/sdk/c-opr-loaders/mace/dump_model.py b/sdk/c-opr-loaders/mace/dump_model.py index ed5e0bb1e09b3e95e405eed345a7f73358e83f8a..21e41022694c4fdf8a0c680f24f108fed038c97b 100644 --- a/sdk/c-opr-loaders/mace/dump_model.py +++ b/sdk/c-opr-loaders/mace/dump_model.py @@ -8,9 +8,10 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse -import megengine._internal as mgb import numpy as np import yaml +from megengine import jit +from megengine.module.external import ExternOprSubgraph # "1,3,224,224" -> (1,3,224,224) @@ -89,26 +90,19 @@ def main(): + raw_param ) - # cn not ensured - cn = mgb.comp_node("xpux") - cg = mgb.comp_graph() - - inp = [ - mgb.make_shared( - comp_node=cn, - comp_graph=cg, - shape=isizes[i], - name=input_names[i], - dtype=np.float32, - ) - for i in range(len(isizes)) - ] + net = ExternOprSubgraph(wk_raw_content, "mace", osizes) + net.eval() - oup = mgb.opr.extern_c_opr_placeholder( - inp, osizes, dump_name="mace", dump_data=wk_raw_content, - ) + @jit.trace(symbolic=True) + def inference(inputs): + return net(inputs) + + inputs = [ + np.random.random(isizes[i]).astype(np.float32) for i in range(len(isizes)) + ] - mgb.serialize_comp_graph_to_file(args.output, oup) + inference.trace(inputs) + inference.dump(args.output) if __name__ == "__main__":