提交 30664b37 编写于 作者: B BBuf

add reshape op

上级 231cec63
...@@ -13,21 +13,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,21 +13,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import tempfile
import oneflow as flow import oneflow as flow
import oneflow.typing as tp
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
class Reshape(flow.nn.Module):
def __init__(self) -> None:
super(Reshape, self).__init__()
def forward(self, x: flow.Tensor) -> flow.Tensor:
return flow.reshape(x, (1, 3, -1))
def test_reshape(): reshape = Reshape()
@flow.global_function() class reshapeOpGraph(flow.nn.Graph):
def reshape(x: tp.Numpy.Placeholder((3, 4, 2, 5))): def __init__(self):
return flow.reshape(x, (4, 30)) super().__init__()
self.m = reshape
def build(self, x):
out = self.m(x)
return out
convert_to_onnx_and_check(reshape)
def test_reshape_negative_dim(): def test_reshape():
@flow.global_function()
def reshape(x: tp.Numpy.Placeholder((3, 4, 2, 5))): reshape_graph = reshapeOpGraph()
return flow.reshape(x, (3, -1)) reshape_graph._compile(flow.randn(1, 3, 224, 224))
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(reshape.state_dict(), tmpdirname)
convert_to_onnx_and_check(reshape_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp")
convert_to_onnx_and_check(reshape) test_reshape()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册