未验证 提交 edcf04ca 编写于 作者: S songyouwei 提交者: GitHub

[cherry-pick] fix pickle between python 2 & 3 (#22620)

* cherry-pick #22555
test=release/1.7, test=develop

* cherry-pick #22621
test=release/1.7, test=develop
上级 c000f8a2
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import collections import collections
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase
import pickle import pickle
import six
from . import learning_rate_scheduler from . import learning_rate_scheduler
import warnings import warnings
from .. import core from .. import core
...@@ -88,7 +89,7 @@ def save_dygraph(state_dict, model_path): ...@@ -88,7 +89,7 @@ def save_dygraph(state_dict, model_path):
os.makedirs(dir_name) os.makedirs(dir_name)
with open(file_name, 'wb') as f: with open(file_name, 'wb') as f:
pickle.dump(model_dict, f) pickle.dump(model_dict, f, protocol=2)
@dygraph_only @dygraph_only
...@@ -130,7 +131,8 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -130,7 +131,8 @@ def load_dygraph(model_path, keep_name_table=False):
params_file_path)) params_file_path))
with open(params_file_path, 'rb') as f: with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
if not keep_name_table and "StructuredToParameterName@@" in para_dict: if not keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"] del para_dict["StructuredToParameterName@@"]
...@@ -138,6 +140,7 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -138,6 +140,7 @@ def load_dygraph(model_path, keep_name_table=False):
opti_file_path = model_path + ".pdopt" opti_file_path = model_path + ".pdopt"
if os.path.exists(opti_file_path): if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f: with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
return para_dict, opti_dict return para_dict, opti_dict
...@@ -800,7 +800,7 @@ def load_vars(executor, ...@@ -800,7 +800,7 @@ def load_vars(executor,
var_temp = paddle.fluid.global_scope().find_var(each_var.name) var_temp = paddle.fluid.global_scope().find_var(each_var.name)
assert var_temp != None, "can't not find var: " + each_var.name assert var_temp != None, "can't not find var: " + each_var.name
new_shape = (np.array(var_temp.get_tensor())).shape new_shape = (np.array(var_temp.get_tensor())).shape
assert each_var.name in orig_para_shape, earch_var.name + "MUST in var list" assert each_var.name in orig_para_shape, each_var.name + "MUST in var list"
orig_shape = orig_para_shape.get(each_var.name) orig_shape = orig_para_shape.get(each_var.name)
if new_shape != orig_shape: if new_shape != orig_shape:
raise RuntimeError( raise RuntimeError(
...@@ -1579,14 +1579,14 @@ def save(program, model_path): ...@@ -1579,14 +1579,14 @@ def save(program, model_path):
parameter_list = list(filter(is_parameter, program.list_vars())) parameter_list = list(filter(is_parameter, program.list_vars()))
param_dict = {p.name: get_tensor(p) for p in parameter_list} param_dict = {p.name: get_tensor(p) for p in parameter_list}
with open(model_path + ".pdparams", 'wb') as f: with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f) pickle.dump(param_dict, f, protocol=2)
optimizer_var_list = list( optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars())) filter(is_belong_to_optimizer, program.list_vars()))
opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list} opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f: with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f) pickle.dump(opt_dict, f, protocol=2)
main_program = program.clone() main_program = program.clone()
program.desc.flush() program.desc.flush()
...@@ -1733,7 +1733,8 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -1733,7 +1733,8 @@ def load(program, model_path, executor=None, var_list=None):
global_scope(), global_scope(),
executor._default_executor) executor._default_executor)
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
load_dict = pickle.load(f) load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
for v in parameter_list: for v in parameter_list:
assert v.name in load_dict, \ assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format( "Can not find [{}] in model file [{}]".format(
...@@ -1753,7 +1754,8 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -1753,7 +1754,8 @@ def load(program, model_path, executor=None, var_list=None):
optimizer_var_list, global_scope(), executor._default_executor) optimizer_var_list, global_scope(), executor._default_executor)
with open(opt_file_name, 'rb') as f: with open(opt_file_name, 'rb') as f:
load_dict = pickle.load(f) load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
for v in optimizer_var_list: for v in optimizer_var_list:
assert v.name in load_dict, \ assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format( "Can not find [{}] in model file [{}]".format(
...@@ -1877,12 +1879,14 @@ def load_program_state(model_path, var_list=None): ...@@ -1877,12 +1879,14 @@ def load_program_state(model_path, var_list=None):
"Parameter file [{}] not exits".format(parameter_file_name) "Parameter file [{}] not exits".format(parameter_file_name)
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
opt_file_name = model_prefix + ".pdopt" opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name): if os.path.exists(opt_file_name):
with open(opt_file_name, 'rb') as f: with open(opt_file_name, 'rb') as f:
opti_dict = pickle.load(f) opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
para_dict.update(opti_dict) para_dict.update(opti_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册