提交 ab9f44f1 编写于 作者: M Megvii Engine Team

feat(mge/quantization): add support for easyquant

GitOrigin-RevId: 060d908349ca6bdcee293be5a2e47a5bee98af5e
上级 fc0fcd2f
......@@ -17,9 +17,7 @@ from .module import QuantizedModule
class Linear(QuantizedModule):
r"""Quantized version of :class:`~.qat.linear.Linear`."""
def __init__(
self, dtype: np.dtype = None,
):
def __init__(self, dtype: np.dtype = None):
super().__init__()
self.weight = None
self.bias = None
......
......@@ -15,7 +15,8 @@ from .qconfig import (
ema_fakequant_qconfig,
ema_lowbit_fakequant_qconfig,
min_max_fakequant_qconfig,
passive_qconfig,
sync_ema_fakequant_qconfig,
tqt_quant_qconfig,
tqt_qconfig,
)
from .utils import QuantMode
......@@ -28,7 +28,9 @@ class _FakeQuantize(Module):
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
def __init__(
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
):
super().__init__()
if not dtype in _metadata_dict.keys():
raise ValueError(
......@@ -114,24 +116,28 @@ class TQT(_FakeQuantize):
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
"""
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__(dtype, narrow_range, enable)
self.scale = Parameter([0.0], dtype=np.float32)
def __init__(
self,
q_dict,
dtype: str,
narrow_range: bool = False,
enable: bool = True,
**kwargs
):
super().__init__(dtype, narrow_range, enable, **kwargs)
assert (
q_dict["mode"] == QuantMode.SYMMERTIC
), "only symmetric quantization is supported by TQT"
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.scale = F.log(q_dict["scale"]) / math.log(2)
def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)
def normal_foward(self, inp, q_dict=None):
if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / math.log(2)
self.scale[...] = tmp_scale
return inp
def get_qparams(self):
q_dict = get_qparam_dict(QuantMode.TQT)
q_dict = get_qparam_dict(QuantMode.SYMMERTIC)
q_dict["scale"] = 2 ** self.scale
return q_dict
......
......@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from abc import abstractmethod
from copy import deepcopy
import numpy as np
......@@ -28,7 +29,7 @@ class Observer(Module):
instead of 1 greater. Usually True for weight and False for activation.
"""
def __init__(self, dtype: str, narrow_range: bool = False):
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__()
if dtype not in _metadata_dict.keys():
raise ValueError(
......@@ -81,8 +82,9 @@ class MinMaxObserver(Observer):
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
**kwargs
):
super().__init__(dtype, narrow_range)
super().__init__(dtype, narrow_range, **kwargs)
self.mode = mode
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
......@@ -105,7 +107,7 @@ class MinMaxObserver(Observer):
else:
# use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
)
# caculate zero_point
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))
......@@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
**kwargs
):
super().__init__(mode, eps, dtype, narrow_range)
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.momentum = Tensor(momentum)
self.runtime_momentum = Tensor(0.0)
......@@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver):
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
**kwargs
):
super().__init__(mode, eps, dtype, narrow_range)
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
......@@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver):
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max(
(new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
new_min, new_max, self.upsample_rate
)
......@@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver):
def forward(self, x_orig):
self.sideeffect_forward(x_orig)
return x_orig
class PassiveObserver(Observer):
r"""
This class can be set :attr:`scale` derectly.
"""
def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__(dtype, narrow_range, **kwargs)
self.q_dict = deepcopy(q_dict)
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.orig_scale = q_dict["scale"].numpy()
@property
def scale(self):
return self.q_dict["scale"]
@scale.setter
def scale(self, value):
assert value > 0
self.q_dict["scale"].set_value(value)
def get_qparams(self):
return self.q_dict
def forward(self, x):
r"""
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`.
"""
return x
......@@ -13,6 +13,7 @@ from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
......@@ -66,17 +67,22 @@ class QConfig:
self.weight_fake_quant = weight_fake_quant
self.act_fake_quant = act_fake_quant
def __eq__(self, other):
def eq(a, b):
if isinstance(a, partial) and isinstance(b, partial):
return all(
[a.func == b.func, a.args == b.args, a.keywords == b.keywords]
)
else:
return a == b
return (
eq(self.weight_observer, other.weight_observer)
and eq(self.act_observer, other.act_observer)
and eq(self.weight_fake_quant, other.weight_fake_quant)
and eq(self.act_fake_quant, other.act_fake_quant)
)
tqt_quant_qconfig = QConfig(
weight_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True
),
act_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
min_max_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
......@@ -118,3 +124,17 @@ calibration_qconfig = QConfig(
weight_fake_quant=None,
act_fake_quant=None,
)
tqt_qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
passive_qconfig = QConfig(
weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True),
act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)
......@@ -6,15 +6,18 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import copy, deepcopy
from functools import partial
from typing import Callable, Dict, Tuple
import numpy as np
from .. import module as Float
from ..functional import concat, norm
from ..module import Module
from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from .fake_quant import TQT
from .qconfig import QConfig, ema_fakequant_qconfig
......@@ -32,9 +35,7 @@ def _get_quantable_module_names():
return quantable_module_names
def _get_convert_dict() -> Tuple[
Dict[Module, QATModule], Dict[QATModule, QuantizedModule]
]:
def _get_convert_dict():
quantable_module_names = _get_quantable_module_names()
quantable_modules = [getattr(Float, key) for key in quantable_module_names]
......@@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[
_float2qat_dict, _qat2quantized_dict = _get_convert_dict()
qat_modules = tuple(_qat2quantized_dict.keys())
def is_qat(mod: Module):
return isinstance(mod, qat_modules)
def quantize(module: Module, inplace: bool = True, mapping: dict = None):
......@@ -133,6 +139,34 @@ def quantize_qat(
return module
def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
r"""
Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig``
:param module: root module to reset recursively.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
:param inplace: whether to reset submodules in-place.
"""
if not inplace:
module = deepcopy(module)
def safe_call(func, q_dict):
return func(q_dict=q_dict) if func is not None else None
for m in list(module._flatten(predicate=is_qat)):
if m.with_weight:
weight_q_dict = m.get_weight_qparams()
m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict)
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict)
if m.with_act:
act_q_dict = m.get_activation_qparams()
m.act_observer = safe_call(qconfig.act_observer, act_q_dict)
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict)
return module
def _propagate(module: Module, func_str: str, *args, **kargs):
def fn(mod: Module):
if isinstance(mod, QATModule):
......@@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
_propagate(module, "set_qconfig", qconfig)
def hook_qat_module(module: Module, func: Callable):
r"""
Add hooks for all :class:`~.QATModule` submodule
"""
hooks = []
for submodule in list(module._flatten(predicate=is_qat)):
hooks.append(submodule.register_forward_hook(func))
return hooks
def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40):
r"""
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
Search for optimal scales.
:param module: root module.
:param data: input tensor used to search optimal scale.
:param start: lower bound of the search interval.
:param stop: upper bound of the search interval.
:param num: number of samples to search.
"""
batch_size = data.shape[0]
def get_cosine(x, y):
ndim = len(x.shape)
axis = tuple(range(1, ndim))
up = (x * y).sum(axis=axis)
down = norm(x, axis=axis) * norm(y, axis=axis)
sim = up / down
return sim.mean(axis=0)
def search(mod, inputs, outputs, where):
mod._forward_hooks.clear()
fp32_in = [_[:batch_size] for _ in inputs]
int8_in = [_[batch_size:] for _ in inputs]
disable_fake_quant(mod)
fp32_out = mod(*fp32_in)
enable_fake_quant(mod)
ob = getattr(mod, where)
if ob is None:
return
orig_scale = ob.orig_scale
distance = 0
best_scale = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale
int8_out = mod(*int8_in)
dis = get_cosine(fp32_out, int8_out)
if dis > distance:
distance = dis
best_scale = scale
ob.scale = best_scale
if where == "act_observer":
int8_out = mod(*int8_in)
return concat([fp32_out, int8_out])
else:
int8_out = outputs[batch_size:]
return concat([fp32_out, int8_out])
data = concat([data, data])
hook_qat_module(module, partial(search, where="weight_observer"))
module(data)
hook_qat_module(module, partial(search, where="act_observer"))
module(data)
return module
def disable_fake_quant(module: Module):
r"""
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
......
......@@ -54,17 +54,15 @@ class QuantMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2
TQT = 3
qparam_dict = {
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,},
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None},
QuantMode.ASYMMERTIC: {
"mode": QuantMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
},
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,},
}
......
......@@ -6,17 +6,53 @@ import pytest
import megengine as mge
import megengine.distributed as dist
import megengine.quantization.observer as ob
from megengine.distributed.helper import get_device_count_by_fork
from megengine.quantization.observer import (
ExponentialMovingAverageObserver,
MinMaxObserver,
Observer,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
def test_observer():
with pytest.raises(TypeError):
Observer("qint8")
def test_min_max_observer():
x = np.random.rand(3, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
x = mge.tensor(x)
m = ob.MinMaxObserver()
m = MinMaxObserver()
m(x)
assert m.min_val == np_min and m.max_val == np_max
np.testing.assert_allclose(m.min_val.numpy(), np_min)
np.testing.assert_allclose(m.max_val.numpy(), np_max)
def test_exponential_moving_average_observer():
t = np.random.rand()
x1 = np.random.rand(3, 3, 3, 3).astype("float32")
x2 = np.random.rand(3, 3, 3, 3).astype("float32")
expected_min = x1.min() * t + x2.min() * (1 - t)
expected_max = x1.max() * t + x2.max() * (1 - t)
m = ExponentialMovingAverageObserver(momentum=t)
m(mge.tensor(x1, dtype=np.float32))
m(mge.tensor(x2, dtype=np.float32))
np.testing.assert_allclose(m.min_val.numpy(), expected_min)
np.testing.assert_allclose(m.max_val.numpy(), expected_max)
def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)}
m = PassiveObserver(q_dict, "qint8")
assert m.orig_scale == 1.0
assert m.scale == 1.0
m.scale = 2.0
assert m.scale == 2.0
assert m.get_qparams() == {"scale": mge.tensor(2.0)}
@pytest.mark.skipif(
......@@ -35,9 +71,39 @@ def test_sync_min_max_observer():
@dist.launcher
def worker():
rank = dist.get_rank()
m = ob.SyncMinMaxObserver()
m = SyncMinMaxObserver()
y = mge.tensor(x[rank * 3 : (rank + 1) * 3])
m(y)
assert m.min_val == np_min and m.max_val == np_max
worker()
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_sync_exponential_moving_average_observer():
word_size = get_device_count_by_fork("gpu")
t = np.random.rand()
x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
expected_min = x1.min() * t + x2.min() * (1 - t)
expected_max = x1.max() * t + x2.max() * (1 - t)
@dist.launcher
def worker():
rank = dist.get_rank()
m = SyncExponentialMovingAverageObserver(momentum=t)
y1 = mge.tensor(x1[rank * 3 : (rank + 1) * 3])
y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3])
m(y1)
m(y2)
np.testing.assert_allclose(m.min_val.numpy(), expected_min)
np.testing.assert_allclose(m.max_val.numpy(), expected_max)
worker()
from functools import partial
from megengine.quantization import QConfig, tqt_qconfig
from megengine.quantization.fake_quant import TQT
def test_equal():
qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
assert qconfig == tqt_qconfig
......@@ -8,17 +8,194 @@
import numpy as np
import pytest
from megengine import functional
from megengine import module as Float
from megengine import tensor
from megengine.module import qat as QAT
from megengine.quantization import min_max_fakequant_qconfig
from megengine.module import quantized as Q
from megengine.quantization import (
min_max_fakequant_qconfig,
passive_qconfig,
tqt_qconfig,
)
from megengine.quantization.fake_quant import TQT, FakeQuantize
from megengine.quantization.observer import MinMaxObserver, PassiveObserver
from megengine.quantization.quantize import (
_get_quantable_module_names,
apply_easy_quant,
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
propagate_qconfig,
quantize,
quantize_qat,
reset_qconfig,
)
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x
class QATNet(Float.Module):
def __init__(self):
super().__init__()
self.quant = QAT.QuantStub()
self.linear = QAT.Linear(3, 3)
self.dequant = QAT.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x
def test_propagate_qconfig():
net = QATNet()
propagate_qconfig(net, min_max_fakequant_qconfig)
assert all(
[
net.quant.weight_observer is None,
net.quant.weight_fake_quant is None,
isinstance(net.quant.act_observer, MinMaxObserver),
isinstance(net.quant.act_fake_quant, FakeQuantize),
isinstance(net.linear.weight_observer, MinMaxObserver),
isinstance(net.linear.weight_fake_quant, FakeQuantize),
isinstance(net.linear.act_observer, MinMaxObserver),
isinstance(net.linear.act_fake_quant, FakeQuantize),
net.dequant.weight_observer is None,
net.dequant.weight_fake_quant is None,
net.dequant.act_observer is None,
net.dequant.act_observer is None,
]
)
def init_qat_net():
net = QATNet()
propagate_qconfig(net, min_max_fakequant_qconfig)
min_val = np.random.randint(-127, 0, size=(2,))
max_val = np.random.randint(1, 127, size=(2,))
net.linear.weight_observer.min_val.set_value(min_val[0])
net.linear.weight_observer.max_val.set_value(max_val[0])
net.linear.act_observer.min_val.set_value(min_val[1])
net.linear.act_observer.max_val.set_value(max_val[1])
return net
def test_reset_qconfig():
qat_net = init_qat_net()
new_qat_net = reset_qconfig(qat_net, passive_qconfig)
assert (
new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams()
)
assert (
new_qat_net.linear.get_activation_qparams()
== qat_net.linear.get_activation_qparams()
)
def test_enable_and_disable_observer():
net = init_qat_net()
enable_observer(net)
assert net.quant.act_observer.enabled == True
assert net.linear.weight_observer.enabled == True
assert net.linear.act_observer.enabled == True
disable_observer(net)
assert net.quant.act_observer.enabled == False
assert net.linear.weight_observer.enabled == False
assert net.linear.act_observer.enabled == False
def test_enable_and_disable_fake_quant():
net = init_qat_net()
disable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == False
assert net.linear.weight_fake_quant.enabled == False
assert net.linear.act_fake_quant.enabled == False
enable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == True
assert net.linear.weight_fake_quant.enabled == True
assert net.linear.act_fake_quant.enabled == True
def init_observer(module, data):
enable_observer(module)
disable_fake_quant(module)
module(data)
disable_observer(module)
enable_fake_quant(module)
def test_enable_and_disable_all():
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
init_observer(net, x)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
enable_fake_quant(net)
y4 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
np.testing.assert_allclose(y2, y4)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)
def test_quantize_qat():
net = Net()
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig)
assert isinstance(qat_net.quant, QAT.QuantStub)
assert isinstance(qat_net.linear, QAT.Linear)
assert isinstance(qat_net.dequant, QAT.DequantStub)
def test_quantize():
qat_net = init_qat_net()
q_net = quantize(qat_net, inplace=False)
assert isinstance(q_net.quant, Q.QuantStub)
assert isinstance(q_net.linear, Q.Linear)
assert isinstance(q_net.dequant, Q.DequantStub)
def test_apply_easy_quant():
qat_net = init_qat_net()
data = tensor(np.random.rand(2, 3, 3, 3), dtype=np.float32)
eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False)
apply_easy_quant(eq_net, data, 0.9, 1.1, 10)
assert isinstance(eq_net.quant.act_observer, PassiveObserver)
assert isinstance(eq_net.linear.weight_observer, PassiveObserver)
assert isinstance(eq_net.linear.act_observer, PassiveObserver)
assert eq_net.dequant.act_observer is None
def test_apply_tqt():
qat_net = init_qat_net()
tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False)
assert isinstance(tqt_net.quant.act_fake_quant, TQT)
assert isinstance(tqt_net.linear.weight_fake_quant, TQT)
assert isinstance(tqt_net.linear.act_fake_quant, TQT)
assert tqt_net.dequant.act_fake_quant is None
def test_get_quantable_module_names():
# need to make sure names from Quantized and QAT are the same
def _get_qat_module_names():
......@@ -87,30 +264,3 @@ def test_convert_with_custom_mapping():
net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)
def test_disable_fake_quant():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))
def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册