提交 1bf18252 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mge/amp): add mix precision autocast support

GitOrigin-RevId: 6fbffc484511854feb1b83e297836cd321811bf3
上级 f12355f7
......@@ -117,6 +117,7 @@ def _atexit(handler):
# subpackages
import megengine.amp
import megengine.autodiff
import megengine.data
import megengine.distributed
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import mprop
from ..core.tensor.amp import *
from .autocast import autocast
mprop.init()
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
from ..core.tensor import amp
class autocast:
r"""
A class to control autocast mode for amp as a context manager or a decorator.
:param enabled: Whether autocast mode is enabled.
:low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16.
:high_prec_dtype: Set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32.
Examples:
..code-block::
# used as decorator
@autocast()
def train_step(image, label):
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
# used as context manager
def train_step(image, label):
with autocast():
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
"""
def __init__(
self,
enabled: bool = True,
low_prec_dtype: str = "float16",
high_prec_dtype: str = "float32",
):
self.enabled = enabled
self.high_prec_dtype = high_prec_dtype
self.low_prec_dtype = low_prec_dtype
self._origin_enabled = None
self._origin_high = None
self._origin_low = None
def __enter__(self):
self._origin_enabled, amp._enabled = amp._enabled, self.enabled
self._origin_high = amp._high_prec_dtype
amp._high_prec_dtype = self.high_prec_dtype
self._origin_low = amp._low_prec_dtype
amp._low_prec_dtype = self.low_prec_dtype
def __exit__(self, *args):
amp._enabled = self._origin_enabled
amp._high_prec_dtype = self._origin_high
amp._low_prec_dtype = self._origin_low
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapper
......@@ -49,7 +49,7 @@ class Device:
return self._cn == rhs._cn
def device(obj):
def as_device(obj):
if isinstance(obj, Device):
return obj
return Device(obj)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
_enabled = False
_high_prec_dtype = "float32"
_low_prec_dtype = "float16"
@property
def enabled(mod):
r"""
Get or set amp autocast mode enabled or not.
Examples:
..code-block::
import megengine as mge
mge.amp.enabled = True
"""
return _enabled
@enabled.setter
def enabled(mod, enabled: bool):
global _enabled
_enabled = enabled
@property
def high_prec_dtype(mod):
r"""
Get or set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32.
Examples:
..code-block::
import megengine as mge
mge.amp.high_prec_dtype = "float32"
"""
return _high_prec_dtype
@high_prec_dtype.setter
def high_prec_dtype(mod, dtype: str):
global _high_prec_dtype
_high_prec_dtype = dtype
@property
def low_prec_dtype(mod):
r"""
Get or set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16.
Examples:
..code-block::
import megengine as mge
mge.amp.low_prec_dtype = "float16"
"""
return _low_prec_dtype
@low_prec_dtype.setter
def low_prec_dtype(mod, dtype: str):
global _low_prec_dtype
_low_prec_dtype = dtype
......@@ -15,15 +15,20 @@ import numpy as np
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from . import utils
from .indexing import getitem as _getitem
from .indexing import setitem as _setitem
from .utils import isscalar
from .utils import make_shape_tuple as _make_shape_tuple
from .utils import setscalar
_ElwMod = Elemwise.Mode
from . import amp
from .indexing import getitem, setitem
from .utils import (
_normalize_axis,
astensor1d,
astype,
cast_tensors,
convert_inputs,
isscalar,
make_shape_tuple,
setscalar,
)
_ElwMod = builtin.Elemwise.Mode
def _elwise_apply(args, mode):
......@@ -40,47 +45,59 @@ def _elwise_apply(args, mode):
def _elwise(*args, mode):
args = convert_inputs(*args)
if mode in (
_ElwMod.TRUE_DIV,
_ElwMod.EXP,
_ElwMod.POW,
_ElwMod.CEIL,
_ElwMod.FLOOR,
_ElwMod.ROUND,
_ElwMod.LOG,
_ElwMod.EXPM1,
_ElwMod.LOG1P,
_ElwMod.TANH,
_ElwMod.ACOS,
_ElwMod.ASIN,
_ElwMod.ATAN2,
_ElwMod.COS,
_ElwMod.H_SWISH,
_ElwMod.SIGMOID,
_ElwMod.SIN,
) and (
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
):
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype(
args[0].dtype, np.integer
):
return args[0]
args = tuple(
map(
lambda x: x.astype("float32")
if hasattr(x, "dtype") and x.dtype != np.float32
else x,
args,
)
)
args = utils.convert_inputs(*args)
# autocast to FP32 to maintain precision
# or to avoid op's not supporting all int args
args = cast_tensors(*args, promote=True)
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype(
args[0].dtype, np.integer
):
return args[0]
return _elwise_apply(args, mode)
def _matmul(inp1, inp2):
if amp._enabled:
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
compute_mode = "default"
inp1, inp2 = convert_inputs(inp1, inp2)
op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode="default", format="default"
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default"
)
inp1, inp2 = utils.convert_inputs(inp1, inp2)
(result,) = apply(op, inp1, inp2)
return result
def _transpose(data, axes):
op = builtin.Dimshuffle(axes)
(data,) = utils.convert_inputs(data)
(data,) = convert_inputs(data)
(result,) = apply(op, data)
return result
def _broadcast(inp, shape):
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device)
shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result
......@@ -88,7 +105,7 @@ def _broadcast(inp, shape):
def _reshape(x, shape):
unspec_axis = None
try:
shape_tuple = _make_shape_tuple(shape)
shape_tuple = make_shape_tuple(shape)
except ValueError:
pass
else:
......@@ -102,7 +119,7 @@ def _reshape(x, shape):
"multiple -1 in shape: {} & {}".format(unspec_axis, i)
)
unspec_axis = i
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device)
shape = astensor1d(shape, x, dtype="int32", device=x.device)
if unspec_axis is None:
op = builtin.Reshape()
else:
......@@ -171,7 +188,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
return list(map(int, axis))
axis = get_axes()
axis = utils._normalize_axis(inp.ndim, axis)
axis = _normalize_axis(inp.ndim, axis)
axis = [a - i for i, a in enumerate(axis)]
op = builtin.RemoveAxis(axis=axis)
......@@ -184,7 +201,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
(data,) = utils.convert_inputs(data)
(data,) = convert_inputs(data)
if mode == "mean":
data = data.astype("float32")
elif self.dtype == np.bool_:
......@@ -196,7 +213,7 @@ def _reduce(mode):
op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
elif isinstance(axis, collections.abc.Iterable):
axis = utils._normalize_axis(self.ndim, axis, reverse=True)
axis = _normalize_axis(self.ndim, axis, reverse=True)
for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
......@@ -359,11 +376,11 @@ class ArrayMethodMixin(abc.ABC):
yield self[i]
def __getitem__(self, index):
return _getitem(self, index)
return getitem(self, index)
def __setitem__(self, index, value):
if index is not Ellipsis:
value = _setitem(self, index, value)
value = setitem(self, index, value)
self._reset(value)
__contains__ = _todo
......@@ -422,7 +439,7 @@ class ArrayMethodMixin(abc.ABC):
Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`.
"""
return utils.astype(self, dtype)
return astype(self, dtype)
def reshape(self, *args):
r"""
......
......@@ -18,7 +18,7 @@ import numpy as np
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._wrap import device as as_device
from .._wrap import as_device
from ..ops.builtin import OpDef
from .core import TensorBase
......
......@@ -13,9 +13,10 @@ import numpy as np
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._wrap import device as as_device
from .._wrap import as_device
from ..ops import builtin
from ..ops.special import Const
from .amp import _high_prec_dtype, _low_prec_dtype
from .dtype import is_dtype_equal, is_quantize
_enable_convert_inputs = True
......@@ -98,6 +99,14 @@ def convert_inputs(*args, device=None):
return tuple(map(convert, args))
def cast_tensors(*args, promote=False):
if promote:
dtype = _high_prec_dtype
else:
dtype = _low_prec_dtype
return tuple(arg.astype(dtype) if arg is not None else None for arg in args)
def result_type(*args):
dtypes = []
for i in args:
......
......@@ -12,10 +12,8 @@ import numpy as np
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor import utils
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astype
from ..device import get_default_device
from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import astype, convert_inputs
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
......@@ -69,46 +67,9 @@ __all__ = [
]
def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args))
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in (
Elemwise.Mode.TRUE_DIV,
Elemwise.Mode.EXP,
Elemwise.Mode.POW,
Elemwise.Mode.LOG,
Elemwise.Mode.EXPM1,
Elemwise.Mode.LOG1P,
Elemwise.Mode.TANH,
Elemwise.Mode.ACOS,
Elemwise.Mode.ASIN,
Elemwise.Mode.ATAN2,
Elemwise.Mode.CEIL,
Elemwise.Mode.COS,
Elemwise.Mode.FLOOR,
Elemwise.Mode.H_SWISH,
Elemwise.Mode.ROUND,
Elemwise.Mode.SIGMOID,
Elemwise.Mode.SIN,
):
if mode in (
Elemwise.Mode.CEIL,
Elemwise.Mode.FLOOR,
Elemwise.Mode.ROUND,
) and np.issubdtype(args[0].dtype, np.integer):
return args[0]
args = tuple(map(lambda x: astype(x, "float32"), args))
return _elwise_apply(args, mode)
def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = utils.convert_inputs(*args)
args = convert_inputs(*args)
(result,) = apply(op, *args)
return result
......
......@@ -14,7 +14,8 @@ from ..core._imperative_rt.core2 import apply
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar
from ..tensor import Tensor
from .debug_param import get_execution_strategy
from .elemwise import clip, exp, log, log1p
......@@ -471,7 +472,7 @@ def argmin(
inp = inp.flatten()
axis = 0
axis = utils._normalize_axis(inp.ndim, axis, reverse=True)
axis = _normalize_axis(inp.ndim, axis, reverse=True)
if isinstance(axis, collections.abc.Iterable):
for ai in axis:
......@@ -528,7 +529,7 @@ def argmax(
assert not keepdims, "can not set axis=None and keepdims=True"
inp = inp.flatten()
axis = 0
axis = utils._normalize_axis(inp.ndim, axis, reverse=True)
axis = _normalize_axis(inp.ndim, axis, reverse=True)
if isinstance(axis, collections.abc.Iterable):
......@@ -807,8 +808,13 @@ def matmul(
[28. 40.]]
"""
if amp._enabled:
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
inp1, inp2 = convert_inputs(inp1, inp2)
remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2)
dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
......@@ -921,12 +927,12 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
"""
op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2)
inp1, inp2 = convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2)
utils.setscalar(result)
setscalar(result)
return result
......
......@@ -15,9 +15,16 @@ from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils
from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astensor1d, astype, setscalar
from ..core.tensor.utils import (
astensor1d,
astype,
cast_tensors,
convert_inputs,
convert_single_value,
setscalar,
)
from ..device import get_default_device
from ..distributed import WORLD, is_distributed
from ..random import uniform
......@@ -91,7 +98,9 @@ def expand_hw(x):
return int(h), int(w)
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
def linear(
inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default",
) -> Tensor:
"""
Applies a linear transformation to the input tensor.
......@@ -102,8 +111,10 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor
:param bias: bias with shape `(out_features,)`.
Default: None
"""
ret = matmul(inp, weight, transpose_b=True)
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
if bias is not None:
if amp._enabled:
bias = bias.astype("float16")
ret += bias
return ret
......@@ -153,6 +164,11 @@ def conv1d(
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
assert inp.ndim == 3, "the input dimension of conv1d should be 3"
assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)
inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3)
......@@ -177,7 +193,6 @@ def conv1d(
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
......@@ -228,7 +243,11 @@ def conv2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
......@@ -247,7 +266,6 @@ def conv2d(
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
......@@ -286,6 +304,7 @@ def conv3d(
:return: output tensor.
"""
assert conv_mode.lower() == "cross_correlation"
inp, weight = convert_inputs(inp, weight)
D, H, W = 0, 1, 2
......@@ -308,7 +327,6 @@ def conv3d(
mode=conv_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
......@@ -358,7 +376,11 @@ def conv_transpose2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
inp, weight = convert_inputs(inp, weight)
if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.")
......@@ -375,8 +397,8 @@ def conv_transpose2d(
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_execution_strategy(),
compute_mode=compute_mode,
)
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
......@@ -428,7 +450,11 @@ def deformable_conv2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if amp._enabled:
compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else:
inp, weight, offset, mask = convert_inputs(inp, weight, offset, mask)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
......@@ -447,7 +473,6 @@ def deformable_conv2d(
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask)
(output,) = apply(op, inp, weight, offset, mask)
if bias is not None:
output += bias
......@@ -468,6 +493,7 @@ def local_conv2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
inp, weight = convert_inputs(inp, weight)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
......@@ -481,10 +507,8 @@ def local_conv2d(
dilate_h=dilate_h,
dilate_w=dilate_w,
mode=conv_mode,
compute_mode="default",
sparse="dense",
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
......@@ -515,8 +539,9 @@ def conv_transpose3d(
:param dilation: dilation of the 3D convolution operation. Default: 1
:return: output tensor.
"""
D, H, W = 0, 1, 2
inp, weight = convert_inputs(inp, weight)
D, H, W = 0, 1, 2
pad = _triple(padding)
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)
......@@ -533,7 +558,6 @@ def conv_transpose3d(
dilate_w=dilate[W],
strategy=get_execution_strategy(),
)
weight, inp = utils.convert_inputs(weight, inp)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
......@@ -994,7 +1018,8 @@ def batch_norm(
training: bool = False,
momentum: float = 0.9,
eps: float = 1e-5,
inplace: bool = True
inplace: bool = True,
compute_mode="default",
):
r"""
Applies batch normalization to the input.
......@@ -1027,15 +1052,11 @@ def batch_norm(
def make_full_if_none(x, value):
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)()
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape)
return result
elif x.ndim == 1:
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Reshape(), x, shape)
return result
return x
......@@ -1052,10 +1073,15 @@ def batch_norm(
if has_var and running_var.ndim != 4:
raise ValueError
inp, weight, bias, running_mean, running_var = utils.convert_inputs(
inp, weight, bias, running_mean, running_var
)
if amp._enabled:
inp = inp.astype("float16")
weight, bias, running_mean, running_var = cast_tensors(
weight, bias, running_mean, running_var, promote=True
)
elif compute_mode != "float32":
inp, weight, bias, running_mean, running_var = convert_inputs(
inp, weight, bias, running_mean, running_var
)
weight = make_full_if_none(weight, 1)
bias = make_full_if_none(bias, 0)
......@@ -1352,7 +1378,7 @@ def indexing_one_hot(
"""
assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, dtype="int32", device=src.device)
index = convert_single_value(index, dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
result = squeeze(result, axis)
......
......@@ -13,7 +13,7 @@ import numpy as np
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._wrap import device as as_device
from ..core._wrap import as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
......
......@@ -33,7 +33,7 @@ from ..core._imperative_rt.ops import (
RemoteSend,
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core._wrap import as_device
from ..core.ops.builtin import BatchNorm, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
......
......@@ -26,6 +26,7 @@ class _BatchNorm(Module):
affine=True,
track_running_stats=True,
freeze=False,
compute_mode="default",
**kwargs
):
super(_BatchNorm, self).__init__(**kwargs)
......@@ -36,6 +37,7 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats
self.freeze = freeze
self.compute_mode = compute_mode
if self.freeze:
assert (
self._track_running_stats_saved
......@@ -123,6 +125,7 @@ class _BatchNorm(Module):
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
compute_mode=self.compute_mode,
)
if _ndims != 4:
......
......@@ -51,7 +51,12 @@ class Linear(Module):
"""
def __init__(
self, in_features: int, out_features: int, bias: bool = True, **kwargs
self,
in_features: int,
out_features: int,
bias: bool = True,
compute_mode: str = "default",
**kwargs
):
super().__init__(**kwargs)
self.out_features = out_features
......@@ -62,6 +67,7 @@ class Linear(Module):
if bias:
b_shape = (out_features,)
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
self.compute_mode = compute_mode
self.reset_parameters()
def _get_fanin(self):
......@@ -75,7 +81,7 @@ class Linear(Module):
init.zeros_(self.bias)
def _calc_linear(self, x, weight, bias):
return linear(x, weight, bias)
return linear(x, weight, bias, compute_mode=self.compute_mode)
def forward(self, x):
return self._calc_linear(x, self.weight, self.bias)
......
......@@ -5,8 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import math
from functools import partial
from .. import functional as F
......
......@@ -14,7 +14,7 @@ from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape
from .core._wrap import device as as_device
from .core._wrap import as_device
from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import amp
from megengine.core.tensor import amp as origin_amp
def test_grad_scaler():
def check(enabled, low, high):
assert amp.enabled == enabled
assert origin_amp._enabled == enabled
assert amp.low_prec_dtype == low
assert origin_amp._low_prec_dtype == low
assert amp.high_prec_dtype == high
assert origin_amp._high_prec_dtype == high
origin_enabled = amp.enabled
origin_high = amp.high_prec_dtype
origin_low = amp.low_prec_dtype
with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"):
check(True, "low", "high")
check(origin_enabled, origin_low, origin_high)
amp.enabled = True
amp.high_prec_dtype = "high"
amp.low_prec_dtype = "low"
check(True, "low", "high")
amp.enabled = origin_enabled
amp.high_prec_dtype = origin_high
amp.low_prec_dtype = origin_low
check(origin_enabled, origin_low, origin_high)
......@@ -14,6 +14,7 @@ import numpy as np
import pytest
from utils import opr_test
import megengine.amp as amp
import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype
import megengine.functional as F
......@@ -767,6 +768,27 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)
def test_conv2d_io16c32():
amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32)
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
amp.enabled = False
expected = F.conv2d(
inp.astype("float16"),
weight.astype("float16"),
None,
(2, 2),
(3, 3),
(1, 1),
1,
compute_mode="float32",
)
assert out.dtype == np.float16
assert expected.dtype == np.float16
np.testing.assert_allclose(out.numpy(), expected.numpy())
def test_conv2d_zero_stride_numpy_array():
inp = np.random.randn(3, 224, 224).astype(np.float32)
inp = inp[np.newaxis, :]
......@@ -787,8 +809,8 @@ def test_conv3d_zero_stride_numpy_array():
def test_conv1d():
inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4))
weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2))
inp = tensor(np.ones((2, 2, 4), dtype=np.float32))
weight = tensor(np.ones((3, 2, 2), dtype=np.float32))
out = F.conv1d(inp, weight, None, 2, 0, 1, 1)
np.testing.assert_equal(
out.numpy(),
......@@ -798,9 +820,31 @@ def test_conv1d():
)
def test_batchnorm2d_io16c32():
amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32)
bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32)
out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False)
amp.enabled = False
expected = F.batch_norm(
inp.astype("float16"),
weight=weight,
bias=bias,
training=True,
inplace=False,
compute_mode="float32",
)
assert out.dtype == np.float16
assert expected.dtype == np.float16
np.testing.assert_allclose(out.numpy(), expected.numpy())
def test_conv3d():
inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4))
weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2))
inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
print(out.numpy().shape)
np.testing.assert_equal(
......
......@@ -473,39 +473,6 @@ def test_pickle_module():
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)
def test_load_quantized():
from megengine.core.tensor import dtype
data_shape = (2, 28)
data = tensor(np.random.random(data_shape), dtype="float32")
data = data.astype(dtype.qint8(0.1))
mlp = MLP()
quantize_qat(mlp)
quantize(mlp)
mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy())
mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy())
mlp.eval()
pred0 = mlp(data)
with BytesIO() as fout:
mge.save(mlp.state_dict(), fout)
fout.seek(0)
checkpoint = mge.load(fout)
# change mlp weight.
mlp.dense0.weight = Parameter(
mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy()
)
mlp.dense1.weight = Parameter(
mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy()
)
mlp.load_state_dict(checkpoint)
pred1 = mlp(data)
np.testing.assert_allclose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), atol=5e-6
)
def test_repr_basic():
# test whether __repr__ can output correct information
class ConvModel(Module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册