提交 b793e5b3 编写于 作者: W wjj19950828

Add dynamic shape

上级 f40171a4
...@@ -14,7 +14,7 @@ treelib ...@@ -14,7 +14,7 @@ treelib
## 使用方式 ## 使用方式
``` python ```python
from x2paddle.convert import pytorch2paddle from x2paddle.convert import pytorch2paddle
pytorch2paddle(module=torch_module, pytorch2paddle(module=torch_module,
save_dir="./pd_model", save_dir="./pd_model",
...@@ -27,11 +27,14 @@ pytorch2paddle(module=torch_module, ...@@ -27,11 +27,14 @@ pytorch2paddle(module=torch_module,
``` ```
**注意:** 当jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; **注意:** 当jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静;
当jit_type为"script"时",input_examples不为None时,才可以进行动转静。
当jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。
## 使用示例 ## 使用示例
``` python ### Trace 模式
```python
import torch import torch
import numpy as np import numpy as np
from torchvision.models import AlexNet from torchvision.models import AlexNet
...@@ -51,3 +54,48 @@ pytorch2paddle(torch_module, ...@@ -51,3 +54,48 @@ pytorch2paddle(torch_module,
jit_type="trace", jit_type="trace",
input_examples=[torch.tensor(input_data)]) input_examples=[torch.tensor(input_data)])
``` ```
### Script 模式动态 shape 导出
```python
import torch
import numpy as np
from torchvision.models import AlexNet
from torchvision.models.utils import load_state_dict_from_url
# 获取PyTorch Module
torch_module = AlexNet()
torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式
torch_module.eval()
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_module,
save_dir="pd_model_script",
jit_type="script",
input_examples=None)
```
在自动生成的x2paddle_code.py中添加如下代码:
```python
def main(x0):
# There are 0 inputs.
paddle.disable_static()
params = paddle.load('model.pdparams')
model = AlexNet()
model.set_dict(params)
model.eval()
## convert to jit
sepc_list = list()
sepc_list.append(
paddle.static.InputSpec(
shape=[-1, 3, -1, -1], name="x0", dtype="float32"))
static_model = paddle.jit.to_static(model, input_spec=sepc_list)
paddle.jit.save(static_model, "pd_model_script/inference_model/model")
out = model(x0)
return out
```
运行main函数导出动态shape的静态图模型,若导出失败,可尝试动态shape导出onnx,再从onnx转到paddle,[相关文档](https://pytorch.org/docs/stable/onnx.html?highlight=onnx%20export#torch.onnx.export)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册