提交 1da94211 编写于 作者: C Channingss

support ONNX >=1.6.0

上级 9270b81d
......@@ -15,7 +15,7 @@ paddlepaddle >= 1.8.0
**按需安装以下依赖**
tensorflow : tensorflow == 1.14.0
caffe : 无
onnx : onnx == 1.6.0
onnx : onnx >= 1.6.0
## 安装
### 安装方式一(推荐)
......
......@@ -170,8 +170,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
try:
import onnx
version = onnx.version.version
if version != '1.6.0':
print("[ERROR] onnx==1.6.0 is required")
if version < '1.6.0':
print("[ERROR] onnx>=1.6.0 is required")
return
except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
......
......@@ -642,14 +642,15 @@ class OpSet9():
elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode):
indices_cast = indices.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=indices,
output=indices,
output=indices_cast,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'embedding',
inputs=indices,
inputs=indices_cast,
output=node,
use_fluid=True,
param_attr={
......@@ -1140,7 +1141,7 @@ class OpSet9():
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y}
if y_shape[0] == 1 and x_shape[-1] != 1:
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册