提交 caf1fac2 编写于 作者: M Megvii Engine Team

refactor(mge/quantization): split `QATModule` and refactor convert api

GitOrigin-RevId: 80cfb12d10590bbc88fd98370f5e3cf5d196d586
上级 ad3c9315
......@@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
"""Applies a linear transformation to the input.
Refer to :class:`~.Linear` for more information.
Refer to :class:`~.module.linear.Linear` for more information.
:param inp: the input tensor with shape `(N, in_features)`.
:param weight: the weight with shape `(out_features, in_features)`.
:param weight: the weight with shape `(out_features, in_features)`.
:param bias: the bias with shape `(out_features,)`.
Default: ``None``
"""
......@@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor:
r"""
Performs the elementwise function:
.. math::
\mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta.
For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`.
......
......@@ -16,7 +16,7 @@ from .elemwise import Elemwise
from .embedding import Embedding
from .identity import Identity
from .linear import Linear
from .module import Module, QATModule
from .module import Module
from .parampack import ParamPack
from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub
......
......@@ -9,19 +9,14 @@ from typing import Iterable
from .. import functional as F
from ..core.tensor import Tensor
from .module import QATModule
from .module import Module
class Concat(QATModule):
class Concat(Module):
r"""
A :class:`~.QATModule` to do functional concat, should replace concat with this module,
supporting ``qat`` mode and ``quantized`` mode.
A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`.
"""
def forward(self, inps: Iterable[Tensor], axis: int = 0):
return F.concat(inps, axis)
def forward_qat(self, inps: Iterable[Tensor], axis: int = 0):
return self.apply_fakequant_with_observer(
self.forward(inps, axis), self.act_fake_quant, self.act_observer
)
......@@ -7,14 +7,13 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union
from ..core import ones, zeros
from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad
from ..functional import relu
from .batchnorm import BatchNorm2d
from .conv import Conv2d
from .module import QATModule
from .module import Module
class _ConvBn2d(QATModule):
class _ConvBnActivation2d(Module):
def __init__(
self,
in_channels: int,
......@@ -47,171 +46,24 @@ class _ConvBn2d(QATModule):
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out
sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold
def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_fakequant_with_observer(
w_fold, self.weight_fake_quant, self.weight_observer
)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv
# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv
class ConvBn2d(_ConvBn2d):
class ConvBn2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode
and ``normal`` mode.
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using
:func:`~.quantize.quantize_qat`.
"""
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer
)
def forward(self, inp):
return self.bn(self.conv(inp))
class ConvBnRelu2d(_ConvBn2d):
class ConvBnRelu2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat``
mode and ``normal`` mode.
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using
:func:`~.quantize.quantize_qat`.
"""
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer
)
def forward(self, inp):
return relu(self.bn(self.conv(inp)))
......@@ -8,7 +8,7 @@
from .. import _internal as mgb
from ..core import Tensor, wrap_io_tensor
from ..core.graph import _use_default_if_none
from .module import QATModule
from .module import Module
@wrap_io_tensor
......@@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor:
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs)
class Elemwise(QATModule):
class Elemwise(Module):
r"""
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module,
supporting ``qat`` mode and ``normal`` mode.
A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`.
:param method: the elemwise method, support the following string.
It will do the normal elemwise operator for float.
......@@ -88,8 +88,3 @@ class Elemwise(QATModule):
def forward(self, *inps):
return _elemwise_func(self.method, *inps)
def forward_qat(self, *inps):
return self.apply_fakequant_with_observer(
self.forward(*inps), self.act_fake_quant, self.act_observer,
)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -11,10 +10,10 @@ import numpy as np
from .. import functional as F
from ..core import Parameter
from . import init
from .module import QATModule
from .module import Module
class Linear(QATModule):
class Linear(Module):
r"""Applies a linear transformation to the input. For instance, if input
is x, then output y is:
......@@ -60,13 +59,3 @@ class Linear(QATModule):
def forward(self, x):
return self._calc_linear(x, self.weight, self.bias)
def forward_qat(self, x):
w_qat = self.apply_fakequant_with_observer(
self.weight, self.weight_fake_quant, self.weight_observer
)
return self.apply_fakequant_with_observer(
self._calc_linear(x, w_qat, self.bias),
self.act_fake_quant,
self.act_observer,
)
......@@ -7,7 +7,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import numpy as np
......@@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta):
loaded.append(k)
return set(loaded), set(skipped)
class QATModule(Module):
r"""
Base class of quantization related Module. Add extra forward methods
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for
``qat``(quantization aware training) mode and ``quantized`` mode respectively.
Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode,
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode,
which is irreversible.
If you want to recursively switch mode for all QATModule in network, use
functions in :mod:`~.quantization.quantize`.
"""
class QATMode(Enum):
DISABLED = 1
QAT = 2
CALIBRATION = 3
def __init__(self):
from ..quantization import (
QConfig,
FakeQuantize,
Observer,
) # pylint: disable=all
super().__init__()
self.quantizing = self.QATMode.DISABLED
self.scale = None
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
def set_qconfig(self, qconfig: "QConfig"):
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()
self.weight_fake_quant = (
None
if qconfig.fake_quant is None
else qconfig.fake_quant(self.weight_observer.dtype)
)
self.act_fake_quant = (
None
if qconfig.fake_quant is None
else qconfig.fake_quant(self.act_observer.dtype)
)
def apply_observer(self, target: Tensor, obs: "Observer"):
return obs(target)
def apply_fakequant_with_observer(
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
if fq is not None:
q_dict = obs.get_qparams()
oup = fq(oup, q_dict)
return oup
def set_qat_mode(self, mode: QATMode):
r"""
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``,
``QAT``,``CALIBRATION``.
"""
if not isinstance(mode, self.QATMode):
raise TypeError("mode must be QATMode Enum type")
self.quantizing = mode
def to_quantized(self):
r"""
Return a new :class:`~.Module` with quantized parameters of ``self``
according to scale and zero_point in ``self.xxx_observer``.
"""
raise NotImplementedError(
"Use megengine.quantization.quantize to register the method."
)
@abstractmethod
def forward_qat(self, *args, **kwargs):
r"""
Forward method for ``qat`` mode.
"""
def __call__(self, *args, **kwargs):
if self.quantizing == self.QATMode.DISABLED:
return self.forward(*args, **kwargs)
else:
return self.forward_qat(*args, **kwargs)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QATModule
from .quant_dequant import DequantStub, QuantStub
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable
from ...core.tensor import Tensor
from .. import concat as Float
from .module import QATModule
class Concat(Float.Concat, QATModule):
r"""
A :class:`~.QATModule` to do functional concat with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def forward(self, inps: Iterable[Tensor], axis: int = 0):
return self.apply_quant_activation(super().forward(inps, axis))
@classmethod
def from_float_module(cls, float_module):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad
from .. import conv_bn_relu as Float
from .module import QATModule
class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out
sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var
def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold
def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var
# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
w_qat = self.apply_quant_weight(w_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv
# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv
@classmethod
def from_float_module(cls, float_module: Float._ConvBnActivation2d):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qat_module = cls(
float_module.conv.in_channels,
float_module.conv.out_channels,
float_module.conv.kernel_size,
float_module.conv.stride,
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
bool(float_module.conv.bias),
float_module.conv.conv_mode.name,
float_module.conv.compute_mode.name,
)
qat_module.conv.weight = float_module.conv.weight
qat_module.conv.bias = float_module.conv.bias
qat_module.bn = float_module.bn
return qat_module
class ConvBn2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_bn_qat(inp))
class ConvBnRelu2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp)))
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import elemwise as Float
from .module import QATModule
class Elemwise(Float.Elemwise, QATModule):
r"""
A :class:`~.QATModule` to do elemwise operator with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
"""
def forward(self, *inps):
return self.apply_quant_activation(super().forward(*inps))
@classmethod
def from_float_module(cls, float_module: Float.Elemwise):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls(float_module.method.name)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import linear as Float
from .module import QATModule
class Linear(Float.Linear, QATModule):
r"""
A :class:`~.QATModule` version of :class:`~.module.linear.Linear`.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
:param in_features: size of each input sample.
:param out_features: size of each output sample.
:param bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
"""
def forward(self, x):
w_qat = self.apply_quant_weight(self.weight)
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),)
@classmethod
def from_float_module(cls, float_module: Float.Linear):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qmod = cls(float_module.in_features, float_module.out_features)
qmod.weight = float_module.weight
qmod.bias = float_module.bias
return qmod
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
from ...core import Tensor
from ...quantization import FakeQuantize, Observer, QConfig
from ..module import Module
class QATModule(Module):
r"""
Base class of quantized-float related Module, basically for QAT and Calibration.
Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`.
Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.
Can also be converted to :class:`~.QuantizedModule` for deployment using
:func:`~.quantize.quantize` further.
"""
def __init__(self):
super().__init__()
self.scale = None
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
def set_qconfig(self, qconfig: QConfig):
r"""
Set quantization related configs with ``qconfig``, including
observer and fake_quant for weight and activation.
"""
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()
if qconfig.fake_quant is None:
self.weight_fake_quant = None
self.act_fake_quant = None
else:
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)
def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
):
oup = observer(target)
if fake_quant is None:
return oup
else:
q_dict = observer.get_qparams()
return fake_quant(oup, q_dict)
def apply_quant_weight(self, target: Tensor):
r"""
Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
"""
return self._apply_fakequant_with_observer(
target, self.weight_fake_quant, self.weight_observer
)
def apply_quant_activation(self, target: Tensor):
r"""
Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
"""
return self._apply_fakequant_with_observer(
target, self.act_fake_quant, self.act_observer
)
def get_weight_dtype(self):
r"""
Get weight's quantization dtype as the method from ``qconfig``.
"""
return self.weight_observer.get_dtype()
def get_activation_dtype(self):
r"""
Get activation's quantization dtype as the method from ``qconfig``.
"""
return self.act_observer.get_dtype()
@classmethod
@abstractmethod
def from_float_module(cls, float_module: Module):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import quant_dequant as Float
from .module import QATModule
class QuantStub(Float.QuantStub, QATModule):
r"""
A helper QATModule simply return input, but will quantize
input after converted to :class:`~.QuantizedModule`.
"""
def forward(self, inp):
return self.apply_quant_activation(inp)
@classmethod
def from_float_module(cls, float_module: Float.QuantStub):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
class DequantStub(Float.DequantStub, QATModule):
r"""
A helper QATModule simply return input, but will de-quantize
input after converted to :class:`~.QuantizedModule`.
"""
def forward(self, inp):
return inp
@classmethod
def from_float_module(cls, float_module: Float.DequantStub):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()
......@@ -5,30 +5,24 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .module import QATModule
from .module import Module
class QuantStub(QATModule):
class QuantStub(Module):
r"""
A helper QATModule doing quantize operation on input.
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.QuantStub` using :func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return inp
def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
inp, self.act_fake_quant, self.act_observer
)
class DequantStub(QATModule):
class DequantStub(Module):
r"""
A helper QATModule doing de-quantize operation on input.
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.DequantStub` using :func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return inp
def forward_qat(self, inp):
return inp
......@@ -9,4 +9,5 @@ from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QuantizedModule
from .quant_dequant import DequantStub, QuantStub
......@@ -7,17 +7,15 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable
from ... import _internal as mgb
from ... import functional as F
from ... import module as Float
from ...core.tensor import Tensor
from ...quantization.utils import register_method_to_class
from ..module import Module
from ..qat import concat as QAT
from .module import QuantizedModule
class Concat(Module):
class Concat(QuantizedModule):
r"""
A :class:`~.Module` to do quantized concat, inference only.
A :class:`~.QuantizedModule` to do quantized concat, inference only.
"""
def __init__(self, dtype=None):
......@@ -25,16 +23,13 @@ class Concat(Module):
self.output_dtype = dtype
def forward(self, inps: Iterable[Tensor], axis: int = 0):
if self.training:
raise ValueError("quantized module only support inference.")
new_inps = (x.astype(self.output_dtype) for x in inps)
return F.concat(new_inps, axis)
@register_method_to_class(Float.Concat)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return Concat(float_module.act_observer.get_dtype())
@classmethod
def from_qat_module(cls, qat_module: QAT.Concat):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.get_activation_dtype())
......@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
from typing import Tuple, Union
import megengine._internal as mgb
......@@ -13,11 +12,11 @@ import megengine._internal as mgb
from ... import module as Float
from ...core import Parameter
from ...functional import conv_bias_activation
from ...module import Conv2d
from ...quantization.utils import register_method_to_class
from ..qat import conv_bn_relu as QAT
from .module import QuantizedModule
class _ConvBnActivation2d(Conv2d):
class _ConvBnActivation2d(Float.Conv2d, QuantizedModule):
r"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d`
......@@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d):
nonlinear_mode=nonlinear_mode,
)
@classmethod
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv = cls(
qat_module.conv.in_channels,
qat_module.conv.out_channels,
qat_module.conv.kernel_size,
qat_module.conv.stride,
qat_module.conv.padding,
qat_module.conv.dilation,
qat_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
return qconv
class ConvBn2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`."""
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
class ConvBnRelu2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`."""
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
def to_quantized(quantized_class, float_module):
output_dtype = float_module.act_observer.get_dtype()
qconv = quantized_class(
float_module.conv.in_channels,
float_module.conv.out_channels,
float_module.conv.kernel_size,
float_module.conv.stride,
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = float_module.fold_weight_bias(
float_module.bn.running_mean, float_module.bn.running_var
)
weight = w_fold.astype(float_module.weight_observer.get_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
return qconv
# replace :class:`~.module.QATModule`'s ``to_quantized`` method.
# implemented here to avoid circular import.
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d))
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d))
......@@ -6,11 +6,10 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb
from ... import module as Float
from ...core import Tensor, wrap_io_tensor
from ...core.graph import _use_default_if_none
from ...quantization.utils import register_method_to_class
from ..module import Module
from ..qat import elemwise as QAT
from .module import QuantizedModule
@wrap_io_tensor
......@@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor:
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs)
class Elemwise(Module):
r"""
quantized module for elemwise operator, inference only.
:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`.
it will do quantized operator with specified output quantized dtype.
"""
class Elemwise(QuantizedModule):
r"""quantized version of :class:`~.qat.elemwise.Elemwise`."""
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode
......@@ -44,11 +38,10 @@ class Elemwise(Module):
raise ValueError("quantized module only support inference.")
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype)
@register_method_to_class(Float.Elemwise)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return Elemwise(float_module.method.name, float_module.act_observer.get_dtype())
@classmethod
def from_qat_module(cls, qat_module: QAT.Elemwise):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.method.name, qat_module.get_activation_dtype())
......@@ -10,19 +10,13 @@ import numpy as np
import megengine._internal as mgb
from ... import functional as F
from ... import module as Float
from ...core import Parameter
from ...quantization.utils import register_method_to_class
from ..module import Module
from ..qat import linear as QAT
from .module import QuantizedModule
class Linear(Module):
r"""Applies a quantized linear transformation to the input. The module
usually convert from QAT module by to_quantized method.
:param dtype: output data type.
"""
class Linear(QuantizedModule):
r"""quantized version of :class:`~.qat.linear.Linear`."""
def __init__(
self, dtype: np.dtype = None,
......@@ -44,17 +38,16 @@ class Linear(Module):
None if self.bias is None else self.bias.astype(bias_dtype),
).astype(self.output_dtype)
@register_method_to_class(Float.Linear)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
output_dtype = float_module.act_observer.get_dtype()
qmod = Linear(dtype=output_dtype,)
weight = float_module.weight.astype(float_module.weight_observer.get_dtype())
qmod.weight = Parameter(weight.numpy())
if float_module.bias is not None:
qmod.bias = Parameter(float_module.bias.numpy())
return qmod
@classmethod
def from_qat_module(cls, qat_module: QAT.Linear):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qmod = cls(dtype=output_dtype)
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qmod.weight = Parameter(weight.numpy())
if qat_module.bias is not None:
qmod.bias = Parameter(qat_module.bias.numpy())
return qmod
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
from ..module import Module
from ..qat import QATModule
class QuantizedModule(Module):
r"""
Base class of quantized Module, which should be converted from QATModule
and not support traning.
"""
def __call__(self, *inputs, **kwargs):
if self.training:
raise ValueError("quantized module only support inference.")
return super().__call__(*inputs, **kwargs)
@classmethod
@abstractmethod
def from_qat_module(cls, qat_module: QATModule):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
......@@ -5,15 +5,14 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb
from ... import module as Float
from ...quantization.utils import register_method_to_class
from ..module import Module
from ..qat import quant_dequant as QAT
from .module import QuantizedModule
class QuantStub(Module):
class QuantStub(QuantizedModule):
r"""
A helper quantize operation on input and inference only.
quantized version of :class:`~.qat.quant_dequant.QuantStub`,
will convert input to quantized dtype.
"""
def __init__(self, dtype=None):
......@@ -21,35 +20,30 @@ class QuantStub(Module):
self.output_dtype = dtype
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype(self.output_dtype)
@classmethod
def from_qat_module(cls, qat_module: QAT.QuantStub):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.get_activation_dtype())
class DequantStub(Module):
class DequantStub(QuantizedModule):
r"""
A helper de-quantize operation and inference only.
quantized version of :class:`~.qat.quant_dequant.DequantStub`,
will restore quantized input to float32 dtype.
"""
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype("float32")
@register_method_to_class(Float.QuantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return QuantStub(float_module.act_observer.get_dtype())
@register_method_to_class(Float.DequantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return DequantStub()
@classmethod
def from_qat_module(cls, qat_module: QAT.DequantStub):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls()
......@@ -13,12 +13,3 @@ from .qconfig import (
ema_fakequant_qconfig,
min_max_fakequant_qconfig,
)
from .quantize import (
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
quantize,
quantize_calibration,
quantize_qat,
)
......@@ -15,16 +15,12 @@ from .observer import (
class QConfig:
"""
r"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation`` and ``weight``.
And ``fake_quant`` parameter to indicate
See :meth:`~.QATModule.set_qconfig` for detail usage.
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating
- how to collect scales and zero_point of wegiht.
how to collect scales and zero_point of wegiht.
:param act_observer: similar to ``weight_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different
......
......@@ -6,68 +6,125 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from ..module import Module, QATModule, Sequential, quantized
from typing import Dict, Tuple
from .. import module as Float
from ..module import Module
from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from .qconfig import QConfig, ema_fakequant_qconfig
def _get_quantable_module_names():
def is_quantable(key: str):
value = getattr(Quantized, key)
return (
isinstance(value, type)
and issubclass(value, QuantizedModule)
and value != QuantizedModule
)
# source should have all quantable modules' names
quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)]
return quantable_module_names
def _get_convert_dict() -> Tuple[
Dict[Module, QATModule], Dict[QATModule, QuantizedModule]
]:
quantable_module_names = _get_quantable_module_names()
quantable_modules = [getattr(Float, key) for key in quantable_module_names]
qat_modules = [getattr(QAT, key) for key in quantable_module_names]
quantized_modules = [getattr(Quantized, key) for key in quantable_module_names]
float2qat_dict = dict(zip(quantable_modules, qat_modules))
qat2quantized_dict = dict(zip(qat_modules, quantized_modules))
return float2qat_dict, qat2quantized_dict
_float2qat_dict, _qat2quantized_dict = _get_convert_dict()
def quantize(module: Module, inplace=True):
r"""
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`.
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`.
:param module: root module to do convert recursively.
:param inplace: whether to convert submodules in-place.
"""
if not inplace:
module = deepcopy(module)
def is_qat_module(obj):
return isinstance(obj, QATModule)
qat_modules = tuple(_qat2quantized_dict.keys())
def is_qat(mod: Module):
return isinstance(mod, qat_modules)
# no need to pass prefix and get pure key of parent Module.
for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_qat_module
with_key=True, with_parent=True, predicate=is_qat
):
if isinstance(parent, Sequential):
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = submodule.to_quantized()
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], submodule.to_quantized())
setattr(parent, key.split(".")[-1], new_mod)
return module
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
def quantize_qat(
module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig
):
r"""
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply`
and set qconfig relatively.
Recursively convert float :class:`~.Module` to :class:`~.QATModule`
through :meth:`~.Module.apply` and set qconfig relatively.
:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
:param inplace: whether to convert submodules in-place.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is ``ema_fakequant_qconfig``.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.QAT)
mod.set_qconfig(qconfig)
if not inplace:
module = deepcopy(module)
module.apply(fn)
quantable_modules = tuple(_float2qat_dict.keys())
def is_quantable(mod: Module):
return isinstance(mod, quantable_modules)
# no need to pass prefix and get pure key of parent Module.
for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_quantable
):
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], new_mod)
propagate_qconfig(module, qconfig)
return module
def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
def propagate_qconfig(module: QATModule, qconfig: QConfig):
r"""
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply`
and set qconfig relatively.
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
:param module: root module to do convert recursively.
:param module: root module to traverse recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.CALIBRATION)
mod.set_qconfig(qconfig)
module.apply(fn)
......
......@@ -5,8 +5,7 @@ import numpy as np
from megengine import tensor
from megengine.module import ConvBn2d
from megengine.quantization import quantize_qat
from megengine.quantization.quantize import disable_fake_quant
from megengine.quantization.quantize import disable_fake_quant, quantize_qat
from megengine.test import assertTensorClose
......@@ -14,18 +13,17 @@ def test_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
module = ConvBn2d(in_channels, out_channels, kernel_size)
quantize_qat(module)
for groups, bias in product([1, 4], [True, False]):
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
module.train()
qat_module = copy.deepcopy(module)
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
normal_outputs = module.forward(inputs)
qat_outputs = qat_module.forward_qat(inputs)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
a = module.bn.running_mean.numpy()
b = qat_module.bn.running_mean.numpy()
assertTensorClose(
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8
)
......@@ -33,7 +31,7 @@ def test_convbn2d():
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7
)
module.eval()
normal_outputs = module.forward(inputs)
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module.forward_qat(inputs)
qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import module as Float
from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names
def test_get_quantable_module_names():
# need to make sure names from Quantized and QAT are the same
def _get_qat_module_names():
def is_qat(key: str):
value = getattr(QAT, key)
return (
isinstance(value, type)
and issubclass(value, QAT.QATModule)
and value != QAT.QATModule
)
# source should have all quantable modules' names
quantable_module_names = [key for key in dir(QAT) if is_qat(key)]
return quantable_module_names
qat_module_names = _get_qat_module_names()
quantized_module_names = _get_quantable_module_names()
assert set(qat_module_names) == set(quantized_module_names)
for key in qat_module_names:
value = getattr(Float, key)
assert (
isinstance(value, type)
and issubclass(value, Float.Module)
and value != Float.Module
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册