提交 7919a9bb 编写于 作者: D dengkaipeng

enable load with postfix

上级 cfb0f787
......@@ -798,6 +798,13 @@ class Model(fluid.dygraph.Layer):
format(key, list(state.shape), list(param.shape)))
return param, state
def _strip_postfix(path):
path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
"Unknown postfix {} from weights".format(ext)
return path
path = _strip_postfix(path)
param_state = _load_state_from_path(path + ".pdparams")
assert param_state, "Failed to load parameters, please check path."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册