提交 c5f8b583 编写于 作者: M Megvii Engine Team

docs(docstring): transfer to google style

GitOrigin-RevId: a71245c553e763bf1ed5b04913da4d2a1d4cd2c5
上级 76ce81e8
...@@ -11,38 +11,37 @@ from ..core.tensor import amp ...@@ -11,38 +11,37 @@ from ..core.tensor import amp
class autocast: class autocast:
r""" r"""A class to control autocast mode for amp as a context manager or a decorator.
A class to control autocast mode for amp as a context manager or a decorator.
:param enabled: Whether autocast mode is enabled. Args:
:param low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change enabled: Whether autocast mode is enabled.
the target dtype in tensor casting for better speed and memory. Default: float16. low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change
:param high_prec_dtype: Set amp autocast mode's higher precision dtype. It will the target dtype in tensor casting for better speed and memory. Default: float16.
change the target dtype in tensor casting for better precision. Default: float32. high_prec_dtype: Set amp autocast mode's higher precision dtype. It will
change the target dtype in tensor casting for better precision. Default: float32.
Examples: Examples:
.. code-block::
.. code-block:: # used as decorator
@autocast()
def train_step(image, label):
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
# used as decorator # used as context manager
@autocast() def train_step(image, label):
def train_step(image, label): with autocast():
with gm: with gm:
logits = model(image) logits = model(image)
loss = F.nn.cross_entropy(logits, label) loss = F.nn.cross_entropy(logits, label)
gm.backward(loss) gm.backward(loss)
opt.step().clear_grad() opt.step().clear_grad()
return loss return loss
# used as context manager
def train_step(image, label):
with autocast():
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
""" """
def __init__( def __init__(
......
...@@ -16,50 +16,51 @@ from ..tensor import Tensor ...@@ -16,50 +16,51 @@ from ..tensor import Tensor
class GradScaler: class GradScaler:
r""" r"""A helper class that performs grad scaling to prevent from data overflow in
A helper class that performs grad scaling to prevent from data overflow in
:class:`~.autocast` mode. :class:`~.autocast` mode.
:param init_scale: Initial scale factor. Args:
:param growth_factor: Factor that the scale is multiplied by in actual init_scale: Initial scale factor.
:meth:`update` stage. If growth_factor is 0, scale_factor will not update. growth_factor: Factor that the scale is multiplied by in actual
:param backoff_factor: Factor that the scale is multiplied by when encountering :meth:`update` stage. If growth_factor is 0, scale_factor will not update.
overflow grad. backoff_factor: Factor that the scale is multiplied by when encountering
:param growth_interval: The interval between two scale update stages. overflow grad.
growth_interval: The interval between two scale update stages.
Example::
Example:
gm = GradManager() .. code-block::
opt = ...
scaler = GradScaler() gm = GradManager()
opt = ...
gm.attach(model.parameters()) scaler = GradScaler()
@autocast() gm.attach(model.parameters())
def train_step(image, label):
with gm: @autocast()
logits = model(image) def train_step(image, label):
loss = F.nn.cross_entropy(logits, label) with gm:
scaler.backward(gm, loss) logits = model(image)
opt.step().clear_grad() loss = F.nn.cross_entropy(logits, label)
return loss scaler.backward(gm, loss)
opt.step().clear_grad()
If need more flexible usage, could split ``scaler.backward`` into three lines: return loss
.. code-block:: If need more flexible usage, could split ``scaler.backward`` into three lines:
@autocast() .. code-block::
def train_step(image, label):
with gm: @autocast()
logits = model(image) def train_step(image, label):
loss = F.nn.cross_entropy(logits, label) with gm:
gm.backward(loss, dy=megengine.tensor(scaler.scale_factor)) logits = model(image)
scaler.unscale(gm.attached_tensors()) loss = F.nn.cross_entropy(logits, label)
scaler.update() gm.backward(loss, dy=megengine.tensor(scaler.scale_factor))
opt.step().clear_grad() scaler.unscale(gm.attached_tensors())
return loss scaler.update()
opt.step().clear_grad()
This is useful when need to accumulate grads for multi batches. return loss
This is useful when need to accumulate grads for multi batches.
""" """
def __init__( def __init__(
...@@ -86,18 +87,18 @@ class GradScaler: ...@@ -86,18 +87,18 @@ class GradScaler:
unscale_grad: bool = True, unscale_grad: bool = True,
update_scale: bool = "if_unscale_grad" update_scale: bool = "if_unscale_grad"
): ):
r""" r"""A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale
A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale
``y``'s grad and unscale parameters' grads. ``y``'s grad and unscale parameters' grads.
:param gm: The to be wrapped GradManager. Args:
:param y: Same as GradManager backward's ``y``. gm: The to be wrapped GradManager.
:param dy: Same as GradManager backward's ``dy``. Will be multiplied y: Same as GradManager backward's ``y``.
by ``scale_factor``. dy: Same as GradManager backward's ``dy``. Will be multiplied
:param unscale_grad: Whether do :meth:`unscale` at the same time. Could be by ``scale_factor``.
``False`` if needs to accumulate grads. unscale_grad: Whether do :meth:`unscale` at the same time. Could be
:param update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored ``False`` if needs to accumulate grads.
if ``unscale_grad`` is ``False``. update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored
if ``unscale_grad`` is ``False``.
""" """
# These checks should be consistent with GradManager's # These checks should be consistent with GradManager's
if y is None: if y is None:
...@@ -121,11 +122,11 @@ class GradScaler: ...@@ -121,11 +122,11 @@ class GradScaler:
self.update() self.update()
def unscale(self, grad_tensors: Iterable[Tensor]): def unscale(self, grad_tensors: Iterable[Tensor]):
r""" r"""Unscale all ``grad_tensors``'s grad.
Unscale all ``grad_tensors``'s grad.
:param grad_tensors: Tensors needed to unscale grads. Should be all tensors Args:
that are affected by ``target`` tensor in GradManager's backward. grad_tensors: Tensors needed to unscale grads. Should be all tensors
that are affected by ``target`` tensor in GradManager's backward.
""" """
# use float64 for better precision # use float64 for better precision
inv_scale = Tensor(1.0 / self.scale_factor) inv_scale = Tensor(1.0 / self.scale_factor)
...@@ -151,7 +152,8 @@ class GradScaler: ...@@ -151,7 +152,8 @@ class GradScaler:
def update(self, new_scale: float = None): def update(self, new_scale: float = None):
r"""Update the scale factor according to whether encountered overflow grad. r"""Update the scale factor according to whether encountered overflow grad.
If ``new_scale`` is provided, internal update mechanism will be ignored.""" If ``new_scale`` is provided, internal update mechanism will be ignored.
"""
if self.growth_interval == 0: if self.growth_interval == 0:
return return
......
...@@ -32,8 +32,7 @@ _global_priority = 0 ...@@ -32,8 +32,7 @@ _global_priority = 0
class GradManager: class GradManager:
r""" r"""GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode
GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode
automatic differentiation (a.k.a. back propagation). automatic differentiation (a.k.a. back propagation).
Reverse mode autodiff normally reuses many intermediate tensors for best computation efficiency. Reverse mode autodiff normally reuses many intermediate tensors for best computation efficiency.
...@@ -120,7 +119,6 @@ class GradManager: ...@@ -120,7 +119,6 @@ class GradManager:
gm = GradManager() gm = GradManager()
gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN")) gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))
""" """
def __init__(self): def __init__(self):
...@@ -136,8 +134,7 @@ class GradManager: ...@@ -136,8 +134,7 @@ class GradManager:
return [spec.tensor() for spec in self._attach_specs.values()] return [spec.tensor() for spec in self._attach_specs.values()]
def attach(self, tensors: Iterable[Tensor], callbacks=None): def attach(self, tensors: Iterable[Tensor], callbacks=None):
r""" r"""Instruct GradManager to track operations on tensors, so that gradients with respect
Instruct GradManager to track operations on tensors, so that gradients with respect
to those tensors could be evaluated later. to those tensors could be evaluated later.
:meth:`attach` also accepts a list of callbacks, which will be called with the tensor and :meth:`attach` also accepts a list of callbacks, which will be called with the tensor and
...@@ -188,8 +185,9 @@ class GradManager: ...@@ -188,8 +185,9 @@ class GradManager:
multiple uses of a GradManager, which is unrelated to whether resources is timely multiple uses of a GradManager, which is unrelated to whether resources is timely
released within a single use. released within a single use.
:param tensors: tensor or list of tensors to track Args:
:param callbacks: callback or list of callbacks tensors: tensor or list of tensors to track
callbacks: callback or list of callbacks
""" """
if callbacks is None: if callbacks is None:
callbacks = [] callbacks = []
...@@ -234,8 +232,7 @@ class GradManager: ...@@ -234,8 +232,7 @@ class GradManager:
y: Union[Tensor, List[Tensor]] = None, y: Union[Tensor, List[Tensor]] = None,
dy: Union[Tensor, List[Tensor]] = None, dy: Union[Tensor, List[Tensor]] = None,
): ):
r""" r"""Compute gradients (or vector-Jacobian product) for all attached tensors, accumulate to
Compute gradients (or vector-Jacobian product) for all attached tensors, accumulate to
corresponding .grad attribute, and release resources along the way. corresponding .grad attribute, and release resources along the way.
:meth:`backward` computes the vector-Jacobian product :math:`dx_j = \sum_{i} dy_i J_{ij}` :meth:`backward` computes the vector-Jacobian product :math:`dx_j = \sum_{i} dy_i J_{ij}`
...@@ -257,8 +254,9 @@ class GradManager: ...@@ -257,8 +254,9 @@ class GradManager:
process of this call. When the call successfully finishes, the GradManager will be put back process of this call. When the call successfully finishes, the GradManager will be put back
to an inactive state. to an inactive state.
:param y: tensor or list of tensors Args:
:param dy: tensor or list of tensors. Defaults to 1 if y is scalar y: tensor or list of tensors
dy: tensor or list of tensors. Defaults to 1 if y is scalar
""" """
push_scope("backward") push_scope("backward")
set_option("record_computing_path", 0) set_option("record_computing_path", 0)
...@@ -310,8 +308,7 @@ class GradManager: ...@@ -310,8 +308,7 @@ class GradManager:
pop_scope("backward") pop_scope("backward")
def record(self): def record(self):
r""" r"""Start recording operations
Start recording operations
After this call, you will be able to call :meth:`backward`. After this call, you will be able to call :meth:`backward`.
""" """
...@@ -342,8 +339,7 @@ class GradManager: ...@@ -342,8 +339,7 @@ class GradManager:
self._grad.wrt(tensor, callback=callback) self._grad.wrt(tensor, callback=callback)
def release(self): def release(self):
r""" r"""Stop recording operations and release resources kept for gradient computation
Stop recording operations and release resources kept for gradient computation
After this call, you will not be able to call :meth:`backward`. After this call, you will not be able to call :meth:`backward`.
""" """
......
...@@ -15,16 +15,12 @@ if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"): ...@@ -15,16 +15,12 @@ if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
def use_symbolic_shape() -> bool: def use_symbolic_shape() -> bool:
""" r"""Returns whether tensor.shape returns a tensor instead of a tuple"""
Returns whether tensor.shape returns a tensor instead of a tuple
"""
return _use_symbolic_shape return _use_symbolic_shape
def set_symbolic_shape(option: bool): def set_symbolic_shape(option: bool):
""" Sets whether tensor.shape returns a tensor instead of a tuple r"""Sets whether tensor.shape returns a tensor instead of a tuple"""
"""
global _use_symbolic_shape global _use_symbolic_shape
_org = _use_symbolic_shape _org = _use_symbolic_shape
_use_symbolic_shape = option _use_symbolic_shape = option
......
...@@ -88,67 +88,56 @@ class Grad: ...@@ -88,67 +88,56 @@ class Grad:
class Function(ops.PyOpBase): class Function(ops.PyOpBase):
""" r"""Defines a block of operations with customizable differentiation.
Defines a block of operations with customizable differentiation.
The computation should be defined in ``forward`` method, with gradient The computation should be defined in ``forward`` method, with gradient
computation defined in ``backward`` method. computation defined in ``backward`` method.
Each instance of ``Function`` should be used only once during forwardding. Each instance of ``Function`` should be used only once during forwardding.
Examples: Examples:
.. code-block:: .. code-block::
class Sigmoid(Function): class Sigmoid(Function):
def forward(self, x): def forward(self, x):
y = 1 / (1 + F.exp(-x)) y = 1 / (1 + F.exp(-x))
self.y = y self.y = y
return y return y
def backward(self, dy): def backward(self, dy):
y = self.y y = self.y
return dy * y * (1-y)
""" """
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
""" r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
Args:
:param input: input tensors. input: input tensors.
:return: a tuple of Tensor or a single Tensor.
Returns:
.. note:: a tuple of Tensor or a single Tensor.
This method should return a tuple of Tensor or a single Tensor representing the output Note:
of the function. * This method should return a tuple of Tensor or a single Tensor representing the output
of the function.
.. note:: * positional arguments should all be Tensor
positional arguments should all be Tensor
""" """
raise NotImplementedError raise NotImplementedError
def backward(self, *output_grads): def backward(self, *output_grads):
""" r"""Compute the gradient of the forward function. It must be overriden by all subclasses.
Compute the gradient of the forward function. It must be overriden by all subclasses.
Args:
:param output_grads: gradients of outputs that are returned by :meth:`forward`. output_grads: gradients of outputs that are returned by :meth:`forward`.
.. note:: Note:
* In case when some tensors of outputs are not related to loss function, the corresponding
In case when some tensors of outputs are not related to loss function, the corresponding values in ``output_grads`` would be ``None``.
values in ``output_grads`` would be ``None``. * This method should return a tuple which containing the gradients of all inputs, in the same order
as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned
.. note:: instead if there is only one input. If users want to stop the propagation of some gradients,
the corresponding returned values should be set ``None`` .
This method should return a tuple which containing the gradients of all inputs, in the same order
as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned
instead if there is only one input. If users want to stop the propagation of some gradients,
the corresponding returned values should be set ``None`` .
""" """
raise NotImplementedError raise NotImplementedError
......
...@@ -12,16 +12,14 @@ _low_prec_dtype = "float16" ...@@ -12,16 +12,14 @@ _low_prec_dtype = "float16"
@property @property
def enabled(mod): def enabled(mod):
r""" r"""Get or set amp autocast mode enabled or not.
Get or set amp autocast mode enabled or not.
Examples: Examples:
.. code-block::
.. code-block:: import megengine as mge
mge.amp.enabled = True
import megengine as mge
mge.amp.enabled = True
""" """
return _enabled return _enabled
...@@ -34,17 +32,15 @@ def enabled(mod, enabled: bool): ...@@ -34,17 +32,15 @@ def enabled(mod, enabled: bool):
@property @property
def high_prec_dtype(mod): def high_prec_dtype(mod):
r""" r"""Get or set amp autocast mode's higher precision dtype. It will change the
Get or set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32. target dtype in tensor casting for better precision. Default: float32.
Examples: Examples:
.. code-block::
.. code-block:: import megengine as mge
mge.amp.high_prec_dtype = "float32"
import megengine as mge
mge.amp.high_prec_dtype = "float32"
""" """
return _high_prec_dtype return _high_prec_dtype
...@@ -57,17 +53,15 @@ def high_prec_dtype(mod, dtype: str): ...@@ -57,17 +53,15 @@ def high_prec_dtype(mod, dtype: str):
@property @property
def low_prec_dtype(mod): def low_prec_dtype(mod):
r""" r"""Get or set amp autocast mode's lower precision dtype. It will change the
Get or set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16. target dtype in tensor casting for better speed and memory. Default: float16.
Examples: Examples:
.. code-block::
.. code-block:: import megengine as mge
mge.amp.low_prec_dtype = "float16"
import megengine as mge
mge.amp.low_prec_dtype = "float16"
""" """
return _low_prec_dtype return _low_prec_dtype
......
...@@ -389,9 +389,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -389,9 +389,7 @@ class ArrayMethodMixin(abc.ABC):
@property @property
def ndim(self): def ndim(self):
r""" r"""Returns the number of dimensions of self :class:`~.Tensor`."""
Returns the number of dimensions of self :class:`~.Tensor`.
"""
shape = self._tuple_shape shape = self._tuple_shape
if shape is None: if shape is None:
raise ValueError("unkown ndim") raise ValueError("unkown ndim")
...@@ -399,8 +397,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -399,8 +397,7 @@ class ArrayMethodMixin(abc.ABC):
@property @property
def size(self): def size(self):
r""" r"""Returns the size of the self :class:`~.Tensor`.
Returns the size of the self :class:`~.Tensor`.
The returned value is a subclass of :class:`tuple`. The returned value is a subclass of :class:`tuple`.
""" """
shape = self.shape shape = self.shape
...@@ -410,14 +407,11 @@ class ArrayMethodMixin(abc.ABC): ...@@ -410,14 +407,11 @@ class ArrayMethodMixin(abc.ABC):
@property @property
def T(self): def T(self):
r""" r"""alias of :attr:`~.Tensor.transpose`."""
alias of :attr:`~.Tensor.transpose`.
"""
return self.transpose() return self.transpose()
def item(self, *args): def item(self, *args):
r""" r"""Returns the value of this :class:`~.Tensor` as a standard Python :class:`numbers.Number`.
Returns the value of this :class:`~.Tensor` as a standard Python :class:`numbers.Number`.
This only works for tensors with one element. For other cases, see :meth:`~.tolist`. This only works for tensors with one element. For other cases, see :meth:`~.tolist`.
""" """
if not args: if not args:
...@@ -427,8 +421,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -427,8 +421,7 @@ class ArrayMethodMixin(abc.ABC):
return self[args].item() return self[args].item()
def tolist(self): def tolist(self):
r""" r"""Returns the tensor as a (nested) list.
Returns the tensor as a (nested) list.
For scalars, a standard Python number is returned, just like with :meth:`~.item`. For scalars, a standard Python number is returned, just like with :meth:`~.item`.
Tensors are automatically moved to the CPU first if necessary. Tensors are automatically moved to the CPU first if necessary.
...@@ -437,16 +430,13 @@ class ArrayMethodMixin(abc.ABC): ...@@ -437,16 +430,13 @@ class ArrayMethodMixin(abc.ABC):
return self.numpy().tolist() return self.numpy().tolist()
def astype(self, dtype): def astype(self, dtype):
r""" r"""Returns a :class:`Tensor` with the same data and number of elements
Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`. with the specified :attr:`~.Tensor.dtype`.
""" """
return astype(self, dtype) return astype(self, dtype)
def reshape(self, *args): def reshape(self, *args):
r""" r"""See :func:`~.reshape`."""
See :func:`~.reshape`.
"""
return _reshape(self, _expand_args(args)) return _reshape(self, _expand_args(args))
# FIXME: remove this method # FIXME: remove this method
...@@ -454,9 +444,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -454,9 +444,7 @@ class ArrayMethodMixin(abc.ABC):
return _broadcast(self, _expand_args(args)) return _broadcast(self, _expand_args(args))
def transpose(self, *args): def transpose(self, *args):
r""" r"""See :func:`~.transpose`."""
See :func:`~.transpose`.
"""
if self.ndim == 0: if self.ndim == 0:
assert ( assert (
len(args) == 0 len(args) == 0
...@@ -469,172 +457,170 @@ class ArrayMethodMixin(abc.ABC): ...@@ -469,172 +457,170 @@ class ArrayMethodMixin(abc.ABC):
return _transpose(self, _expand_args(args)) return _transpose(self, _expand_args(args))
def flatten(self): def flatten(self):
r""" r"""See :func:`~.flatten`."""
See :func:`~.flatten`.
"""
return self.reshape(-1) return self.reshape(-1)
def sum(self, axis=None, keepdims: bool = False): def sum(self, axis=None, keepdims: bool = False):
r""" r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them. If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1. except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).
:param axis: the dimension or dimensions to reduce. Args:
:param keepdims: whether the output tensor has ndim retained or not. axis: the dimension or dimensions to reduce.
:return: output tensor. keepdims: whether the output tensor has ndim retained or not.
Examples: Returns:
output tensor.
.. testcode::
from megengine import tensor Examples:
a = tensor([False, True, True, False]) .. testcode::
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.sum().numpy())
print(b.sum().numpy())
Outputs: from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.sum().numpy())
print(b.sum().numpy())
.. testoutput:: Outputs:
2 .. testoutput::
10.0
2
10.0
""" """
return _reduce("sum")(self, axis, keepdims) return _reduce("sum")(self, axis, keepdims)
def prod(self, axis=None, keepdims: bool = False): def prod(self, axis=None, keepdims: bool = False):
r""" r"""Returns the product of each row of the input tensor in the given dimension ``axis``.
Returns the product of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them. If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1. except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).
:param axis: the dimension or dimensions to reduce. Args:
:param keepdims: whether the output tensor has ndim retained or not. axis: the dimension or dimensions to reduce.
:return: output tensor. keepdims: whether the output tensor has ndim retained or not.
Examples:
.. testcode:: Returns:
output tensor.
from megengine import tensor Examples:
a = tensor([False, True, True, False]) .. testcode::
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.prod().numpy())
print(b.prod().numpy())
Outputs: from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.prod().numpy())
print(b.prod().numpy())
.. testoutput:: Outputs:
0 .. testoutput::
24.0
0
24.0
""" """
return _reduce("product")(self, axis, keepdims) return _reduce("product")(self, axis, keepdims)
def min(self, axis=None, keepdims: bool = False): def min(self, axis=None, keepdims: bool = False):
r""" r"""Returns the min value of each row of the input tensor in the given dimension ``axis``.
Returns the min value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them. If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1. except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).
:param axis: the dimension or dimensions to reduce. Args:
:param keepdims: whether the output tensor has ndim retained or not. axis: the dimension or dimensions to reduce.
:return: output tensor. keepdims: whether the output tensor has ndim retained or not.
Examples:
.. testcode:: Returns:
output tensor.
from megengine import tensor Examples:
a = tensor([False, True, True, False]) .. testcode::
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.min().numpy())
print(b.min().numpy())
Outputs: from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.min().numpy())
print(b.min().numpy())
.. testoutput:: Outputs:
False .. testoutput::
1.0
False
1.0
""" """
return _reduce("min")(self, axis, keepdims) return _reduce("min")(self, axis, keepdims)
def max(self, axis=None, keepdims: bool = False): def max(self, axis=None, keepdims: bool = False):
r""" r"""Returns the max value of each row of the input tensor in the given dimension ``axis``.
Returns the max value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them. If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1. except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).
:param axis: the dimension or dimensions to reduce. Args:
:param keepdims: whether the output tensor has ndim retained or not. axis: the dimension or dimensions to reduce.
:return: output tensor. keepdims: whether the output tensor has ndim retained or not.
Examples:
.. testcode:: Returns:
output tensor.
from megengine import tensor Examples:
a = tensor([False, True, True, False]) .. testcode::
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.max().numpy())
print(b.max().numpy())
Outputs: from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.max().numpy())
print(b.max().numpy())
.. testoutput:: Outputs:
True .. testoutput::
4.0
True
4.0
""" """
return _reduce("max")(self, axis, keepdims) return _reduce("max")(self, axis, keepdims)
def mean(self, axis=None, keepdims: bool = False): def mean(self, axis=None, keepdims: bool = False):
r""" r"""Returns the mean value of each row of the input tensor in the given dimension ``axis``.
Returns the mean value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them. If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1. except in the dimension(s) ``axis`` where it is of size 1.
Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`). Otherwise, ``axis`` is squeezed (see :func:`~.squeeze`).
:param axis: the dimension or dimensions to reduce. Args:
:param keepdims: whether the output tensor has ndim retained or not. axis: the dimension or dimensions to reduce.
:return: output tensor. keepdims: whether the output tensor has ndim retained or not.
Examples: Returns:
output tensor.
.. testcode:: Examples:
.. testcode::
from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.mean().numpy())
print(b.mean().numpy())
Outputs: from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.mean().numpy())
print(b.mean().numpy())
.. testoutput:: Outputs:
0.5 .. testoutput::
2.5
0.5
2.5
""" """
return _reduce("mean")(self, axis, keepdims) return _reduce("mean")(self, axis, keepdims)
...@@ -47,17 +47,17 @@ class QuantDtypeMeta( ...@@ -47,17 +47,17 @@ class QuantDtypeMeta(
["name", "cname", "np_dtype_str", "qmin", "qmax", "is_unsigned"], ["name", "cname", "np_dtype_str", "qmin", "qmax", "is_unsigned"],
) )
): ):
r""" r"""Store metadata for quantize dtype. Could be used to create custom quant dtype
Store metadata for quantize dtype. Could be used to create custom quant dtype
for QAT when the network don't need to be converted for inference, but only for QAT when the network don't need to be converted for inference, but only
to export network metadata for third-party platform inference. to export network metadata for third-party platform inference.
:param name: a unique name string. Args:
:param cname: used in :func:`~.create_quantized_dtype` for model dump and inference. name: a unique name string.
:param np_dtype_str: used in :func:`~.create_quantized_dtype` to generate ``np.dtype``. cname: used in :func:`~.create_quantized_dtype` for model dump and inference.
:param qmin: a int number indicating quant dtype's lowerbound. np_dtype_str: used in :func:`~.create_quantized_dtype` to generate ``np.dtype``.
:param qmax: a int number indicating quant dtype's upperbound. qmin: a int number indicating quant dtype's lowerbound.
:param is_unsigned: a helper value that could be inference from np_dtype_str. qmax: a int number indicating quant dtype's upperbound.
is_unsigned: a helper value that could be inference from np_dtype_str.
""" """
def __new__( def __new__(
...@@ -77,7 +77,7 @@ class QuantDtypeMeta( ...@@ -77,7 +77,7 @@ class QuantDtypeMeta(
return self return self
def __deepcopy__(self, _): def __deepcopy__(self, _):
""" r"""
Ignore deepcopy so that a dtype meta can be treated as singleton, for more Ignore deepcopy so that a dtype meta can be treated as singleton, for more
strict check in :meth:`~.FakeQuantize.fake_quant_forward`. strict check in :meth:`~.FakeQuantize.fake_quant_forward`.
""" """
...@@ -113,17 +113,17 @@ def _check_zero_point(zp: int, dtype_meta: QuantDtypeMeta): ...@@ -113,17 +113,17 @@ def _check_zero_point(zp: int, dtype_meta: QuantDtypeMeta):
def create_quantized_dtype( def create_quantized_dtype(
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None]
): ):
r""" r"""Get quantized dtype with metadata attribute according to _metadata_dict.
Get quantized dtype with metadata attribute according to _metadata_dict.
Note that unsigned dtype must have ``zero_point`` and signed dtype must Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`. compiled function from `CompGraph.compile(inputs, outspec)`.
:param dtype_meta: a QuantDtypeMeta indicating which dtype to return. the Args:
``cname`` attribute cannot be ``None``. dtype_meta: a QuantDtypeMeta indicating which dtype to return. the
:param scale: a number for scale to store in dtype's metadata ``cname`` attribute cannot be ``None``.
:param zp: a number for zero_point to store in dtype's metadata scale: a number for scale to store in dtype's metadata
zp: a number for zero_point to store in dtype's metadata
""" """
if dtype_meta.cname is None: if dtype_meta.cname is None:
raise ValueError("dtype {} without cname attr is not supported.") raise ValueError("dtype {} without cname attr is not supported.")
...@@ -152,8 +152,7 @@ def create_quantized_dtype( ...@@ -152,8 +152,7 @@ def create_quantized_dtype(
def quint8(scale, zero_point): def quint8(scale, zero_point):
""" r"""Consturct a quantized unsigned int8 data type with ``scale`` (float) and
Consturct a quantized unsigned int8 data type with ``scale`` (float) and
``zero_point`` (uint8). The real value represented by a quint8 data type is ``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point) float_val = scale * (uint8_val - zero_point)
""" """
...@@ -161,24 +160,21 @@ def quint8(scale, zero_point): ...@@ -161,24 +160,21 @@ def quint8(scale, zero_point):
def qint8(scale): def qint8(scale):
""" r"""Construct a quantized int8 data type with ``scale`` (float). The real value
Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val represented by a qint8 data type is float_val = scale * int8_val
""" """
return create_quantized_dtype(_builtin_quant_dtypes["qint8"], scale, None) return create_quantized_dtype(_builtin_quant_dtypes["qint8"], scale, None)
def qint32(scale): def qint32(scale):
""" r"""Construct a quantized int32 data type with ``scale`` (float). The real value
Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val represented by a qint32 data type is float_val = scale * int32_val
""" """
return create_quantized_dtype(_builtin_quant_dtypes["qint32"], scale, None) return create_quantized_dtype(_builtin_quant_dtypes["qint32"], scale, None)
def quint4(scale, zero_point): def quint4(scale, zero_point):
""" r"""Consturct a quantized unsigned int4 data type with ``scale`` (float) and
Consturct a quantized unsigned int4 data type with ``scale`` (float) and
``zero_point`` (uint8). The real value represented by a quint4 data type is ``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point) float_val = scale * (uint4_val - zero_point)
""" """
...@@ -186,8 +182,7 @@ def quint4(scale, zero_point): ...@@ -186,8 +182,7 @@ def quint4(scale, zero_point):
def qint4(scale): def qint4(scale):
""" r"""Construct a quantized int4 data type with ``scale`` (float). The real value
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val represented by a qint4 data type is float_val = scale * int4_val
""" """
return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None)
...@@ -244,95 +239,95 @@ def _convert_from_quantized_dtype(arr: np.ndarray, dtype_meta: QuantDtypeMeta): ...@@ -244,95 +239,95 @@ def _convert_from_quantized_dtype(arr: np.ndarray, dtype_meta: QuantDtypeMeta):
def convert_to_quint8(arr: np.ndarray, q: np.dtype): def convert_to_quint8(arr: np.ndarray, q: np.dtype):
""" r"""Quantize a float NumPy ndarray into a quint8 one with specified params.
Quantize a float NumPy ndarray into a quint8 one with specified params.
:param arr: Input ndarray. Args:
:param q: Target data type, should be a quint8. arr: Input ndarray.
q: Target data type, should be a quint8.
""" """
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint8"]) return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint8"])
def convert_from_quint8(arr: np.ndarray): def convert_from_quint8(arr: np.ndarray):
""" r"""Dequantize a quint8 NumPy ndarray into a float one.
Dequantize a quint8 NumPy ndarray into a float one.
:param arr: Input ndarray. Args:
arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint8"]) return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint8"])
def convert_to_qint8(arr: np.ndarray, q: np.dtype): def convert_to_qint8(arr: np.ndarray, q: np.dtype):
""" r"""Quantize a float NumPy ndarray into a qint8 one with specified params.
Quantize a float NumPy ndarray into a qint8 one with specified params.
:param arr: Input ndarray. Args:
:param q: Target data type, should be a qint8. arr: Input ndarray.
q: Target data type, should be a qint8.
""" """
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint8"]) return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint8"])
def convert_from_qint8(arr: np.ndarray): def convert_from_qint8(arr: np.ndarray):
""" r"""Dequantize a qint8 NumPy ndarray into a float one.
Dequantize a qint8 NumPy ndarray into a float one.
:param arr: Input ndarray. Args:
arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint8"]) return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint8"])
def convert_to_qint32(arr: np.ndarray, q: np.dtype): def convert_to_qint32(arr: np.ndarray, q: np.dtype):
""" r"""Quantize a float NumPy ndarray into a qint32 one with specified params.
Quantize a float NumPy ndarray into a qint32 one with specified params.
:param arr: Input ndarray. Args:
:param q: Target data type, should be a qint8. arr: Input ndarray.
q: Target data type, should be a qint8.
""" """
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint32"]) return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint32"])
def convert_from_qint32(arr): def convert_from_qint32(arr):
""" r"""Dequantize a qint32 NumPy ndarray into a float one.
Dequantize a qint32 NumPy ndarray into a float one.
:param arr: Input ndarray. Args:
arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint32"]) return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint32"])
def convert_to_quint4(arr: np.ndarray, q: np.dtype): def convert_to_quint4(arr: np.ndarray, q: np.dtype):
""" r"""Quantize a float NumPy ndarray into a quint4 one with specified params.
Quantize a float NumPy ndarray into a quint4 one with specified params.
:param arr: Input ndarray. Args:
:param q: Target data type, should be a quint4. arr: Input ndarray.
q: Target data type, should be a quint4.
""" """
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint4"]) return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint4"])
def convert_from_quint4(arr: np.ndarray): def convert_from_quint4(arr: np.ndarray):
""" r"""Dequantize a quint4 NumPy ndarray into a float one.
Dequantize a quint4 NumPy ndarray into a float one.
:param arr: Input ndarray. Args:
arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint4"]) return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint4"])
def convert_to_qint4(arr: np.ndarray, q: np.dtype): def convert_to_qint4(arr: np.ndarray, q: np.dtype):
""" r"""Quantize a float NumPy ndarray into a qint4 one with specified params.
Quantize a float NumPy ndarray into a qint4 one with specified params.
:param arr: Input ndarray. Args:
:param q: Target data type, should be a qint4. arr: Input ndarray.
q: Target data type, should be a qint4.
""" """
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint4"]) return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint4"])
def convert_from_qint4(arr: np.ndarray): def convert_from_qint4(arr: np.ndarray):
""" r"""Dequantize a qint4 NumPy ndarray into a float one.
Dequantize a qint4 NumPy ndarray into a float one.
:param arr: Input ndarray. Args:
arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"])
...@@ -24,11 +24,11 @@ from .core import TensorBase ...@@ -24,11 +24,11 @@ from .core import TensorBase
def set_priority_to_id(dest_vars): def set_priority_to_id(dest_vars):
""" r"""For all oprs in the subgraph constructed by dest_vars,
For all oprs in the subgraph constructed by dest_vars,
sets its priority to id if its original priority is zero. sets its priority to id if its original priority is zero.
:param dest_vars: target vars representing the graph. Args:
dest_vars: target vars representing the graph.
""" """
dest_vec = [] dest_vec = []
for i in dest_vars: for i in dest_vars:
...@@ -220,54 +220,50 @@ class OpNode: ...@@ -220,54 +220,50 @@ class OpNode:
def optimize_for_inference(dest_vars, **kwargs): def optimize_for_inference(dest_vars, **kwargs):
r""" r"""Applies optimize_for_inference pass for computing graph.
Applies optimize_for_inference pass for computing graph.
Args:
:param dest_vars: list of output vars in the computing graph dest_vars: list of output vars in the computing graph
:Keyword Arguments: Keyword Arguments:
* enable_io16xc32 -- * enable_io16xc32 --
whether to use float16 for I/O between oprs and use whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be float32 as internal computation precision. Note the output var would be
changed to float16. changed to float16.
* enable_ioc16 -- * enable_ioc16 --
whether to use float16 for both I/O and computation whether to use float16 for both I/O and computation
precision. precision.
* enable_hwcd4 --
* enable_hwcd4 -- whether to use NHWCD4 data layout. This is faster on some
whether to use NHWCD4 data layout. This is faster on some OpenCL backend.
OpenCL backend. * enable_nchw88 --
* enable_nchw88 -- whether to use NCHW88 data layout, currently
whether to use NCHW88 data layout, currently used in X86 AVX backend.
used in X86 AVX backend. * enable_nchw44 --
* enable_nchw44 -- whether to use NCHW44 data layout, currently
whether to use NCHW44 data layout, currently used in arm backend.
used in arm backend. * enable_nchw44_dot --
* enable_nchw44_dot -- whether to use NCHW44_dot data layout, currently
whether to use NCHW44_dot data layout, currently used in armv8.2+dotprod backend.
used in armv8.2+dotprod backend. * enable_nchw4 --
* enable_nchw4 -- whether to use NCHW4 data layout, currently
whether to use NCHW4 data layout, currently used in nvidia backend(based on cudnn).
used in nvidia backend(based on cudnn). * enable_nchw32 --
* enable_nchw32 -- whether to use NCHW32 data layout, currently
whether to use NCHW32 data layout, currently used in nvidia backend with tensorcore(based on cudnn).
used in nvidia backend with tensorcore(based on cudnn). * enable_chwn4 --
* enable_chwn4 -- whether to use CHWN4 data layout, currently
whether to use CHWN4 data layout, currently used in nvidia backend with tensorcore.
used in nvidia backend with tensorcore. * enable_nchw64 --
* enable_nchw64 -- whether to use NCHW64 data layout, used for fast int4
whether to use NCHW64 data layout, used for fast int4 support on Nvidia GPU.
support on Nvidia GPU. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
into one opr. input for inference on nvidia backend(this optimization pass will
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z result in mismatch of the precision of output of training and
input for inference on nvidia backend(this optimization pass will inference)
result in mismatch of the precision of output of training and
inference)
* enable_fuse_preprocess: whether to fuse astype\pad channel\dimshuffle and
etc opr from h2d opr.
""" """
inference_options = GraphOptimizeOptions() inference_options = GraphOptimizeOptions()
inference_optimize_layout_transform_map = { inference_optimize_layout_transform_map = {
...@@ -305,11 +301,13 @@ def optimize_for_inference(dest_vars, **kwargs): ...@@ -305,11 +301,13 @@ def optimize_for_inference(dest_vars, **kwargs):
def deserialize_infer_option(x: int) -> Dict[str, bool]: def deserialize_infer_option(x: int) -> Dict[str, bool]:
r""" r"""Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``.
Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``.
:param x: inference options represented by int. Args:
:return: inference options represented by dict. x: inference options represented by int.
Returns:
inference options represented by dict.
""" """
inference_options = GraphOptimizeOptions.deserialize(x) inference_options = GraphOptimizeOptions.deserialize(x)
...@@ -346,13 +344,12 @@ def deserialize_infer_option(x: int) -> Dict[str, bool]: ...@@ -346,13 +344,12 @@ def deserialize_infer_option(x: int) -> Dict[str, bool]:
def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): def modify_opr_algo_strategy_inplace(dest_vars, strategy: str):
""" r"""C++ graph version of :func:`~.set_execution_strategy`. Used to inplacely modify
C++ graph version of :func:`~.set_execution_strategy`. Used to inplacely modify
dumped graph's fast-run strategy. dumped graph's fast-run strategy.
:param dest_vars: list of output vars in the computing graph. Args:
:param strategy: fast-run algorithms strategy. dest_vars: list of output vars in the computing graph.
strategy: fast-run algorithms strategy.
""" """
dest_vars = _unwrap(dest_vars) dest_vars = _unwrap(dest_vars)
_imperative_rt.modify_opr_algo_strategy_inplace(dest_vars, strategy) _imperative_rt.modify_opr_algo_strategy_inplace(dest_vars, strategy)
...@@ -383,39 +380,40 @@ def dump_graph( ...@@ -383,39 +380,40 @@ def dump_graph(
append_json=False, append_json=False,
metadata=None metadata=None
) -> Tuple[bytes, CompGraphDumpResult]: ) -> Tuple[bytes, CompGraphDumpResult]:
""" r"""serialize the computing graph of `output_vars` and get byte result.
serialize the computing graph of `output_vars` and get byte result.
Args:
:param output_vars: output variables which are the graph's end point. output_vars: output variables which are the graph's end point.
keep_var_name: level for keeping variable names:
.. note::
* 0: none of the names are kept
The underlying C++ API only accepts a var list. If a dict is given, * 1: (default)keep names of output vars
the vars would be renamed to the given names. * 2: keep names of all (output and internal) vars
:param keep_var_name: level for keeping variable names: keep_opr_name: whether to keep operator names.
keep_param_name: whether to keep param names, so param values can be
* 0: none of the names are kept easily manipulated after loading model
* 1: (default)keep names of output vars keep_opr_priority: whether to keep priority setting for operators
* 2: keep names of all (output and internal) vars strip_info_file: a string for path or a file handler. if is not None,
:param keep_opr_name: whether to keep operator names. then the dump information for code strip would be written to ``strip_info_file``
:param keep_param_name: whether to keep param names, so param values can be append_json: will be check when `strip_info_file` is not None. if set
easily manipulated after loading model true, the information for code strip will be append to strip_info_file.
:param keep_opr_priority: whether to keep priority setting for operators if set false, will rewrite strip_info_file
:param strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file`` Note:
:param append_json: will be check when `strip_info_file` is not None. if set The underlying C++ API only accepts a var list. If a dict is given,
true, the information for code strip will be append to strip_info_file. the vars would be renamed to the given names.
if set false, will rewrite strip_info_file
:return: dump result as byte string, and an instance of namedtuple Returns:
dump result as byte string, and an instance of namedtuple
:class:`CompGraphDumpResult`, whose fields are: :class:`CompGraphDumpResult`, whose fields are:
* ``nr_opr`` number of operators dumped * ``nr_opr`` number of operators dumped
* ``tot_bytes`` total bytes for the whole graph * ``tot_bytes`` total bytes for the whole graph
* ``tensor_value_bytes`` bytes consumed for dumping tensor values * ``tensor_value_bytes`` bytes consumed for dumping tensor values
* ``inputs`` names of input tensors * ``inputs`` names of input tensors
* ``params`` list of names of dumped params * ``params`` list of names of dumped params
* ``outputs`` names of output vars * ``outputs`` names of output vars
""" """
if isinstance(output_vars, dict): if isinstance(output_vars, dict):
used_vars = set() used_vars = set()
...@@ -483,17 +481,19 @@ CompGraphLoadResult = collections.namedtuple( ...@@ -483,17 +481,19 @@ CompGraphLoadResult = collections.namedtuple(
def load_graph(fpath) -> CompGraphLoadResult: def load_graph(fpath) -> CompGraphLoadResult:
""" r"""Load a serialized computing graph from file.
Load a serialized computing graph from file.
Args:
fpath: Path or Handle of the input file
:param fpath: Path or Handle of the input file Returns:
:return: An instance of namedtuple :class:`CompGraphLoadResult`, An instance of namedtuple :class:`CompGraphLoadResult`,
whose fields are: whose fields are:
* ``graph`` loaded CompGraph * ``graph`` loaded CompGraph
* ``output_vars_dict`` A Python dict, mapping name to output SymbolVar * ``output_vars_dict`` A Python dict, mapping name to output SymbolVar
* ``output_vars_list`` A Python list, containing output vars in the * ``output_vars_list`` A Python list, containing output vars in the
order passed to serialize_comp_graph_to_file order passed to serialize_comp_graph_to_file
""" """
output_vars_map = [] output_vars_map = []
output_vars_list = [] output_vars_list = []
......
...@@ -24,12 +24,12 @@ _enable_convert_inputs = True ...@@ -24,12 +24,12 @@ _enable_convert_inputs = True
def get_convert_inputs(): def get_convert_inputs():
""" get the curerent state of `_enable_convert_inputs` """ r"""get the curerent state of `_enable_convert_inputs`"""
return _enable_convert_inputs return _enable_convert_inputs
def set_convert_inputs(flag): def set_convert_inputs(flag):
""" This function is a temporary workaround for reducing the overhead of operator r"""This function is a temporary workaround for reducing the overhead of operator
invocations. The function `convert_inputs` is disabled if the global state invocations. The function `convert_inputs` is disabled if the global state
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for `_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for
internal use only, and should be removed when the tensor-like system is refactored. internal use only, and should be removed when the tensor-like system is refactored.
...@@ -137,11 +137,11 @@ def setscalar(x): ...@@ -137,11 +137,11 @@ def setscalar(x):
def astensor1d(x, *reference, dtype=None, device=None): def astensor1d(x, *reference, dtype=None, device=None):
""" """Convert something to 1D tensor. Support following types
Convert something to 1D tensor. Support following types
* sequence of scalar literal / tensor * sequence of scalar literal / tensor
* numpy array * numpy array
* tensor (returned as is, regardless of dtype and device) * tensor (returned as is, regardless of dtype and device)
""" """
try: try:
ndim = x.ndim ndim = x.ndim
......
...@@ -33,16 +33,11 @@ default_collate_err_msg_format = ( ...@@ -33,16 +33,11 @@ default_collate_err_msg_format = (
class Collator: class Collator:
r""" r"""Used for merging a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a dataset.
Used for merging a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a dataset.
Modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py Modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
""" """
def apply(self, inputs): def apply(self, inputs):
"""
:param inputs: sequence_N(tuple(CHW, C, CK)).
:return: tuple(NCHW, NC, NCK).
"""
elem = inputs[0] elem = inputs[0]
elem_type = type(elem) elem_type = type(elem)
if ( if (
......
...@@ -44,28 +44,28 @@ def raise_timeout_error(): ...@@ -44,28 +44,28 @@ def raise_timeout_error():
class DataLoader: class DataLoader:
r"""Provides a convenient way to iterate on a given dataset. r"""Provides a convenient way to iterate on a given dataset.
DataLoader combines a dataset with DataLoader combines a dataset with
:class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
make it flexible to get minibatch continually from a dataset. make it flexible to get minibatch continually from a dataset.
:param dataset: dataset from which to load the minibatch. Args:
:param sampler: defines the strategy to sample data from the dataset. dataset: dataset from which to load the minibatch.
:param transform: defined the transforming strategy for a sampled batch. sampler: defines the strategy to sample data from the dataset.
Default: None transform: defined the transforming strategy for a sampled batch.
:param collator: defined the merging strategy for a transformed batch. Default: None
Default: None collator: defined the merging strategy for a transformed batch.
:param num_workers: the number of sub-process to load, transform and collate Default: None
the batch. ``0`` means using single-process. Default: 0 num_workers: the number of sub-process to load, transform and collate
:param timeout: if positive, means the timeout value(second) for collecting a the batch. ``0`` means using single-process. Default: 0
batch from workers. Default: 0 timeout: if positive, means the timeout value(second) for collecting a
:param timeout_event: callback function triggered by timeout, default to raise batch from workers. Default: 0
runtime error. timeout_event: callback function triggered by timeout, default to raise
:param divide: define the paralleling strategy in multi-processing mode. runtime error.
``True`` means one batch is divided into :attr:`num_workers` pieces, and divide: define the paralleling strategy in multi-processing mode.
the workers will process these pieces parallelly. ``False`` means ``True`` means one batch is divided into :attr:`num_workers` pieces, and
different sub-process will process different batch. Default: False the workers will process these pieces parallelly. ``False`` means
different sub-process will process different batch. Default: False
""" """
__initialized = False __initialized = False
......
...@@ -11,8 +11,7 @@ from typing import Tuple ...@@ -11,8 +11,7 @@ from typing import Tuple
class Dataset(ABC): class Dataset(ABC):
r""" r"""An abstract base class for all datasets.
An abstract base class for all datasets.
__getitem__ and __len__ method are aditionally needed. __getitem__ and __len__ method are aditionally needed.
""" """
...@@ -31,8 +30,7 @@ class Dataset(ABC): ...@@ -31,8 +30,7 @@ class Dataset(ABC):
class StreamDataset(Dataset): class StreamDataset(Dataset):
r""" r"""An abstract class for stream data.
An abstract class for stream data.
__iter__ method is aditionally needed. __iter__ method is aditionally needed.
""" """
...@@ -53,10 +51,9 @@ class StreamDataset(Dataset): ...@@ -53,10 +51,9 @@ class StreamDataset(Dataset):
class ArrayDataset(Dataset): class ArrayDataset(Dataset):
r""" r"""ArrayDataset is a dataset for numpy array data.
ArrayDataset is a dataset for numpy array data.
One or more numpy arrays are needed to initiate the dataset. One or more numpy arrays are needed to initiate the dataset.
And the dimensions represented sample number are expected to be the same. And the dimensions represented sample number are expected to be the same.
""" """
......
...@@ -21,8 +21,7 @@ logger = get_logger(__name__) ...@@ -21,8 +21,7 @@ logger = get_logger(__name__)
class CIFAR10(VisionDataset): class CIFAR10(VisionDataset):
r""" :class:`~.Dataset` for CIFAR10 meta data. r""":class:`~.Dataset` for CIFAR10 meta data."""
"""
url_path = "http://www.cs.utoronto.ca/~kriz/" url_path = "http://www.cs.utoronto.ca/~kriz/"
raw_file_name = "cifar-10-python.tar.gz" raw_file_name = "cifar-10-python.tar.gz"
...@@ -138,8 +137,7 @@ class CIFAR10(VisionDataset): ...@@ -138,8 +137,7 @@ class CIFAR10(VisionDataset):
class CIFAR100(CIFAR10): class CIFAR100(CIFAR10):
r""" :class:`~.Dataset` for CIFAR100 meta data. r""":class:`~.Dataset` for CIFAR100 meta data."""
"""
url_path = "http://www.cs.utoronto.ca/~kriz/" url_path = "http://www.cs.utoronto.ca/~kriz/"
raw_file_name = "cifar-100-python.tar.gz" raw_file_name = "cifar-100-python.tar.gz"
......
...@@ -23,9 +23,7 @@ from .meta_vision import VisionDataset ...@@ -23,9 +23,7 @@ from .meta_vision import VisionDataset
class Cityscapes(VisionDataset): class Cityscapes(VisionDataset):
r""" r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset."""
`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
"""
supported_order = ( supported_order = (
"image", "image",
......
...@@ -46,9 +46,7 @@ def has_valid_annotation(anno, order): ...@@ -46,9 +46,7 @@ def has_valid_annotation(anno, order):
class COCO(VisionDataset): class COCO(VisionDataset):
r""" r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset."""
`MS COCO <http://cocodataset.org/#home>`_ Dataset.
"""
supported_order = ( supported_order = (
"image", "image",
......
...@@ -26,22 +26,21 @@ from .utils import is_img ...@@ -26,22 +26,21 @@ from .utils import is_img
class ImageFolder(VisionDataset): class ImageFolder(VisionDataset):
r""" r"""ImageFolder is a class for loading image data and labels from a organized folder.
ImageFolder is a class for loading image data and labels from a organized folder.
The folder is expected to be organized as followed: root/cls/xxx.img_ext The folder is expected to be organized as followed: root/cls/xxx.img_ext
Labels are indices of sorted classes in the root directory. Labels are indices of sorted classes in the root directory.
:param root: root directory of an image folder. Args:
:param loader: a function used to load image from path, root: root directory of an image folder.
if ``None``, default function that loads loader: a function used to load image from path,
images with PIL will be called. if ``None``, default function that loads
:param check_valid_func: a function used to check if files in folder are images with PIL will be called.
expected image files, if ``None``, default function check_valid_func: a function used to check if files in folder are
that checks file extensions will be called. expected image files, if ``None``, default function
:param class_name: if ``True``, return class name instead of class index. that checks file extensions will be called.
class_name: if ``True``, return class name instead of class index.
""" """
def __init__(self, root: str, check_valid_func=None, class_name: bool = False): def __init__(self, root: str, check_valid_func=None, class_name: bool = False):
......
...@@ -30,11 +30,10 @@ logger = get_logger(__name__) ...@@ -30,11 +30,10 @@ logger = get_logger(__name__)
class ImageNet(ImageFolder): class ImageNet(ImageFolder):
r""" r"""Load ImageNet from raw files or folder. Expected folder looks like:
Load ImageNet from raw files or folder. Expected folder looks like:
.. code-block:: shell
.. code-block:: bash
${root}/ ${root}/
| [REQUIRED TAR FILES] | [REQUIRED TAR FILES]
|- ILSVRC2012_img_train.tar |- ILSVRC2012_img_train.tar
...@@ -45,22 +44,8 @@ class ImageNet(ImageFolder): ...@@ -45,22 +44,8 @@ class ImageNet(ImageFolder):
|- val/cls/xxx.${img_ext} |- val/cls/xxx.${img_ext}
|- ILSVRC2012_devkit_t12/data/meta.mat |- ILSVRC2012_devkit_t12/data/meta.mat
|- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
If the image folders don't exist, raw tar files are required to get extracted and processed. If the image folders don't exist, raw tar files are required to get extracted and processed.
"""
raw_file_meta = {
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
} # ImageNet raw files
default_train_dir = "train"
default_val_dir = "val"
default_devkit_dir = "ILSVRC2012_devkit_t12"
def __init__(self, root: str = None, train: bool = True, **kwargs):
r"""
Initialization:
* if ``root`` contains ``self.target_folder`` depending on ``train``: * if ``root`` contains ``self.target_folder`` depending on ``train``:
...@@ -77,10 +62,22 @@ class ImageNet(ImageFolder): ...@@ -77,10 +62,22 @@ class ImageNet(ImageFolder):
* raise error. * raise error.
:param root: root directory of imagenet data, if root is ``None``, use default_dataset_root. Args:
:param train: if ``True``, load the train split, otherwise load the validation split. root: root directory of imagenet data, if root is ``None``, use default_dataset_root.
""" train: if ``True``, load the train split, otherwise load the validation split.
"""
raw_file_meta = {
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
} # ImageNet raw files
default_train_dir = "train"
default_val_dir = "val"
default_devkit_dir = "ILSVRC2012_devkit_t12"
def __init__(self, root: str = None, train: bool = True, **kwargs):
# process the root path # process the root path
if root is None: if root is None:
self.root = self._default_root self.root = self._default_root
......
...@@ -22,8 +22,7 @@ logger = get_logger(__name__) ...@@ -22,8 +22,7 @@ logger = get_logger(__name__)
class MNIST(VisionDataset): class MNIST(VisionDataset):
r""" :class:`~.Dataset` for MNIST meta data. r""":class:`~.Dataset` for MNIST meta data."""
"""
url_path = "http://yann.lecun.com/exdb/mnist/" url_path = "http://yann.lecun.com/exdb/mnist/"
""" """
......
...@@ -23,9 +23,7 @@ from .meta_vision import VisionDataset ...@@ -23,9 +23,7 @@ from .meta_vision import VisionDataset
class Objects365(VisionDataset): class Objects365(VisionDataset):
r""" r"""`Objects365 <https://www.objects365.org/overview.html>`_ Dataset."""
`Objects365 <https://www.objects365.org/overview.html>`_ Dataset.
"""
supported_order = ( supported_order = (
"image", "image",
......
...@@ -24,9 +24,7 @@ from .meta_vision import VisionDataset ...@@ -24,9 +24,7 @@ from .meta_vision import VisionDataset
class PascalVOC(VisionDataset): class PascalVOC(VisionDataset):
r""" r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset."""
`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
"""
supported_order = ( supported_order = (
"image", "image",
......
...@@ -17,9 +17,7 @@ import megengine.distributed as dist ...@@ -17,9 +17,7 @@ import megengine.distributed as dist
class Sampler(ABC): class Sampler(ABC):
r""" r"""An abstract base class for all Sampler"""
An abstract base class for all Sampler
"""
@abstractmethod @abstractmethod
def __init__(self): def __init__(self):
...@@ -27,19 +25,19 @@ class Sampler(ABC): ...@@ -27,19 +25,19 @@ class Sampler(ABC):
class MapSampler(Sampler): class MapSampler(Sampler):
r""" r"""Sampler for map dataset.
Sampler for map dataset.
Args:
:param dataset: dataset to sample from. dataset: dataset to sample from.
:param batch_size: batch size for batch method. batch_size: batch size for batch method.
:param drop_last: set ``True`` to drop the last incomplete batch, drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False be smaller. Default: False
:param num_samples: number of samples assigned to one rank. num_samples: number of samples assigned to one rank.
:param world_size: number of ranks. world_size: number of ranks.
:param rank: rank id, non-negative interger within 0 and ``world_size``. rank: rank id, non-negative interger within 0 and ``world_size``.
:param seed: seed for random operators. seed: seed for random operators.
""" """
def __init__( def __init__(
...@@ -106,14 +104,11 @@ class MapSampler(Sampler): ...@@ -106,14 +104,11 @@ class MapSampler(Sampler):
return int(math.ceil(self.num_samples / self.batch_size)) return int(math.ceil(self.num_samples / self.batch_size))
def sample(self): def sample(self):
""" r"""Return a list contains all sample indices."""
Return a list contains all sample indices.
"""
raise NotImplementedError raise NotImplementedError
def scatter(self, indices) -> List: def scatter(self, indices) -> List:
r""" r"""Scatter method is used for splitting indices into subset, each subset
Scatter method is used for splitting indices into subset, each subset
will be assigned to a rank. Indices are evenly splitted by default. will be assigned to a rank. Indices are evenly splitted by default.
If customized indices assignment method is needed, please rewrite this method. If customized indices assignment method is needed, please rewrite this method.
""" """
...@@ -130,9 +125,7 @@ class MapSampler(Sampler): ...@@ -130,9 +125,7 @@ class MapSampler(Sampler):
return indices return indices
def batch(self) -> Iterator[List[Any]]: def batch(self) -> Iterator[List[Any]]:
r""" r"""Batch method provides a batch indices generator."""
Batch method provides a batch indices generator.
"""
indices = list(self.sample()) indices = list(self.sample())
# user might pass the world_size parameter without dist, # user might pass the world_size parameter without dist,
...@@ -150,18 +143,15 @@ class MapSampler(Sampler): ...@@ -150,18 +143,15 @@ class MapSampler(Sampler):
class StreamSampler(Sampler): class StreamSampler(Sampler):
r""" r"""Sampler for stream dataset.
Sampler for stream dataset.
.. warning::
Warning:
In the case of multiple machines, sampler should ensure that each worker gets In the case of multiple machines, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal. dataset and sampler to achieve this goal.
Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by
``rank = dist.get_rank()``. So that they will get different data. ``rank = dist.get_rank()``. So that they will get different data.
""" """
def __init__(self, batch_size=1): def __init__(self, batch_size=1):
...@@ -175,18 +165,18 @@ class StreamSampler(Sampler): ...@@ -175,18 +165,18 @@ class StreamSampler(Sampler):
class SequentialSampler(MapSampler): class SequentialSampler(MapSampler):
r""" r"""Sample elements sequentially.
Sample elements sequentially.
Args:
:param dataset: dataset to sample from. dataset: dataset to sample from.
:param batch_size: batch size for batch method. batch_size: batch size for batch method.
:param drop_last: set ``True`` to drop the last incomplete batch, drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False be smaller. Default: False
:param indices: indice of samples. indices: indice of samples.
:param world_size: number of ranks. world_size: number of ranks.
:param rank: rank id, non-negative interger within 0 and ``world_size``. rank: rank id, non-negative interger within 0 and ``world_size``.
""" """
def __init__( def __init__(
...@@ -207,9 +197,7 @@ class SequentialSampler(MapSampler): ...@@ -207,9 +197,7 @@ class SequentialSampler(MapSampler):
self.indices = indices self.indices = indices
def sample(self) -> Iterator[Any]: def sample(self) -> Iterator[Any]:
r""" r"""Return a generator."""
Return a generator.
"""
if self.indices is None: if self.indices is None:
return iter(range(len(self.dataset))) return iter(range(len(self.dataset)))
else: else:
...@@ -217,19 +205,19 @@ class SequentialSampler(MapSampler): ...@@ -217,19 +205,19 @@ class SequentialSampler(MapSampler):
class RandomSampler(MapSampler): class RandomSampler(MapSampler):
r""" r"""Sample elements randomly without replacement.
Sample elements randomly without replacement.
Args:
:param dataset: dataset to sample from. dataset: dataset to sample from.
:param batch_size: batch size for batch method. batch_size: batch size for batch method.
:param drop_last: set ``True`` to drop the last incomplete batch, drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False be smaller. Default: False
:param indices: indice of samples. indices: indice of samples.
:param world_size: number of ranks. world_size: number of ranks.
:param rank: rank id, non-negative interger within 0 and ``world_size``. rank: rank id, non-negative interger within 0 and ``world_size``.
:param seed: seed for random operators. seed: seed for random operators.
""" """
def __init__( def __init__(
...@@ -258,20 +246,20 @@ class RandomSampler(MapSampler): ...@@ -258,20 +246,20 @@ class RandomSampler(MapSampler):
class ReplacementSampler(MapSampler): class ReplacementSampler(MapSampler):
r""" r"""Sample elements randomly with replacement.
Sample elements randomly with replacement.
Args:
:param dataset: dataset to sample from. dataset: dataset to sample from.
:param batch_size: batch size for batch method. batch_size: batch size for batch method.
:param drop_last: set ``True`` to drop the last incomplete batch, drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False be smaller. Default: False
:param num_samples: number of samples assigned to one rank. num_samples: number of samples assigned to one rank.
:param weights: weights for sampling indices, it could be unnormalized weights. weights: weights for sampling indices, it could be unnormalized weights.
:param world_size: number of ranks. world_size: number of ranks.
:param rank: rank id, non-negative interger within 0 and ``world_size``. rank: rank id, non-negative interger within 0 and ``world_size``.
:param seed: seed for random operators. seed: seed for random operators.
""" """
def __init__( def __init__(
......
...@@ -59,15 +59,13 @@ class _PlasmaStoreManager: ...@@ -59,15 +59,13 @@ class _PlasmaStoreManager:
class PlasmaShmQueue: class PlasmaShmQueue:
def __init__(self, maxsize: int = 0): def __init__(self, maxsize: int = 0):
r""" r"""Use pyarrow in-memory plasma store to implement shared memory queue.
Use pyarrow in-memory plasma store to implement shared memory queue.
Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle
and communication overhead, leading to better performance in multi-process and communication overhead, leading to better performance in multi-process
application. application.
:type maxsize: int Args:
:param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) maxsize: maximum size of the queue, `None` means no limit. (default: ``None``)
""" """
# Lazy start the plasma store manager # Lazy start the plasma store manager
......
...@@ -11,9 +11,7 @@ from typing import Sequence, Tuple ...@@ -11,9 +11,7 @@ from typing import Sequence, Tuple
class Transform(ABC): class Transform(ABC):
""" r"""Rewrite apply method in subclass."""
Rewrite apply method in subclass.
"""
def apply_batch(self, inputs: Sequence[Tuple]): def apply_batch(self, inputs: Sequence[Tuple]):
return tuple(self.apply(input) for input in inputs) return tuple(self.apply(input) for input in inputs)
......
...@@ -15,7 +15,7 @@ import numpy as np ...@@ -15,7 +15,7 @@ import numpy as np
def wrap_keepdims(func): def wrap_keepdims(func):
"""Wraper to keep the dimension of input images unchanged.""" r"""Wraper to keep the dimension of input images unchanged."""
@functools.wraps(func) @functools.wraps(func)
def wrapper(image, *args, **kwargs): def wrapper(image, *args, **kwargs):
...@@ -33,41 +33,47 @@ def wrap_keepdims(func): ...@@ -33,41 +33,47 @@ def wrap_keepdims(func):
@wrap_keepdims @wrap_keepdims
def to_gray(image): def to_gray(image):
r""" r"""Change BGR format image's color space to gray.
Change BGR format image's color space to gray.
:param image: input BGR format image, with `(H, W, C)` shape. Args:
:return: gray format image, with `(H, W, C)` shape. image: input BGR format image, with `(H, W, C)` shape.
Returns:
gray format image, with `(H, W, C)` shape.
""" """
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
@wrap_keepdims @wrap_keepdims
def to_bgr(image): def to_bgr(image):
r""" r"""Change gray format image's color space to BGR.
Change gray format image's color space to BGR.
Args:
image: input Gray format image, with `(H, W, C)` shape.
:param image: input Gray format image, with `(H, W, C)` shape. Returns:
:return: BGR format image, with `(H, W, C)` shape. BGR format image, with `(H, W, C)` shape.
""" """
return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
@wrap_keepdims @wrap_keepdims
def pad(input, size, value): def pad(input, size, value):
r""" r"""Pad input data with *value* and given *size*.
Pad input data with *value* and given *size*.
Args:
:param input: input data, with `(H, W, C)` shape. input: input data, with `(H, W, C)` shape.
:param size: padding size of input data, it could be integer or sequence. size: padding size of input data, it could be integer or sequence.
If it is an integer, the input data will be padded in four directions. If it is an integer, the input data will be padded in four directions.
If it is a sequence contains two integer, the bottom and right side If it is a sequence contains two integer, the bottom and right side
of input data will be padded. of input data will be padded.
If it is a sequence contains four integer, the top, bottom, left, right If it is a sequence contains four integer, the top, bottom, left, right
side of input data will be padded with given size. side of input data will be padded with given size.
:param value: padding value of data, could be a sequence of int or float. value: padding value of data, could be a sequence of int or float.
If it is float value, the dtype of image will be casted to float32 also. If it is float value, the dtype of image will be casted to float32 also.
:return: padded image.
Returns:
padded image.
""" """
if isinstance(size, int): if isinstance(size, int):
size = (size, size, size, size) size = (size, size, size, size)
...@@ -80,32 +86,33 @@ def pad(input, size, value): ...@@ -80,32 +86,33 @@ def pad(input, size, value):
@wrap_keepdims @wrap_keepdims
def flip(image, flipCode): def flip(image, flipCode):
r""" r"""Accordding to the flipCode (the type of flip), flip the input image.
Accordding to the flipCode (the type of flip), flip the input image.
:param image: input image, with `(H, W, C)` shape.
:param flipCode: code that indicates the type of flip.
* 1 : Flip horizontally Args:
image: input image, with `(H, W, C)` shape.
flipCode: code that indicates the type of flip.
* 0 : Flip vertically * 1 : Flip horizontally
* 0 : Flip vertically
* -1: Flip horizontally and vertically
* -1: Flip horizontally and vertically Returns:
BGR format image, with `(H, W, C)` shape.
:return: BGR format image, with `(H, W, C)` shape.
""" """
return cv2.flip(image, flipCode=flipCode) return cv2.flip(image, flipCode=flipCode)
@wrap_keepdims @wrap_keepdims
def resize(input, size, interpolation=cv2.INTER_LINEAR): def resize(input, size, interpolation=cv2.INTER_LINEAR):
r""" r"""Resize the input data to given size.
Resize the input data to given size.
Args:
input: input data, could be image or masks, with `(H, W, C)` shape.
size: target size of input data, with (height, width) shape.
interpolation: interpolation method.
:param input: input data, could be image or masks, with `(H, W, C)` shape. Returns:
:param size: target size of input data, with (height, width) shape. resized data, with `(H, W, C)` shape.
:param interpolation: interpolation method.
:return: resized data, with `(H, W, C)` shape.
""" """
if len(size) != 2: if len(size) != 2:
raise ValueError("resize needs (h, w), but got {}".format(size)) raise ValueError("resize needs (h, w), but got {}".format(size))
......
...@@ -54,10 +54,10 @@ _device_type_set = {"cpu", "gpu", "xpu", "rocm"} ...@@ -54,10 +54,10 @@ _device_type_set = {"cpu", "gpu", "xpu", "rocm"}
def get_device_count(device_type: str) -> int: def get_device_count(device_type: str) -> int:
""" r"""Gets number of devices installed on this system.
Gets number of devices installed on this system.
:param device_type: device type, one of 'gpu' or 'cpu' Args:
device_type: device type, one of 'gpu' or 'cpu'
""" """
assert device_type in _device_type_set, "device must be one of {}".format( assert device_type in _device_type_set, "device must be one of {}".format(
_device_type_set _device_type_set
...@@ -67,73 +67,59 @@ def get_device_count(device_type: str) -> int: ...@@ -67,73 +67,59 @@ def get_device_count(device_type: str) -> int:
def is_cuda_available() -> bool: def is_cuda_available() -> bool:
""" r"""Returns whether cuda device is available on this system."""
Returns whether cuda device is available on this system.
"""
t = _str2device_type("gpu") t = _str2device_type("gpu")
return CompNode._get_device_count(t, False) > 0 return CompNode._get_device_count(t, False) > 0
def is_cambricon_available() -> bool: def is_cambricon_available() -> bool:
""" r"""Returns whether cambricon device is available on this system."""
Returns whether cambricon device is available on this system.
"""
t = _str2device_type("cambricon") t = _str2device_type("cambricon")
return CompNode._get_device_count(t, False) > 0 return CompNode._get_device_count(t, False) > 0
def is_atlas_available() -> bool: def is_atlas_available() -> bool:
""" r"""Returns whether atlas device is available on this system."""
Returns whether atlas device is available on this system.
"""
t = _str2device_type("atlas") t = _str2device_type("atlas")
return CompNode._get_device_count(t, False) > 0 return CompNode._get_device_count(t, False) > 0
def is_rocm_available() -> bool: def is_rocm_available() -> bool:
"""Returns whether rocm device is available on this system. r"""Returns whether rocm device is available on this system."""
"""
t = _str2device_type("rocm") t = _str2device_type("rocm")
return CompNode._get_device_count(t, False) > 0 return CompNode._get_device_count(t, False) > 0
def set_default_device(device: str = "xpux"): def set_default_device(device: str = "xpux"):
r""" r"""Sets default computing node.
Sets default computing node.
Args:
:param device: default device type. The type can be 'cpu0', 'cpu1', etc., device: default device type.
or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use.
'cpux' and 'gpux' can also be used to specify any number of cpu or gpu devices. Note:
* The type can be 'cpu0', 'cpu1', etc., or 'gpu0', 'gpu1', etc.,
'multithread' device type is avaliable when inference, which implements to specify the particular CPU or GPU to use.
multi-threading parallelism at the operator level. For example, * 'cpux' and 'gpux' can also be used to specify any number of CPU or GPU devices.
'multithread4' will compute with 4 threads. * The default value is 'xpux' to specify any device available.
* The priority of using GPU is higher when both GPU and CPU are available.
The default value is 'xpux' to specify any device available. The priority of using gpu is higher when both gpu and cpu are available. * 'multithread' device type is avaliable when inference,
which implements multi-threading parallelism at the operator level.
It can also be set by environment variable `MGE_DEFAULT_DEVICE`. For example, 'multithread4' will compute with 4 threads.
* It can also be set by environment variable ``MGE_DEFAULT_DEVICE``.
""" """
assert _valid_device(device), "Invalid device name {}".format(device) assert _valid_device(device), "Invalid device name {}".format(device)
CompNode._set_default_device(device) CompNode._set_default_device(device)
def get_default_device() -> str: def get_default_device() -> str:
r""" r"""Gets default computing node.
Gets default computing node.
It returns the value set by :func:`~.set_default_device`. It returns the value set by :func:`~.set_default_device`.
""" """
return CompNode._get_default_device() return CompNode._get_default_device()
def get_mem_status_bytes(device: Optional[str] = None): def get_mem_status_bytes(device: Optional[str] = None):
r""" r"""Get total and free memory on the computing device in bytes."""
Get total and free memory on the computing device in bytes.
"""
if device is None: if device is None:
device = get_default_device() device = get_default_device()
tot, free = CompNode(device).get_mem_status_bytes tot, free = CompNode(device).get_mem_status_bytes
...@@ -150,15 +136,17 @@ def set_prealloc_config( ...@@ -150,15 +136,17 @@ def set_prealloc_config(
growth_factor=2.0, growth_factor=2.0,
device_type=DeviceType.CUDA, device_type=DeviceType.CUDA,
): ):
""" r"""Specifies how to pre-allocate from raw device allocator.
Specifies how to pre-allocate from raw device allocator.
Args:
:param alignment: specifies the alignment in bytes. alignment: specifies the alignment in bytes.
:param min_req: min request size in bytes. min_req: min request size in bytes.
:param max_overhead: max overhead above required size in bytes. max_overhead: max overhead above required size in bytes.
:param growth_factor: `request size / cur allocated` growth_factor: request size / cur allocated`
:param device_type: the device type device_type: the device type
alignment: int:
min_req: int:
max_overhead: int:
""" """
assert alignment > 0 assert alignment > 0
assert min_req > 0 assert min_req > 0
......
...@@ -31,17 +31,15 @@ from .server import Client, Server ...@@ -31,17 +31,15 @@ from .server import Client, Server
@mproperty @mproperty
def backend(mod): def backend(mod):
r""" r"""Get or set backend of collective communication.
Get or set backend of collective communication.
Available backends are ['nccl', 'shm', 'rccl'] Available backends are ['nccl', 'shm', 'rccl']
Examples: Examples:
.. code-block:: .. code-block::
import megengine.distributed as dist
dist.backend = "nccl"
import megengine.distributed as dist
dist.backend = "nccl"
""" """
assert group._sd, "please call init_process_group first" assert group._sd, "please call init_process_group first"
return group._sd.backend return group._sd.backend
......
...@@ -50,7 +50,7 @@ def _backend(): ...@@ -50,7 +50,7 @@ def _backend():
def collective_comm(inp, mode, group, device): def collective_comm(inp, mode, group, device):
"""Helper function for applying collective communication functions.""" r"""Helper function for applying collective communication functions."""
assert isinstance(group, Group) assert isinstance(group, Group)
if group is None: if group is None:
return inp return inp
...@@ -158,8 +158,7 @@ class _ReduceSum(Function): ...@@ -158,8 +158,7 @@ class _ReduceSum(Function):
def reduce_sum( def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
r""" r"""Reduce tensor data across the specified group by sum.
Reduce tensor data across the specified group by sum.
Only root process will receive the final result. Only root process will receive the final result.
Args: Args:
...@@ -176,22 +175,20 @@ def reduce_sum( ...@@ -176,22 +175,20 @@ def reduce_sum(
Reduced tensor if in root process, None in other processes. Reduced tensor if in root process, None in other processes.
Examples: Examples:
.. code-block::
.. code-block::
input = Tensor([rank])
input = Tensor([rank]) # Rank 0 # input: Tensor([0])
# Rank 0 # input: Tensor([0]) # Rank 1 # input: Tensor([1])
# Rank 1 # input: Tensor([1]) output = reduce_sum(input)
output = reduce_sum(input) # Rank 0 # output: Tensor([1])
# Rank 0 # output: Tensor([1]) # Rank 1 # output: None
# Rank 1 # output: None
input = Tensor([rank])
input = Tensor([rank]) group = Group([1, 0]) # first rank is root
group = Group([1, 0]) # first rank is root output = reduce_sum(input, group)
output = reduce_sum(input, group) # Rank 0 # output: None
# Rank 0 # output: None # Rank 1 # output: Tensor([1])
# Rank 1 # output: Tensor([1])
""" """
op = _ReduceSum(group, device) op = _ReduceSum(group, device)
(out,) = apply(op, inp) (out,) = apply(op, inp)
...@@ -222,8 +219,7 @@ class _Broadcast(Function): ...@@ -222,8 +219,7 @@ class _Broadcast(Function):
def broadcast( def broadcast(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
r""" r"""Broadcast tensor data from root process to others.
Broadcast tensor data from root process to others.
Args: Args:
inp: Input tensor. inp: Input tensor.
...@@ -240,21 +236,20 @@ def broadcast( ...@@ -240,21 +236,20 @@ def broadcast(
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor([rank])
# Rank 0 # input: Tensor([0])
# Rank 1 # input: Tensor([1])
output = broadcast(input)
# Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([0])
input = Tensor([rank]) input = Tensor([rank])
group = Group([1, 0]) # first rank is root # Rank 0 # input: Tensor([0])
output = broadcast(input, group) # Rank 1 # input: Tensor([1])
# Rank 0 # output: Tensor([1]) output = broadcast(input)
# Rank 1 # output: Tensor([1]) # Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([0])
input = Tensor([rank])
group = Group([1, 0]) # first rank is root
output = broadcast(input, group)
# Rank 0 # output: Tensor([1])
# Rank 1 # output: Tensor([1])
""" """
shape, dtype = _bcast_shape_dtype(group, inp) shape, dtype = _bcast_shape_dtype(group, inp)
if group.rank != 0: if group.rank != 0:
...@@ -278,8 +273,7 @@ def _bcast_param( ...@@ -278,8 +273,7 @@ def _bcast_param(
def all_gather( def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
) -> Tensor: ) -> Tensor:
r""" r"""Gather tensors across the specified group and concat them at first dimension.
Gather tensors across the specified group and concat them at first dimension.
Args: Args:
inp: Input tensor. inp: Input tensor.
...@@ -298,21 +292,20 @@ def all_gather( ...@@ -298,21 +292,20 @@ def all_gather(
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor([rank])
# Rank 0 # input: Tensor([0])
# Rank 1 # input: Tensor([1])
output = all_gather(input)
# Rank 0 # output: Tensor([0 1])
# Rank 1 # output: Tensor([0 1])
input = Tensor([rank]) input = Tensor([rank])
group = Group([1, 0]) # Rank 0 # input: Tensor([0])
output = all_gather(input, group) # Rank 1 # input: Tensor([1])
# Rank 0 # output: Tensor([1 0]) output = all_gather(input)
# Rank 1 # output: Tensor([1 0]) # Rank 0 # output: Tensor([0 1])
# Rank 1 # output: Tensor([0 1])
input = Tensor([rank])
group = Group([1, 0])
output = all_gather(input, group)
# Rank 0 # output: Tensor([1 0])
# Rank 1 # output: Tensor([1 0])
""" """
mode = CollectiveComm.Mode.ALL_GATHER mode = CollectiveComm.Mode.ALL_GATHER
out = collective_comm(inp, mode, group, device) out = collective_comm(inp, mode, group, device)
...@@ -338,8 +331,7 @@ def all_gather( ...@@ -338,8 +331,7 @@ def all_gather(
def reduce_scatter_sum( def reduce_scatter_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0 inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0
) -> Tensor: ) -> Tensor:
r""" r"""Reduce tensors across the specified group by sum and split them at first dimension.
Reduce tensors across the specified group by sum and split them at first dimension.
Args: Args:
inp: Input tensor. inp: Input tensor.
...@@ -358,21 +350,20 @@ def reduce_scatter_sum( ...@@ -358,21 +350,20 @@ def reduce_scatter_sum(
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor([0 1])
# Rank 0 # input: Tensor([0 1])
# Rank 1 # input: Tensor([0 1])
output = reduce_scatter_sum(input)
# Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([2])
input = Tensor([0 1]) input = Tensor([0 1])
group = Group([1, 0]) # Rank 0 # input: Tensor([0 1])
output = reduce_scatter_sum(input, group) # Rank 1 # input: Tensor([0 1])
# Rank 0 # output: Tensor([2]) output = reduce_scatter_sum(input)
# Rank 1 # output: Tensor([0]) # Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([2])
input = Tensor([0 1])
group = Group([1, 0])
output = reduce_scatter_sum(input, group)
# Rank 0 # output: Tensor([2])
# Rank 1 # output: Tensor([0])
""" """
group_size = group.size if group is not None else 1 group_size = group.size if group is not None else 1
assert ( assert (
...@@ -398,8 +389,7 @@ def reduce_scatter_sum( ...@@ -398,8 +389,7 @@ def reduce_scatter_sum(
def all_reduce_sum( def all_reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
r""" r"""Reduce tensors across the specified group by sum.
Reduce tensors across the specified group by sum.
Args: Args:
inp: Input tensor. inp: Input tensor.
...@@ -416,15 +406,14 @@ def all_reduce_sum( ...@@ -416,15 +406,14 @@ def all_reduce_sum(
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor(rank)
# Rank 0 # input: Tensor(0)
# Rank 1 # input: Tensor(1)
output = all_reduce_sum(input)
# Rank 0 # output: Tensor(1)
# Rank 1 # output: Tensor(1)
input = Tensor(rank)
# Rank 0 # input: Tensor(0)
# Rank 1 # input: Tensor(1)
output = all_reduce_sum(input)
# Rank 0 # output: Tensor(1)
# Rank 1 # output: Tensor(1)
""" """
mode = CollectiveComm.Mode.ALL_REDUCE_SUM mode = CollectiveComm.Mode.ALL_REDUCE_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -433,8 +422,7 @@ def all_reduce_sum( ...@@ -433,8 +422,7 @@ def all_reduce_sum(
def all_reduce_max( def all_reduce_max(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
r""" r"""Reduce tensors across the specified group by max.
Reduce tensors across the specified group by max.
Args: Args:
inp: Input tensor. inp: Input tensor.
...@@ -451,15 +439,14 @@ def all_reduce_max( ...@@ -451,15 +439,14 @@ def all_reduce_max(
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor(rank)
# Rank 0 # input: Tensor(0)
# Rank 1 # input: Tensor(1)
output = all_reduce_max(input)
# Rank 0 # output: Tensor(1)
# Rank 1 # output: Tensor(1)
input = Tensor(rank)
# Rank 0 # input: Tensor(0)
# Rank 1 # input: Tensor(1)
output = all_reduce_max(input)
# Rank 0 # output: Tensor(1)
# Rank 1 # output: Tensor(1)
""" """
mode = CollectiveComm.Mode.ALL_REDUCE_MAX mode = CollectiveComm.Mode.ALL_REDUCE_MAX
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -468,8 +455,7 @@ def all_reduce_max( ...@@ -468,8 +455,7 @@ def all_reduce_max(
def all_reduce_min( def all_reduce_min(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
r""" r"""Reduce tensors across the specified group by min.
Reduce tensors across the specified group by min.
Args: Args:
inp: Input tensor. inp: Input tensor.
...@@ -486,15 +472,14 @@ def all_reduce_min( ...@@ -486,15 +472,14 @@ def all_reduce_min(
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor(rank)
# Rank 0 # input: Tensor(0)
# Rank 1 # input: Tensor(1)
output = all_reduce_min(input)
# Rank 0 # output: Tensor(0)
# Rank 1 # output: Tensor(0)
input = Tensor(rank)
# Rank 0 # input: Tensor(0)
# Rank 1 # input: Tensor(1)
output = all_reduce_min(input)
# Rank 0 # output: Tensor(0)
# Rank 1 # output: Tensor(0)
""" """
mode = CollectiveComm.Mode.ALL_REDUCE_MIN mode = CollectiveComm.Mode.ALL_REDUCE_MIN
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -520,8 +505,7 @@ class _Gather(Function): ...@@ -520,8 +505,7 @@ class _Gather(Function):
def gather( def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
) -> Tensor: ) -> Tensor:
r""" r"""Gather tensors across the specified group.
Gather tensors across the specified group.
Only root process will receive the final result. Only root process will receive the final result.
Args: Args:
...@@ -534,27 +518,23 @@ def gather( ...@@ -534,27 +518,23 @@ def gather(
Specify "gpu0:1" to execute this operator on diffrent cuda stream, Specify "gpu0:1" to execute this operator on diffrent cuda stream,
1 is stream id, and default stream id is 0. 1 is stream id, and default stream id is 0.
axis: The concat axis for collective_comm result axis: The concat axis for collective_comm result
The default axis is 0
Returns:
Result tensor if in root process, None if in other process
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor([rank])
# Rank 0 # input: Tensor([0])
# Rank 1 # input: Tensor([1])
output = gather(input)
# Rank 0 # output: Tensor([0 1])
# Rank 1 # output: None
input = Tensor([rank]) input = Tensor([rank])
group = Group([1, 0]) # first rank is root # Rank 0 # input: Tensor([0])
output = gather(input, group) # Rank 1 # input: Tensor([1])
# Rank 0 # output: None output = gather(input)
# Rank 1 # output: Tensor([1 0]) # Rank 0 # output: Tensor([0 1])
# Rank 1 # output: None
input = Tensor([rank])
group = Group([1, 0]) # first rank is root
output = gather(input, group)
# Rank 0 # output: None
# Rank 1 # output: Tensor([1 0])
""" """
assert ( assert (
axis < inp.ndim axis < inp.ndim
...@@ -607,8 +587,7 @@ class _Scatter(Function): ...@@ -607,8 +587,7 @@ class _Scatter(Function):
def scatter( def scatter(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
) -> Tensor: ) -> Tensor:
r""" r"""Split tensor in root process at first dimension.
Split tensor in root process at first dimension.
Args: Args:
inp: Input tensor. inp: Input tensor.
...@@ -627,21 +606,20 @@ def scatter( ...@@ -627,21 +606,20 @@ def scatter(
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor([0 1]) + rank*2
# Rank 0 # input: Tensor([0 1])
# Rank 1 # input: Tensor([2 3])
output = scatter(input)
# Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([1])
input = Tensor([0 1]) + rank*2 input = Tensor([0 1]) + rank*2
group = Group([1, 0]) # first rank is root # Rank 0 # input: Tensor([0 1])
output = scatter(input, group) # Rank 1 # input: Tensor([2 3])
# Rank 0 # output: Tensor([3]) output = scatter(input)
# Rank 1 # output: Tensor([2]) # Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([1])
input = Tensor([0 1]) + rank*2
group = Group([1, 0]) # first rank is root
output = scatter(input, group)
# Rank 0 # output: Tensor([3])
# Rank 1 # output: Tensor([2])
""" """
shape, dtype = _bcast_shape_dtype(group, inp) shape, dtype = _bcast_shape_dtype(group, inp)
if group.rank != 0: if group.rank != 0:
...@@ -680,8 +658,7 @@ def all_to_all( ...@@ -680,8 +658,7 @@ def all_to_all(
split_axis: int = 0, split_axis: int = 0,
concat_axis: int = 0, concat_axis: int = 0,
) -> Tensor: ) -> Tensor:
r""" r"""Each process scatter input tensor to all processes and return gathered tensor.
Each process scatter input tensor to all processes and return gathered tensor.
Args: Args:
inp: Input tensor. inp: Input tensor.
...@@ -694,29 +671,26 @@ def all_to_all( ...@@ -694,29 +671,26 @@ def all_to_all(
1 is stream id, and default stream id is 0. 1 is stream id, and default stream id is 0.
split_axis: The axis that collectivecomm will split data split_axis: The axis that collectivecomm will split data
the default axis is 0 the default axis is 0
split_axis: The axis that collectivecomm will concat data
the default axis is 0
Returns: Returns:
Result tensor. Result tensor.
Examples: Examples:
.. code-block:: .. code-block::
input = Tensor([0 1]) + rank*2
# Rank 0 # input: Tensor([0 1])
# Rank 1 # input: Tensor([2 3])
output = all_to_all(input)
# Rank 0 # output: Tensor([0 2])
# Rank 1 # output: Tensor([1 3])
input = Tensor([0 1]) + rank*2 input = Tensor([0 1]) + rank*2
group = Group([1, 0]) # Rank 0 # input: Tensor([0 1])
output = all_to_all(input, group) # Rank 1 # input: Tensor([2 3])
# Rank 0 # output: Tensor([0 3]) output = all_to_all(input)
# Rank 1 # output: Tensor([2 1]) # Rank 0 # output: Tensor([0 2])
# Rank 1 # output: Tensor([1 3])
input = Tensor([0 1]) + rank*2
group = Group([1, 0])
output = all_to_all(input, group)
# Rank 0 # output: Tensor([0 3])
# Rank 1 # output: Tensor([2 1])
""" """
group_size = group.size if group is not None else 1 group_size = group.size if group is not None else 1
assert ( assert (
...@@ -805,8 +779,7 @@ class _RemoteRecv(Function): ...@@ -805,8 +779,7 @@ class _RemoteRecv(Function):
def remote_send(inp: Tensor, dest_rank: int): def remote_send(inp: Tensor, dest_rank: int):
r""" r"""Send tensor to another process.
Send tensor to another process.
Args: Args:
inp: Tensor to send. inp: Tensor to send.
...@@ -816,17 +789,15 @@ def remote_send(inp: Tensor, dest_rank: int): ...@@ -816,17 +789,15 @@ def remote_send(inp: Tensor, dest_rank: int):
None. None.
Examples: Examples:
.. code-block::
.. code-block::
if rank == 0:
if rank == 0: data = mge.tensor(1)
data = mge.tensor(1) # Tensor(1)
# Tensor(1) F.distributed.remote_send(data, 1) # return None
F.distributed.remote_send(data, 1) # return None else:
else: data = F.distributed.remote_recv(0)
data = F.distributed.remote_recv(0) # Tensor(1)
# Tensor(1)
""" """
group = _SendRecvGroup(get_rank(), dest_rank) group = _SendRecvGroup(get_rank(), dest_rank)
_bcast_shape_dtype(group, inp) _bcast_shape_dtype(group, inp)
...@@ -844,8 +815,7 @@ def remote_send(inp: Tensor, dest_rank: int): ...@@ -844,8 +815,7 @@ def remote_send(inp: Tensor, dest_rank: int):
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor: def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
r""" r"""Receive a tensor from another process.
Receive a tensor from another process.
Args: Args:
src_rank: Rank of source process. src_rank: Rank of source process.
...@@ -862,14 +832,13 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor ...@@ -862,14 +832,13 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor
.. code-block:: .. code-block::
if rank == 0: if rank == 0:
data = mge.tensor(1) data = mge.tensor(1)
# Tensor(1) # Tensor(1)
F.distributed.remote_send(data, 1) # return None F.distributed.remote_send(data, 1) # return None
else: else:
data = F.distributed.remote_recv(0) data = F.distributed.remote_recv(0)
# Tensor(1) # Tensor(1)
""" """
group = _SendRecvGroup(src_rank, get_rank()) group = _SendRecvGroup(src_rank, get_rank())
shape, dtype = _bcast_shape_dtype(group, None) shape, dtype = _bcast_shape_dtype(group, None)
......
...@@ -36,15 +36,13 @@ _sd = None ...@@ -36,15 +36,13 @@ _sd = None
class Group: class Group:
r""" r"""Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).
Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).
By default collectives operate on the default group (also called ``WORLD``) By default collectives operate on the default group (also called ``WORLD``)
and require all processes to enter the distributed function call. and require all processes to enter the distributed function call.
:param proc_ranks: rank list of the group, the first one is root rank. Args:
proc_ranks: rank list of the group, the first one is root rank.
""" """
def __init__(self, proc_ranks): def __init__(self, proc_ranks):
...@@ -116,15 +114,15 @@ def init_process_group( ...@@ -116,15 +114,15 @@ def init_process_group(
backend: Optional[str] = "auto", backend: Optional[str] = "auto",
device_type: str = "xpu", device_type: str = "xpu",
) -> None: ) -> None:
""" r"""Initialize the distributed process group and specify the device used in the current process
Initialize the distributed process group and specify the device used in the current process
Args:
:param master_ip: ip address of the master node. master_ip: ip address of the master node.
:param port: port available for all processes to communicate. port: port available for all processes to communicate.
:param world_size: total number of processes participating in the job. world_size: total number of processes participating in the job.
:param rank: rank of the current process. rank: rank of the current process.
:param device: the GPU device id to bind this process to. device: the GPU device id to bind this process to.
:param backend: communicator backend, currently support 'nccl' and 'shm'. backend: communicator backend, currently support 'nccl' and 'shm'.
""" """
physical_device_type = what_is_xpu() if device_type == "xpu" else device_type physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
if not isinstance(master_ip, str): if not isinstance(master_ip, str):
...@@ -180,10 +178,10 @@ def _set_machine_ranks(ranks) -> None: ...@@ -180,10 +178,10 @@ def _set_machine_ranks(ranks) -> None:
@contextmanager @contextmanager
def override_backend(new_backend: str): def override_backend(new_backend: str):
""" r"""Override distributed backend
Override distributed backend
:param new_backend: communicator backend set in this context. Args:
new_backend: communicator backend set in this context.
""" """
global _sd global _sd
assert _sd, "please call init_process_group first" assert _sd, "please call init_process_group first"
...@@ -196,51 +194,51 @@ def override_backend(new_backend: str): ...@@ -196,51 +194,51 @@ def override_backend(new_backend: str):
def is_distributed() -> bool: def is_distributed() -> bool:
"""Return True if the distributed process group has been initialized.""" r"""Return True if the distributed process group has been initialized."""
return _sd is not None return _sd is not None
def get_rank() -> int: def get_rank() -> int:
"""Get the rank of the current process.""" r"""Get the rank of the current process."""
return _sd.proc_rank if _sd is not None else 0 return _sd.proc_rank if _sd is not None else 0
def get_world_size() -> int: def get_world_size() -> int:
"""Get the total number of processes participating in the job.""" r"""Get the total number of processes participating in the job."""
return _sd.world_size if _sd is not None else 1 return _sd.world_size if _sd is not None else 1
def get_backend() -> str: def get_backend() -> str:
"""Get the backend str.""" r"""Get the backend str."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.backend if _sd is not None else None return _sd.backend if _sd is not None else None
def get_py_server_addr() -> Tuple[str, int]: def get_py_server_addr() -> Tuple[str, int]:
"""Get master_ip and port of python XML RPC server.""" r"""Get master_ip and port of python XML RPC server."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.master_ip, _sd.py_server_port return _sd.master_ip, _sd.py_server_port
def get_mm_server_addr() -> Tuple[str, int]: def get_mm_server_addr() -> Tuple[str, int]:
"""Get master_ip and port of C++ mm_server.""" r"""Get master_ip and port of C++ mm_server."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.master_ip, _sd.mm_server_port return _sd.master_ip, _sd.mm_server_port
def get_client() -> Client: def get_client() -> Client:
"""Get client of python XML RPC server.""" r"""Get client of python XML RPC server."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.client return _sd.client
def new_group(proc_ranks: List[int]) -> Group: def new_group(proc_ranks: List[int]) -> Group:
"""Build a subgroup containing certain ranks.""" r"""Build a subgroup containing certain ranks."""
return Group(proc_ranks) return Group(proc_ranks)
def group_barrier(group: Group = WORLD) -> None: def group_barrier(group: Group = WORLD) -> None:
"""Block until all ranks in the group reach this barrier.""" r"""Block until all ranks in the group reach this barrier."""
# if running with single node, skip it # if running with single node, skip it
if _sd is None: if _sd is None:
return return
......
...@@ -28,39 +28,40 @@ from .group import WORLD, Group, group_barrier, is_distributed, override_backend ...@@ -28,39 +28,40 @@ from .group import WORLD, Group, group_barrier, is_distributed, override_backend
def param_pack_split(inp: Tensor, offsets: list, shapes: list): def param_pack_split(inp: Tensor, offsets: list, shapes: list):
r""" r"""Returns split tensor to tensor list as offsets and shapes described,
Returns split tensor to tensor list as offsets and shapes described, only used for ``parampack``.
only used for ``parampack``.
:param inp: input tensor. Args:
:param offsets: offsets of outputs, length of `2 * n`, inp: input tensor.
offsets: offsets of outputs, length of `2 * n`,
while n is tensor nums you want to split, while n is tensor nums you want to split,
format `[begin0, end0, begin1, end1]`. format `[begin0, end0, begin1, end1]`.
:param shapes: tensor shapes of outputs. shapes: tensor shapes of outputs.
:return: splitted tensors.
Examples: Returns:
splitted tensors.
.. testcode:: Examples:
import numpy as np .. testcode::
from megengine import tensor
from megengine.distributed.helper import param_pack_split
a = tensor(np.ones((10,), np.int32)) import numpy as np
b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) from megengine import tensor
print(b.numpy()) from megengine.distributed.helper import param_pack_split
print(c.numpy())
Outputs: a = tensor(np.ones((10,), np.int32))
b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy())
print(c.numpy())
.. testoutput:: Outputs:
[1] .. testoutput::
[[1 1 1]
[1 1 1]
[1 1 1]]
[1]
[[1 1 1]
[1 1 1]
[1 1 1]]
""" """
op = ParamPackSplit() op = ParamPackSplit()
op.offsets = offsets op.offsets = offsets
...@@ -73,36 +74,37 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): ...@@ -73,36 +74,37 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
r""" r"""Returns concated tensor, only used for ``parampack``.
Returns concated tensor, only used for ``parampack``.
:param inps: input tensors. Args:
:param offsets: device value of offsets. inps: input tensors.
:param offsets_val: offsets of inputs, length of `2 * n`, offsets: device value of offsets.
offsets_val: offsets of inputs, length of `2 * n`,
format `[begin0, end0, begin1, end1]`. format `[begin0, end0, begin1, end1]`.
:return: concated tensor.
Examples: Returns:
concated tensor.
.. testcode:: Examples:
import numpy as np .. testcode::
from megengine import tensor
from megengine.distributed.helper import param_pack_concat
a = tensor(np.ones((1,), np.int32)) import numpy as np
b = tensor(np.ones((3, 3), np.int32)) from megengine import tensor
offsets_val = [0, 1, 1, 10] from megengine.distributed.helper import param_pack_concat
offsets = tensor(offsets_val, np.int32)
c = param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy())
Outputs: a = tensor(np.ones((1,), np.int32))
b = tensor(np.ones((3, 3), np.int32))
offsets_val = [0, 1, 1, 10]
offsets = tensor(offsets_val, np.int32)
c = param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy())
.. testoutput:: Outputs:
[1 1 1 1 1 1 1 1 1 1] .. testoutput::
[1 1 1 1 1 1 1 1 1 1]
""" """
op = ParamPackConcat() op = ParamPackConcat()
op.offsets = offsets_val op.offsets = offsets_val
...@@ -165,9 +167,9 @@ class TensorFuture(Future): ...@@ -165,9 +167,9 @@ class TensorFuture(Future):
def synchronized(func: Callable): def synchronized(func: Callable):
r"""Decorator. Decorated function will synchronize when finished.
Specifically, we use this to prevent data race during hub.load
""" """
Decorator. Decorated function will synchronize when finished.
Specifically, we use this to prevent data race during hub.load"""
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
...@@ -199,23 +201,23 @@ get_device_count_by_fork = deprecated_func( ...@@ -199,23 +201,23 @@ get_device_count_by_fork = deprecated_func(
def bcast_list_(inps: list, group: Group = WORLD): def bcast_list_(inps: list, group: Group = WORLD):
""" r"""Broadcast tensors between given group.
Broadcast tensors between given group.
:param inps: input tensors. Args:
:param group: communication group. inps: input tensors.
group: communication group.
""" """
for inp in inps: for inp in inps:
inp._reset(_bcast_param(inp, group)) inp._reset(_bcast_param(inp, group))
class AllreduceCallback: class AllreduceCallback:
""" r"""Allreduce Callback with tensor fusion optimization.
Allreduce Callback with tensor fusion optimization.
:param reduce_method: the method to reduce gradiants. Args:
:param group: communication group. reduce_method: the method to reduce gradiants.
:param backend: override distributed backend in allreduce group: communication group.
backend: override distributed backend in allreduce
""" """
def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None): def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None):
......
...@@ -39,7 +39,7 @@ def _run_wrapped( ...@@ -39,7 +39,7 @@ def _run_wrapped(
queue: mp.Queue, queue: mp.Queue,
machine_ranks: list, machine_ranks: list,
): ):
"""Init distributed process group and run wrapped function.""" r"""Init distributed process group and run wrapped function."""
_check_device_initialized(device_type, dev) _check_device_initialized(device_type, dev)
init_process_group( init_process_group(
master_ip=master_ip, master_ip=master_ip,
...@@ -64,15 +64,16 @@ def _run_wrapped( ...@@ -64,15 +64,16 @@ def _run_wrapped(
class launcher: class launcher:
"""Decorator for launching multiple processes in single-machine multi-gpu training. r"""Decorator for launching multiple processes in single-machine multi-gpu training.
:param func: the function you want to launch in distributed mode. Args:
:param n_gpus: how many devices each node. func: the function you want to launch in distributed mode.
:param world_size: how many devices totally. n_gpus: how many devices each node.
:param rank_start: start number for rank. world_size: how many devices totally.
:param master_ip: ip address for master node (where the rank 0 is). rank_start: start number for rank.
:param port: server port for distributed server. master_ip: ip address for master node (where the rank 0 is).
:param backend: set default collective communication backend. port: server port for distributed server.
backend: set default collective communication backend.
""" """
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
......
...@@ -20,11 +20,11 @@ from ..utils.future import Future ...@@ -20,11 +20,11 @@ from ..utils.future import Future
class Methods: class Methods:
""" r"""Distributed Server Method.
Distributed Server Method.
Used for exchange information between distributed nodes. Used for exchange information between distributed nodes.
:param mm_server_port: multiple machine rpc server port. Args:
mm_server_port: multiple machine rpc server port.
""" """
def __init__(self, mm_server_port): def __init__(self, mm_server_port):
...@@ -39,19 +39,19 @@ class Methods: ...@@ -39,19 +39,19 @@ class Methods:
self.bcast_dict = {} self.bcast_dict = {}
def connect(self): def connect(self):
"""Method for checking connection success.""" r"""Method for checking connection success."""
return True return True
def get_mm_server_port(self): def get_mm_server_port(self):
"""Get multiple machine rpc server port.""" r"""Get multiple machine rpc server port."""
return self.mm_server_port return self.mm_server_port
def set_is_grad(self, key, is_grad): def set_is_grad(self, key, is_grad):
""" r"""Mark send/recv need gradiants by key.
Mark send/recv need gradiants by key.
:param key: key to match send/recv op. Args:
:param is_grad: whether this op need grad. key: key to match send/recv op.
is_grad: whether this op need grad.
""" """
with self.lock: with self.lock:
future = self.dict_is_grad[key] future = self.dict_is_grad[key]
...@@ -59,10 +59,10 @@ class Methods: ...@@ -59,10 +59,10 @@ class Methods:
return True return True
def check_is_grad(self, key): def check_is_grad(self, key):
""" r"""Check whether send/recv need gradiants.
Check whether send/recv need gradiants.
:param key: key to match send/recv op. Args:
key: key to match send/recv op.
""" """
with self.lock: with self.lock:
future = self.dict_is_grad[key] future = self.dict_is_grad[key]
...@@ -72,11 +72,11 @@ class Methods: ...@@ -72,11 +72,11 @@ class Methods:
return ret return ret
def set_remote_tracer(self, key, tracer_set): def set_remote_tracer(self, key, tracer_set):
""" r"""Set tracer dict for tracing send/recv op.
Set tracer dict for tracing send/recv op.
:param key: key to match send/recv op. Args:
:param tracer_set: valid tracer set. key: key to match send/recv op.
tracer_set: valid tracer set.
""" """
with self.lock: with self.lock:
future = self.dict_remote_tracer[key] future = self.dict_remote_tracer[key]
...@@ -84,10 +84,10 @@ class Methods: ...@@ -84,10 +84,10 @@ class Methods:
return True return True
def check_remote_tracer(self, key): def check_remote_tracer(self, key):
""" r"""Get tracer dict for send/recv op.
Get tracer dict for send/recv op.
:param key: key to match send/recv op. Args:
key: key to match send/recv op.
""" """
with self.lock: with self.lock:
future = self.dict_remote_tracer[key] future = self.dict_remote_tracer[key]
...@@ -97,11 +97,11 @@ class Methods: ...@@ -97,11 +97,11 @@ class Methods:
return ret return ret
def group_barrier(self, key, size): def group_barrier(self, key, size):
""" r"""A barrier wait for all group member.
A barrier wait for all group member.
:param key: group key to match each other. Args:
:param size: group size. key: group key to match each other.
size: group size.
""" """
with self.lock: with self.lock:
self.dict_barrier_counter[key] += 1 self.dict_barrier_counter[key] += 1
...@@ -116,14 +116,14 @@ class Methods: ...@@ -116,14 +116,14 @@ class Methods:
return True return True
def user_set(self, key, val): def user_set(self, key, val):
"""Set user defined key-value pairs across processes.""" r"""Set user defined key-value pairs across processes."""
with self.lock: with self.lock:
future = self.user_dict[key] future = self.user_dict[key]
future.set(val) future.set(val)
return True return True
def user_get(self, key): def user_get(self, key):
"""Get user defined key-value pairs across processes.""" r"""Get user defined key-value pairs across processes."""
with self.lock: with self.lock:
future = self.user_dict[key] future = self.user_dict[key]
return future.get() return future.get()
...@@ -161,12 +161,12 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): ...@@ -161,12 +161,12 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
def _start_server(py_server_port, queue): def _start_server(py_server_port, queue):
""" r"""Start python distributed server and multiple machine server.
Start python distributed server and multiple machine server.
:param py_server_port: python server port. Args:
:param mm_server_port: multiple machine server port. py_server_port: python server port.
:param queue: server port will put in this queue, puts exception when process fails. mm_server_port: multiple machine server port.
queue: server port will put in this queue, puts exception when process fails.
""" """
try: try:
mm_server_port = create_mm_server("0.0.0.0", 0) mm_server_port = create_mm_server("0.0.0.0", 0)
...@@ -182,11 +182,11 @@ def _start_server(py_server_port, queue): ...@@ -182,11 +182,11 @@ def _start_server(py_server_port, queue):
class Server: class Server:
""" r"""Distributed Server for distributed training.
Distributed Server for distributed training.
Should be running at master node. Should be running at master node.
:param port: python server port. Args:
port: python server port.
""" """
def __init__(self, port=0): def __init__(self, port=0):
...@@ -204,11 +204,11 @@ class Server: ...@@ -204,11 +204,11 @@ class Server:
class Client: class Client:
""" r"""Distributed Client for distributed training.
Distributed Client for distributed training.
:param master_ip: ip address of master node. Args:
:param port: port of server at master node. master_ip: ip address of master node.
port: port of server at master node.
""" """
def __init__(self, master_ip, port): def __init__(self, master_ip, port):
...@@ -218,7 +218,7 @@ class Client: ...@@ -218,7 +218,7 @@ class Client:
self.bcast_dict = defaultdict(lambda: 0) self.bcast_dict = defaultdict(lambda: 0)
def connect(self): def connect(self):
"""Check connection success.""" r"""Check connection success."""
while True: while True:
try: try:
self.proxy = ServerProxy( self.proxy = ServerProxy(
...@@ -230,62 +230,62 @@ class Client: ...@@ -230,62 +230,62 @@ class Client:
time.sleep(1) time.sleep(1)
def get_mm_server_port(self): def get_mm_server_port(self):
"""Get multiple machine server port.""" r"""Get multiple machine server port."""
return self.proxy.get_mm_server_port() return self.proxy.get_mm_server_port()
def set_is_grad(self, key, is_grad): def set_is_grad(self, key, is_grad):
""" r"""Mark send/recv need gradiants by key.
Mark send/recv need gradiants by key.
:param key: key to match send/recv op. Args:
:param is_grad: whether this op need grad. key: key to match send/recv op.
is_grad: whether this op need grad.
""" """
self.proxy.set_is_grad(key, is_grad) self.proxy.set_is_grad(key, is_grad)
def check_is_grad(self, key): def check_is_grad(self, key):
""" r"""Check whether send/recv need gradiants.
Check whether send/recv need gradiants.
:param key: key to match send/recv op. Args:
key: key to match send/recv op.
""" """
return self.proxy.check_is_grad(key) return self.proxy.check_is_grad(key)
def set_remote_tracer(self, key, tracer_set): def set_remote_tracer(self, key, tracer_set):
""" r"""Set tracer dict for tracing send/recv op.
Set tracer dict for tracing send/recv op.
:param key: key to match send/recv op. Args:
:param tracer_set: valid tracer set. key: key to match send/recv op.
tracer_set: valid tracer set.
""" """
self.proxy.set_remote_tracer(key, tracer_set) self.proxy.set_remote_tracer(key, tracer_set)
def check_remote_tracer(self, key): def check_remote_tracer(self, key):
""" r"""Get tracer dict for send/recv op.
Get tracer dict for send/recv op.
:param key: key to match send/recv op. Args:
key: key to match send/recv op.
""" """
return self.proxy.check_remote_tracer(key) return self.proxy.check_remote_tracer(key)
def group_barrier(self, key, size): def group_barrier(self, key, size):
""" r"""A barrier wait for all group member.
A barrier wait for all group member.
:param key: group key to match each other. Args:
:param size: group size. key: group key to match each other.
size: group size.
""" """
self.proxy.group_barrier(key, size) self.proxy.group_barrier(key, size)
def user_set(self, key, val): def user_set(self, key, val):
"""Set user defined key-value pairs across processes.""" r"""Set user defined key-value pairs across processes."""
return self.proxy.user_set(key, val) return self.proxy.user_set(key, val)
def user_get(self, key): def user_get(self, key):
"""Get user defined key-value pairs across processes.""" r"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key) return self.proxy.user_get(key)
def user_pop(self, key): def user_pop(self, key):
"""Get user defined key-value pairs and delete the resources when the get is done""" r"""Get user defined key-value pairs and delete the resources when the get is done"""
return self.proxy.user_pop(key) return self.proxy.user_pop(key)
def bcast_val(self, val, key, size): def bcast_val(self, val, key, size):
......
...@@ -30,24 +30,20 @@ def _str2bytes(text: str) -> int: ...@@ -30,24 +30,20 @@ def _str2bytes(text: str) -> int:
@property @property
def eviction_threshold(mod): def eviction_threshold(mod):
r""" r"""Get or set the eviction threshold in bytes. It can also be set to a string,
Get or set the eviction threshold in bytes. It can also be set to a string,
whose formatting supports byte(B), kilobyte(KB), megabyte(MB) and whose formatting supports byte(B), kilobyte(KB), megabyte(MB) and
gigabyte(GB) units. gigabyte(GB) units.
.. note:: Note:
When GPU memory usage exceeds this value, DTR will heuristically select When GPU memory usage exceeds this value, DTR will heuristically select
and evict resident tensors until the amount of used memory falls below and evict resident tensors until the amount of used memory falls below
this threshold. this threshold.
Examples: Examples:
.. code-block::
.. code-block:: import megengine as mge
mge.dtr.eviction_threshold = "2GB"
import megengine as mge
mge.dtr.eviction_threshold = "2GB"
""" """
return _eviction_threshold return _eviction_threshold
...@@ -66,24 +62,21 @@ def eviction_threshold(mod, value: Union[int, str]): ...@@ -66,24 +62,21 @@ def eviction_threshold(mod, value: Union[int, str]):
@property @property
def evictee_minimum_size(mod): def evictee_minimum_size(mod):
r""" r"""Get or set the memory threshold of tensors in bytes. It can also be set to a
Get or set the memory threshold of tensors in bytes. It can also be set to a
string, whose formatting supports byte(B), kilobyte(KB), megabyte(MB) and string, whose formatting supports byte(B), kilobyte(KB), megabyte(MB) and
gigabyte(GB) units. gigabyte(GB) units.
.. note:: Note:
Only tensors whose size exceeds this threshold will be added to the Only tensors whose size exceeds this threshold will be added to the
candidate set. A tensor that is not added to the candidate set will candidate set. A tensor that is not added to the candidate set will
never be evicted during its lifetime. never be evicted during its lifetime.
Examples: Examples:
.. code-block::
.. code-block:: import megengine as mge
mge.dtr.evictee_minimum_size = "2MB"
import megengine as mge
mge.dtr.evictee_minimum_size = "2MB"
""" """
return _evictee_minimum_size return _evictee_minimum_size
...@@ -102,19 +95,16 @@ def evictee_minimum_size(mod, value: Union[int, str]): ...@@ -102,19 +95,16 @@ def evictee_minimum_size(mod, value: Union[int, str]):
@property @property
def enable_sqrt_sampling(mod): def enable_sqrt_sampling(mod):
r""" r"""Get or set whether sqrt sampling is allowed. Sqrt sampling means that given
Get or set whether sqrt sampling is allowed. Sqrt sampling means that given
the size of the candidate set is N, only enumerate sqrt(N) tensors. When the size of the candidate set is N, only enumerate sqrt(N) tensors. When
the number of tensors is very high, enabling this optimization will speed the number of tensors is very high, enabling this optimization will speed
up the training. up the training.
Examples:
.. code-block::
Examples: import megengine as mge
mge.dtr.enable_sqrt_sampling = True
.. code-block::
import megengine as mge
mge.dtr.enable_sqrt_sampling = True
""" """
return _enable_sqrt_sampling return _enable_sqrt_sampling
...@@ -127,9 +117,7 @@ def enable_sqrt_sampling(mod, value: bool): ...@@ -127,9 +117,7 @@ def enable_sqrt_sampling(mod, value: bool):
def enable(): def enable():
r""" r"""Enable to record computing path of tensors and to perform DTR policy."""
Enable to record computing path of tensors and to perform DTR policy.
"""
_set_defrag(True) _set_defrag(True)
_set_option("enable_dtr_auto_drop", 1) _set_option("enable_dtr_auto_drop", 1)
_set_option("enable_drop", 1) _set_option("enable_drop", 1)
...@@ -138,9 +126,7 @@ def enable(): ...@@ -138,9 +126,7 @@ def enable():
def disable(): def disable():
r""" r"""Stop recording computing path of tensors and performing DTR policy."""
Stop recording computing path of tensors and performing DTR policy.
"""
_set_defrag(False) _set_defrag(False)
_set_option("enable_dtr_auto_drop", 0) _set_option("enable_dtr_auto_drop", 0)
_set_option("enable_drop", 0) _set_option("enable_drop", 0)
......
...@@ -23,8 +23,7 @@ if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None: ...@@ -23,8 +23,7 @@ if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
def get_execution_strategy() -> Strategy: def get_execution_strategy() -> Strategy:
""" r"""Returns the execution strategy of :class:`~module..Conv2d` and :func:`~.matmul`
Returns the execution strategy of :class:`~module..Conv2d` and :func:`~.matmul`
See :func:`~.set_execution_strategy` for possible return values See :func:`~.set_execution_strategy` for possible return values
""" """
...@@ -32,31 +31,32 @@ def get_execution_strategy() -> Strategy: ...@@ -32,31 +31,32 @@ def get_execution_strategy() -> Strategy:
def set_execution_strategy(option): def set_execution_strategy(option):
""" r"""Sets the execution strategy of :class:`~module.Conv2d` and :func:`~.matmul`
Sets the execution strategy of :class:`~module.Conv2d` and :func:`~.matmul`
Args:
option: Decides how :class:`~.module.Conv2d`and :func:`~.matmul` algorithms are chosen.
Available value Strategy
:param option: Decides how :class:`~module.Conv2d`and :func:`~.matmul` algorithms are chosen. * HEURISTIC uses heuristic to choose the fastest algorithm.
Available value Strategy * PROFILE runs possible algorithms on real device to find the best one.
* HEURISTIC uses heuristic to choose the fastest algorithm. * REPRODUCIBLE uses the algorithms that is reproducible.
* PROFILE runs possible algorithms on real device to find the best one. * OPTIMIZED uses the algorithms that is optimized.
* REPRODUCIBLE uses the algorithms that is reproducible.
* OPTIMIZED uses the algorithms that is optimized.
The default strategy is HEURISTIC, this options can be combined to The default strategy is HEURISTIC, this options can be combined to
form a combination option, e.g. PROFILE | REPRODUCIBLE form a combination option, e.g. PROFILE | REPRODUCIBLE
can combined a option that uses the fastest of profiling result that is also reproducible. can combined a option that uses the fastest of profiling result that is also reproducible.
Available values string: Available values string:
* 'HEURISTIC' uses heuristic to choose the fastest algorithm. * 'HEURISTIC' uses heuristic to choose the fastest algorithm.
* 'PROFILE' runs possible algorithms on real device to find the best one. * 'PROFILE' runs possible algorithms on real device to find the best one.
* 'PROFILE_HEURISTIC' uses profiling result and heuristic to choose the fastest algorithm. * 'PROFILE_HEURISTIC' uses profiling result and heuristic to choose the fastest algorithm.
* 'PROFILE_REPRODUCIBLE' uses the fastest of profiling result that is also reproducible. * 'PROFILE_REPRODUCIBLE' uses the fastest of profiling result that is also reproducible.
* 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. * 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible.
The default strategy is 'HEURISTIC'. The default strategy is 'HEURISTIC'.
It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'. It can also be set through the environment variable 'MEGENGINE_EXECUTION_STRATEGY'.
""" """
valid_string_option = { valid_string_option = {
"REPRODUCIBLE": Strategy.REPRODUCIBLE, "REPRODUCIBLE": Strategy.REPRODUCIBLE,
......
...@@ -78,182 +78,163 @@ def _elemwise_multi_type(*args, mode, **kwargs): ...@@ -78,182 +78,163 @@ def _elemwise_multi_type(*args, mode, **kwargs):
def add(x, y): def add(x, y):
""" r"""Element-wise `addition`.
Element-wise `addition`.
At least one operand should be tensor.
Same for sub/mul/div/floor_div/pow/mod/atan2/equal/not_equal/less/less_equal/greater/greater_equal/maximum/minmium.
:param x: input tensor.
:return: computed tensor.
Examples: Examples:
.. testcode:: .. testcode::
import numpy as np import numpy as np
from megengine import tensor from megengine import tensor
import megengine.functional as F import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.add(x, y) out = F.add(x, y)
print(out.numpy()) print(out.numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::
[[ 0. 2. 4.]
[ 6. 8. 10.]]
[[ 0. 2. 4.]
[ 6. 8. 10.]]
""" """
return _elwise(x, y, mode=Elemwise.Mode.ADD) return _elwise(x, y, mode=Elemwise.Mode.ADD)
def sub(x, y): def sub(x, y):
"""Element-wise `subtraction`.""" r"""Element-wise `subtraction`."""
return _elwise(x, y, mode=Elemwise.Mode.SUB) return _elwise(x, y, mode=Elemwise.Mode.SUB)
def mul(x, y): def mul(x, y):
"""Element-wise `multiplication`.""" r"""Element-wise `multiplication`."""
return _elwise(x, y, mode=Elemwise.Mode.MUL) return _elwise(x, y, mode=Elemwise.Mode.MUL)
def div(x, y): def div(x, y):
"""Element-wise `(x / y)`.""" r"""Element-wise `(x / y)`."""
return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV) return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV)
def floor_div(x, y): def floor_div(x, y):
"""Element-wise `floor(x / y)`.""" r"""Element-wise `floor(x / y)`."""
return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV) return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV)
def neg(x): def neg(x):
"""Element-wise `negation`.""" r"""Element-wise `negation`."""
return _elwise(x, mode=Elemwise.Mode.NEGATE) return _elwise(x, mode=Elemwise.Mode.NEGATE)
def pow(x, y): def pow(x, y):
"""Element-wise `power`.""" r"""Element-wise `power`."""
return _elwise(x, y, mode=Elemwise.Mode.POW) return _elwise(x, y, mode=Elemwise.Mode.POW)
def mod(x, y): def mod(x, y):
"""Element-wise `remainder of division`.""" r"""Element-wise `remainder of division`."""
return _elwise(x, y, mode=Elemwise.Mode.MOD) return _elwise(x, y, mode=Elemwise.Mode.MOD)
def abs(x): def abs(x):
"""Element-wise `absolute value`.""" r"""Element-wise `absolute value`."""
return _elwise(x, mode=Elemwise.Mode.ABS) return _elwise(x, mode=Elemwise.Mode.ABS)
def exp(x): def exp(x):
"""Element-wise `exponential`.""" r"""Element-wise `exponential`."""
return _elwise(x, mode=Elemwise.Mode.EXP) return _elwise(x, mode=Elemwise.Mode.EXP)
def expm1(x): def expm1(x):
"""Element-wise `exp(x)-1`.""" r"""Element-wise `exp(x)-1`."""
return _elwise(x, mode=Elemwise.Mode.EXPM1) return _elwise(x, mode=Elemwise.Mode.EXPM1)
def log(x): def log(x):
"""Element-wise `logarithm (base e)`.""" r"""Element-wise `logarithm (base e)`."""
return _elwise(x, mode=Elemwise.Mode.LOG) return _elwise(x, mode=Elemwise.Mode.LOG)
def log1p(x): def log1p(x):
"""Element-wise `log(x+1) (base e)`.""" r"""Element-wise `log(x+1) (base e)`."""
return _elwise(x, mode=Elemwise.Mode.LOG1P) return _elwise(x, mode=Elemwise.Mode.LOG1P)
def sqrt(x: Tensor) -> Tensor: def sqrt(x: Tensor) -> Tensor:
""" r"""Element-wise `sqrt`.
Element-wise `sqrt`.
Returns ``NaN`` for negative input value.
:param x: input tensor.
:return: computed tensor.
Examples: Examples:
.. testcode:: .. testcode::
import numpy as np import numpy as np
from megengine import tensor from megengine import tensor
import megengine.functional as F import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.sqrt(x) out = F.sqrt(x)
print(out.numpy().round(decimals=4)) print(out.numpy().round(decimals=4))
Outputs: Outputs:
.. testoutput:: .. testoutput::
[[0. 1. 1.4142]
[1.7321 2. 2.2361]]
[[0. 1. 1.4142]
[1.7321 2. 2.2361]]
""" """
return x ** 0.5 return x ** 0.5
def square(x: Tensor) -> Tensor: def square(x: Tensor) -> Tensor:
""" r"""Element-wise `square`.
Returns a new tensor with the square of the elements of input tensor.
:param inp: input tensor.
:return: computed tensor.
Examples: Examples:
.. testcode:: .. testcode::
import numpy as np import numpy as np
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.square(data) out = F.square(data)
print(out.numpy().round(decimals=4)) print(out.numpy().round(decimals=4))
Outputs: Outputs:
.. testoutput:: .. testoutput::
[[ 0. 1. 4.]
[ 9. 16. 25.]]
[[ 0. 1. 4.]
[ 9. 16. 25.]]
""" """
return x ** 2 return x ** 2
def round(x): def round(x):
"""Element-wise `rounding to int`.""" r"""Element-wise `rounding to int`."""
return _elwise(x, mode=Elemwise.Mode.ROUND) return _elwise(x, mode=Elemwise.Mode.ROUND)
def ceil(x): def ceil(x):
"""Element-wise `ceiling`.""" r"""Element-wise `ceiling`."""
return _elwise(x, mode=Elemwise.Mode.CEIL) return _elwise(x, mode=Elemwise.Mode.CEIL)
def floor(x): def floor(x):
"""Element-wise `floor`.""" r"""Element-wise `floor`."""
return _elwise(x, mode=Elemwise.Mode.FLOOR) return _elwise(x, mode=Elemwise.Mode.FLOOR)
def maximum(x, y): def maximum(x, y):
"""Element-wise `maximum of array elements`.""" r"""Element-wise `maximum of array elements`."""
return _elwise(x, y, mode=Elemwise.Mode.MAX) return _elwise(x, y, mode=Elemwise.Mode.MAX)
def minimum(x, y): def minimum(x, y):
"""Element-wise `minimum of array elements`.""" r"""Element-wise `minimum of array elements`."""
return _elwise(x, y, mode=Elemwise.Mode.MIN) return _elwise(x, y, mode=Elemwise.Mode.MIN)
...@@ -261,62 +242,57 @@ def minimum(x, y): ...@@ -261,62 +242,57 @@ def minimum(x, y):
def cos(x): def cos(x):
""" r"""Element-wise `cosine`.
Element-wise `cosine`.
:param x: input tensor.
:return: computed tensor.
Examples: Examples:
.. testcode:: .. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) import numpy as np
out = F.cos(x) from megengine import tensor
print(out.numpy().round(decimals=4)) import megengine.functional as F
Outputs: x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.cos(x)
print(out.numpy().round(decimals=4))
.. testoutput:: Outputs:
[[ 1. 0.5403 -0.4161] .. testoutput::
[-0.99 -0.6536 0.2837]]
[[ 1. 0.5403 -0.4161]
[-0.99 -0.6536 0.2837]]
""" """
return _elwise(x, mode=Elemwise.Mode.COS) return _elwise(x, mode=Elemwise.Mode.COS)
def sin(x): def sin(x):
"""Element-wise `sine`.""" r"""Element-wise `sine`."""
return _elwise(x, mode=Elemwise.Mode.SIN) return _elwise(x, mode=Elemwise.Mode.SIN)
def tan(x): def tan(x):
"""Element-wise `tangent`.""" r"""Element-wise `tangent`."""
return sin(x) / cos(x) return sin(x) / cos(x)
def acos(x): def acos(x):
"""Element-wise `inverse cosine`.""" r"""Element-wise `inverse cosine`."""
return _elwise(x, mode=Elemwise.Mode.ACOS) return _elwise(x, mode=Elemwise.Mode.ACOS)
def asin(x): def asin(x):
"""Element-wise `inverse sine`.""" r"""Element-wise `inverse sine`."""
return _elwise(x, mode=Elemwise.Mode.ASIN) return _elwise(x, mode=Elemwise.Mode.ASIN)
def atan(x): def atan(x):
"""Element-wise `inverse tangent`.""" r"""Element-wise `inverse tangent`."""
return _elwise(x, 1, mode=Elemwise.Mode.ATAN2) return _elwise(x, 1, mode=Elemwise.Mode.ATAN2)
def atan2(y, x): def atan2(y, x):
"""Element-wise `2-argument arctangent`.""" r"""Element-wise `2-argument arctangent`."""
return _elwise(y, x, mode=Elemwise.Mode.ATAN2) return _elwise(y, x, mode=Elemwise.Mode.ATAN2)
...@@ -355,38 +331,33 @@ def atanh(x): ...@@ -355,38 +331,33 @@ def atanh(x):
def left_shift(x, y): def left_shift(x, y):
""" r"""Element-wise `bitwise binary: x << y`.
Element-wise `bitwise binary: x << y`.
:param x: input tensor, should be int. Examples:
:param y: how many bits to be left-shifted.
:return: computed tensor.
Examples: .. testcode::
.. testcode::
import numpy as np import numpy as np
from megengine import tensor from megengine import tensor
import megengine.functional as F import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3)) x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3))
out = F.left_shift(x, 2) out = F.left_shift(x, 2)
print(out.numpy()) print(out.numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::
[[ 0 4 8] [[ 0 4 8]
[12 16 20]] [12 16 20]]
""" """
return _elwise(x, y, mode=Elemwise.Mode.SHL) return _elwise(x, y, mode=Elemwise.Mode.SHL)
def right_shift(x, y): def right_shift(x, y):
"""Element-wise `bitwise binary: x >> y`.""" r"""Element-wise `bitwise binary: x >> y`."""
return _elwise(x, y, mode=Elemwise.Mode.SHR) return _elwise(x, y, mode=Elemwise.Mode.SHR)
...@@ -394,22 +365,22 @@ def right_shift(x, y): ...@@ -394,22 +365,22 @@ def right_shift(x, y):
def logical_and(x, y): def logical_and(x, y):
"""Element-wise `logical and: x && y`.""" r"""Element-wise `logical and: x && y`."""
return _elwise(x, y, mode=Elemwise.Mode.AND) return _elwise(x, y, mode=Elemwise.Mode.AND)
def logical_not(x): def logical_not(x):
"""Element-wise `logical not: ~x`.""" r"""Element-wise `logical not: ~x`."""
return _elwise(x, mode=Elemwise.Mode.NOT) return _elwise(x, mode=Elemwise.Mode.NOT)
def logical_or(x, y): def logical_or(x, y):
"""Element-wise `logical or: x || y`.""" r"""Element-wise `logical or: x || y`."""
return _elwise(x, y, mode=Elemwise.Mode.OR) return _elwise(x, y, mode=Elemwise.Mode.OR)
def logical_xor(x, y): def logical_xor(x, y):
"""Element-wise `logical xor: x ^ y`.""" r"""Element-wise `logical xor: x ^ y`."""
return _elwise(x, y, mode=Elemwise.Mode.XOR) return _elwise(x, y, mode=Elemwise.Mode.XOR)
...@@ -417,59 +388,53 @@ def logical_xor(x, y): ...@@ -417,59 +388,53 @@ def logical_xor(x, y):
def equal(x, y): def equal(x, y):
""" r"""Element-wise `(x == y)`.
Element-wise `(x == y)`.
:param x: input tensor 1.
:param y: input tensor 2.
:return: computed tensor.
Examples: Examples:
.. testcode:: .. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) import numpy as np
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) from megengine import tensor
out = F.equal(x, y) import megengine.functional as F
print(out.numpy())
Outputs: x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
out = F.equal(x, y)
print(out.numpy())
.. testoutput:: Outputs:
[[1. 1. 1.] .. testoutput::
[1. 1. 1.]]
[[1. 1. 1.]
[1. 1. 1.]]
""" """
return _elwise(x, y, mode=Elemwise.Mode.EQ) return _elwise(x, y, mode=Elemwise.Mode.EQ)
def not_equal(x, y): def not_equal(x, y):
"""Element-wise `(x != y)`.""" r"""Element-wise `(x != y)`."""
return x != y return x != y
def less(x, y): def less(x, y):
"""Element-wise `(x < y)`.""" r"""Element-wise `(x < y)`."""
return _elwise(x, y, mode=Elemwise.Mode.LT) return _elwise(x, y, mode=Elemwise.Mode.LT)
def less_equal(x, y): def less_equal(x, y):
"""Element-wise `(x <= y)`.""" r"""Element-wise `(x <= y)`."""
return _elwise(x, y, mode=Elemwise.Mode.LEQ) return _elwise(x, y, mode=Elemwise.Mode.LEQ)
def greater(x, y): def greater(x, y):
"""Element-wise `(x > y)`.""" r"""Element-wise `(x > y)`."""
return _elwise(y, x, mode=Elemwise.Mode.LT) return _elwise(y, x, mode=Elemwise.Mode.LT)
def greater_equal(x, y): def greater_equal(x, y):
"""Element-wise `(x >= y)`.""" r"""Element-wise `(x >= y)`."""
return _elwise(y, x, mode=Elemwise.Mode.LEQ) return _elwise(y, x, mode=Elemwise.Mode.LEQ)
...@@ -477,43 +442,45 @@ def greater_equal(x, y): ...@@ -477,43 +442,45 @@ def greater_equal(x, y):
def clip(x: Tensor, lower=None, upper=None) -> Tensor: def clip(x: Tensor, lower=None, upper=None) -> Tensor:
r""" r"""Clamps all elements in input tensor into the range ``[ lower, upper ]`` and returns
Clamps all elements in input tensor into the range `[` :attr:`lower`, :attr:`upper` `]` and returns
a resulting tensor: a resulting tensor:
.. math:: .. math::
y_i = \begin{cases} y_i = \begin{cases}
\text{lower} & \text{if } x_i < \text{lower} \\ \text{lower} & \text{if } x_i < \text{lower} \\
x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\ x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
\text{upper} & \text{if } x_i > \text{upper} \text{upper} & \text{if } x_i > \text{upper}
\end{cases} \end{cases}
:param x: input tensor. Args:
:param lower: lower-bound of the range to be clamped to. x: input tensor.
:param upper: upper-bound of the range to be clamped to. lower: lower-bound of the range to be clamped to.
:return: output clamped tensor. upper: upper-bound of the range to be clamped to.
Examples: Returns:
output clamped tensor.
.. testcode:: Examples:
import numpy as np .. testcode::
from megengine import tensor
import megengine.functional as F
a = tensor(np.arange(5).astype(np.int32)) import numpy as np
print(F.clip(a, 2, 4).numpy()) from megengine import tensor
print(F.clip(a, lower=3).numpy()) import megengine.functional as F
print(F.clip(a, upper=3).numpy())
Outputs: a = tensor(np.arange(5).astype(np.int32))
print(F.clip(a, 2, 4).numpy())
print(F.clip(a, lower=3).numpy())
print(F.clip(a, upper=3).numpy())
.. testoutput:: Outputs:
[2 2 2 3 4] .. testoutput::
[3 3 3 3 4]
[0 1 2 3 3]
[2 2 2 3 4]
[3 3 3 3 4]
[0 1 2 3 3]
""" """
assert ( assert (
lower is not None or upper is not None lower is not None or upper is not None
......
...@@ -23,14 +23,14 @@ def tensorrt_runtime_opr(inputs, *, data: bytes = None): ...@@ -23,14 +23,14 @@ def tensorrt_runtime_opr(inputs, *, data: bytes = None):
def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable): def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable):
r""" r"""Load a serialized Cambricon model as a runtime operator in MegEngine.
Load a serialized Cambricon model as a runtime operator in MegEngine.
Args:
:param inputs: list of input tensors. inputs: list of input tensors.
:param data: the serialized Cambricon model. data: the serialized Cambricon model.
:param symbol: name of the function in Cambricon model. symbol: name of the function in Cambricon model.
:param tensor_dim_mutable: whether the input tensors' shapes are mutable tensor_dim_mutable: whether the input tensors' shapes are mutable
in ``cnrtModel_t``. in ``cnrtModel_t``.
""" """
op = builtin.CambriconRuntime(data, len(data), symbol, tensor_dim_mutable) op = builtin.CambriconRuntime(data, len(data), symbol, tensor_dim_mutable)
...@@ -38,11 +38,11 @@ def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable): ...@@ -38,11 +38,11 @@ def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable):
def atlas_runtime_opr(inputs, data): def atlas_runtime_opr(inputs, data):
r""" r"""Load a serialized Atlas model as a runtime operator in MegEngine.
Load a serialized Atlas model as a runtime operator in MegEngine.
:param inputs: list of input tensors. Args:
:param data: the serialized Atlas model. inputs: list of input tensors.
data: the serialized Atlas model.
""" """
op = builtin.AtlasRuntime(data, len(data)) op = builtin.AtlasRuntime(data, len(data))
......
...@@ -26,9 +26,7 @@ __all__ = [ ...@@ -26,9 +26,7 @@ __all__ = [
def _reduce_output(loss_fn): def _reduce_output(loss_fn):
r""" r"""Wrapper to apply canonical reductions to loss outputs."""
Wrapper to apply canonical reductions to loss outputs.
"""
@functools.wraps(loss_fn) @functools.wraps(loss_fn)
def reduced_loss_fn(*args, reduction="mean", **kwargs): def reduced_loss_fn(*args, reduction="mean", **kwargs):
...@@ -45,13 +43,14 @@ def _reduce_output(loss_fn): ...@@ -45,13 +43,14 @@ def _reduce_output(loss_fn):
@_reduce_output @_reduce_output
def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor: def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
r""" r"""Calculates the mean absolute error (MAE) between
Calculates the mean absolute error (MAE) between
each element in the pred :math:`x` and label :math:`y`. each element in the pred :math:`x` and label :math:`y`.
The mean absolute error can be described as: The mean absolute error can be described as:
.. math:: \ell(x,y) = mean\left(L \right) .. math::
\ell(x,y) = mean\left(L \right)
where where
...@@ -63,30 +62,32 @@ def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor: ...@@ -63,30 +62,32 @@ def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
of :math:`N` elements each. :math:`N` is the batch size. of :math:`N` elements each. :math:`N` is the batch size.
:param pred: predicted result from model. Args:
:param label: ground truth to compare. pred: predicted result from model.
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' label: ground truth to compare.
:return: loss value. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
Examples: Returns:
loss value.
.. testcode:: Examples:
import numpy as np .. testcode::
import megengine as mge
import megengine.functional as F
ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) import numpy as np
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) import megengine as mge
loss = F.nn.l1_loss(ipt, tgt) import megengine.functional as F
print(loss.numpy())
Outputs: ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
loss = F.nn.l1_loss(ipt, tgt)
print(loss.numpy())
.. testoutput:: Outputs:
2.75 .. testoutput::
2.75
""" """
diff = pred - label diff = pred - label
return abs(diff) return abs(diff)
...@@ -94,53 +95,56 @@ def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor: ...@@ -94,53 +95,56 @@ def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
@_reduce_output @_reduce_output
def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor: def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
r""" r"""Calculates the mean squared error (squared L2 norm) between
Calculates the mean squared error (squared L2 norm) between
each element in the pred :math:`x` and label :math:`y`. each element in the pred :math:`x` and label :math:`y`.
The mean squared error can be described as: The mean squared error can be described as:
.. math:: \ell(x, y) = mean\left( L \right) .. math::
\ell(x, y) = mean\left( L \right)
where where
.. math:: .. math::
L = \{l_1,\dots,l_N\}, \quad L = \{l_1,\dots,l_N\}, \quad
l_n = \left( x_n - y_n \right)^2, l_n = \left( x_n - y_n \right)^2,
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
of :math:`N` elements each. :math:`N` is the batch size. of :math:`N` elements each. :math:`N` is the batch size.
:param pred: predicted result from model. Args:
:param label: ground truth to compare. pred: predicted result from model.
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' label: ground truth to compare.
:return: loss value. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
Returns:
loss value.
Shape: Shape:
- pred: :math:`(N, *)` where :math:`*` means any number of additional * pred: :math:`(N, *)` where :math:`*` means any number of additional
dimensions. dimensions.
- label: :math:`(N, *)`. Same shape as ``pred``. * label: :math:`(N, *)`. Same shape as ``pred``.
Examples: Examples:
.. testcode:: .. testcode::
import numpy as np
import megengine as mge
import megengine.functional as F
ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) import numpy as np
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) import megengine as mge
loss = F.nn.square_loss(ipt, tgt) import megengine.functional as F
print(loss.numpy())
Outputs: ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32))
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32))
loss = F.nn.square_loss(ipt, tgt)
print(loss.numpy())
.. testoutput:: Outputs:
9.75 .. testoutput::
9.75
""" """
diff = pred - label diff = pred - label
return diff ** 2 return diff ** 2
...@@ -155,8 +159,7 @@ def cross_entropy( ...@@ -155,8 +159,7 @@ def cross_entropy(
label_smooth: float = 0, label_smooth: float = 0,
reduction: str = "mean", reduction: str = "mean",
) -> Tensor: ) -> Tensor:
r""" r"""Computes the multi-class cross entropy loss (using logits by default).
Computes the multi-class cross entropy loss (using logits by default).
By default(``with_logitis`` is True), ``pred`` is assumed to be logits, By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
class probabilities are given by softmax. class probabilities are given by softmax.
...@@ -170,35 +173,37 @@ def cross_entropy( ...@@ -170,35 +173,37 @@ def cross_entropy(
where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes. k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
:param pred: input tensor representing the predicted probability. Args:
:param label: input tensor representing the classification label. pred: input tensor representing the predicted probability.
:param axis: an axis along which softmax will be applied. Default: 1 label: input tensor representing the classification label.
:param with_logits: whether to apply softmax first. Default: True axis: an axis along which softmax will be applied. Default: 1
:param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0 with_logits: whether to apply softmax first. Default: True
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
:return: loss value. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
Examples: Returns:
loss value.
.. testcode:: Examples:
import numpy as np .. testcode::
from megengine import tensor
import megengine.functional as F
data_shape = (1, 2) import numpy as np
label_shape = (1, ) from megengine import tensor
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape)) import megengine.functional as F
label = tensor(np.ones(label_shape, dtype=np.int32))
loss = F.nn.cross_entropy(pred, label)
print(loss.numpy().round(decimals=4))
Outputs: data_shape = (1, 2)
label_shape = (1, )
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(data_shape))
label = tensor(np.ones(label_shape, dtype=np.int32))
loss = F.nn.cross_entropy(pred, label)
print(loss.numpy().round(decimals=4))
.. testoutput:: Outputs:
0.6931 .. testoutput::
0.6931
""" """
n0 = pred.ndim n0 = pred.ndim
n1 = label.ndim n1 = label.ndim
...@@ -226,37 +231,38 @@ def cross_entropy( ...@@ -226,37 +231,38 @@ def cross_entropy(
def binary_cross_entropy( def binary_cross_entropy(
pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean", pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean",
) -> Tensor: ) -> Tensor:
r""" r"""Computes the binary cross entropy loss (using logits by default).
Computes the binary cross entropy loss (using logits by default).
By default(``with_logitis`` is True), ``pred`` is assumed to be logits, By default(``with_logitis`` is True), ``pred`` is assumed to be logits,
class probabilities are given by sigmoid. class probabilities are given by sigmoid.
:param pred: `(N, *)`, where `*` means any number of additional dimensions. Args:
:param label: `(N, *)`, same shape as the input. pred: `(N, *)`, where `*` means any number of additional dimensions.
:param with_logits: bool, whether to apply sigmoid first. Default: True label: `(N, *)`, same shape as the input.
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' with_logits: bool, whether to apply sigmoid first. Default: True
:return: loss value. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
Examples: Returns:
loss value.
.. testcode:: Examples:
import numpy as np .. testcode::
from megengine import tensor
import megengine.functional as F
pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2)) import numpy as np
label = tensor(np.ones((1, 2), dtype=np.float32)) from megengine import tensor
loss = F.nn.binary_cross_entropy(pred, label) import megengine.functional as F
print(loss.numpy().round(decimals=4))
Outputs: pred = tensor(np.array([0, 0], dtype=np.float32).reshape(1, 2))
label = tensor(np.ones((1, 2), dtype=np.float32))
loss = F.nn.binary_cross_entropy(pred, label)
print(loss.numpy().round(decimals=4))
.. testoutput:: Outputs:
0.6931 .. testoutput::
0.6931
""" """
if not with_logits: if not with_logits:
return -(label * log(pred) + (1 - label) * log(1 - pred)) return -(label * log(pred) + (1 - label) * log(1 - pred))
...@@ -269,37 +275,38 @@ def binary_cross_entropy( ...@@ -269,37 +275,38 @@ def binary_cross_entropy(
def hinge_loss( def hinge_loss(
pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean" pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean"
) -> Tensor: ) -> Tensor:
r""" r"""Caculates the hinge loss which is often used in SVM.
Caculates the hinge loss which is often used in SVM.
The hinge loss can be described as: The hinge loss can be described as:
.. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij})) .. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij}))
:param pred: input tensor representing the predicted probability, shape is `(N, C)`. Args:
:param label: input tensor representing the binary classification label, shape is `(N, C)`. pred: input tensor representing the predicted probability, shape is `(N, C)`.
:param norm: specify the norm to caculate the loss, should be "L1" or "L2". label: input tensor representing the binary classification label, shape is `(N, C)`.
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean' norm: specify the norm to caculate the loss, should be "L1" or "L2".
:return: loss value. reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
Examples: Returns:
loss value.
.. testcode:: Examples:
from megengine import tensor .. testcode::
import megengine.functional as F
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32") from megengine import tensor
label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32") import megengine.functional as F
loss = F.nn.hinge_loss(pred, label)
print(loss.numpy())
Outputs: pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32")
label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32")
loss = F.nn.hinge_loss(pred, label)
print(loss.numpy())
.. testoutput:: Outputs:
1.5 .. testoutput::
1.5
""" """
norm = norm.upper() norm = norm.upper()
assert norm in ["L1", "L2"], "norm must be L1 or L2" assert norm in ["L1", "L2"], "norm must be L1 or L2"
......
...@@ -19,33 +19,16 @@ from .tensor import broadcast_to, transpose ...@@ -19,33 +19,16 @@ from .tensor import broadcast_to, transpose
def topk_accuracy( def topk_accuracy(
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
) -> Union[Tensor, Iterable[Tensor]]: ) -> Union[Tensor, Iterable[Tensor]]:
r""" r"""Calculates the classification accuracy given predicted logits and ground-truth labels.
Calculates the classification accuracy given predicted logits and ground-truth labels.
:param logits: model predictions of shape `[batch_size, num_classes]`, Args:
representing the probability (likelyhood) of each class. logits: model predictions of shape `[batch_size, num_classes]`,
:param target: ground-truth labels, 1d tensor of int32. representing the probability (likelyhood) of each class.
:param topk: specifies the topk values, could be an int or tuple of ints. Default: 1 target: ground-truth labels, 1d tensor of int32.
:return: tensor(s) of classification accuracy between 0.0 and 1.0. topk: specifies the topk values, could be an int or tuple of ints. Default: 1
Examples: Returns:
tensor(s) of classification accuracy between 0.0 and 1.0.
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10))
target = tensor(np.arange(8, dtype=np.int32))
top1, top5 = F.metric.topk_accuracy(logits, target, (1, 5))
print(top1.numpy(), top5.numpy())
Outputs:
.. testoutput::
0.0 0.375
""" """
if isinstance(topk, int): if isinstance(topk, int):
topk = (topk,) topk = (topk,)
......
...@@ -28,32 +28,28 @@ def conv_bias_activation( ...@@ -28,32 +28,28 @@ def conv_bias_activation(
conv_mode="cross_correlation", conv_mode="cross_correlation",
compute_mode="default", compute_mode="default",
) -> Tensor: ) -> Tensor:
""" r"""Convolution bias with activation operation, only for inference.
Convolution bias with activation operation, only for inference.
:param inp: feature map of the convolution operation.
:param weight: convolution kernel.
:param bias: bias added to the result of convolution
:param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides
of its spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`Convolution.Mode`.
:param conv_mode: supports 'cross_correlation' or 'convolution'. Default:
'cross_correlation'
:param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or
:class:`Convolution.ComputeMode`.
:param compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result,
but only effective when input and output are of float16 dtype.
Args:
inp: feature map of the convolution operation.
weight: convolution kernel.
bias: bias added to the result of convolution
stride: stride of the 2D convolution operation. Default: 1
padding: size of the paddings added to the input on both sides
of its spatial dimensions. Only zero-padding is supported. Default: 0
dilation: dilation of the 2D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
conv_mode: supports 'cross_correlation' or 'convolution'. Default:
'cross_correlation'
dtype: support for ``np.dtype``, Default: np.int8
compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result,
but only effective when input and output are of float16 dtype.
""" """
ph, pw = _pair(padding) ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride) sh, sw = _pair_nonzero(stride)
...@@ -91,32 +87,28 @@ def batch_conv_bias_activation( ...@@ -91,32 +87,28 @@ def batch_conv_bias_activation(
conv_mode="cross_correlation", conv_mode="cross_correlation",
compute_mode="default", compute_mode="default",
) -> Tensor: ) -> Tensor:
""" r"""Batch convolution bias with activation operation, only for inference.
Batch convolution bias with activation operation, only for inference.
:param inp: feature map of the convolution operation.
:param weight: convolution kernel in batched way.
:param bias: bias added to the result of convolution
:param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides
of its spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`Convolution.Mode`.
:param conv_mode: supports 'cross_correlation' or 'convolution'. Default:
'cross_correlation'
:param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or
:class:`Convolution.ComputeMode`.
:param compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result,
but only effective when input and output are of float16 dtype.
Args:
inp: feature map of the convolution operation.
weight: convolution kernel in batched way.
bias: bias added to the result of convolution
stride: stride of the 2D convolution operation. Default: 1
padding: size of the paddings added to the input on both sides
of its spatial dimensions. Only zero-padding is supported. Default: 0
dilation: dilation of the 2D convolution operation. Default: 1
groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
conv_mode: supports 'cross_correlation' or 'convolution'. Default:
'cross_correlation'
dtype: support for ``np.dtype``, Default: np.int8
compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result,
but only effective when input and output are of float16 dtype.
""" """
ph, pw = _pair(padding) ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride) sh, sw = _pair_nonzero(stride)
......
...@@ -19,37 +19,36 @@ __all__ = ["topk_accuracy"] ...@@ -19,37 +19,36 @@ __all__ = ["topk_accuracy"]
def _assert_equal( def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
): ):
r""" r"""Asserts two tensors equal and returns expected value (first input).
Asserts two tensors equal and returns expected value (first input).
It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``). It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``).
If we want to verify the correctness of model, just ``assert`` its states and outputs. If we want to verify the correctness of model, just ``assert`` its states and outputs.
While sometimes we need to verify the correctness at different backends for *dumped* model While sometimes we need to verify the correctness at different backends for *dumped* model
(or in :class:`~jit.trace` context), and no python code could be executed in that case. (or in :class:`~jit.trace` context), and no python code could be executed in that case.
Thus we have to use :func:`~functional.utils._assert_equal` instead. Thus we have to use :func:`~functional.utils._assert_equal` instead.
:param expect: expected tensor value Args:
:param actual: tensor to check value expect: expected tensor value
:param maxerr: max allowed error; error is defined as the minimal of absolute and relative error actual: tensor to check value
:param verbose: whether to print maxerr to stdout during opr exec maxerr: max allowed error; error is defined as the minimal of absolute and relative error
:return: expected tensor verbose: whether to print maxerr to stdout during opr exec
Examples: Examples:
.. testcode:: .. testcode::
import numpy as np import numpy as np
from megengine import tensor from megengine import tensor
import megengine.functional as F import megengine.functional as F
x = tensor([1, 2, 3], np.float32) x = tensor([1, 2, 3], np.float32)
y = tensor([1, 2, 3], np.float32) y = tensor([1, 2, 3], np.float32)
print(F.utils._assert_equal(x, y, maxerr=0).numpy()) print(F.utils._assert_equal(x, y, maxerr=0).numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::
[1. 2. 3.] [1. 2. 3.]
""" """
err = ( err = (
abs(expect - actual) abs(expect - actual)
......
...@@ -7,24 +7,24 @@ ...@@ -7,24 +7,24 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
class FetcherError(Exception): class FetcherError(Exception):
"""Base class for fetch related error.""" r"""Base class for fetch related error."""
class InvalidRepo(FetcherError): class InvalidRepo(FetcherError):
"""The repo provided was somehow invalid.""" r"""The repo provided was somehow invalid."""
class InvalidGitHost(FetcherError): class InvalidGitHost(FetcherError):
"""The git host provided was somehow invalid.""" r"""The git host provided was somehow invalid."""
class GitPullError(FetcherError): class GitPullError(FetcherError):
"""A git pull error occurred.""" r"""A git pull error occurred."""
class GitCheckoutError(FetcherError): class GitCheckoutError(FetcherError):
"""A git checkout error occurred.""" r"""A git checkout error occurred."""
class InvalidProtocol(FetcherError): class InvalidProtocol(FetcherError):
"""The protocol provided was somehow invalid.""" r"""The protocol provided was somehow invalid."""
...@@ -102,24 +102,18 @@ class GitSSHFetcher(RepoFetcherBase): ...@@ -102,24 +102,18 @@ class GitSSHFetcher(RepoFetcherBase):
commit: str = None, commit: str = None,
silent: bool = True, silent: bool = True,
) -> str: ) -> str:
""" """Fetches git repo by SSH protocol
Fetches git repo by SSH protocol
Args:
:param git_host: git_host: host address of git repo. Eg: github.com
host address of git repo. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
Example: github.com tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
:param repo_info: use_cache: whether to use locally fetched code or completely re-fetch.
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional commit: commit id on github or gitlab.
tag/branch. The default branch is ``master`` if not specified. silent: whether to accept the stdout and stderr of the subprocess with PIPE, instead of
Example: ``"brain_sdk/MegBrain[:hub]"`` displaying on the screen.
:param use_cache:
whether to use locally fetched code or completely re-fetch. Returns:
:param commit:
commit id on github or gitlab.
:param silent:
whether to accept the stdout and stderr of the subprocess with PIPE, instead of
displaying on the screen.
:return:
directory where the repo code is stored. directory where the repo code is stored.
""" """
if not cls._check_git_host(git_host): if not cls._check_git_host(git_host):
...@@ -217,24 +211,19 @@ class GitHTTPSFetcher(RepoFetcherBase): ...@@ -217,24 +211,19 @@ class GitHTTPSFetcher(RepoFetcherBase):
commit: str = None, commit: str = None,
silent: bool = True, silent: bool = True,
) -> str: ) -> str:
""" """Fetches git repo by HTTPS protocol.
Fetches git repo by HTTPS protocol.
Args:
:param git_host: git_host: host address of git repo. Eg: github.com
host address of git repo. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
Example: github.com tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
:param repo_info: use_cache: whether to use locally cached code or completely re-fetch.
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional commit: commit id on github or gitlab.
tag/branch. The default branch is ``master`` if not specified. silent: whether to accept the stdout and stderr of the subprocess with PIPE, instead of
Example: ``"brain_sdk/MegBrain[:hub]"`` displaying on the screen.
:param use_cache:
whether to use locally cached code or completely re-fetch.
:param commit: Returns:
commit id on github or gitlab.
:param silent:
whether to accept the stdout and stderr of the subprocess with PIPE, instead of
displaying on the screen.
:return:
directory where the repo code is stored. directory where the repo code is stored.
""" """
if not cls._check_git_host(git_host): if not cls._check_git_host(git_host):
......
...@@ -9,12 +9,12 @@ ...@@ -9,12 +9,12 @@
class GraphOptimizationConfig: class GraphOptimizationConfig:
r""" r"""Configuration for graph optimization: False for OFF, True for ON. The default value
Configuration for graph optimization: False for OFF, True for ON. The default value
None means that opt_level will decide whther this optimization will be applied or not. None means that opt_level will decide whther this optimization will be applied or not.
:param jit_fuse_dimshuffle: whether to fuse dimshuffle in JIT optimization Args:
:param jit_fuse_reduce: whether to fuse reduce in JIT optimization jit_fuse_dimshuffle: whether to fuse dimshuffle in JIT optimization
jit_fuse_reduce: whether to fuse reduce in JIT optimization
""" """
def __init__(self): def __init__(self):
......
...@@ -14,9 +14,7 @@ from .module import Module ...@@ -14,9 +14,7 @@ from .module import Module
class BatchMatMulActivation(Module): class BatchMatMulActivation(Module):
r""" r"""Batched :func:`~.matmul` with activation(only :func:`~.relu` supported), no transpose anywhere."""
Batched :func:`~.matmul` with activation(only :func:`~.relu` supported), no transpose anywhere.
"""
def __init__( def __init__(
self, self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册