未验证 提交 b1e8fd8a 编写于 作者: Y Yizhuang Zhou 提交者: GitHub

feat(quantization): Model Quantization (#17)

* add quantization codebase for resnet, shufflenetv1, and mobilenetv2
Co-authored-by: NKevin.W <xsw200831720@gmail.com>
Co-authored-by: Nwangshupeng <wangshupeng@megvii.com>
Co-authored-by: NLiZhiyuan <848796515@qq.com>
Co-authored-by: Nwangjianfeng <wangjianfeng@megvii.com>
上级 f238d145
模型量化 Model Quantization
---
本目录包含了采用MegEngine实现的量化训练和部署的代码,包括常用的ResNet、ShuffleNet和MobileNet,其量化模型的ImageNet Top 1 准确率如下:
| Model | top1 acc (float32) | FPS* (float32) | top1 acc (int8) | FPS* (int8) |
| --- | --- | --- | --- | --- |
| ResNet18 | 69.824 | 10.5 | 69.754 | 16.3 |
| ShufflenetV1 (1.5x) | 71.954 | 17.3 | | 25.3 |
| MobilenetV2 | 72.820 | 13.1 | | 17.4 |
**: FPS is measured on Intel(R) Xeon(R) Gold 6130 CPU @ 2.10GHz, single 224x224 image*
量化模型使用时,统一读取0-255的uint8图片,减去128的均值,转化为int8,输入网络。
## Quantization Aware Training (QAT)
```python
import megengine.quantization as Q
model = ...
# Quantization Aware Training
Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
for _ in range(...):
train(model)
```
## Deploying Quantized Model
```python
import megengine.quantization as Q
import megengine.jit as jit
model = ...
Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
# real quant
Q.quantize(model)
@jit.trace(symbolic=True):
def inference_func(x):
return model(x)
inference_func.dump(...)
```
# HOWTO use this codebase
## Step 1. Train a fp32 model
```
python3 train.py -a resnet18 -d /path/to/imagenet --mode normal
```
## Step 2. Finetune fp32 model with quantization aware training(QAT)
```
python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.normal/checkpoint.pkl --mode qat
```
## Step 3. Test QAT model on ImageNet Testset
```
python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode qat
```
or testing in quantized mode, which uses only cpu for inference and takes longer time
```
python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode quantized -n 1
```
## Step 4. Inference and dump
```
python3 inference.py -a resnet18 --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode quantized --dump
```
will feed a cat image to the network and output the classification probabilities with quantized network.
Also, set `--dump` will dump the quantized network to `resnet18.quantized.megengine` binary file.
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""
Configurations to train/finetune quantized classification models
"""
import megengine.data.transform as T
class ShufflenetConfig:
BATCH_SIZE = 128
LEARNING_RATE = 0.0625
MOMENTUM = 0.9
WEIGHT_DECAY = lambda self, n, p: \
4e-5 if n.find("weight") >= 0 and len(p.shape) > 1 else 0
EPOCHS = 240
SCHEDULER = "Linear"
COLOR_JITTOR = T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)
class ResnetConfig:
BATCH_SIZE = 32
LEARNING_RATE = 0.0125
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
EPOCHS = 90
SCHEDULER = "Multistep"
SCHEDULER_STEPS = [30, 60, 80]
SCHEDULER_GAMMA = 0.1
COLOR_JITTOR = T.PseudoTransform() # disable colorjittor
def get_config(arch: str):
if "resne" in arch: # both resnet and resnext
return ResnetConfig()
elif "shufflenet" in arch or "mobilenet" in arch:
return ShufflenetConfig()
else:
raise ValueError("config for {} not exists".format(arch))
class ShufflenetFinetuneConfig(ShufflenetConfig):
BATCH_SIZE = 128 // 2
LEARNING_RATE = 0.03125
EPOCHS = 120
class ResnetFinetuneConfig(ResnetConfig):
BATCH_SIZE = 32
LEARNING_RATE = 0.000125
EPOCHS = 12
SCHEDULER = "Multistep"
SCHEDULER_STEPS = [6,]
SCHEDULER_GAMMA = 0.1
def get_finetune_config(arch: str):
if "resne" in arch: # both resnet and resnext
return ResnetFinetuneConfig()
elif "shufflenet" in arch or "mobilenet" in arch:
return ShufflenetFinetuneConfig()
else:
raise ValueError("config for {} not exists".format(arch))
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""Finetune a pretrained fp32 with int8 quantization aware training(QAT)"""
import argparse
import collections
import multiprocessing as mp
import numbers
import os
import bisect
import time
import megengine as mge
import megengine.data as data
import megengine.data.transform as T
import megengine.distributed as dist
import megengine.functional as F
import megengine.jit as jit
import megengine.optimizer as optim
import megengine.quantization as Q
import config
import models
logger = mge.get_logger(__name__)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--arch", default="resnet18", type=str)
parser.add_argument("-d", "--data", default=None, type=str)
parser.add_argument("-s", "--save", default="/data/models", type=str)
parser.add_argument("-c", "--checkpoint", default=None, type=str,
help="pretrained model to finetune")
parser.add_argument("-m", "--mode", default="qat", type=str,
choices=["normal", "qat", "quantized"],
help="Quantization Mode\n"
"normal: no quantization, using float32\n"
"qat: quantization aware training, simulate int8\n"
"quantized: convert mode to int8 quantized, inference only")
parser.add_argument("-n", "--ngpus", default=None, type=int)
parser.add_argument("-w", "--workers", default=4, type=int)
parser.add_argument("--report-freq", default=50, type=int)
args = parser.parse_args()
world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus
if world_size > 1:
# start distributed training, dispatch sub-processes
mp.set_start_method("spawn")
processes = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank, world_size, args))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
worker(0, 1, args)
def get_parameters(model, cfg):
if isinstance(cfg.WEIGHT_DECAY, numbers.Number):
return {"params": model.parameters(requires_grad=True),
"weight_decay": cfg.WEIGHT_DECAY}
groups = collections.defaultdict(list) # weight_decay -> List[param]
for pname, p in model.named_parameters(requires_grad=True):
wd = cfg.WEIGHT_DECAY(pname, p)
groups[wd].append(p)
groups = [
{"params": params, "weight_decay": wd}
for wd, params in groups.items()
] # List[{param, weight_decay}]
return groups
def worker(rank, world_size, args):
# pylint: disable=too-many-statements
if world_size > 1:
# Initialize distributed process group
logger.info("init distributed process group {} / {}".format(rank, world_size))
dist.init_process_group(
master_ip="localhost",
master_port=23456,
world_size=world_size,
rank=rank,
dev=rank,
)
save_dir = os.path.join(args.save, args.arch + "." + args.mode)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
mge.set_log_file(os.path.join(save_dir, "log.txt"))
model = models.__dict__[args.arch]()
cfg = config.get_finetune_config(args.arch)
cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training
total_batch_size = cfg.BATCH_SIZE * world_size
steps_per_epoch = 1280000 // total_batch_size
total_steps = steps_per_epoch * cfg.EPOCHS
if args.mode != "normal":
Q.quantize_qat(model, Q.ema_fakequant_qconfig)
if args.checkpoint:
logger.info("Load pretrained weights from %s", args.checkpoint)
ckpt = mge.load(args.checkpoint)
ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
model.load_state_dict(ckpt, strict=False)
if args.mode == "quantized":
raise ValueError("mode = quantized only used during inference")
Q.quantize(model)
optimizer = optim.SGD(
get_parameters(model, cfg),
lr=cfg.LEARNING_RATE,
momentum=cfg.MOMENTUM,
)
# Define train and valid graph
@jit.trace(symbolic=True)
def train_func(image, label):
model.train()
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
optimizer.backward(loss) # compute gradients
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "train_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "train_acc5") / dist.get_world_size()
return loss, acc1, acc5
@jit.trace(symbolic=True)
def valid_func(image, label):
model.eval()
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
return loss, acc1, acc5
# Build train and valid datasets
logger.info("preparing dataset..")
train_dataset = data.dataset.ImageNet(args.data, train=True)
train_sampler = data.Infinite(data.RandomSampler(
train_dataset, batch_size=cfg.BATCH_SIZE, drop_last=True
))
train_queue = data.DataLoader(
train_dataset,
sampler=train_sampler,
transform=T.Compose(
[
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
cfg.COLOR_JITTOR,
T.Normalize(mean=128),
T.ToMode("CHW"),
]
),
num_workers=args.workers,
)
train_queue = iter(train_queue)
valid_dataset = data.dataset.ImageNet(args.data, train=False)
valid_sampler = data.SequentialSampler(
valid_dataset, batch_size=100, drop_last=False
)
valid_queue = data.DataLoader(
valid_dataset,
sampler=valid_sampler,
transform=T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.Normalize(mean=128),
T.ToMode("CHW"),
]
),
num_workers=args.workers,
)
def adjust_learning_rate(step, epoch):
learning_rate = cfg.LEARNING_RATE
if cfg.SCHEDULER == "Linear":
learning_rate *= 1 - float(step) / total_steps
elif cfg.SCHEDULER == "Multistep":
learning_rate *= cfg.SCHEDULER_GAMMA ** bisect.bisect_right(cfg.SCHEDULER_STEPS, epoch)
else:
raise ValueError(cfg.SCHEDULER)
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
return learning_rate
# Start training
objs = AverageMeter("Loss")
top1 = AverageMeter("Acc@1")
top5 = AverageMeter("Acc@5")
total_time = AverageMeter("Time")
t = time.time()
for step in range(0, total_steps):
# Linear learning rate decay
epoch = step // steps_per_epoch
learning_rate = adjust_learning_rate(step, epoch)
image, label = next(train_queue)
image = image.astype("float32")
label = label.astype("int32")
n = image.shape[0]
optimizer.zero_grad()
loss, acc1, acc5 = train_func(image, label)
optimizer.step()
top1.update(100 * acc1.numpy()[0], n)
top5.update(100 * acc5.numpy()[0], n)
objs.update(loss.numpy()[0], n)
total_time.update(time.time() - t)
t = time.time()
if step % args.report_freq == 0 and rank == 0:
logger.info(
"TRAIN e%d %06d %f %s %s %s %s",
epoch, step, learning_rate,
objs, top1, top5, total_time
)
objs.reset()
top1.reset()
top5.reset()
total_time.reset()
if step % 10000 == 0 and rank == 0:
logger.info("SAVING %06d", step)
mge.save(
{"step": step, "state_dict": model.state_dict()},
os.path.join(save_dir, "checkpoint.pkl"),
)
if step % 10000 == 0 and step != 0:
_, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
mge.save(
{"step": step, "state_dict": model.state_dict()},
os.path.join(save_dir, "checkpoint-final.pkl")
)
_, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
def infer(model, data_queue, args):
objs = AverageMeter("Loss")
top1 = AverageMeter("Acc@1")
top5 = AverageMeter("Acc@5")
total_time = AverageMeter("Time")
t = time.time()
for step, (image, label) in enumerate(data_queue):
n = image.shape[0]
image = image.astype("float32") # convert np.uint8 to float32
label = label.astype("int32")
loss, acc1, acc5 = model(image, label)
objs.update(loss.numpy()[0], n)
top1.update(100 * acc1.numpy()[0], n)
top5.update(100 * acc5.numpy()[0], n)
total_time.update(time.time() - t)
t = time.time()
if step % args.report_freq == 0 and dist.get_rank() == 0:
logger.info("Step %d, %s %s %s %s",
step, objs, top1, top5, total_time)
return objs.avg, top1.avg, top5.avg
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":.3f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""Finetune a pretrained fp32 with int8 quantization aware training(QAT)"""
import argparse
import json
import cv2
import megengine as mge
import megengine.data.transform as T
import megengine.functional as F
import megengine.jit as jit
import megengine.quantization as Q
import numpy as np
import models
logger = mge.get_logger(__name__)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--arch", default="resnet18", type=str)
parser.add_argument("-c", "--checkpoint", default=None, type=str)
parser.add_argument("-i", "--image", default=None, type=str)
parser.add_argument("-m", "--mode", default="quantized", type=str,
choices=["normal", "qat", "quantized"],
help="Quantization Mode\n"
"normal: no quantization, using float32\n"
"qat: quantization aware training, simulate int8\n"
"quantized: convert mode to int8 quantized, inference only")
parser.add_argument("--dump", action="store_true",
help="Dump quantized model")
args = parser.parse_args()
if args.mode == "quantized":
mge.set_default_device("cpux")
model = models.__dict__[args.arch]()
if args.mode != "normal":
Q.quantize_qat(model, Q.ema_fakequant_qconfig)
if args.checkpoint:
logger.info("Load pretrained weights from %s", args.checkpoint)
ckpt = mge.load(args.checkpoint)
ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
model.load_state_dict(ckpt, strict=False)
if args.mode == "quantized":
Q.quantize(model)
if args.image is None:
path = "../assets/cat.jpg"
else:
path = args.image
image = cv2.imread(path, cv2.IMREAD_COLOR)
transform = T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.Normalize(mean=128),
T.ToMode("CHW"),
]
)
@jit.trace(symbolic=True)
def infer_func(processed_img):
model.eval()
logits = model(processed_img)
probs = F.softmax(logits)
return probs
processed_img = transform.apply(image)[np.newaxis, :]
if args.mode == "normal":
processed_img = processed_img.astype("float32")
elif args.mode == "quantized":
processed_img = processed_img.astype("int8")
probs = infer_func(processed_img)
top_probs, classes = F.top_k(probs, k=5, descending=True)
if args.dump:
output_file = ".".join([args.arch, args.mode, "megengine"])
logger.info("Dump to {}".format(output_file))
infer_func.dump(output_file, arg_names=["data"])
mge.save(model.state_dict(), output_file.replace("megengine", "pkl"))
with open("../assets/imagenet_class_info.json") as fp:
imagenet_class_index = json.load(fp)
for rank, (prob, classid) in enumerate(
zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1))
):
print(
"{}: class = {:20s} with probability = {:4.1f} %".format(
rank, imagenet_class_index[str(classid)][1], 100 * prob
)
)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .resnet import *
from .shufflenet import *
from .mobilenet_v2 import *
# BSD 3-Clause License
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR