提交 d8ac6c70 编写于 作者: M Megvii Engine Team

fix(quantize): fix quantize calibration dtype issue

GitOrigin-RevId: 667c99f469054134efd006aa8f54fca22c4c85b6
上级 4afa4b72
...@@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver): ...@@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver):
self.bins = bins self.bins = bins
self.upsample_rate = upsample_rate self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
self.histogram = Tensor([-1] + [0.0] * (bins - 1)) self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
def _non_linear_param_search(self): def _non_linear_param_search(self):
r"""Non-linear parameter search. r"""Non-linear parameter search.
...@@ -304,8 +304,8 @@ class HistogramObserver(MinMaxObserver): ...@@ -304,8 +304,8 @@ class HistogramObserver(MinMaxObserver):
start_bin = next_start_bin start_bin = next_start_bin
end_bin = next_end_bin end_bin = next_end_bin
new_min = self.min_val + bin_width * start_bin new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
new_max = self.min_val + bin_width * (end_bin + 1) new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
return new_min, new_max return new_min, new_max
def get_qparams(self): def get_qparams(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册