提交 7f06bb94 编写于 作者: M Megvii Engine Team

refactor(mge/functional): remove dependence to trace in functional implementations

GitOrigin-RevId: 0b18479fccd551a9ab2902ae5f086176e6c58d0a
上级 dac2b9e7
...@@ -16,7 +16,6 @@ from ..core.tensor import utils ...@@ -16,7 +16,6 @@ from ..core.tensor import utils
from ..core.tensor.array_method import _elwise_apply from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astype from ..core.tensor.utils import astype
from ..device import get_default_device from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_func from ..utils.deprecation import deprecated_func
...@@ -560,8 +559,8 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor: ...@@ -560,8 +559,8 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor:
), "At least one of 'lower' or 'upper' must not be None" ), "At least one of 'lower' or 'upper' must not be None"
if lower is not None: if lower is not None:
if upper is not None: if upper is not None:
if not is_tracing(): # FIXME: following assertion won't work during trace if upper and lower are Tensors
assert lower <= upper, "clip lower bound is bigger that upper bound" # assert lower <= upper, "clip lower bound is bigger that upper bound"
return minimum(maximum(x, lower), upper) return minimum(maximum(x, lower), upper)
else: else:
return maximum(x, lower) return maximum(x, lower)
......
...@@ -12,7 +12,6 @@ from ..core._imperative_rt.core2 import apply ...@@ -12,7 +12,6 @@ from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.utils import astensor1d from ..core.tensor.utils import astensor1d
from ..jit.tracing import is_tracing
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import floor from .elemwise import floor
from .math import argsort from .math import argsort
...@@ -226,6 +225,10 @@ def nms( ...@@ -226,6 +225,10 @@ def nms(
otherwise it required to be specified; if it is not specified, all boxes are kept. otherwise it required to be specified; if it is not specified, all boxes are kept.
:return: indices of the elements that have been kept by NMS, sorted by scores. :return: indices of the elements that have been kept by NMS, sorted by scores.
.. note::
max_output should be specified and should have valid positive value under tracing
Examples: Examples:
.. testcode:: .. testcode::
...@@ -263,11 +266,6 @@ def nms( ...@@ -263,11 +266,6 @@ def nms(
sorted_idx = argsort(scores, descending=True) sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx] boxes = boxes[sorted_idx]
if is_tracing():
assert (
max_output is not None and max_output > 0
), "max_output should be specified under tracing"
if max_output is None: if max_output is None:
max_output = boxes.shape[0] max_output = boxes.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册