提交 e6820b91 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mge/module): add conv and conv_relu quantization module

GitOrigin-RevId: 9cd668d97b4ccae8adfa801fd43856cd3fdca813
上级 a1f8ecc7
......@@ -9,8 +9,8 @@
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from .concat import Concat
from .conv import Conv2d, ConvTranspose2d, LocalConv2d
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .dropout import Dropout
from .elemwise import Elemwise
from .embedding import Embedding
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -13,8 +12,8 @@ import numpy as np
import megengine._internal as mgb
from .. import functional as F
from ..core import Parameter
from ..functional import conv2d, conv_transpose2d, local_conv2d
from ..utils.types import _pair, _pair_nonzero
from . import init
from .module import Module
......@@ -183,7 +182,7 @@ class Conv2d(_ConvNd):
return (1, self.out_channels, 1, 1)
def calc_conv(self, inp, weight, bias):
return conv2d(
return F.conv2d(
inp,
weight,
bias,
......@@ -295,7 +294,7 @@ class ConvTranspose2d(_ConvNd):
return (1, self.out_channels, 1, 1)
def forward(self, inp):
return conv_transpose2d(
return F.conv_transpose2d(
inp,
self.weight,
self.bias,
......@@ -324,7 +323,7 @@ class LocalConv2d(Conv2d):
spatial dimensions. Only zero-padding is supported. Default: 0
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``.
``in_channels`` and ``out_channels`` must be divisible by ``groups``.
The shape of weight is ``(groups, output_height, output_width,
in_channels // groups, *kernel_size, out_channels // groups)``.
"""
......@@ -377,6 +376,17 @@ class LocalConv2d(Conv2d):
)
def forward(self, inp):
return local_conv2d(
return F.local_conv2d(
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode
)
class ConvRelu2d(Conv2d):
r"""
A fused :class:`~.Module` including Conv2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using
:func:`~.quantize.quantize_qat`.
"""
def forward(self, inp):
return F.relu(self.calc_conv(inp, self.weight, self.bias))
......@@ -50,7 +50,7 @@ class _ConvBnActivation2d(Module):
class ConvBn2d(_ConvBnActivation2d):
r"""
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using
:func:`~.quantize.quantize_qat`.
"""
......@@ -61,7 +61,7 @@ class ConvBn2d(_ConvBnActivation2d):
class ConvBnRelu2d(_ConvBnActivation2d):
r"""
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using
:func:`~.quantize.quantize_qat`.
"""
......
......@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .conv import Conv2d, ConvRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QATModule
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F
from .. import conv as Float
from .module import QATModule
class Conv2d(Float.Conv2d, QATModule):
r"""
A :class:`~.QATModule` Conv2d with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
conv = self.calc_conv(inp, w_qat, self.bias)
return conv
@classmethod
def from_float_module(cls, float_module: Float.Conv2d):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qat_module = cls(
float_module.in_channels,
float_module.out_channels,
float_module.kernel_size,
float_module.stride,
float_module.padding,
float_module.dilation,
float_module.groups,
float_module.bias is not None,
float_module.conv_mode.name,
float_module.compute_mode.name,
)
qat_module.weight = float_module.weight
qat_module.bias = float_module.bias
return qat_module
def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_qat(inp))
class ConvRelu2d(Conv2d):
r"""
A :class:`~.QATModule` include Conv2d and Relu with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""
def forward(self, inp):
return self.apply_quant_activation(F.relu(self.calc_conv_qat(inp)))
......@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad
from .. import conv_bn_relu as Float
from .. import conv_bn as Float
from .module import QATModule
......@@ -163,7 +163,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
bool(float_module.conv.bias),
float_module.conv.bias is not None,
float_module.conv.conv_mode.name,
float_module.conv.compute_mode.name,
)
......
......@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .conv import Conv2d, ConvRelu2d
from .conv_bn import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QuantizedModule
......
......@@ -7,16 +7,19 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union
import numpy as np
import megengine._internal as mgb
from ... import module as Float
from ...core import Parameter
from ...functional import conv_bias_activation
from ..qat import conv_bn_relu as QAT
from ..qat import conv as QAT
from .module import QuantizedModule
class _ConvBnActivation2d(Float.Conv2d, QuantizedModule):
class Conv2d(Float.Conv2d, QuantizedModule):
r"""quantized version of :class:`~.qat.conv.Conv2d`."""
r"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d`
......@@ -68,40 +71,38 @@ class _ConvBnActivation2d(Float.Conv2d, QuantizedModule):
)
@classmethod
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d):
def from_qat_module(cls, qat_module: QAT.Conv2d):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv = cls(
qat_module.conv.in_channels,
qat_module.conv.out_channels,
qat_module.conv.kernel_size,
qat_module.conv.stride,
qat_module.conv.padding,
qat_module.conv.dilation,
qat_module.conv.groups,
qat_module.in_channels,
qat_module.out_channels,
qat_module.kernel_size,
qat_module.stride,
qat_module.padding,
qat_module.dilation,
qat_module.groups,
dtype=output_dtype,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
if qat_module.bias is not None:
qconv.bias = Parameter(qat_module.bias.numpy())
else:
qconv.bias = Parameter(
np.zeros(qat_module._infer_bias_shape(), dtype=np.float32)
)
return qconv
class ConvBn2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
class ConvBnRelu2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`."""
class ConvRelu2d(Conv2d):
r"""quantized version of :class:`~.qat.conv.ConvRelu2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import Parameter
from ..qat import conv_bn as QAT
from .conv import Conv2d
class _ConvBnActivation2d(Conv2d):
r"""Applies a 2D convolution over an quantized input tensor, inference only.
The parameter is same with :class: `~.Conv2d`
"""
@classmethod
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv = cls(
qat_module.conv.in_channels,
qat_module.conv.out_channels,
qat_module.conv.kernel_size,
qat_module.conv.stride,
qat_module.conv.padding,
qat_module.conv.dilation,
qat_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
return qconv
class ConvBn2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn.ConvBn2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
class ConvBnRelu2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
......@@ -104,6 +104,10 @@ def quantize_qat(
for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_quantable
):
# only convert top quantable module.
if is_quantable(parent):
continue
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
......
import copy
from itertools import product
import numpy as np
from megengine import tensor
from megengine.module import ConvBn2d
from megengine.module import (
Conv2d,
ConvBn2d,
ConvRelu2d,
DequantStub,
Module,
QuantStub,
)
from megengine.quantization.quantize import disable_fake_quant, quantize_qat
from megengine.test import assertTensorClose
def test_convbn2d():
def test_qat_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
......@@ -35,3 +41,45 @@ def test_convbn2d():
qat_module.eval()
qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
def test_qat_conv():
in_channels = 32
out_channels = 64
kernel_size = 3
class TestNet(Module):
def __init__(self, groups, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv = Conv2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
self.conv_relu = ConvRelu2d(
out_channels, in_channels, kernel_size, groups=groups, bias=bias
)
def forward(self, inp):
out = self.quant(inp)
out = self.conv(out)
out = self.conv_relu(out)
out = self.dequant(out)
return out
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
for groups, bias in product([1, 4], [True, False]):
net = TestNet(groups, bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
assertTensorClose(normal_outputs, qat_outputs)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
assertTensorClose(normal_outputs, qat_outputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册