提交 992a90bb 编写于 作者: M Megvii Engine Team

docs(mge/quantization): add docstring for Observer

GitOrigin-RevId: 043be3886dc05205426fd60c9af3bc6172c70ce9
上级 c45f1eb2
...@@ -26,9 +26,10 @@ logger = get_logger(__name__) ...@@ -26,9 +26,10 @@ logger = get_logger(__name__)
class Observer(Module, QParamsModuleMixin): class Observer(Module, QParamsModuleMixin):
r""" r"""
A base class for Observer Module. A base class for Observer Module. Used to record input tensor's statistics for
quantization.
:param dtype: a string indicating to collect scale and zero_point of which dtype. :param dtype: a string indicating which dtype to collect scale and zero_point of.
""" """
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
...@@ -72,6 +73,14 @@ class Observer(Module, QParamsModuleMixin): ...@@ -72,6 +73,14 @@ class Observer(Module, QParamsModuleMixin):
class MinMaxObserver(Observer): class MinMaxObserver(Observer):
r"""
A Observer Module records input tensor's running min and max values to calc scale.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__( def __init__(
self, self,
mode: QuantMode = QuantMode.SYMMERTIC, mode: QuantMode = QuantMode.SYMMERTIC,
...@@ -119,6 +128,14 @@ class MinMaxObserver(Observer): ...@@ -119,6 +128,14 @@ class MinMaxObserver(Observer):
class SyncMinMaxObserver(MinMaxObserver): class SyncMinMaxObserver(MinMaxObserver):
r"""
A distributed version of :class:`~.MinMaxObserver`.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def forward(self, x_orig): def forward(self, x_orig):
if self.enable: if self.enable:
x = x_orig.detach() x = x_orig.detach()
...@@ -134,6 +151,15 @@ class SyncMinMaxObserver(MinMaxObserver): ...@@ -134,6 +151,15 @@ class SyncMinMaxObserver(MinMaxObserver):
class ExponentialMovingAverageObserver(MinMaxObserver): class ExponentialMovingAverageObserver(MinMaxObserver):
r"""
A :class:`~.MinMaxObserver` with momentum support for min/max updating.
:param momentum: momentum ratio for min/max updating.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__( def __init__(
self, self,
momentum: float = 0.9, momentum: float = 0.9,
...@@ -170,6 +196,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver): ...@@ -170,6 +196,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
r"""
A distributed version of :class:`~.ExponentialMovingAverageObserver`.
:param momentum: momentum ratio for min/max updating.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def forward(self, x_orig): def forward(self, x_orig):
if self.enabled: if self.enabled:
x = x_orig.detach() x = x_orig.detach()
...@@ -192,6 +227,17 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): ...@@ -192,6 +227,17 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
class HistogramObserver(MinMaxObserver): class HistogramObserver(MinMaxObserver):
r"""
A :class:`~.MinMaxObserver` using running histogram of tensor values
for min/max updating. Usually used for calibration quantization.
:param bins: number of bins to use for the histogram.
:param upsample_rate: which ratio to interpolate histograms in.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""
def __init__( def __init__(
self, self,
bins: int = 2048, bins: int = 2048,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册