提交 44bafd3f 编写于 作者: M Megvii Engine Team

fix(imperative/quantization): fix zero scale bug of easy quant

GitOrigin-RevId: f45e19b3e4d1988330b137386863bc9b80ffab48
上级 8aced67b
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
from .. import module as Float from .. import module as Float
from ..functional import concat, norm from ..functional import concat, norm
from ..logger import get_logger
from ..module import Module from ..module import Module
from ..module import qat as QAT from ..module import qat as QAT
from ..module import quantized as Quantized from ..module import quantized as Quantized
...@@ -22,6 +23,8 @@ from ..tensor import Tensor ...@@ -22,6 +23,8 @@ from ..tensor import Tensor
from ..utils.module_utils import set_expand_structure from ..utils.module_utils import set_expand_structure
from .qconfig import QConfig, ema_fakequant_qconfig from .qconfig import QConfig, ema_fakequant_qconfig
logger = get_logger(__name__)
def _get_quantable_module_names(): def _get_quantable_module_names():
def is_quantable(key: str): def is_quantable(key: str):
...@@ -236,16 +239,18 @@ def apply_easy_quant( ...@@ -236,16 +239,18 @@ def apply_easy_quant(
return return
orig_scale = ob.orig_scale orig_scale = ob.orig_scale
distance = 0 cosine = optimal = 0
best_scale = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num): for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale ob.scale = scale
fakequant_out = mod(*fakequant_in) fakequant_out = mod(*fakequant_in)
dis = get_cosine(normal_out, fakequant_out) dis = get_cosine(normal_out, fakequant_out)
if dis > distance: if dis > cosine:
distance = dis cosine = dis
best_scale = scale optimal = scale
ob.scale = best_scale if optimal == 0:
logger.warning("EasyQuant finds no better scale")
else:
ob.scale = optimal
fakequant_out = outputs[batch_size:] fakequant_out = outputs[batch_size:]
return concat([normal_out, fakequant_out]) return concat([normal_out, fakequant_out])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册