提交 992a90bb 编写于 作者: M Megvii Engine Team

docs(mge/quantization): add docstring for Observer

GitOrigin-RevId: 043be3886dc05205426fd60c9af3bc6172c70ce9
上级 c45f1eb2
......@@ -26,9 +26,10 @@ logger = get_logger(__name__)
class Observer(Module, QParamsModuleMixin):
r"""
A base class for Observer Module.
A base class for Observer Module. Used to record input tensor's statistics for
quantization.
:param dtype: a string indicating to collect scale and zero_point of which dtype.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
......@@ -72,6 +73,14 @@ class Observer(Module, QParamsModuleMixin):
class MinMaxObserver(Observer):
r"""
A Observer Module records input tensor's running min and max values to calc scale.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__(
self,
mode: QuantMode = QuantMode.SYMMERTIC,
......@@ -119,6 +128,14 @@ class MinMaxObserver(Observer):
class SyncMinMaxObserver(MinMaxObserver):
r"""
A distributed version of :class:`~.MinMaxObserver`.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def forward(self, x_orig):
if self.enable:
x = x_orig.detach()
......@@ -134,6 +151,15 @@ class SyncMinMaxObserver(MinMaxObserver):
class ExponentialMovingAverageObserver(MinMaxObserver):
r"""
A :class:`~.MinMaxObserver` with momentum support for min/max updating.
:param momentum: momentum ratio for min/max updating.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__(
self,
momentum: float = 0.9,
......@@ -170,6 +196,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
r"""
A distributed version of :class:`~.ExponentialMovingAverageObserver`.
:param momentum: momentum ratio for min/max updating.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def forward(self, x_orig):
if self.enabled:
x = x_orig.detach()
......@@ -192,6 +227,17 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
class HistogramObserver(MinMaxObserver):
r"""
A :class:`~.MinMaxObserver` using running histogram of tensor values
for min/max updating. Usually used for calibration quantization.
:param bins: number of bins to use for the histogram.
:param upsample_rate: which ratio to interpolate histograms in.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__(
self,
bins: int = 2048,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册