From 03625d2f377e23465949cfc5ba78b0c58722b9dd Mon Sep 17 00:00:00 2001 From: Yizhuang Zhou <62599194+zhouyizhuang-megvii@users.noreply.github.com> Date: Wed, 1 Jul 2020 14:56:20 +0800 Subject: [PATCH] fix(quant): fix code and add quantized weights (#38) --- hubconf.py | 2 + official/quantization/config.py | 2 +- official/quantization/inference.py | 9 ++-- official/quantization/models/resnet.py | 60 ++++++-------------------- 4 files changed, 21 insertions(+), 52 deletions(-) diff --git a/hubconf.py b/hubconf.py index 5742000..ab6e6c7 100644 --- a/hubconf.py +++ b/hubconf.py @@ -47,3 +47,5 @@ from official.vision.keypoints.models import ( ) from official.vision.keypoints.inference import KeypointEvaluator + +from official.quantization.models import quantized_resnet18 diff --git a/official/quantization/config.py b/official/quantization/config.py index ef3ea23..4199e83 100644 --- a/official/quantization/config.py +++ b/official/quantization/config.py @@ -48,7 +48,7 @@ def get_config(arch: str): class ShufflenetFinetuneConfig(ShufflenetConfig): BATCH_SIZE = 128 // 2 - LEARNING_RATE = 0.003125 // 2 + LEARNING_RATE = 0.003125 / 2 EPOCHS = 30 diff --git a/official/quantization/inference.py b/official/quantization/inference.py index ce0703e..8aafa56 100644 --- a/official/quantization/inference.py +++ b/official/quantization/inference.py @@ -17,6 +17,7 @@ import megengine.functional as F import megengine.jit as jit import megengine.quantization as Q import numpy as np +from megengine.quantization.quantize import quantize, quantize_qat import models @@ -45,7 +46,10 @@ def main(): model = models.__dict__[args.arch]() if args.mode != "normal": - Q.quantize_qat(model, Q.ema_fakequant_qconfig) + quantize_qat(model, Q.ema_fakequant_qconfig) + + if args.mode == "quantized": + quantize(model) if args.checkpoint: logger.info("Load pretrained weights from %s", args.checkpoint) @@ -53,9 +57,6 @@ def main(): ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) - if args.mode == "quantized": - Q.quantize(model) - if args.image is None: path = "../assets/cat.jpg" else: diff --git a/official/quantization/models/resnet.py b/official/quantization/models/resnet.py index c29e0f3..3a482ab 100644 --- a/official/quantization/models/resnet.py +++ b/official/quantization/models/resnet.py @@ -46,6 +46,7 @@ import math import megengine.functional as F import megengine.hub as hub import megengine.module as M +from megengine.quantization.quantize import quantize_qat, quantize class BasicBlock(M.Module): @@ -292,58 +293,23 @@ def resnet18(**kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ """ - return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) - - -def resnet34(**kwargs): - r"""ResNet-34 model from - `"Deep Residual Learning for Image Recognition" `_ - """ - return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + m = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + m.fc.disable_quantize() + return m def resnet50(**kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ """ - return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) - - -def resnet101(**kwargs): - r"""ResNet-101 model from - `"Deep Residual Learning for Image Recognition" `_ - """ - return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) - - -def resnet152(**kwargs): - r"""ResNet-152 model from - `"Deep Residual Learning for Image Recognition" `_ - """ - return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + m = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + m.fc.disable_quantize() + return m -def resnext50_32x4d(**kwargs): - r"""ResNeXt-50 32x4d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 4 - return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) - - -def resnext101_32x8d(**kwargs): - r"""ResNeXt-101 32x8d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_ - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 - return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) +@hub.pretrained("https://data.megengine.org.cn/models/weights/resnet18.quantized.pkl") +def quantized_resnet18(**kwargs): + model = resnet18(**kwargs) + quantize_qat(model) + quantize(model) + return model -- GitLab