未验证 提交 03625d2f 编写于 作者: Y Yizhuang Zhou 提交者: GitHub

fix(quant): fix code and add quantized weights (#38)

上级 f64a4ccb
......@@ -47,3 +47,5 @@ from official.vision.keypoints.models import (
)
from official.vision.keypoints.inference import KeypointEvaluator
from official.quantization.models import quantized_resnet18
......@@ -48,7 +48,7 @@ def get_config(arch: str):
class ShufflenetFinetuneConfig(ShufflenetConfig):
BATCH_SIZE = 128 // 2
LEARNING_RATE = 0.003125 // 2
LEARNING_RATE = 0.003125 / 2
EPOCHS = 30
......
......@@ -17,6 +17,7 @@ import megengine.functional as F
import megengine.jit as jit
import megengine.quantization as Q
import numpy as np
from megengine.quantization.quantize import quantize, quantize_qat
import models
......@@ -45,7 +46,10 @@ def main():
model = models.__dict__[args.arch]()
if args.mode != "normal":
Q.quantize_qat(model, Q.ema_fakequant_qconfig)
quantize_qat(model, Q.ema_fakequant_qconfig)
if args.mode == "quantized":
quantize(model)
if args.checkpoint:
logger.info("Load pretrained weights from %s", args.checkpoint)
......@@ -53,9 +57,6 @@ def main():
ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
model.load_state_dict(ckpt, strict=False)
if args.mode == "quantized":
Q.quantize(model)
if args.image is None:
path = "../assets/cat.jpg"
else:
......
......@@ -46,6 +46,7 @@ import math
import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
from megengine.quantization.quantize import quantize_qat, quantize
class BasicBlock(M.Module):
......@@ -292,58 +293,23 @@ def resnet18(**kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
def resnet34(**kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
m = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
m.fc.disable_quantize()
return m
def resnet50(**kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnet101(**kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
def resnet152(**kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
m = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
m.fc.disable_quantize()
return m
def resnext50_32x4d(**kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["groups"] = 32
kwargs["width_per_group"] = 4
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnext101_32x8d(**kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
@hub.pretrained("https://data.megengine.org.cn/models/weights/resnet18.quantized.pkl")
def quantized_resnet18(**kwargs):
model = resnet18(**kwargs)
quantize_qat(model)
quantize(model)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册