提交 10e942d9 编写于 作者: M Megvii Engine Team

refactor(mge): polish api

- refactor(mge): add support for optimizer.step().clear_grad() idiom

- refactor(mge): rename some methods of GradManager

- refactor(mge): remove tensor_nn and TensorDict

- refactor(mge): remove Buffer

- refactor(mge): remove requires_grad flag

- refactor(mge): add a default grad=None attribute to Tensor

- refactor(mge): deprecation for 1.0

GitOrigin-RevId: 3b723d938747e7a7d765ba3716b2e08f5223f62a
上级 9389a805
...@@ -74,8 +74,7 @@ from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func ...@@ -74,8 +74,7 @@ from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import * from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
from .tensor import Tensor, tensor from .tensor import Parameter, Tensor, tensor
from .tensor_nn import Buffer, Parameter
from .version import __version__ from .version import __version__
_set_fork_exec_path_for_timed_func( _set_fork_exec_path_for_timed_func(
......
...@@ -22,7 +22,7 @@ class GradManager: ...@@ -22,7 +22,7 @@ class GradManager:
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = dict() self._gradients = dict()
def register(self, params, callbacks=None): def attach(self, params, callbacks=None):
if callbacks is None: if callbacks is None:
callbacks = [] callbacks = []
if isinstance(callbacks, Callable): if isinstance(callbacks, Callable):
...@@ -62,7 +62,7 @@ class GradManager: ...@@ -62,7 +62,7 @@ class GradManager:
if isinstance(grad, Future): if isinstance(grad, Future):
grad = grad.get() grad = grad.get()
param = self._param_dict[p] param = self._param_dict[p]
if getattr(param, "grad", None) is None: if param.grad is None:
param.grad = grad param.grad = grad
else: else:
param.grad += grad param.grad += grad
...@@ -70,9 +70,9 @@ class GradManager: ...@@ -70,9 +70,9 @@ class GradManager:
self._stop_record() self._stop_record()
backwarding_grad_manager = cache backwarding_grad_manager = cache
def __enter__(self): def record(self):
if self._recording: if self._recording:
return self raise RuntimeError("already recording")
grad = Grad() grad = Grad()
self._recording = True self._recording = True
self._grad = grad self._grad = grad
...@@ -88,16 +88,22 @@ class GradManager: ...@@ -88,16 +88,22 @@ class GradManager:
grad.wrt(param_wrapper, callback=callback) grad.wrt(param_wrapper, callback=callback)
grad.__enter__() grad.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb): def release(self):
if not self._recording:
raise RuntimeError("not recording")
self._stop_record() self._stop_record()
record = __enter__
def _stop_record(self): def _stop_record(self):
if self._grad is not None: if self._grad is not None:
self._grad.__exit__(None, None, None) self._grad.__exit__(None, None, None)
self._recording = False self._recording = False
self._grad = None self._grad = None
self._gradients = dict() self._gradients = dict()
def __enter__(self):
self.record()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_record()
...@@ -70,7 +70,7 @@ class Dimshuffle(PodOpVisitor): ...@@ -70,7 +70,7 @@ class Dimshuffle(PodOpVisitor):
return bytes(ctypes.c_uint32(0)) + bytes(self) return bytes(ctypes.c_uint32(0)) + bytes(self)
def __init__(self, pattern, ndim=0): def __init__(self, pattern, ndim=0):
assert isinstance(pattern, collections.Iterable) assert isinstance(pattern, collections.abc.Iterable)
assert len(pattern) <= TensorShape.MAX_NDIM assert len(pattern) <= TensorShape.MAX_NDIM
pattern_array = Dimshuffle.Pattern.Pattern_Array() pattern_array = Dimshuffle.Pattern.Pattern_Array()
for idx, v in enumerate(pattern): for idx, v in enumerate(pattern):
......
...@@ -231,13 +231,13 @@ class OpNode: ...@@ -231,13 +231,13 @@ class OpNode:
def _wrap(x): def _wrap(x):
if isinstance(x, collections.Sequence): if isinstance(x, collections.abc.Sequence):
return type(x)(map(_wrap, x)) return type(x)(map(_wrap, x))
return x.graph._wrap(x) return x.graph._wrap(x)
def _unwrap(x): def _unwrap(x):
if isinstance(x, collections.Sequence): if isinstance(x, collections.abc.Sequence):
return type(x)(map(_unwrap, x)) return type(x)(map(_unwrap, x))
return x._node return x._node
......
...@@ -166,7 +166,7 @@ def _reduce(mode): ...@@ -166,7 +166,7 @@ def _reduce(mode):
op = builtin.Reduce(mode=mode, axis=0) op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data) (result,) = apply(op, data)
elif isinstance(axis, collections.Iterable): elif isinstance(axis, collections.abc.Iterable):
axis = list(axis) axis = list(axis)
axis.sort(reverse=True) axis.sort(reverse=True)
...@@ -204,7 +204,9 @@ def _todo(*_): ...@@ -204,7 +204,9 @@ def _todo(*_):
def _expand_args(args): def _expand_args(args):
if len(args) == 1: if len(args) == 1:
if isinstance(args[0], (collections.Sequence, TensorBase, TensorWrapperBase)): if isinstance(
args[0], (collections.abc.Sequence, TensorBase, TensorWrapperBase)
):
args = args[0] args = args[0]
return args return args
......
...@@ -143,7 +143,7 @@ def astensor1d(x, *reference, dtype=None, device=None): ...@@ -143,7 +143,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
(x,) = Const(x, dtype=dtype, device=device)(*reference) (x,) = Const(x, dtype=dtype, device=device)(*reference)
return x return x
if not isinstance(x, collections.Sequence): if not isinstance(x, collections.abc.Sequence):
raise TypeError raise TypeError
if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x): if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x):
......
...@@ -432,7 +432,7 @@ def argmin( ...@@ -432,7 +432,7 @@ def argmin(
[0] [0]
""" """
if isinstance(axis, collections.Iterable): if isinstance(axis, collections.abc.Iterable):
axis = list(axis) axis = list(axis)
axis.sort(reverse=True) axis.sort(reverse=True)
...@@ -486,7 +486,7 @@ def argmax( ...@@ -486,7 +486,7 @@ def argmax(
[5] [5]
""" """
if isinstance(axis, collections.Iterable): if isinstance(axis, collections.abc.Iterable):
axis = list(axis) axis = list(axis)
axis.sort(reverse=True) axis.sort(reverse=True)
......
...@@ -15,7 +15,7 @@ def get_ndtuple(value, *, n, allow_zero=True): ...@@ -15,7 +15,7 @@ def get_ndtuple(value, *, n, allow_zero=True):
:type allow_zero: bool :type allow_zero: bool
:param allow_zero: whether to allow zero tuple value""" :param allow_zero: whether to allow zero tuple value"""
if not isinstance(value, collections.Iterable): if not isinstance(value, collections.abc.Iterable):
value = int(value) value = int(value)
value = tuple([value for i in range(n)]) value = tuple([value for i in range(n)])
else: else:
......
...@@ -502,7 +502,7 @@ class trace: ...@@ -502,7 +502,7 @@ class trace:
raise TypeError( raise TypeError(
"cannot specify output_names when output is already in dict format" "cannot specify output_names when output is already in dict format"
) )
if output_names and not isinstance(output_names, collections.Sequence): if output_names and not isinstance(output_names, collections.abc.Sequence):
output_names = (output_names,) output_names = (output_names,)
if output_names and len(output_names) != len(self._output_bindings): if output_names and len(output_names) != len(self._output_bindings):
raise ValueError( raise ValueError(
...@@ -510,7 +510,7 @@ class trace: ...@@ -510,7 +510,7 @@ class trace:
len(self._output_bindings) len(self._output_bindings)
) )
) )
if arg_names and not isinstance(arg_names, collections.Sequence): if arg_names and not isinstance(arg_names, collections.abc.Sequence):
arg_names = (arg_names,) arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings): if arg_names and len(arg_names) != len(self._arg_bindings):
raise ValueError( raise ValueError(
...@@ -646,9 +646,9 @@ class trace: ...@@ -646,9 +646,9 @@ class trace:
def _process_outputs(self, outputs): def _process_outputs(self, outputs):
output_names = None output_names = None
if isinstance(outputs, collections.Mapping): if isinstance(outputs, collections.abc.Mapping):
output_names, outputs = zip(*sorted(outputs.items())) output_names, outputs = zip(*sorted(outputs.items()))
elif not isinstance(outputs, collections.Sequence): elif not isinstance(outputs, collections.abc.Sequence):
outputs = (outputs,) outputs = (outputs,)
if not self._untraced: if not self._untraced:
......
...@@ -18,7 +18,6 @@ from .embedding import Embedding ...@@ -18,7 +18,6 @@ from .embedding import Embedding
from .identity import Identity from .identity import Identity
from .linear import Linear from .linear import Linear
from .module import Module from .module import Module
from .parampack import ParamPack
from .pooling import AvgPool2d, MaxPool2d from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub from .quant_dequant import DequantStub, QuantStub
from .sequential import Sequential from .sequential import Sequential
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
import numpy as np import numpy as np
from ..functional import leaky_relu, prelu, relu, sigmoid, softmax from ..functional import leaky_relu, prelu, relu, sigmoid, softmax
from ..tensor_nn import Parameter from ..tensor import Parameter
from .module import Module from .module import Module
......
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
from ..distributed.group import WORLD, Group from ..distributed.group import WORLD, Group
from ..functional import batch_norm2d, sync_batch_norm from ..functional import batch_norm2d, sync_batch_norm
from ..tensor_nn import Buffer, Parameter, Tensor from ..tensor import Parameter, Tensor
from . import init from . import init
from .module import Module from .module import Module
...@@ -45,8 +45,8 @@ class _BatchNorm(Module): ...@@ -45,8 +45,8 @@ class _BatchNorm(Module):
tshape = (1, self.num_features, 1, 1) tshape = (1, self.num_features, 1, 1)
if self.track_running_stats: if self.track_running_stats:
self.running_mean = Buffer(np.zeros(tshape, dtype=np.float32)) self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
self.running_var = Buffer(np.ones(tshape, dtype=np.float32)) self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
else: else:
self.running_mean = None self.running_mean = None
self.running_var = None self.running_var = None
......
...@@ -13,7 +13,7 @@ import numpy as np ...@@ -13,7 +13,7 @@ import numpy as np
from ..core.ops._internal import param_defs as P from ..core.ops._internal import param_defs as P
from ..functional import conv2d, conv_transpose2d, local_conv2d, relu from ..functional import conv2d, conv_transpose2d, local_conv2d, relu
from ..functional.types import _pair, _pair_nonzero from ..functional.types import _pair, _pair_nonzero
from ..tensor_nn import Parameter from ..tensor import Parameter
from . import init from . import init
from .module import Module from .module import Module
......
...@@ -11,7 +11,7 @@ from typing import Optional ...@@ -11,7 +11,7 @@ from typing import Optional
import numpy as np import numpy as np
from ..functional import embedding as embedding_func from ..functional import embedding as embedding_func
from ..tensor_nn import Parameter from ..tensor import Parameter
from . import init from . import init
from .module import Module from .module import Module
...@@ -72,6 +72,7 @@ class Embedding(Module): ...@@ -72,6 +72,7 @@ class Embedding(Module):
max_norm: Optional[float] = None, max_norm: Optional[float] = None,
norm_type: Optional[float] = None, norm_type: Optional[float] = None,
initial_weight: Parameter = None, initial_weight: Parameter = None,
freeze: bool = False,
): ):
super().__init__() super().__init__()
if padding_idx is not None: if padding_idx is not None:
...@@ -83,6 +84,7 @@ class Embedding(Module): ...@@ -83,6 +84,7 @@ class Embedding(Module):
self.norm_type = norm_type self.norm_type = norm_type
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.freeze = freeze
if initial_weight is None: if initial_weight is None:
self.weight = Parameter( self.weight = Parameter(
np.random.uniform( np.random.uniform(
...@@ -101,7 +103,11 @@ class Embedding(Module): ...@@ -101,7 +103,11 @@ class Embedding(Module):
init.normal_(self.weight) init.normal_(self.weight)
def forward(self, inputs): def forward(self, inputs):
return embedding_func(inputs, self.weight) if self.freeze:
weight = self.weight.detach()
else:
weight = self.weight
return embedding_func(inputs, weight)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
...@@ -166,6 +172,6 @@ class Embedding(Module): ...@@ -166,6 +172,6 @@ class Embedding(Module):
padding_idx=padding_idx, padding_idx=padding_idx,
max_norm=max_norm, max_norm=max_norm,
norm_type=norm_type, norm_type=norm_type,
freeze=freeze,
) )
embedding.weight.requires_grad = not freeze
return embedding return embedding
...@@ -23,7 +23,7 @@ def fill_(tensor: Tensor, val: Union[float, int]) -> None: ...@@ -23,7 +23,7 @@ def fill_(tensor: Tensor, val: Union[float, int]) -> None:
:param tensor: An n-dimentional tensor to be initialized :param tensor: An n-dimentional tensor to be initialized
:param val: The value to be filled throughout the tensor :param val: The value to be filled throughout the tensor
""" """
tensor.set_value(full(shape=tensor.shape, value=val, dtype=tensor.dtype)) tensor._reset(full(shape=tensor.shape, value=val, dtype=tensor.dtype))
def zeros_(tensor: Tensor) -> None: def zeros_(tensor: Tensor) -> None:
...@@ -50,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: ...@@ -50,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None:
:param a: Lower bound of the sampling interval :param a: Lower bound of the sampling interval
:param b: Upper bound of the sampling interval :param b: Upper bound of the sampling interval
""" """
tensor.set_value(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype)) tensor._reset(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype))
def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
...@@ -61,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: ...@@ -61,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
:param mean: The mean of the normal distribution :param mean: The mean of the normal distribution
:param std: The standard deviation of the normal distribution :param std: The standard deviation of the normal distribution
""" """
tensor.set_value(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype)) tensor._reset(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype))
def calculate_gain( def calculate_gain(
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import numpy as np import numpy as np
from ..functional import linear from ..functional import linear
from ..tensor_nn import Parameter from ..tensor import Parameter
from . import init from . import init
from .module import Module from .module import Module
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
...@@ -14,8 +15,8 @@ import numpy as np ...@@ -14,8 +15,8 @@ import numpy as np
from ..core.tensor.dtype import is_quantize from ..core.tensor.dtype import is_quantize
from ..core.tensor.utils import make_shape_tuple from ..core.tensor.utils import make_shape_tuple
from ..logger import get_logger from ..logger import get_logger
from ..tensor import Tensor from ..tensor import Parameter, Tensor
from ..tensor_nn import Buffer, Parameter from ..utils.deprecation import deprecated
from ..utils.hook import HookHandler from ..utils.hook import HookHandler
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -48,7 +49,7 @@ def _is_parameter(obj): ...@@ -48,7 +49,7 @@ def _is_parameter(obj):
def _is_buffer(obj): def _is_buffer(obj):
return isinstance(obj, Buffer) return isinstance(obj, Tensor) and not isinstance(obj, Parameter)
def _is_module(obj): def _is_module(obj):
...@@ -163,49 +164,43 @@ class Module(metaclass=ABCMeta): ...@@ -163,49 +164,43 @@ class Module(metaclass=ABCMeta):
seen=seen, seen=seen,
) )
def parameters( def parameters(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]:
self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs
) -> Iterable[Parameter]:
r"""Returns an iterable for the :class:`~.Parameter` of the module. r"""Returns an iterable for the :class:`~.Parameter` of the module.
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
attribute of returned :class:`.Parameter`. ``None`` for no limitation.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this :param recursive: If ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct attributes module, else only returns :class:`~.Parameter` that are direct attributes
of this module. of this module.
""" """
if "requires_grad" in kwargs:
del kwargs["requires_grad"]
warnings.warn("passing requires_grad has no effect currently")
def predicate(obj) -> bool: def predicate(obj) -> bool:
return _is_parameter(obj) and ( return _is_parameter(obj)
requires_grad is None or obj.requires_grad == requires_grad
)
yield from self._flatten( yield from self._flatten(
with_key=False, predicate=predicate, recursive=recursive, **kwargs with_key=False, predicate=predicate, recursive=recursive, **kwargs
) )
def named_parameters( def named_parameters(
self, self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
requires_grad: Optional[bool] = None,
prefix: Optional[str] = None,
recursive: bool = True,
**kwargs
) -> Iterable[Tuple[str, Parameter]]: ) -> Iterable[Tuple[str, Parameter]]:
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where """Returns an iterable for key :class:`~.Parameter` pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Parameter` . ``key`` is the dotted path from this module to the :class:`~.Parameter` .
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
attribute of returned :class:`~.Parameter` . ``None`` for no limitation.
:param prefix: The prefix prepended to the keys. :param prefix: The prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this :param recursive: If ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct attributes module, else only returns :class:`~.Parameter` that are direct attributes
of this module. of this module.
""" """
if "requires_grad" in kwargs:
del kwargs["requires_grad"]
warnings.warn("passing requires_grad has no effect currently")
def predicate(obj) -> bool: def predicate(obj) -> bool:
return _is_parameter(obj) and ( return _is_parameter(obj)
requires_grad is None or obj.requires_grad == requires_grad
)
yield from self._flatten( yield from self._flatten(
with_key=True, with_key=True,
...@@ -215,11 +210,13 @@ class Module(metaclass=ABCMeta): ...@@ -215,11 +210,13 @@ class Module(metaclass=ABCMeta):
**kwargs, **kwargs,
) )
def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Tensor]:
"""Returns an iterable for the :class:`~.Buffer` of the module. """Returns an iterable for the buffers of the module.
:param recursive: If ``True``, returns all :class:`~.Buffer` within this Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`.
module, else only returns :class:`~.Buffer` that are direct attributes
:param recursive: If ``True``, returns all buffers within this
module, else only returns buffers that are direct attributes
of this module. of this module.
""" """
yield from self._flatten( yield from self._flatten(
...@@ -228,13 +225,15 @@ class Module(metaclass=ABCMeta): ...@@ -228,13 +225,15 @@ class Module(metaclass=ABCMeta):
def named_buffers( def named_buffers(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Buffer]]: ) -> Iterable[Tuple[str, Tensor]]:
"""Returns an iterable for key :class:`~.Buffer` pairs of the module, where """Returns an iterable for key buffer pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Buffer` . ``key`` is the dotted path from this module to the buffer.
Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`.
:param prefix: The prefix prepended to the keys. :param prefix: The prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Buffer` within this :param recursive: If ``True``, returns all buffers within this
module, else only returns :class:`~.Buffer` that are direct attributes module, else only returns buffers that are direct attributes
of this module. of this module.
""" """
yield from self._flatten( yield from self._flatten(
...@@ -297,6 +296,7 @@ class Module(metaclass=ABCMeta): ...@@ -297,6 +296,7 @@ class Module(metaclass=ABCMeta):
for it in self.modules(): for it in self.modules():
fn(it) fn(it)
@deprecated(version="1.0")
def zero_grad(self) -> None: def zero_grad(self) -> None:
"""Set all parameters' grads to zero """Set all parameters' grads to zero
""" """
...@@ -505,7 +505,7 @@ class Module(metaclass=ABCMeta): ...@@ -505,7 +505,7 @@ class Module(metaclass=ABCMeta):
# scale/zero_points maybe invalid, use pretrained dtype instead. # scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var = var.astype(to_be_load.dtype) var = var.astype(to_be_load.dtype)
var.set_value(to_be_load) var._reset(to_be_load)
loaded.append(k) loaded.append(k)
return set(loaded), set(skipped) return set(loaded), set(skipped)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Callable, Iterable, Optional, Tuple
import numpy as np
from ..tensor_nn import Parameter, Tensor
from .module import Module
class ParamPack(Module):
r"""Pack module's parameters by gathering their memory to continuous address.
Using (device, dtype, requires_grad) as key, for example ('gpu0', float32, True),
parameters with same key will be packed togather.
It helps a lot for multimachine training by speeding up allreduce gradients.
:param model: the module you want to pack parameters.
:param nr_ignore_first: how many parameters will be unpacked at first.
:param max_size_per_group: upper bound of packed parameters' size in MB.
:param max_nr_params_per_group: upper bound of the number of parameters of each group.
"""
def __init__(
self,
model: Module,
nr_ignore_first: int = 8,
max_size_per_group: int = 10,
max_nr_params_per_group: int = 100,
group_func: Callable = lambda name, param: 0,
):
super().__init__()
self._model = model
self._nr_ignore_first = nr_ignore_first
self._max_size_per_group = max_size_per_group
self._max_nr_params_per_group = max_nr_params_per_group
self._group_func = group_func
self._grouped_params = []
self._packed_params = []
params = model.named_parameters()
self._pack_params(params)
def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]:
for param in self._packed_params:
if requires_grad is None or param.requires_grad == requires_grad:
yield param
def named_parameters(
self, requires_grad: Optional[bool] = None
) -> Iterable[Tuple[str, Parameter]]:
for idx, param in enumerate(self._packed_params):
if requires_grad is None or param.requires_grad == requires_grad:
yield "packed_param_" + str(idx), param
def _pack_params(self, params: Iterable[Tuple[str, Parameter]]):
groups = collections.defaultdict(list)
ignored = 0
param_id = 0
for name, param in params:
if self._nr_ignore_first > ignored:
ignored += 1
self._grouped_params.append([{"shape": param.shape, "id": param_id}])
param.pack_group_key = self._group_func(name, param)
self._packed_params.append(param)
else:
key = (
param.dtype,
param.device,
param.requires_grad,
self._group_func(name, param),
)
groups[key].append({"tensor": param, "id": param_id})
param_id += 1
for (dtype, device, requires_grad, group_key) in groups.keys():
dtype_sz = np.dtype(dtype).itemsize
align = device.mem_align
if align < dtype_sz:
align = 1
else:
assert align % dtype_sz == 0
align //= dtype_sz
group = groups[(dtype, device, requires_grad, group_key)]
while group:
aligned_pos = []
offset = 0
params = []
idx = 0
while idx < len(group):
param = group[idx]
assert param["tensor"].device == device
padding = (align - (offset & (align - 1))) & (align - 1)
offset += padding
aligned_pos.append(offset)
params.append(param)
offset += int(np.prod(param["tensor"].shape))
idx += 1
if (
offset * dtype_sz >= self._max_size_per_group * 1024 * 1024
or idx >= self._max_nr_params_per_group
):
break
group = group[idx:]
if idx == 1:
# ignore param packs with only one item
params[0]["tensor"].pack_group_key = group_key
self._packed_params.append(params[0]["tensor"])
self._grouped_params.append(
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}]
)
continue
packed_value = np.zeros((offset,), dtype=dtype)
for param, pos in zip(params, aligned_pos):
val = param["tensor"].numpy()
packed_value[pos : pos + val.size] = val.flatten()
new_param = Parameter(
value=packed_value,
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
new_param.pack_group_key = group_key
self._packed_params.append(new_param)
self._grouped_params.append(
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params]
)
def forward(self, *args, **kwargs):
replace_param = dict()
for i in range(len(self._packed_params)):
packed_param = self._packed_params[i]
grouped_params = self._grouped_params[i]
if len(grouped_params) == 1:
continue
split = param_pack_split(
packed_param._symvar, [i["shape"] for i in grouped_params]
)
split = [
Parameter(Tensor(i, requires_grad=packed_param.requires_grad))
for i in split
]
for j in range(len(split)):
replace_param[grouped_params[j]["id"]] = split[j]
self._model.replace_param(replace_param, 0)
return self._model.forward(*args, **kwargs)
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
from ... import module as Float from ... import module as Float
from ...core.tensor import dtype from ...core.tensor import dtype
from ...functional import conv_bias_activation from ...functional import conv_bias_activation
from ...tensor_nn import Parameter from ...tensor import Parameter
from ..qat import conv as QAT from ..qat import conv as QAT
from .module import QuantizedModule from .module import QuantizedModule
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...tensor_nn import Parameter from ...tensor import Parameter
from ..qat import conv_bn as QAT from ..qat import conv_bn as QAT
from .conv import Conv2d from .conv import Conv2d
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from ... import functional as F from ... import functional as F
from ...core.tensor import dtype from ...core.tensor import dtype
from ...tensor_nn import Parameter from ...tensor import Parameter
from ..qat import linear as QAT from ..qat import linear as QAT
from .module import QuantizedModule from .module import QuantizedModule
......
...@@ -11,7 +11,7 @@ from typing import Iterable, Union ...@@ -11,7 +11,7 @@ from typing import Iterable, Union
import numpy as np import numpy as np
from ..functional import sqrt from ..functional import sqrt
from ..tensor_nn import Parameter from ..tensor import Parameter
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -63,7 +63,7 @@ class Adadelta(Optimizer): ...@@ -63,7 +63,7 @@ class Adadelta(Optimizer):
for param in param_group["params"]: for param in param_group["params"]:
if not param.requires_grad or "grad" not in param.__dict__: if param.grad is None:
continue continue
states = self._state[param] states = self._state[param]
......
...@@ -11,7 +11,7 @@ from typing import Iterable, Union ...@@ -11,7 +11,7 @@ from typing import Iterable, Union
import numpy as np import numpy as np
from ..functional import sqrt from ..functional import sqrt
from ..tensor_nn import Parameter from ..tensor import Parameter
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -62,7 +62,7 @@ class Adagrad(Optimizer): ...@@ -62,7 +62,7 @@ class Adagrad(Optimizer):
for param in param_group["params"]: for param in param_group["params"]:
if not param.requires_grad or "grad" not in param.__dict__: if param.grad is None:
continue continue
states = self._state[param] states = self._state[param]
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union
from ..tensor_nn import Parameter from ..tensor import Parameter
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -59,7 +59,7 @@ class Adam(Optimizer): ...@@ -59,7 +59,7 @@ class Adam(Optimizer):
for param in param_group["params"]: for param in param_group["params"]:
if not param.requires_grad or "grad" not in param.__dict__: if param.grad is None:
continue continue
grad = param.grad grad = param.grad
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import Iterable from collections.abc import Iterable
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict from typing import Dict
from typing import Iterable as Iter from typing import Iterable as Iter
...@@ -15,8 +15,7 @@ from typing import Union ...@@ -15,8 +15,7 @@ from typing import Union
import numpy as np import numpy as np
from ..tensor import Tensor, TensorDict from ..tensor import Parameter, Tensor
from ..tensor_nn import Buffer, Parameter
class _RequiredParameter: class _RequiredParameter:
...@@ -37,7 +36,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -37,7 +36,7 @@ class Optimizer(metaclass=ABCMeta):
def __init__( # pylint: disable=too-many-branches def __init__( # pylint: disable=too-many-branches
self, params: Union[Iter[Parameter], dict], defaults: dict, self, params: Union[Iter[Parameter], dict], defaults: dict,
): ):
self._state = TensorDict() self._state = dict()
self._defaults = defaults self._defaults = defaults
if isinstance(params, (Parameter, dict)): if isinstance(params, (Parameter, dict)):
...@@ -93,10 +92,6 @@ class Optimizer(metaclass=ABCMeta): ...@@ -93,10 +92,6 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ type(param) + type(param)
) )
if not param.requires_grad:
raise ValueError(
"optimizer can only optimize Parameters with requires_grad=True"
)
for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:
...@@ -122,7 +117,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -122,7 +117,7 @@ class Optimizer(metaclass=ABCMeta):
initializer = np.zeros(param.shape, dtype=np.float32) initializer = np.zeros(param.shape, dtype=np.float32)
state_dict = self._state.setdefault(param, {}) state_dict = self._state.setdefault(param, {})
assert state_name not in state_dict assert state_name not in state_dict
state = Buffer(initializer) state = Tensor(initializer)
state_dict[state_name] = state state_dict[state_name] = state
@abstractmethod @abstractmethod
...@@ -140,7 +135,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -140,7 +135,7 @@ class Optimizer(metaclass=ABCMeta):
params.append(param) params.append(param)
return params return params
def step(self, clear_grad=False): def step(self):
r"""Performs a single optimization step. r"""Performs a single optimization step.
""" """
...@@ -152,8 +147,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -152,8 +147,7 @@ class Optimizer(metaclass=ABCMeta):
"Please use a list instead." "Please use a list instead."
) )
self._updates(group) self._updates(group)
if clear_grad: return self
self.clear_grad()
def clear_grad(self): def clear_grad(self):
r"""Clear the grad buffer. r"""Clear the grad buffer.
...@@ -161,8 +155,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -161,8 +155,7 @@ class Optimizer(metaclass=ABCMeta):
""" """
for param_group in self.param_groups: for param_group in self.param_groups:
for param in param_group["params"]: for param in param_group["params"]:
if getattr(param, "grad", None) is not None: param.grad = None
param.grad = None
def state_dict(self) -> Dict: def state_dict(self) -> Dict:
r"""Export the optimizer state. r"""Export the optimizer state.
...@@ -171,7 +164,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -171,7 +164,7 @@ class Optimizer(metaclass=ABCMeta):
""" """
param_groups = [] param_groups = []
state = dict() state = dict()
param2id = TensorDict() param2id = dict()
cur_id = 0 cur_id = 0
for group in self.param_groups: for group in self.param_groups:
...@@ -213,8 +206,9 @@ class Optimizer(metaclass=ABCMeta): ...@@ -213,8 +206,9 @@ class Optimizer(metaclass=ABCMeta):
p = param_new p = param_new
self._state[p] = state["state"][param_saved].copy() self._state[p] = state["state"][param_saved].copy()
for k, v in self._state[p].items(): for k, v in self._state[p].items():
if isinstance(v, Buffer): if isinstance(v, Tensor):
self._state[p][k] = Buffer(v.numpy()) # TODO: maybe a more efficient way?
self._state[p][k] = Tensor(v.numpy())
if set(group_new.keys()) != set(group_saved.keys()): if set(group_new.keys()) != set(group_saved.keys()):
raise ValueError( raise ValueError(
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Union from typing import Iterable, Union
from ..tensor_nn import Parameter from ..tensor import Parameter
from .optimizer import Optimizer from .optimizer import Optimizer
...@@ -52,7 +52,7 @@ class SGD(Optimizer): ...@@ -52,7 +52,7 @@ class SGD(Optimizer):
momentum = param_group["momentum"] momentum = param_group["momentum"]
for param in param_group["params"]: for param in param_group["params"]:
if not param.requires_grad or "grad" not in param.__dict__: if param.grad is None:
continue continue
grad = param.grad grad = param.grad
......
...@@ -14,8 +14,7 @@ from .. import functional as F ...@@ -14,8 +14,7 @@ from .. import functional as F
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..core.tensor.function import Function from ..core.tensor.function import Function
from ..module import Module from ..module import Module
from ..tensor import Tensor from ..tensor import Parameter, Tensor
from ..tensor_nn import Parameter
from .utils import QuantMode, fake_quant_tensor, get_qparam_dict from .utils import QuantMode, fake_quant_tensor, get_qparam_dict
......
...@@ -13,7 +13,7 @@ import numpy as np ...@@ -13,7 +13,7 @@ import numpy as np
from .. import functional as F from .. import functional as F
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..module import Module from ..module import Module
from ..tensor_nn import Buffer from ..tensor import Tensor
from .utils import QuantMode, Round, get_qparam_dict from .utils import QuantMode, Round, get_qparam_dict
...@@ -82,8 +82,8 @@ class MinMaxObserver(Observer): ...@@ -82,8 +82,8 @@ class MinMaxObserver(Observer):
): ):
super().__init__(dtype, narrow_range) super().__init__(dtype, narrow_range)
self.mode = mode self.mode = mode
self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
self.scale_limit = eps self.scale_limit = eps
def _calculate_qparams(self, inp_min_val, inp_max_val): def _calculate_qparams(self, inp_min_val, inp_max_val):
...@@ -118,8 +118,8 @@ class MinMaxObserver(Observer): ...@@ -118,8 +118,8 @@ class MinMaxObserver(Observer):
# stop gradient # stop gradient
x = x_orig.detach() x = x_orig.detach()
# find max and min # find max and min
self.min_val.set_value(F.minimum(self.min_val, x.min())) self.min_val._reset(F.minimum(self.min_val, x.min()))
self.max_val.set_value(F.maximum(self.max_val, x.max())) self.max_val._reset(F.maximum(self.max_val, x.max()))
return x_orig return x_orig
...@@ -133,22 +133,22 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ...@@ -133,22 +133,22 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
narrow_range: bool = False, narrow_range: bool = False,
): ):
super().__init__(mode, eps, dtype, narrow_range) super().__init__(mode, eps, dtype, narrow_range)
self.momentum = Buffer(momentum) self.momentum = Tensor(momentum)
self.runtime_momentum = Buffer(0.0) self.runtime_momentum = Tensor(0.0)
def set_momentum(self, momentum): def set_momentum(self, momentum):
self.momentum.set_value(momentum) self.momentum._reset(momentum)
def forward(self, x_orig): def forward(self, x_orig):
if self.enabled: if self.enabled:
# stop gradient # stop gradient
x = x_orig.detach() x = x_orig.detach()
# Exponential Moving Average # Exponential Moving Average
self.min_val.set_value( self.min_val._reset(
self.min_val * self.runtime_momentum self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.min() + (1 - self.runtime_momentum) * x.min()
) )
self.max_val.set_value( self.max_val._reset(
self.max_val * self.runtime_momentum self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.max() + (1 - self.runtime_momentum) * x.max()
) )
...@@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver): ...@@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver):
self.bins = bins self.bins = bins
self.upsample_rate = upsample_rate self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
self.histogram = Buffer([-1] + [0.0] * (bins - 1)) self.histogram = Tensor([-1] + [0.0] * (bins - 1))
def _non_linear_param_search(self): def _non_linear_param_search(self):
r"""Non-linear parameter search. r"""Non-linear parameter search.
...@@ -395,9 +395,9 @@ class HistogramObserver(MinMaxObserver): ...@@ -395,9 +395,9 @@ class HistogramObserver(MinMaxObserver):
self.bins, self.bins,
) )
self.histogram.set_value(new_histogram) self.histogram._reset(new_histogram)
self.min_val.set_value(new_min) self.min_val._reset(new_min)
self.max_val.set_value(new_max) self.max_val._reset(new_max)
def forward(self, x_orig): def forward(self, x_orig):
self.sideeffect_forward(x_orig) self.sideeffect_forward(x_orig)
......
...@@ -14,10 +14,11 @@ from .core import Tensor as _Tensor ...@@ -14,10 +14,11 @@ from .core import Tensor as _Tensor
from .core.ops.builtin import Copy from .core.ops.builtin import Copy
from .core.tensor.core import apply from .core.tensor.core import apply
from .device import get_default_device from .device import get_default_device
from .utils.deprecation import deprecated
class Tensor(_Tensor): class Tensor(_Tensor):
requires_grad = False grad = None
dmap_callback = None dmap_callback = None
def __init__(self, data, dtype=None, device=None): def __init__(self, data, dtype=None, device=None):
...@@ -26,15 +27,32 @@ class Tensor(_Tensor): ...@@ -26,15 +27,32 @@ class Tensor(_Tensor):
self.q_dict = {"mode": None, "scale": None, "zero_point": None} self.q_dict = {"mode": None, "scale": None, "zero_point": None}
super().__init__(data, dtype=dtype, device=device) super().__init__(data, dtype=dtype, device=device)
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value): def set_value(self, value):
self._reset(value) self._reset(value)
@deprecated(version="1.0", reason="use *= 0 instead")
def reset_zero(self): def reset_zero(self):
self *= 0 self *= 0
def to(self, cn): def to(self, cn):
return apply(Copy(comp_node=cn), self)[0] return apply(Copy(comp_node=cn), self)[0]
@property
def requires_grad(self):
raise AttributeError("requires_grad is reserved for future use")
@requires_grad.setter
def requires_grad(self, value):
raise AttributeError("requires_grad is reserved for future use")
@requires_grad.deleter
def requires_grad(self):
raise AttributeError("requires_grad is reserved for future use")
def __hash__(self):
return id(self)
def __getstate__(self): def __getstate__(self):
r""" __getstate__ will be called for pickle serialization or deep copy r""" __getstate__ will be called for pickle serialization or deep copy
""" """
...@@ -73,53 +91,6 @@ class Tensor(_Tensor): ...@@ -73,53 +91,6 @@ class Tensor(_Tensor):
tensor = Tensor tensor = Tensor
class Dict(collections.MutableMapping): class Parameter(Tensor):
def __init__(self, *args, key=None, **kwargs): r"""A kind of Tensor that is to be considered a module parameter.
self.data = {} """
if key:
self.keyfn = key
for i in args:
self.update(i)
self.update(**kwargs)
@staticmethod
def keyfn(key): # pylint: disable=method-hidden
return key
def __getitem__(self, key):
_, v = self.data[self.keyfn(key)]
return v
def __setitem__(self, key, value):
self.data[self.keyfn(key)] = key, value
def __delitem__(self, key):
del self.data[self.keyfn(key)]
def __iter__(self):
for _, (k, _) in self.data.items():
yield k
def __len__(self):
return len(self.data)
class TensorDict(Dict): # pylint: disable=too-many-ancestors
class keyfn:
def __new__(cls, x: Tensor):
if not isinstance(x, Tensor):
return x
return super().__new__(cls)
def __init__(self, x: Tensor):
self._data = x # do not save id directly to make pickle work
def __hash__(self):
return id(self._data)
def __eq__(self, other):
# pylint: disable=undefined-variable
return isinstance(other, __class__) and id(self._data) == id(other._data)
def __init__(self, *args):
super().__init__(*args)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from . import Tensor, tensor
class Buffer(Tensor):
r"""A kind of Tensor with ``requires_grad=False``.
"""
class Parameter(Tensor):
r"""A kind of Tensor that is to be considered a module parameter.
"""
requires_grad = True
from deprecated.sphinx import deprecated
...@@ -15,7 +15,7 @@ def get_ndtuple(value, *, n, allow_zero=True): ...@@ -15,7 +15,7 @@ def get_ndtuple(value, *, n, allow_zero=True):
:type allow_zero: bool :type allow_zero: bool
:param allow_zero: whether to allow zero tuple value""" :param allow_zero: whether to allow zero tuple value"""
if not isinstance(value, collections.Iterable): if not isinstance(value, collections.abc.Iterable):
value = int(value) value = int(value)
value = tuple([value for i in range(n)]) value = tuple([value for i in range(n)])
else: else:
......
...@@ -5,3 +5,4 @@ requests ...@@ -5,3 +5,4 @@ requests
tabulate tabulate
tqdm tqdm
redispy redispy
deprecated
...@@ -38,7 +38,7 @@ class Simple2(Module): ...@@ -38,7 +38,7 @@ class Simple2(Module):
def test_advance_indexing(): def test_advance_indexing():
net = Simple() net = Simple()
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()
...@@ -48,7 +48,7 @@ def test_advance_indexing(): ...@@ -48,7 +48,7 @@ def test_advance_indexing():
data = tensor(raw_data) data = tensor(raw_data)
mask = tensor(raw_mask) mask = tensor(raw_mask)
answer = 1.0 - raw_data[raw_mask].sum() answer = 1.0 - raw_data[raw_mask].sum()
with gm.record(): with gm:
loss = net(data, mask).sum() loss = net(data, mask).sum()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -58,7 +58,7 @@ def test_advance_indexing(): ...@@ -58,7 +58,7 @@ def test_advance_indexing():
def test_advance_indexing_with_subtensor(): def test_advance_indexing_with_subtensor():
net = Simple2() net = Simple2()
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()
...@@ -66,7 +66,7 @@ def test_advance_indexing_with_subtensor(): ...@@ -66,7 +66,7 @@ def test_advance_indexing_with_subtensor():
raw_data = np.arange(576).reshape(dshape).astype(np.float32) raw_data = np.arange(576).reshape(dshape).astype(np.float32)
data = tensor(raw_data) data = tensor(raw_data)
answer = 1.0 - raw_data[1, ..., :, 0:4:2, 0:2].sum() answer = 1.0 - raw_data[1, ..., :, 0:4:2, 0:2].sum()
with gm.record(): with gm:
loss = net(data).sum() loss = net(data).sum()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
......
...@@ -28,13 +28,13 @@ class Simple(Module): ...@@ -28,13 +28,13 @@ class Simple(Module):
def test_ai(): def test_ai():
net = Simple() net = Simple()
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()
dshape = (10, 10) dshape = (10, 10)
data = tensor(np.ones(dshape).astype(np.float32)) data = tensor(np.ones(dshape).astype(np.float32))
with gm.record(): with gm:
loss = net(data).sum() loss = net(data).sum()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
......
...@@ -25,12 +25,12 @@ def test_frozen_bn(): ...@@ -25,12 +25,12 @@ def test_frozen_bn():
saved_wt = m.weight.numpy() saved_wt = m.weight.numpy()
saved_bias = m.bias.numpy() saved_bias = m.bias.numpy()
gm = ad.GradManager().register(m.parameters()) gm = ad.GradManager().attach(m.parameters())
optim = optimizer.SGD(m.parameters(), lr=1.0) optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()
data = np.random.random((6, nchannel, 2, 2)).astype("float32") data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with gm.record(): with gm:
loss = m(data).mean() loss = m(data).mean()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -46,12 +46,12 @@ def test_bn_no_track_stat(): ...@@ -46,12 +46,12 @@ def test_bn_no_track_stat():
nchannel = 3 nchannel = 3
m = BatchNorm2d(nchannel, track_running_stats=False) m = BatchNorm2d(nchannel, track_running_stats=False)
gm = ad.GradManager().register(m.parameters()) gm = ad.GradManager().attach(m.parameters())
optim = optimizer.SGD(m.parameters(), lr=1.0) optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()
data = np.random.random((6, nchannel, 2, 2)).astype("float32") data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with gm.record(): with gm:
loss = m(data).sum() loss = m(data).sum()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -68,12 +68,12 @@ def test_bn_no_track_stat2(): ...@@ -68,12 +68,12 @@ def test_bn_no_track_stat2():
saved_mean = m.running_mean.numpy() saved_mean = m.running_mean.numpy()
assert saved_mean is not None assert saved_mean is not None
gm = ad.GradManager().register(m.parameters()) gm = ad.GradManager().attach(m.parameters())
optim = optimizer.SGD(m.parameters(), lr=1.0) optim = optimizer.SGD(m.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()
data = np.random.random((6, nchannel, 2, 2)).astype("float32") data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with gm.record(): with gm:
loss = m(data).sum() loss = m(data).sum()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
......
...@@ -74,13 +74,11 @@ class XORNet(Module): ...@@ -74,13 +74,11 @@ class XORNet(Module):
def test_training_converge(): def test_training_converge():
net = XORNet() net = XORNet()
opt = SGD( opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 gm = ad.GradManager().attach(net.parameters())
)
gm = ad.GradManager().register(net.parameters())
def train(data, label): def train(data, label):
with gm.record(): with gm:
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label) loss = F.cross_entropy_with_softmax(pred, label)
gm.backward(loss) gm.backward(loss)
......
...@@ -91,7 +91,7 @@ class MnistNet(Module): ...@@ -91,7 +91,7 @@ class MnistNet(Module):
def train(data, label, net, opt, gm): def train(data, label, net, opt, gm):
with gm.record(): with gm:
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label) loss = F.cross_entropy_with_softmax(pred, label)
gm.backward(loss) gm.backward(loss)
...@@ -117,7 +117,7 @@ def update_model(model_path): ...@@ -117,7 +117,7 @@ def update_model(model_path):
net.load_state_dict(checkpoint["net_init"]) net.load_state_dict(checkpoint["net_init"])
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
opt = SGD(net.parameters(), lr=lr) opt = SGD(net.parameters(), lr=lr)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
data = Tensor(checkpoint["data"], dtype=np.float32) data = Tensor(checkpoint["data"], dtype=np.float32)
label = Tensor(checkpoint["label"], dtype=np.int32) label = Tensor(checkpoint["label"], dtype=np.int32)
...@@ -152,7 +152,7 @@ def run_train( ...@@ -152,7 +152,7 @@ def run_train(
net.load_state_dict(checkpoint["net_init"]) net.load_state_dict(checkpoint["net_init"])
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
opt = SGD(net.parameters(), lr=lr) opt = SGD(net.parameters(), lr=lr)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
data = Tensor(checkpoint["data"], dtype=np.float32) data = Tensor(checkpoint["data"], dtype=np.float32)
label = Tensor(checkpoint["label"], dtype=np.int32) label = Tensor(checkpoint["label"], dtype=np.int32)
......
...@@ -32,11 +32,11 @@ def test_detach(): ...@@ -32,11 +32,11 @@ def test_detach():
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
dshape = (10, 10) dshape = (10, 10)
data = tensor(np.ones(dshape).astype(np.float32)) data = tensor(np.ones(dshape).astype(np.float32))
with gm.record(): with gm:
loss = net(data).sum() loss = net(data).sum()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
......
...@@ -97,7 +97,7 @@ class MnistNet(Module): ...@@ -97,7 +97,7 @@ class MnistNet(Module):
def train(data, label, net, opt, gm): def train(data, label, net, opt, gm):
opt.clear_grad() opt.clear_grad()
with gm.record(): with gm:
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label) loss = F.cross_entropy_with_softmax(pred, label)
gm.backward(loss) gm.backward(loss)
...@@ -125,8 +125,7 @@ def update_model(model_path): ...@@ -125,8 +125,7 @@ def update_model(model_path):
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
opt = SGD(net.parameters(), lr=lr) opt = SGD(net.parameters(), lr=lr)
gm = ad.GradManager() gm = ad.GradManager().attach(
gm.register(
net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)]
) )
...@@ -171,8 +170,7 @@ def run_test( ...@@ -171,8 +170,7 @@ def run_test(
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
opt = SGD(net.parameters(), lr=lr) opt = SGD(net.parameters(), lr=lr)
gm = ad.GradManager() gm = ad.GradManager().attach(
gm.register(
net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)]
) )
......
...@@ -33,10 +33,10 @@ def test_hello_world(): ...@@ -33,10 +33,10 @@ def test_hello_world():
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.clear_grad() optim.clear_grad()
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
data = tensor([2.34]) data = tensor([2.34])
with gm.record(): with gm:
loss = net(data) loss = net(data)
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
......
...@@ -13,7 +13,7 @@ import megengine.functional as F ...@@ -13,7 +13,7 @@ import megengine.functional as F
from megengine import Parameter, optimizer from megengine import Parameter, optimizer
from megengine.jit import trace from megengine.jit import trace
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.tensor import TensorDict, tensor from megengine.tensor import tensor
class MLP(Module): class MLP(Module):
...@@ -44,7 +44,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -44,7 +44,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
net = Simple() net = Simple()
opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
check_func = check_class(net, **test_case) check_func = check_class(net, **test_case)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
step = 0 step = 0
data_shape = (2, 28) data_shape = (2, 28)
...@@ -57,12 +57,12 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -57,12 +57,12 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
data = tensor(np.random.random(data_shape).astype(np.float32)) data = tensor(np.random.random(data_shape).astype(np.float32))
opt.clear_grad() opt.clear_grad()
with gm.record(): with gm:
pred = net(data) pred = net(data)
loss = pred.sum() loss = pred.sum()
gm.backward(loss) gm.backward(loss)
ori_params = TensorDict() ori_params = {}
for param in net.parameters(): for param in net.parameters():
ori_params[param] = np.copy(param.numpy()) ori_params[param] = np.copy(param.numpy())
opt.step() opt.step()
...@@ -75,7 +75,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -75,7 +75,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def train_func(data, *, opt=None, gm=None): def train_func(data, *, opt=None, gm=None):
opt.clear_grad() opt.clear_grad()
with gm.record(): with gm:
pred = net(data) pred = net(data)
loss = pred.sum() loss = pred.sum()
gm.backward(loss) gm.backward(loss)
...@@ -84,7 +84,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -84,7 +84,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
# reset net and opt # reset net and opt
net = Simple() net = Simple()
opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
check_func = check_class(net, **test_case) check_func = check_class(net, **test_case)
step = 0 step = 0
for i in range(iter_num): for i in range(iter_num):
...@@ -93,7 +93,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -93,7 +93,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
group["lr"] += 0.01 group["lr"] += 0.01
check_func.lr += 0.01 check_func.lr += 0.01
ori_params = TensorDict() ori_params = {}
for param in net.parameters(): for param in net.parameters():
ori_params[param] = np.copy(param.numpy()) ori_params[param] = np.copy(param.numpy())
...@@ -105,7 +105,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -105,7 +105,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
def test_sgd(): def test_sgd():
class CheckValue: class CheckValue:
def __init__(self, net, **kwarg): def __init__(self, net, **kwarg):
self.slots = TensorDict() self.slots = {}
for param in net.parameters(): for param in net.parameters():
self.slots[param] = np.zeros(param.shape).astype(np.float32) self.slots[param] = np.zeros(param.shape).astype(np.float32)
for k, v in kwarg.items(): for k, v in kwarg.items():
...@@ -134,8 +134,8 @@ def test_sgd(): ...@@ -134,8 +134,8 @@ def test_sgd():
def test_adam(): def test_adam():
class CheckValue: class CheckValue:
def __init__(self, net, **kwarg): def __init__(self, net, **kwarg):
self.m_slots = TensorDict() self.m_slots = {}
self.v_slots = TensorDict() self.v_slots = {}
for param in net.parameters(): for param in net.parameters():
self.m_slots[param] = np.zeros(param.shape).astype(np.float32) self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
self.v_slots[param] = np.zeros(param.shape).astype(np.float32) self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
...@@ -175,7 +175,7 @@ def test_adam(): ...@@ -175,7 +175,7 @@ def test_adam():
def test_adagrad(): def test_adagrad():
class CheckValue: class CheckValue:
def __init__(self, net, **kwarg): def __init__(self, net, **kwarg):
self.s_slots = TensorDict() self.s_slots = {}
for param in net.parameters(): for param in net.parameters():
self.s_slots[param] = np.zeros(param.shape).astype(np.float32) self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
for k, v in kwarg.items(): for k, v in kwarg.items():
...@@ -207,8 +207,8 @@ def test_adagrad(): ...@@ -207,8 +207,8 @@ def test_adagrad():
def test_adadelta(): def test_adadelta():
class CheckValue: class CheckValue:
def __init__(self, net, **kwarg): def __init__(self, net, **kwarg):
self.s_slots = TensorDict() self.s_slots = {}
self.a_slots = TensorDict() self.a_slots = {}
for param in net.parameters(): for param in net.parameters():
self.s_slots[param] = np.zeros(param.shape).astype(np.float32) self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
self.a_slots[param] = np.zeros(param.shape).astype(np.float32) self.a_slots[param] = np.zeros(param.shape).astype(np.float32)
......
...@@ -23,11 +23,11 @@ def test_save_load(): ...@@ -23,11 +23,11 @@ def test_save_load():
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
optim.clear_grad() optim.clear_grad()
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
data = tensor([2.34]) data = tensor([2.34])
with gm.record(): with gm:
loss = net(data) loss = net(data)
gm.backward(loss) gm.backward(loss)
...@@ -55,7 +55,7 @@ def test_save_load(): ...@@ -55,7 +55,7 @@ def test_save_load():
optim.load_state_dict(checkpoint["opt_state"]) optim.load_state_dict(checkpoint["opt_state"])
print("load done") print("load done")
with gm.record(): with gm:
loss = net([1.23]) loss = net([1.23])
gm.backward(loss) gm.backward(loss)
......
...@@ -31,12 +31,12 @@ def test_sgd_momentum(): ...@@ -31,12 +31,12 @@ def test_sgd_momentum():
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
optim.clear_grad() optim.clear_grad()
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
data = tensor([2.34]) data = tensor([2.34])
# do a step of train # do a step of train
with gm.record(): with gm:
loss = net(data) loss = net(data)
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -51,7 +51,7 @@ def test_sgd_momentum(): ...@@ -51,7 +51,7 @@ def test_sgd_momentum():
# do a step of train # do a step of train
optim.clear_grad() optim.clear_grad()
with gm.record(): with gm:
loss = net(data) loss = net(data)
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -69,7 +69,7 @@ def test_sgd_momentum_trace(): ...@@ -69,7 +69,7 @@ def test_sgd_momentum_trace():
@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def train_func(data, *, model=None, optim=None, gm=None): def train_func(data, *, model=None, optim=None, gm=None):
optim.clear_grad() optim.clear_grad()
with gm.record(): with gm:
loss = net(data) loss = net(data)
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -82,7 +82,7 @@ def test_sgd_momentum_trace(): ...@@ -82,7 +82,7 @@ def test_sgd_momentum_trace():
net = Simple() net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
data = tensor([2.34]) data = tensor([2.34])
train_func(data, model=net, optim=optim, gm=gm) train_func(data, model=net, optim=optim, gm=gm)
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
......
...@@ -61,15 +61,15 @@ class XORNet(M.Module): ...@@ -61,15 +61,15 @@ class XORNet(M.Module):
def test_xornet_trace_dump(): def test_xornet_trace_dump():
net = XORNet() net = XORNet()
opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
gm = GradManager().register(net.parameters(requires_grad=True)) gm = GradManager().attach(net.parameters())
batch_size = 64 batch_size = 64
train_dataset = minibatch_generator(batch_size) train_dataset = minibatch_generator(batch_size)
val_dataset = minibatch_generator(batch_size) val_dataset = minibatch_generator(batch_size)
@trace @trace
def train_fun(data, label): def train_fun(data, label):
with gm.record(): with gm:
net.train() net.train()
pred = net(data) pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label) loss = F.cross_entropy_with_softmax(pred, label)
......
...@@ -14,7 +14,7 @@ import pytest ...@@ -14,7 +14,7 @@ import pytest
import megengine.core.ops.builtin as builtin import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype import megengine.core.tensor.dtype as dtype
import megengine.functional as F import megengine.functional as F
from megengine import Buffer, Parameter, is_cuda_available, tensor from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.core._trace_option import use_tensor_shape from megengine.core._trace_option import use_tensor_shape
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
...@@ -330,7 +330,7 @@ def test_roi_pooling(): ...@@ -330,7 +330,7 @@ def test_roi_pooling():
def test_add_update(): def test_add_update():
shape = (2, 3) shape = (2, 3)
v = np.random.random(shape).astype(np.float32) v = np.random.random(shape).astype(np.float32)
b = Buffer(v) b = Tensor(v)
u = F.add_update(b, 1) u = F.add_update(b, 1)
assertTensorClose(u.numpy(), v + 1) assertTensorClose(u.numpy(), v + 1)
...@@ -347,7 +347,7 @@ def test_add_update(): ...@@ -347,7 +347,7 @@ def test_add_update():
def test_add_update_params(): def test_add_update_params():
b = np.random.random((2, 3)).astype(np.float32) b = np.random.random((2, 3)).astype(np.float32)
y = Buffer(b) y = Tensor(b)
# @jit.trace # @jit.trace
def f(x): def f(x):
...@@ -355,7 +355,7 @@ def test_add_update_params(): ...@@ -355,7 +355,7 @@ def test_add_update_params():
f(np.zeros((2, 3)).astype(np.float32)) f(np.zeros((2, 3)).astype(np.float32))
z = Buffer(np.zeros((2, 3)).astype(np.float32)) z = Tensor(np.zeros((2, 3)).astype(np.float32))
F.add_update(y, z, beta=0.1) F.add_update(y, z, beta=0.1)
res = f(np.ones((2, 3)).astype(np.float32)) res = f(np.ones((2, 3)).astype(np.float32))
......
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
import pytest import pytest
import megengine.functional as F import megengine.functional as F
from megengine import Buffer, Parameter, is_cuda_available, tensor from megengine import tensor
from megengine.core._trace_option import use_tensor_shape from megengine.core._trace_option import use_tensor_shape
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
......
...@@ -14,10 +14,9 @@ import pytest ...@@ -14,10 +14,9 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import tensor from megengine import Tensor
from megengine.core._trace_option import use_tensor_shape from megengine.core._trace_option import use_tensor_shape
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from megengine.tensor import Tensor
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -45,10 +44,8 @@ def test_syncbn(): ...@@ -45,10 +44,8 @@ def test_syncbn():
return return
dist.init_process_group("localhost", port, nr_ranks, rank, rank) dist.init_process_group("localhost", port, nr_ranks, rank, rank)
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
data_tensor = tensor([])
for i in range(steps): for i in range(steps):
data_tensor.set_value(data[i]) yv = bn(Tensor(data[i]))
yv = bn(data_tensor)
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6)
...@@ -105,7 +102,6 @@ def test_batchnorm(): ...@@ -105,7 +102,6 @@ def test_batchnorm():
bn = BatchNorm1d(nr_chan, momentum=momentum) bn = BatchNorm1d(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1), dtype=np.float32) running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
data = tensor([])
for i in range(3): for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
...@@ -120,8 +116,7 @@ def test_batchnorm(): ...@@ -120,8 +116,7 @@ def test_batchnorm():
running_mean = running_mean * momentum + mean * (1 - momentum) running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum) running_var = running_var * momentum + var_unbiased * (1 - momentum)
data.set_value(xv) yv = bn(Tensor(xv))
yv = bn(data)
yv_expect = (xv - mean) / sd yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
...@@ -137,7 +132,7 @@ def test_batchnorm(): ...@@ -137,7 +132,7 @@ def test_batchnorm():
var_backup = bn.running_var.numpy() var_backup = bn.running_var.numpy()
bn.training = False bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv) data = Tensor(xv)
yv1 = bn(data) yv1 = bn(data)
yv2 = bn(data) yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
...@@ -161,7 +156,6 @@ def test_syncbn1d(): ...@@ -161,7 +156,6 @@ def test_syncbn1d():
bn = SyncBatchNorm(nr_chan, momentum=momentum) bn = SyncBatchNorm(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1), dtype=np.float32) running_var = np.ones((1, nr_chan, 1), dtype=np.float32)
data = tensor([])
for i in range(3): for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True)
...@@ -176,8 +170,7 @@ def test_syncbn1d(): ...@@ -176,8 +170,7 @@ def test_syncbn1d():
running_mean = running_mean * momentum + mean * (1 - momentum) running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum) running_var = running_var * momentum + var_unbiased * (1 - momentum)
data.set_value(xv) yv = bn(Tensor(xv))
yv = bn(data)
yv_expect = (xv - mean) / sd yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
...@@ -193,7 +186,7 @@ def test_syncbn1d(): ...@@ -193,7 +186,7 @@ def test_syncbn1d():
var_backup = bn.running_var.numpy() var_backup = bn.running_var.numpy()
bn.training = False bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv) data = Tensor(xv)
yv1 = bn(data) yv1 = bn(data)
yv2 = bn(data) yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
...@@ -210,7 +203,6 @@ def test_batchnorm2d(): ...@@ -210,7 +203,6 @@ def test_batchnorm2d():
bn = BatchNorm2d(nr_chan, momentum=momentum) bn = BatchNorm2d(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
data = tensor([])
for i in range(3): for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
...@@ -226,8 +218,7 @@ def test_batchnorm2d(): ...@@ -226,8 +218,7 @@ def test_batchnorm2d():
running_mean = running_mean * momentum + mean * (1 - momentum) running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum) running_var = running_var * momentum + var_unbiased * (1 - momentum)
data.set_value(xv) yv = bn(Tensor(xv))
yv = bn(data)
yv_expect = (xv - mean) / sd yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
...@@ -239,7 +230,7 @@ def test_batchnorm2d(): ...@@ -239,7 +230,7 @@ def test_batchnorm2d():
var_backup = bn.running_var.numpy() var_backup = bn.running_var.numpy()
bn.training = False bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv) data = Tensor(xv)
yv1 = bn(data) yv1 = bn(data)
yv2 = bn(data) yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
...@@ -263,7 +254,6 @@ def test_syncbn2d(): ...@@ -263,7 +254,6 @@ def test_syncbn2d():
bn = SyncBatchNorm(nr_chan, momentum=momentum) bn = SyncBatchNorm(nr_chan, momentum=momentum)
running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32)
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
data = tensor([])
for i in range(3): for i in range(3):
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape(
...@@ -279,8 +269,7 @@ def test_syncbn2d(): ...@@ -279,8 +269,7 @@ def test_syncbn2d():
running_mean = running_mean * momentum + mean * (1 - momentum) running_mean = running_mean * momentum + mean * (1 - momentum)
running_var = running_var * momentum + var_unbiased * (1 - momentum) running_var = running_var * momentum + var_unbiased * (1 - momentum)
data.set_value(xv) yv = bn(Tensor(xv))
yv = bn(data)
yv_expect = (xv - mean) / sd yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
...@@ -292,7 +281,7 @@ def test_syncbn2d(): ...@@ -292,7 +281,7 @@ def test_syncbn2d():
var_backup = bn.running_var.numpy() var_backup = bn.running_var.numpy()
bn.training = False bn.training = False
xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32)
data.set_value(xv) data = Tensor(xv)
yv1 = bn(data) yv1 = bn(data)
yv2 = bn(data) yv2 = bn(data)
assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0)
...@@ -306,7 +295,6 @@ def test_batchnorm_no_stats(): ...@@ -306,7 +295,6 @@ def test_batchnorm_no_stats():
nr_chan = 8 nr_chan = 8
data_shape = (3, nr_chan, 4) data_shape = (3, nr_chan, 4)
bn = BatchNorm1d(8, track_running_stats=False) bn = BatchNorm1d(8, track_running_stats=False)
data = tensor([])
for i in range(4): for i in range(4):
if i == 2: if i == 2:
bn.training = False bn.training = False
...@@ -320,8 +308,7 @@ def test_batchnorm_no_stats(): ...@@ -320,8 +308,7 @@ def test_batchnorm_no_stats():
).reshape((1, nr_chan, 1)) ).reshape((1, nr_chan, 1))
sd = np.sqrt(var + bn.eps) sd = np.sqrt(var + bn.eps)
data.set_value(xv) yv = bn(Tensor(xv))
yv = bn(data)
yv_expect = (xv - mean) / sd yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
...@@ -338,7 +325,6 @@ def test_syncbn_no_stats(): ...@@ -338,7 +325,6 @@ def test_syncbn_no_stats():
nr_chan = 8 nr_chan = 8
data_shape = (3, nr_chan, 4) data_shape = (3, nr_chan, 4)
bn = SyncBatchNorm(8, track_running_stats=False) bn = SyncBatchNorm(8, track_running_stats=False)
data = tensor([])
for i in range(4): for i in range(4):
if i == 2: if i == 2:
bn.training = False bn.training = False
...@@ -352,8 +338,7 @@ def test_syncbn_no_stats(): ...@@ -352,8 +338,7 @@ def test_syncbn_no_stats():
).reshape((1, nr_chan, 1)) ).reshape((1, nr_chan, 1))
sd = np.sqrt(var + bn.eps) sd = np.sqrt(var + bn.eps)
data.set_value(xv) yv = bn(Tensor(xv))
yv = bn(data)
yv_expect = (xv - mean) / sd yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
...@@ -363,7 +348,6 @@ def test_batchnorm2d_no_stats(): ...@@ -363,7 +348,6 @@ def test_batchnorm2d_no_stats():
nr_chan = 8 nr_chan = 8
data_shape = (3, nr_chan, 16, 16) data_shape = (3, nr_chan, 16, 16)
bn = BatchNorm2d(8, track_running_stats=False) bn = BatchNorm2d(8, track_running_stats=False)
data = tensor([])
for i in range(4): for i in range(4):
if i == 2: if i == 2:
bn.training = False bn.training = False
...@@ -376,8 +360,7 @@ def test_batchnorm2d_no_stats(): ...@@ -376,8 +360,7 @@ def test_batchnorm2d_no_stats():
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var + bn.eps) sd = np.sqrt(var + bn.eps)
data.set_value(xv) yv = bn(Tensor(xv))
yv = bn(data)
yv_expect = (xv - mean) / sd yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
...@@ -394,7 +377,6 @@ def test_syncbn2d_no_stats(): ...@@ -394,7 +377,6 @@ def test_syncbn2d_no_stats():
nr_chan = 8 nr_chan = 8
data_shape = (3, nr_chan, 16, 16) data_shape = (3, nr_chan, 16, 16)
bn = SyncBatchNorm(8, track_running_stats=False) bn = SyncBatchNorm(8, track_running_stats=False)
data = tensor([])
for i in range(4): for i in range(4):
if i == 2: if i == 2:
bn.training = False bn.training = False
...@@ -407,8 +389,7 @@ def test_syncbn2d_no_stats(): ...@@ -407,8 +389,7 @@ def test_syncbn2d_no_stats():
var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1))
sd = np.sqrt(var + bn.eps) sd = np.sqrt(var + bn.eps)
data.set_value(xv) yv = bn(Tensor(xv))
yv = bn(data)
yv_expect = (xv - mean) / sd yv_expect = (xv - mean) / sd
assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6)
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
import pytest import pytest
import megengine as mge import megengine as mge
from megengine import tensor from megengine import Tensor
from megengine.module import Module from megengine.module import Module
...@@ -35,12 +35,12 @@ def test_cambricon_module(): ...@@ -35,12 +35,12 @@ def test_cambricon_module():
with open(model, "rb") as f: with open(model, "rb") as f:
data = f.read() data = f.read()
m = MyModule(data) m = MyModule(data)
inputs = [] inp = Tensor(
inputs.append(tensor(data=[], dtype=np.float16, device="cambricon0")) np.random.normal((1, 64, 32, 32)).astype(np.float16), device="cambricon0"
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) )
def inference(inps): def inference(inps):
pred = m(inps) pred = m(inps)
return pred return pred
pred = inference(inputs) pred = inference([inp])
...@@ -16,7 +16,7 @@ import pytest ...@@ -16,7 +16,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import Buffer, Parameter, Tensor, tensor from megengine import Parameter, Tensor, tensor
from megengine.module import ( from megengine.module import (
BatchNorm1d, BatchNorm1d,
BatchNorm2d, BatchNorm2d,
...@@ -196,7 +196,7 @@ class MyModule(Module): ...@@ -196,7 +196,7 @@ class MyModule(Module):
self.i = self.InnerModule() self.i = self.InnerModule()
self.bn = BatchNorm2d(4) self.bn = BatchNorm2d(4)
self.param = Parameter(np.ones(1, dtype=np.float32)) self.param = Parameter(np.ones(1, dtype=np.float32))
self.buff = Buffer(np.ones(1, dtype=np.float32)) self.buff = Tensor(np.ones(1, dtype=np.float32))
def forward(self, x): def forward(self, x):
x = self.i(x) x = self.i(x)
...@@ -464,8 +464,7 @@ def test_sequential_named_children(): ...@@ -464,8 +464,7 @@ def test_sequential_named_children():
def test_state_dict(): def test_state_dict():
data_shape = (2, 28) data_shape = (2, 28)
data = tensor([]) data = tensor(np.random.random(data_shape))
data.set_value(np.random.random(data_shape))
mlp = MLP() mlp = MLP()
pred0 = mlp(data) pred0 = mlp(data)
...@@ -542,8 +541,7 @@ def test_shared_param(): ...@@ -542,8 +541,7 @@ def test_shared_param():
def test_pickle_module(): def test_pickle_module():
data_shape = (2, 28) data_shape = (2, 28)
data = tensor([]) data = tensor(np.random.random(data_shape))
data.set_value(np.random.random(data_shape))
mlp = MLP() mlp = MLP()
# pickle before forward # pickle before forward
with BytesIO() as fout: with BytesIO() as fout:
...@@ -568,8 +566,7 @@ def test_pickle_module(): ...@@ -568,8 +566,7 @@ def test_pickle_module():
@pytest.mark.skip(reason="under development") @pytest.mark.skip(reason="under development")
def test_dump_model(): def test_dump_model():
data_shape = (2, 28) data_shape = (2, 28)
data = tensor([]) data = Tensor(np.random.random(data_shape))
data.set_value(np.random.random(data_shape))
mlp = MLP() mlp = MLP()
pred = mlp(data) pred = mlp(data)
f = tempfile.NamedTemporaryFile(delete=False) f = tempfile.NamedTemporaryFile(delete=False)
......
...@@ -13,7 +13,7 @@ import pytest ...@@ -13,7 +13,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import Buffer, Parameter from megengine import Parameter, Tensor
from megengine.module import Conv2d from megengine.module import Conv2d
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -33,7 +33,7 @@ def test_set_value(): ...@@ -33,7 +33,7 @@ def test_set_value():
@pytest.mark.skip(reason="fill unsupported") @pytest.mark.skip(reason="fill unsupported")
def test_fill(): def test_fill():
a = Buffer(np.zeros((2, 3), dtype=np.float32)) a = Tensor(np.zeros((2, 3), dtype=np.float32))
a.fill(3) a.fill(3)
assertTensorClose(a.numpy(), np.full((2, 3), 3, dtype=np.float32)) assertTensorClose(a.numpy(), np.full((2, 3), 3, dtype=np.float32))
a.fill(124.568) a.fill(124.568)
...@@ -80,7 +80,7 @@ def test_fill(): ...@@ -80,7 +80,7 @@ def test_fill():
# def test_shape_warning(): # def test_shape_warning():
# with Graph() as cg: # with Graph() as cg:
# cg.set_option("eager_evaluation", False) # cg.set_option("eager_evaluation", False)
# b = Buffer(np.ones((2, 3)).astype(np.float32)) # b = Tensor(np.ones((2, 3)).astype(np.float32))
# with pytest.warns(None) as record: # with pytest.warns(None) as record:
# print(b.shape) # print(b.shape)
# if len(record) != 0: # if len(record) != 0:
......
...@@ -42,11 +42,11 @@ def test_single_input(): ...@@ -42,11 +42,11 @@ def test_single_input():
return x return x
net = Simple(av) net = Simple(av)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
opt = optimizer.SGD(net.parameters(), lr=1.0) opt = optimizer.SGD(net.parameters(), lr=1.0)
opt.clear_grad() opt.clear_grad()
with gm.record(): with gm:
loss = net() loss = net()
gm.backward(loss.sum()) gm.backward(loss.sum())
opt.step() opt.step()
...@@ -81,11 +81,11 @@ def test_multi_input(): ...@@ -81,11 +81,11 @@ def test_multi_input():
return x return x
net = Simple(av, bv) net = Simple(av, bv)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
opt = optimizer.SGD(net.parameters(), lr=1.0) opt = optimizer.SGD(net.parameters(), lr=1.0)
opt.clear_grad() opt.clear_grad()
with gm.record(): with gm:
loss = net() loss = net()
gm.backward(loss.sum()) gm.backward(loss.sum())
opt.step() opt.step()
...@@ -121,11 +121,11 @@ def test_multi_output(): ...@@ -121,11 +121,11 @@ def test_multi_output():
return x + y return x + y
net = Simple(av, bv) net = Simple(av, bv)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
opt = optimizer.SGD(net.parameters(), lr=1.0) opt = optimizer.SGD(net.parameters(), lr=1.0)
opt.clear_grad() opt.clear_grad()
with gm.record(): with gm:
loss = net() loss = net()
gm.backward(loss.sum()) gm.backward(loss.sum())
opt.step() opt.step()
...@@ -163,9 +163,9 @@ def test_skip_invalid_grad(): ...@@ -163,9 +163,9 @@ def test_skip_invalid_grad():
net = Simple(av, bv) net = Simple(av, bv)
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
optim.clear_grad() optim.clear_grad()
with gm.record(): with gm:
loss = net().sum() loss = net().sum()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -198,10 +198,10 @@ def test_ste(): ...@@ -198,10 +198,10 @@ def test_ste():
av = np.random.random(data_shape).astype(np.float32) av = np.random.random(data_shape).astype(np.float32)
net = Simple(av) net = Simple(av)
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
optim.clear_grad() optim.clear_grad()
with gm.record(): with gm:
loss = net() loss = net()
gm.backward(loss.sum()) gm.backward(loss.sum())
optim.step() optim.step()
...@@ -256,9 +256,9 @@ def test_none_in_out_grad(): ...@@ -256,9 +256,9 @@ def test_none_in_out_grad():
b = tensor(np.array([2.0], dtype=np.float32)) b = tensor(np.array([2.0], dtype=np.float32))
net = Simple(a, b) net = Simple(a, b)
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
optim.clear_grad() optim.clear_grad()
with gm.record(): with gm:
loss, _ = net() loss, _ = net()
gm.backward(loss) gm.backward(loss)
optim.step() optim.step()
...@@ -293,10 +293,10 @@ def test_zero_grad(): ...@@ -293,10 +293,10 @@ def test_zero_grad():
a = tensor(np.array([1.0], dtype=np.float32)) a = tensor(np.array([1.0], dtype=np.float32))
net = Simple(a) net = Simple(a)
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
gm = ad.GradManager().register(net.parameters()) gm = ad.GradManager().attach(net.parameters())
optim.clear_grad() optim.clear_grad()
with gm.record(): with gm:
loss = net() loss = net()
gm.backward(loss.sum()) gm.backward(loss.sum())
optim.step() optim.step()
......
...@@ -38,7 +38,7 @@ def cvt_to_shape_desc(val, inpvar, config=None): ...@@ -38,7 +38,7 @@ def cvt_to_shape_desc(val, inpvar, config=None):
if isinstance(val, RawTensor): if isinstance(val, RawTensor):
return as_tensor(val, device) return as_tensor(val, device)
if not isinstance(val, collections.Iterable): if not isinstance(val, collections.abc.Iterable):
val = [val] val = [val]
components = [] components = []
......
...@@ -12,19 +12,18 @@ from tempfile import TemporaryFile ...@@ -12,19 +12,18 @@ from tempfile import TemporaryFile
import numpy as np import numpy as np
import megengine as mge import megengine as mge
from megengine import Buffer, Parameter, tensor from megengine import Parameter, Tensor
def test_tensor_serialization(): def test_tensor_serialization():
def tensor_eq(a, b): def tensor_eq(a, b):
assert a.dtype == b.dtype assert a.dtype == b.dtype
assert a.device == b.device assert a.device == b.device
assert a.requires_grad == b.requires_grad
np.testing.assert_equal(a.numpy(), b.numpy()) np.testing.assert_equal(a.numpy(), b.numpy())
with TemporaryFile() as f: with TemporaryFile() as f:
data = np.random.randint(low=0, high=7, size=[233]) data = np.random.randint(low=0, high=7, size=[233])
a = tensor(data, device="xpux", dtype=np.int32) a = Tensor(data, device="xpux", dtype=np.int32)
pickle.dump(a, f) pickle.dump(a, f)
f.seek(0) f.seek(0)
b = pickle.load(f) b = pickle.load(f)
...@@ -39,19 +38,19 @@ def test_tensor_serialization(): ...@@ -39,19 +38,19 @@ def test_tensor_serialization():
np.testing.assert_equal(a.numpy(), b.numpy()) np.testing.assert_equal(a.numpy(), b.numpy())
with TemporaryFile() as f: with TemporaryFile() as f:
a = Buffer(np.random.random(size=(2, 233)).astype(np.float32)) a = Tensor(np.random.random(size=(2, 233)).astype(np.float32))
pickle.dump(a, f) pickle.dump(a, f)
f.seek(0) f.seek(0)
b = pickle.load(f) b = pickle.load(f)
assert isinstance(b, Buffer) assert type(b) is Tensor
np.testing.assert_equal(a.numpy(), b.numpy()) np.testing.assert_equal(a.numpy(), b.numpy())
with TemporaryFile() as f: with TemporaryFile() as f:
a = Buffer(np.random.random(size=(2, 233)).astype(np.float32)) a = Tensor(np.random.random(size=(2, 233)).astype(np.float32))
mge.save(a, f) mge.save(a, f)
f.seek(0) f.seek(0)
b = mge.load(f, map_location="cpux") b = mge.load(f, map_location="cpux")
assert isinstance(b, Buffer) assert type(b) is Tensor
assert "cpu" in str(b.device) assert "cpu" in str(b.device)
np.testing.assert_equal(a.numpy(), b.numpy()) np.testing.assert_equal(a.numpy(), b.numpy())
...@@ -59,12 +58,12 @@ def test_tensor_serialization(): ...@@ -59,12 +58,12 @@ def test_tensor_serialization():
if mge.is_cuda_available(): if mge.is_cuda_available():
device_org = mge.get_default_device() device_org = mge.get_default_device()
mge.set_default_device("gpu0") mge.set_default_device("gpu0")
a = Buffer(np.random.random(size=(2, 233)).astype(np.float32)) a = Tensor(np.random.random(size=(2, 233)).astype(np.float32))
mge.save(a, f) mge.save(a, f)
f.seek(0) f.seek(0)
mge.set_default_device("cpux") mge.set_default_device("cpux")
b = mge.load(f, map_location={"gpu0": "cpu0"}) b = mge.load(f, map_location={"gpu0": "cpu0"})
assert isinstance(b, Buffer) assert type(b) is Tensor
assert "cpu0" in str(b.device) assert "cpu0" in str(b.device)
np.testing.assert_equal(a.numpy(), b.numpy()) np.testing.assert_equal(a.numpy(), b.numpy())
mge.set_default_device(device_org) mge.set_default_device(device_org)
...@@ -66,7 +66,7 @@ def main(): ...@@ -66,7 +66,7 @@ def main():
mge.set_default_device("cpux") mge.set_default_device("cpux")
net = XORNet() net = XORNet()
opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
batch_size = 64 batch_size = 64
train_dataset = minibatch_generator(batch_size) train_dataset = minibatch_generator(batch_size)
val_dataset = minibatch_generator(batch_size) val_dataset = minibatch_generator(batch_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册