提交 44bafd3f 编写于 作者: M Megvii Engine Team

fix(imperative/quantization): fix zero scale bug of easy quant

GitOrigin-RevId: f45e19b3e4d1988330b137386863bc9b80ffab48
上级 8aced67b
......@@ -13,6 +13,7 @@ import numpy as np
from .. import module as Float
from ..functional import concat, norm
from ..logger import get_logger
from ..module import Module
from ..module import qat as QAT
from ..module import quantized as Quantized
......@@ -22,6 +23,8 @@ from ..tensor import Tensor
from ..utils.module_utils import set_expand_structure
from .qconfig import QConfig, ema_fakequant_qconfig
logger = get_logger(__name__)
def _get_quantable_module_names():
def is_quantable(key: str):
......@@ -236,16 +239,18 @@ def apply_easy_quant(
return
orig_scale = ob.orig_scale
distance = 0
best_scale = 0
cosine = optimal = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale
fakequant_out = mod(*fakequant_in)
dis = get_cosine(normal_out, fakequant_out)
if dis > distance:
distance = dis
best_scale = scale
ob.scale = best_scale
if dis > cosine:
cosine = dis
optimal = scale
if optimal == 0:
logger.warning("EasyQuant finds no better scale")
else:
ob.scale = optimal
fakequant_out = outputs[batch_size:]
return concat([normal_out, fakequant_out])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册