未验证 提交 20265d73 编写于 作者: J Jason 提交者: GitHub

Merge pull request #357 from mamingjie-China/develop

fix bug in 1.8.2
......@@ -58,7 +58,7 @@ x2paddle --framework=paddle2onnx --model=paddle_infer_model_dir --save_dir=onnx_
|--save_dir | 指定转换后的模型保存目录路径 |
|--model | 当framework为tensorflow/onnx时,该参数指定tensorflow的pb模型文件或onnx模型路径 |
|--caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None |
|--without_data_format_optimization | **[可选]** For TensorFlow, 当指定该参数时,关闭NHWC->NCHW的优化,见[文档Q2](FAQ.md) |
|--without_data_format_optimization | **[可选]** For TensorFlow, 当指定该参数为False时,打开NHWC->NCHW的优化,见[文档Q2](FAQ.md),默认为True|
|--define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](FAQ.md) |
|--params_merge | **[可选]** 当指定该参数时,转换完成后,inference_model中的所有模型参数将合并保存为一个文件__params__ |
|--onnx_opset | **[可选]** 当framework为paddle2onnx时,该参数可设置转换为ONNX的OpSet版本,目前支持9、10、11,默认为10 |
......
......@@ -66,8 +66,8 @@ def arg_parser():
parser.add_argument(
"--without_data_format_optimization",
"-wo",
action="store_true",
default=False,
type=_text_type,
default="True",
help="tf model conversion without data format optimization")
parser.add_argument(
"--define_input_shape",
......@@ -93,7 +93,7 @@ def arg_parser():
def tf2paddle(model_path,
save_dir,
without_data_format_optimization=False,
without_data_format_optimization,
define_input_shape=False,
params_merge=False):
# check tensorflow installation and version
......@@ -240,11 +240,12 @@ def main():
if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False
assert args.without_data_format_optimization in [
"True", "False"
], "--the param without_data_format_optimization should be defined True or False"
define_input_shape = False
params_merge = False
if args.without_data_format_optimization:
without_data_format_optimization = True
without_data_format_optimization = True if args.without_data_format_optimization == "True" else False
if args.define_input_shape:
define_input_shape = True
if args.params_merge:
......
......@@ -1068,13 +1068,25 @@ class TFOpMapperNHWC(OpMapper):
axis = axis.value.tolist()
assert axis == 0, "Only support axis=0 in GatherV2 OP"
attr = {'overwrite': False}
embeddings_shape = embeddings.out_shapes[0][-1]
reshape_list = list()
reshape_name = index.layer_name
if len(index.out_shapes[0]) != 1:
reshape_list = index.out_shapes[0]
reshape_attr = {"shape": [-1]}
reshape_name = "{}_reshape".format(index.layer_name)
node.fluid_code.add_layer(
"reshape", inputs=index, output=index, param_attr=reshape_attr)
inputs = {'input': embeddings, 'index': index}
"reshape",
inputs=index,
output=reshape_name,
param_attr=reshape_attr)
inputs = {'input': embeddings, 'index': reshape_name}
node.fluid_code.add_layer(
"gather", inputs=inputs, output=node, param_attr=attr)
if len(index.out_shapes[0]) != 1:
reshape_attr = {"shape": reshape_list + [embeddings_shape]}
node.fluid_code.add_layer(
"reshape", inputs=node, output=node, param_attr=reshape_attr)
def OneShotIterator(self, node):
return self.Placeholder(node)
......
......@@ -864,8 +864,8 @@ class TFOptimizer(object):
weight = numpy.expand_dims(weight, 3)
self.op_mapper.weights[in_nodes3[0].layer_name] = weight
# fix bug in Paddle1.8.3 and may change in next version.
self.op_mapper.weights[in_nodes3[0].layer_name +
'_1'] = weight.reshape(1, -1)
# self.op_mapper.weights[in_nodes3[0].layer_name +
# '_1'] = weight.reshape(1, -1)
in_nodes3[0].fluid_code.layers[0].param_attr["shape"] = [
1, in_shape[-1], 1, 1
]
......@@ -888,7 +888,7 @@ class TFOptimizer(object):
node.fluid_code.clear()
attr = {
"mode": string(mode),
"param_attr": string(in_nodes3[0].layer_name + "_1")
"param_attr": string(in_nodes3[0].layer_name)
}
node.fluid_code.add_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册