提交 205a39b0 编写于 作者: M Megvii Engine Team

feat(mge): remove add_update

GitOrigin-RevId: 3068593eebb6f83884602651b21c94d623e78079
上级 67d33303
......@@ -410,7 +410,7 @@ class ArrayMethodMixin(abc.ABC):
def sum(self, axis=None, keepdims: bool = False):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).
Same for prod/mean/max/min.
......
......@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from .elemwise import *
from .graph import add_update
from .loss import *
from .math import *
from .nn import *
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Iterable, Optional, Union
from ..tensor import Tensor
def add_update(
dest: Tensor,
delta: Tensor,
*,
alpha: Union[Tensor, float, int] = 1.0,
beta: Union[Tensor, float, int] = 1.0,
bias: Union[Tensor, float, int] = 0.0
):
r"""Modify ``dest`` inplace as follows:
.. math::
dest = alpha * dest + beta * delta + bias
:param dest: input data that will be inplace modified.
:param delta: update value that will be added to ``dest``.
:param alpha: weight ratio of ``dest``. Default: 1.0
:param beta: weight ratio of ``delta``. Default: 1.0
:param bias: bias value appended to the result. Default: 0.0
"""
if beta is not None and beta != 1.0:
delta = delta * beta
if bias is not None and bias != 0.0:
delta = delta + bias
if alpha is not None and alpha != 1.0:
dest *= alpha
dest += delta
return dest
......@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...functional import add_update, ones, relu, sqrt, sum, zeros
from ...functional import ones, relu, sqrt, sum, zeros
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float
from .module import QATModule
......@@ -76,18 +76,10 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
self.bn.running_mean *= self.bn.momentum
self.bn.running_mean += exponential_average_factor * bn_mean
self.bn.running_var *= self.bn.momentum
self.bn.running_var += exponential_average_factor * bn_var
def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
......
......@@ -127,7 +127,7 @@ class TQT(_FakeQuantize):
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / math.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
self.scale[...] = tmp_scale
return inp
def get_qparams(self):
......
......@@ -290,41 +290,6 @@ def test_one_hot():
onehot_high_dimension()
def test_add_update():
shape = (2, 3)
v = np.random.random(shape).astype(np.float32)
b = Tensor(v)
u = F.add_update(b, 1)
np.testing.assert_allclose(u.numpy(), v + 1, atol=1e-6)
u = F.add_update(b, 1)
np.testing.assert_allclose(u.numpy(), v + 2, atol=1e-6)
x = np.ones((2, 2), dtype=np.float32)
y = x * 0.5
dest = tensor(x)
delta = tensor(y)
r = F.add_update(dest, delta, alpha=0.9, beta=0.1, bias=0.1)
np.testing.assert_allclose(r.numpy(), x * 0.9 + y * 0.1 + 0.1, atol=1e-6)
def test_add_update_params():
b = np.random.random((2, 3)).astype(np.float32)
y = Tensor(b)
# @jit.trace
def f(x):
return F.add_update(y, x)
f(np.zeros((2, 3)).astype(np.float32))
z = Tensor(np.zeros((2, 3)).astype(np.float32))
F.add_update(y, z, beta=0.1)
res = f(np.ones((2, 3)).astype(np.float32))
np.testing.assert_allclose(res.numpy(), b + 1)
def test_binary_cross_entropy():
data1_shape = (2, 2)
label1_shape = (2, 2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册