提交 9748aebe 编写于 作者: M Megvii Engine Team

refactor(mge): tensor_shape -> symbolic_shape

GitOrigin-RevId: 366dc048bfd7473a6bd148cb5d1ab70235aa43f1
上级 8acc3acf
...@@ -9,20 +9,20 @@ ...@@ -9,20 +9,20 @@
import os import os
_use_tensor_shape = False _use_symbolic_shape = False
if os.environ.get("MEGENGINE_USE_TENSOR_SHAPE"): if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
_use_tensor_shape = True _use_symbolic_shape = True
def use_tensor_shape() -> bool: def use_symbolic_shape() -> bool:
"""Returns whether tensor.shape returns a tensor instead of a tuple """Returns whether tensor.shape returns a tensor instead of a tuple
""" """
return _use_tensor_shape return _use_symbolic_shape
def set_tensor_shape(option: bool): def set_symbolic_shape(option: bool):
""" Sets whether tensor.shape returns a tensor instead of a tuple """ Sets whether tensor.shape returns a tensor instead of a tuple
""" """
global _use_tensor_shape global _use_symbolic_shape
_use_tensor_shape = option _use_symbolic_shape = option
...@@ -10,7 +10,7 @@ from typing import Iterable ...@@ -10,7 +10,7 @@ from typing import Iterable
import numpy as np import numpy as np
from .._trace_option import use_tensor_shape from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from .core import TensorBase, TensorWrapperBase, apply from .core import TensorBase, TensorWrapperBase, apply
...@@ -58,7 +58,7 @@ def check_bool_index(tensor, tuple_val): ...@@ -58,7 +58,7 @@ def check_bool_index(tensor, tuple_val):
) )
) )
i = i.reshape(-1) i = i.reshape(-1)
if not use_tensor_shape(): if not use_symbolic_shape():
cur_shape = ( cur_shape = (
cur_shape[:idx] cur_shape[:idx]
+ (i.shape[0],) + (i.shape[0],)
...@@ -76,7 +76,7 @@ def check_bool_index(tensor, tuple_val): ...@@ -76,7 +76,7 @@ def check_bool_index(tensor, tuple_val):
offset += 1 offset += 1
tensor = tensor.reshape(cur_shape) tensor = tensor.reshape(cur_shape)
tdim += tot tdim += tot
if use_tensor_shape(): if use_symbolic_shape():
cur_shape = make_shape_tuple(cur_shape) cur_shape = make_shape_tuple(cur_shape)
new_tuple_val.append(i) new_tuple_val.append(i)
else: else:
......
...@@ -11,7 +11,7 @@ import collections ...@@ -11,7 +11,7 @@ import collections
import numpy as np import numpy as np
from .._trace_option import use_tensor_shape from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.builtin import GetVarShape from ..ops.builtin import GetVarShape
from ..ops.special import Const from ..ops.special import Const
...@@ -342,7 +342,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -342,7 +342,7 @@ class ArrayMethodMixin(abc.ABC):
def __len__(self): def __len__(self):
shape = self.shape shape = self.shape
if use_tensor_shape(): if use_symbolic_shape():
shape = shape.numpy() shape = shape.numpy()
if shape: if shape:
return int(shape[0]) return int(shape[0])
...@@ -372,7 +372,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -372,7 +372,7 @@ class ArrayMethodMixin(abc.ABC):
@property @property
def size(self): def size(self):
if use_tensor_shape(): if use_symbolic_shape():
return self.shape.prod() return self.shape.prod()
return np.prod(self.shape).item() return np.prod(self.shape).item()
...@@ -462,7 +462,7 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): ...@@ -462,7 +462,7 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):
@property @property
def shape(self): def shape(self):
if use_tensor_shape(): if use_symbolic_shape():
return apply(GetVarShape(), self)[0] return apply(GetVarShape(), self)[0]
else: else:
return self.__wrapped__.shape return self.__wrapped__.shape
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr from ..core._imperative_rt.ops import OprAttr
from ..core._trace_option import set_tensor_shape from ..core._trace_option import set_symbolic_shape
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
...@@ -121,7 +121,7 @@ class trace: ...@@ -121,7 +121,7 @@ class trace:
sublinear_memory_config: SublinearMemoryConfig = None, sublinear_memory_config: SublinearMemoryConfig = None,
profiling: bool = False, profiling: bool = False,
opt_level: int = None, opt_level: int = None,
tensor_shape: bool = True, symbolic_shape: bool = True,
): ):
self.__wrapped__ = function self.__wrapped__ = function
self._symbolic = symbolic self._symbolic = symbolic
...@@ -130,7 +130,7 @@ class trace: ...@@ -130,7 +130,7 @@ class trace:
self._profiling = profiling self._profiling = profiling
self._profiler = None self._profiler = None
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._tensor_shape = tensor_shape self._symbolic_shape = symbolic_shape
self._reset() self._reset()
...@@ -152,7 +152,7 @@ class trace: ...@@ -152,7 +152,7 @@ class trace:
self._output_bindings = None self._output_bindings = None
self._output_names = None self._output_names = None
set_tensor_shape(self._tensor_shape) set_symbolic_shape(self._symbolic_shape)
def _new_handle(self): def _new_handle(self):
handle = len(self._tinfo) handle = len(self._tinfo)
......
...@@ -18,7 +18,7 @@ import megengine as mge ...@@ -18,7 +18,7 @@ import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import jit from megengine import jit
from megengine.core._trace_option import set_tensor_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.functional.debug_param import set_conv_execution_strategy from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.jit import SublinearMemoryConfig from megengine.jit import SublinearMemoryConfig
......
...@@ -13,7 +13,7 @@ import pytest ...@@ -13,7 +13,7 @@ import pytest
import megengine.core.ops.builtin import megengine.core.ops.builtin
import megengine.core.tensor.raw_tensor import megengine.core.tensor.raw_tensor
from megengine.core._trace_option import use_tensor_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops._internal import all_ops from megengine.core.ops._internal import all_ops
from megengine.core.tensor import Tensor from megengine.core.tensor import Tensor
from megengine.core.tensor.core import apply from megengine.core.tensor.core import apply
...@@ -532,7 +532,7 @@ def test_advance_indexing_with_bool(): ...@@ -532,7 +532,7 @@ def test_advance_indexing_with_bool():
np.testing.assert_equal(a, aa.numpy()) np.testing.assert_equal(a, aa.numpy())
# XXX: trace does not expect empty condtake tensor # XXX: trace does not expect empty condtake tensor
if not use_tensor_shape(): if not use_symbolic_shape():
a = np.ones((2, 2), dtype=np.int32) a = np.ones((2, 2), dtype=np.int32)
b = np.array([[False, False], [False, False]]) b = np.array([[False, False], [False, False]])
aa = Tensor(a) aa = Tensor(a)
......
...@@ -17,7 +17,7 @@ import megengine.core.ops.builtin as builtin ...@@ -17,7 +17,7 @@ import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype import megengine.core.tensor.dtype as dtype
import megengine.functional as F import megengine.functional as F
from megengine import Parameter, Tensor, is_cuda_available, tensor from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.core._trace_option import use_tensor_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
......
...@@ -15,7 +15,7 @@ from utils import opr_test ...@@ -15,7 +15,7 @@ from utils import opr_test
import megengine.functional as F import megengine.functional as F
from megengine import tensor from megengine import tensor
from megengine.core._trace_option import use_tensor_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
......
...@@ -16,7 +16,7 @@ import pytest ...@@ -16,7 +16,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Tensor from megengine import Tensor
from megengine.core._trace_option import use_tensor_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
......
...@@ -15,7 +15,7 @@ import pytest ...@@ -15,7 +15,7 @@ import pytest
import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F import megengine.functional as F
from megengine import cgtools, tensor from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.tensor.core import apply from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.tensor.raw_tensor import as_raw_tensor
...@@ -238,7 +238,7 @@ def test_optimize_for_inference(): ...@@ -238,7 +238,7 @@ def test_optimize_for_inference():
def test_optimize_for_inference_broadcast(): def test_optimize_for_inference_broadcast():
a = tensor(np.ones(1, dtype=np.float32)) a = tensor(np.ones(1, dtype=np.float32))
@trace(capture_as_const=True, tensor_shape=True) @trace(capture_as_const=True, symbolic_shape=True)
def f(): def f():
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) (b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
return b return b
...@@ -248,7 +248,7 @@ def test_optimize_for_inference_broadcast(): ...@@ -248,7 +248,7 @@ def test_optimize_for_inference_broadcast():
def test_trace_cvt_bool(): def test_trace_cvt_bool():
set_tensor_shape(True) set_symbolic_shape(True)
x = tensor([0], dtype=np.int32) x = tensor([0], dtype=np.int32)
@trace(symbolic=True) @trace(symbolic=True)
...@@ -261,7 +261,7 @@ def test_trace_cvt_bool(): ...@@ -261,7 +261,7 @@ def test_trace_cvt_bool():
def test_trace_reshape(): def test_trace_reshape():
for symbolic in [False, True]: for symbolic in [False, True]:
set_tensor_shape(True) set_symbolic_shape(True)
x1 = tensor(np.random.randn(2, 10, 10)) x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10)) x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10)) x3 = tensor(np.random.randn(8, 10, 10))
...@@ -344,7 +344,7 @@ def test_raise_on_trace(): ...@@ -344,7 +344,7 @@ def test_raise_on_trace():
def test_trace_broadcast(): def test_trace_broadcast():
for symbolic in [False, True]: for symbolic in [False, True]:
set_tensor_shape(True) set_symbolic_shape(True)
x1 = tensor(np.random.randn(3, 1, 1)) x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1)) x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5)) x3 = tensor(np.random.randn(1, 1, 5))
...@@ -382,7 +382,7 @@ def test_trace_nms(): ...@@ -382,7 +382,7 @@ def test_trace_nms():
def test_trace_valid_broadcast(): def test_trace_valid_broadcast():
set_tensor_shape(True) set_symbolic_shape(True)
x1 = tensor(np.random.randn(1, 1)) x1 = tensor(np.random.randn(1, 1))
x2 = tensor(np.random.randn(1, 2)) x2 = tensor(np.random.randn(1, 2))
shape = (tensor([2]), tensor([2])) shape = (tensor([2]), tensor([2]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册