提交 dc96f6aa 编写于 作者: M Megvii Engine Team

fix(mge/quantization): fix quantized concat forward problem

GitOrigin-RevId: dc21b340d1732fa5c1c186904d2a6c1b13e10121
上级 33da8de1
......@@ -23,7 +23,7 @@ class Concat(QuantizedModule):
self.output_dtype = dtype
def forward(self, inps: Iterable[Tensor], axis: int = 0):
new_inps = (x.astype(self.output_dtype) for x in inps)
new_inps = tuple(x.astype(self.output_dtype) for x in inps)
return F.concat(new_inps, axis)
@classmethod
......
......@@ -6,8 +6,16 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import FakeQuantize
from .observer import Observer
from .fake_quant import TQT, FakeQuantize
from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
Observer,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
from .qconfig import (
QConfig,
calibration_qconfig,
......
......@@ -30,7 +30,10 @@ min_max_fakequant_qconfig = QConfig(
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
)
inp_scale = np.float32(np.random.rand() + 1)
def gen_inp_scale():
return np.float32(np.random.rand() + 1)
min_val = np.random.randint(-127, 0, size=(2,)).astype("float32")
max_val = np.random.randint(1, 127, size=(2,)).astype("float32")
......@@ -116,6 +119,7 @@ def test_dequant_stub():
q_net.eval()
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale)
x.qparams.scale = inp_scale
......@@ -192,6 +196,7 @@ def test_linear():
init_qat_net(qat_net)
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale)
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
......@@ -235,6 +240,7 @@ def test_conv(module):
init_qat_net(qat_net)
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
inp_scale = gen_inp_scale()
x = fake_quant_act(x, inp_scale)
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
......@@ -269,3 +275,41 @@ def test_conv(module):
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5)
np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale)
np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale)
def test_concat():
normal_net = Float.Concat()
normal_net.eval()
qat_net = QAT.Concat()
qat_net.eval()
disable_observer(qat_net)
propagate_qconfig(qat_net, min_max_fakequant_qconfig)
init_qat_net(qat_net)
inps = []
inps_int8 = []
for i in range(3):
inp_scale = gen_inp_scale()
inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32")))
inps[i] = fake_quant_act(inps[i], inp_scale)
inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))
inps_int8.append(quant(inps[i], inp_scale))
qat_from_float = QAT.Concat.from_float_module(normal_net)
qat_from_float.eval()
disable_fake_quant(qat_from_float)
disable_observer(qat_from_float)
q_net = Q.Concat.from_qat_module(qat_net)
q_net.eval()
normal = normal_net(inps)
qat_without_fakequant = qat_from_float(inps)
fake_quant_normal = fake_quant_act(normal_net(inps), act_scale)
qat = qat_net(inps)
q = q_net(inps_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal.numpy())
np.testing.assert_allclose(q, fake_quant_normal.numpy())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册