未验证 提交 b1e8fd8a 编写于 作者: Y Yizhuang Zhou 提交者: GitHub

feat(quantization): Model Quantization (#17)

* add quantization codebase for resnet, shufflenetv1, and mobilenetv2
Co-authored-by: NKevin.W <xsw200831720@gmail.com>
Co-authored-by: Nwangshupeng <wangshupeng@megvii.com>
Co-authored-by: X-SPY's avatarLiZhiyuan <848796515@qq.com>
Co-authored-by: Nwangjianfeng <wangjianfeng@megvii.com>
上级 f238d145
模型量化 Model Quantization
---
本目录包含了采用MegEngine实现的量化训练和部署的代码,包括常用的ResNet、ShuffleNet和MobileNet,其量化模型的ImageNet Top 1 准确率如下:
| Model | top1 acc (float32) | FPS* (float32) | top1 acc (int8) | FPS* (int8) |
| --- | --- | --- | --- | --- |
| ResNet18 | 69.824 | 10.5 | 69.754 | 16.3 |
| ShufflenetV1 (1.5x) | 71.954 | 17.3 | | 25.3 |
| MobilenetV2 | 72.820 | 13.1 | | 17.4 |
**: FPS is measured on Intel(R) Xeon(R) Gold 6130 CPU @ 2.10GHz, single 224x224 image*
量化模型使用时,统一读取0-255的uint8图片,减去128的均值,转化为int8,输入网络。
## Quantization Aware Training (QAT)
```python
import megengine.quantization as Q
model = ...
# Quantization Aware Training
Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
for _ in range(...):
train(model)
```
## Deploying Quantized Model
```python
import megengine.quantization as Q
import megengine.jit as jit
model = ...
Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
# real quant
Q.quantize(model)
@jit.trace(symbolic=True):
def inference_func(x):
return model(x)
inference_func.dump(...)
```
# HOWTO use this codebase
## Step 1. Train a fp32 model
```
python3 train.py -a resnet18 -d /path/to/imagenet --mode normal
```
## Step 2. Finetune fp32 model with quantization aware training(QAT)
```
python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.normal/checkpoint.pkl --mode qat
```
## Step 3. Test QAT model on ImageNet Testset
```
python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode qat
```
or testing in quantized mode, which uses only cpu for inference and takes longer time
```
python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode quantized -n 1
```
## Step 4. Inference and dump
```
python3 inference.py -a resnet18 --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode quantized --dump
```
will feed a cat image to the network and output the classification probabilities with quantized network.
Also, set `--dump` will dump the quantized network to `resnet18.quantized.megengine` binary file.
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""
Configurations to train/finetune quantized classification models
"""
import megengine.data.transform as T
class ShufflenetConfig:
BATCH_SIZE = 128
LEARNING_RATE = 0.0625
MOMENTUM = 0.9
WEIGHT_DECAY = lambda self, n, p: \
4e-5 if n.find("weight") >= 0 and len(p.shape) > 1 else 0
EPOCHS = 240
SCHEDULER = "Linear"
COLOR_JITTOR = T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)
class ResnetConfig:
BATCH_SIZE = 32
LEARNING_RATE = 0.0125
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
EPOCHS = 90
SCHEDULER = "Multistep"
SCHEDULER_STEPS = [30, 60, 80]
SCHEDULER_GAMMA = 0.1
COLOR_JITTOR = T.PseudoTransform() # disable colorjittor
def get_config(arch: str):
if "resne" in arch: # both resnet and resnext
return ResnetConfig()
elif "shufflenet" in arch or "mobilenet" in arch:
return ShufflenetConfig()
else:
raise ValueError("config for {} not exists".format(arch))
class ShufflenetFinetuneConfig(ShufflenetConfig):
BATCH_SIZE = 128 // 2
LEARNING_RATE = 0.03125
EPOCHS = 120
class ResnetFinetuneConfig(ResnetConfig):
BATCH_SIZE = 32
LEARNING_RATE = 0.000125
EPOCHS = 12
SCHEDULER = "Multistep"
SCHEDULER_STEPS = [6,]
SCHEDULER_GAMMA = 0.1
def get_finetune_config(arch: str):
if "resne" in arch: # both resnet and resnext
return ResnetFinetuneConfig()
elif "shufflenet" in arch or "mobilenet" in arch:
return ShufflenetFinetuneConfig()
else:
raise ValueError("config for {} not exists".format(arch))
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""Finetune a pretrained fp32 with int8 quantization aware training(QAT)"""
import argparse
import collections
import multiprocessing as mp
import numbers
import os
import bisect
import time
import megengine as mge
import megengine.data as data
import megengine.data.transform as T
import megengine.distributed as dist
import megengine.functional as F
import megengine.jit as jit
import megengine.optimizer as optim
import megengine.quantization as Q
import config
import models
logger = mge.get_logger(__name__)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--arch", default="resnet18", type=str)
parser.add_argument("-d", "--data", default=None, type=str)
parser.add_argument("-s", "--save", default="/data/models", type=str)
parser.add_argument("-c", "--checkpoint", default=None, type=str,
help="pretrained model to finetune")
parser.add_argument("-m", "--mode", default="qat", type=str,
choices=["normal", "qat", "quantized"],
help="Quantization Mode\n"
"normal: no quantization, using float32\n"
"qat: quantization aware training, simulate int8\n"
"quantized: convert mode to int8 quantized, inference only")
parser.add_argument("-n", "--ngpus", default=None, type=int)
parser.add_argument("-w", "--workers", default=4, type=int)
parser.add_argument("--report-freq", default=50, type=int)
args = parser.parse_args()
world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus
if world_size > 1:
# start distributed training, dispatch sub-processes
mp.set_start_method("spawn")
processes = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank, world_size, args))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
worker(0, 1, args)
def get_parameters(model, cfg):
if isinstance(cfg.WEIGHT_DECAY, numbers.Number):
return {"params": model.parameters(requires_grad=True),
"weight_decay": cfg.WEIGHT_DECAY}
groups = collections.defaultdict(list) # weight_decay -> List[param]
for pname, p in model.named_parameters(requires_grad=True):
wd = cfg.WEIGHT_DECAY(pname, p)
groups[wd].append(p)
groups = [
{"params": params, "weight_decay": wd}
for wd, params in groups.items()
] # List[{param, weight_decay}]
return groups
def worker(rank, world_size, args):
# pylint: disable=too-many-statements
if world_size > 1:
# Initialize distributed process group
logger.info("init distributed process group {} / {}".format(rank, world_size))
dist.init_process_group(
master_ip="localhost",
master_port=23456,
world_size=world_size,
rank=rank,
dev=rank,
)
save_dir = os.path.join(args.save, args.arch + "." + args.mode)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
mge.set_log_file(os.path.join(save_dir, "log.txt"))
model = models.__dict__[args.arch]()
cfg = config.get_finetune_config(args.arch)
cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training
total_batch_size = cfg.BATCH_SIZE * world_size
steps_per_epoch = 1280000 // total_batch_size
total_steps = steps_per_epoch * cfg.EPOCHS
if args.mode != "normal":
Q.quantize_qat(model, Q.ema_fakequant_qconfig)
if args.checkpoint:
logger.info("Load pretrained weights from %s", args.checkpoint)
ckpt = mge.load(args.checkpoint)
ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
model.load_state_dict(ckpt, strict=False)
if args.mode == "quantized":
raise ValueError("mode = quantized only used during inference")
Q.quantize(model)
optimizer = optim.SGD(
get_parameters(model, cfg),
lr=cfg.LEARNING_RATE,
momentum=cfg.MOMENTUM,
)
# Define train and valid graph
@jit.trace(symbolic=True)
def train_func(image, label):
model.train()
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
optimizer.backward(loss) # compute gradients
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "train_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "train_acc5") / dist.get_world_size()
return loss, acc1, acc5
@jit.trace(symbolic=True)
def valid_func(image, label):
model.eval()
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
return loss, acc1, acc5
# Build train and valid datasets
logger.info("preparing dataset..")
train_dataset = data.dataset.ImageNet(args.data, train=True)
train_sampler = data.Infinite(data.RandomSampler(
train_dataset, batch_size=cfg.BATCH_SIZE, drop_last=True
))
train_queue = data.DataLoader(
train_dataset,
sampler=train_sampler,
transform=T.Compose(
[
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
cfg.COLOR_JITTOR,
T.Normalize(mean=128),
T.ToMode("CHW"),
]
),
num_workers=args.workers,
)
train_queue = iter(train_queue)
valid_dataset = data.dataset.ImageNet(args.data, train=False)
valid_sampler = data.SequentialSampler(
valid_dataset, batch_size=100, drop_last=False
)
valid_queue = data.DataLoader(
valid_dataset,
sampler=valid_sampler,
transform=T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.Normalize(mean=128),
T.ToMode("CHW"),
]
),
num_workers=args.workers,
)
def adjust_learning_rate(step, epoch):
learning_rate = cfg.LEARNING_RATE
if cfg.SCHEDULER == "Linear":
learning_rate *= 1 - float(step) / total_steps
elif cfg.SCHEDULER == "Multistep":
learning_rate *= cfg.SCHEDULER_GAMMA ** bisect.bisect_right(cfg.SCHEDULER_STEPS, epoch)
else:
raise ValueError(cfg.SCHEDULER)
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
return learning_rate
# Start training
objs = AverageMeter("Loss")
top1 = AverageMeter("Acc@1")
top5 = AverageMeter("Acc@5")
total_time = AverageMeter("Time")
t = time.time()
for step in range(0, total_steps):
# Linear learning rate decay
epoch = step // steps_per_epoch
learning_rate = adjust_learning_rate(step, epoch)
image, label = next(train_queue)
image = image.astype("float32")
label = label.astype("int32")
n = image.shape[0]
optimizer.zero_grad()
loss, acc1, acc5 = train_func(image, label)
optimizer.step()
top1.update(100 * acc1.numpy()[0], n)
top5.update(100 * acc5.numpy()[0], n)
objs.update(loss.numpy()[0], n)
total_time.update(time.time() - t)
t = time.time()
if step % args.report_freq == 0 and rank == 0:
logger.info(
"TRAIN e%d %06d %f %s %s %s %s",
epoch, step, learning_rate,
objs, top1, top5, total_time
)
objs.reset()
top1.reset()
top5.reset()
total_time.reset()
if step % 10000 == 0 and rank == 0:
logger.info("SAVING %06d", step)
mge.save(
{"step": step, "state_dict": model.state_dict()},
os.path.join(save_dir, "checkpoint.pkl"),
)
if step % 10000 == 0 and step != 0:
_, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
mge.save(
{"step": step, "state_dict": model.state_dict()},
os.path.join(save_dir, "checkpoint-final.pkl")
)
_, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
def infer(model, data_queue, args):
objs = AverageMeter("Loss")
top1 = AverageMeter("Acc@1")
top5 = AverageMeter("Acc@5")
total_time = AverageMeter("Time")
t = time.time()
for step, (image, label) in enumerate(data_queue):
n = image.shape[0]
image = image.astype("float32") # convert np.uint8 to float32
label = label.astype("int32")
loss, acc1, acc5 = model(image, label)
objs.update(loss.numpy()[0], n)
top1.update(100 * acc1.numpy()[0], n)
top5.update(100 * acc5.numpy()[0], n)
total_time.update(time.time() - t)
t = time.time()
if step % args.report_freq == 0 and dist.get_rank() == 0:
logger.info("Step %d, %s %s %s %s",
step, objs, top1, top5, total_time)
return objs.avg, top1.avg, top5.avg
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":.3f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""Finetune a pretrained fp32 with int8 quantization aware training(QAT)"""
import argparse
import json
import cv2
import megengine as mge
import megengine.data.transform as T
import megengine.functional as F
import megengine.jit as jit
import megengine.quantization as Q
import numpy as np
import models
logger = mge.get_logger(__name__)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--arch", default="resnet18", type=str)
parser.add_argument("-c", "--checkpoint", default=None, type=str)
parser.add_argument("-i", "--image", default=None, type=str)
parser.add_argument("-m", "--mode", default="quantized", type=str,
choices=["normal", "qat", "quantized"],
help="Quantization Mode\n"
"normal: no quantization, using float32\n"
"qat: quantization aware training, simulate int8\n"
"quantized: convert mode to int8 quantized, inference only")
parser.add_argument("--dump", action="store_true",
help="Dump quantized model")
args = parser.parse_args()
if args.mode == "quantized":
mge.set_default_device("cpux")
model = models.__dict__[args.arch]()
if args.mode != "normal":
Q.quantize_qat(model, Q.ema_fakequant_qconfig)
if args.checkpoint:
logger.info("Load pretrained weights from %s", args.checkpoint)
ckpt = mge.load(args.checkpoint)
ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
model.load_state_dict(ckpt, strict=False)
if args.mode == "quantized":
Q.quantize(model)
if args.image is None:
path = "../assets/cat.jpg"
else:
path = args.image
image = cv2.imread(path, cv2.IMREAD_COLOR)
transform = T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.Normalize(mean=128),
T.ToMode("CHW"),
]
)
@jit.trace(symbolic=True)
def infer_func(processed_img):
model.eval()
logits = model(processed_img)
probs = F.softmax(logits)
return probs
processed_img = transform.apply(image)[np.newaxis, :]
if args.mode == "normal":
processed_img = processed_img.astype("float32")
elif args.mode == "quantized":
processed_img = processed_img.astype("int8")
probs = infer_func(processed_img)
top_probs, classes = F.top_k(probs, k=5, descending=True)
if args.dump:
output_file = ".".join([args.arch, args.mode, "megengine"])
logger.info("Dump to {}".format(output_file))
infer_func.dump(output_file, arg_names=["data"])
mge.save(model.state_dict(), output_file.replace("megengine", "pkl"))
with open("../assets/imagenet_class_info.json") as fp:
imagenet_class_index = json.load(fp)
for rank, (prob, classid) in enumerate(
zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1))
):
print(
"{}: class = {:20s} with probability = {:4.1f} %".format(
rank, imagenet_class_index[str(classid)][1], 100 * prob
)
)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .resnet import *
from .shufflenet import *
from .mobilenet_v2 import *
# BSD 3-Clause License
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.functional as F
import megengine.module as M
__all__ = ['MobileNetV2', 'mobilenet_v2']
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class InvertedResidual(M.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(M.ConvBnRelu2d(inp, hidden_dim, kernel_size=1, bias=False))
layers.extend([
# dw
M.ConvBnRelu2d(hidden_dim, hidden_dim, kernel_size=3, padding=1,
stride=stride, groups=hidden_dim, bias=False),
# pw-linear
M.ConvBn2d(hidden_dim, oup, kernel_size=1, bias=False)
])
self.conv = M.Sequential(*layers)
self.add = M.Elemwise("ADD")
def forward(self, x):
if self.use_res_connect:
return self.add(x, self.conv(x))
else:
return self.conv(x)
class MobileNetV2(M.Module):
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
"""
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [M.ConvBnRelu2d(3, input_channel, kernel_size=3, padding=1, stride=2, bias=False)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(M.ConvBnRelu2d(input_channel, self.last_channel, kernel_size=1, bias=False))
# make it M.Sequential
self.features = M.Sequential(*features)
# building classifier
self.classifier = M.Sequential(
M.Dropout(0.2),
M.Linear(self.last_channel, num_classes),
)
self.quant = M.QuantStub()
self.dequant = M.DequantStub()
# weight initialization
for m in self.modules():
if isinstance(m, M.Conv2d):
M.init.msra_normal_(m.weight, mode='fan_out')
if m.bias is not None:
M.init.zeros_(m.bias)
elif isinstance(m, M.BatchNorm2d):
M.init.ones_(m.weight)
M.init.zeros_(m.bias)
elif isinstance(m, M.Linear):
M.init.normal_(m.weight, 0, 0.01)
M.init.zeros_(m.bias)
def forward(self, x):
x = self.quant(x)
x = self.features(x)
x = F.avg_pool2d(x, 7)
x = F.flatten(x, 1)
x = self.dequant(x)
x = self.classifier(x)
return x
def mobilenet_v2(**kwargs):
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
"""
model = MobileNetV2(**kwargs)
return model
# BSD 3-Clause License
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
"""ResNet optimized for quantization, idential after modification."""
import math
import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
class BasicBlock(M.Module):
expansion = 1
def __init__(
self,
in_channels,
channels,
stride=1,
groups=1,
base_width=64,
dilation=1,
norm=M.BatchNorm2d,
):
assert norm is M.BatchNorm2d, 'Quant mode only support BatchNorm2d currently.'
super(BasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv_bn_relu1 = M.ConvBnRelu2d(
in_channels, channels, 3, stride, padding=dilation, bias=False
)
self.conv_bn2 = M.ConvBn2d(
channels, channels, 3, 1, padding=1, bias=False
)
self.downsample = (
M.Identity()
if in_channels == channels and stride == 1
else M.ConvBn2d(in_channels, channels, 1, stride, bias=False)
)
self.add = M.Elemwise("ADD")
def forward(self, x):
identity = x
x = self.conv_bn_relu1(x)
x = self.conv_bn2(x)
identity = self.downsample(identity)
x = self.add(x, identity)
x = F.relu(x)
return x
class Bottleneck(M.Module):
expansion = 4
def __init__(
self,
in_channels,
channels,
stride=1,
groups=1,
base_width=64,
dilation=1,
norm=M.BatchNorm2d,
):
super(Bottleneck, self).__init__()
width = int(channels * (base_width / 64.0)) * groups
self.conv_bn_relu1 = M.ConvBnRelu2d(in_channels, width, 1, 1, bias=False)
self.conv_bn_relu2 = M.ConvBnRelu2d(
width,
width,
3,
stride,
padding=dilation,
groups=groups,
dilation=dilation,
bias=False,
)
self.conv_bn3 = M.ConvBn2d(
width, channels * self.expansion, 1, 1, bias=False
)
self.downsample = (
M.Identity()
if in_channels == channels * self.expansion and stride == 1
else M.ConvBn2d(in_channels, channels * self.expansion, 1, stride, bias=False)
)
self.add = M.Elemwise("ADD")
def forward(self, x):
identity = x
x = self.conv_bn_relu1(x)
x = self.conv_bn_relu2(x)
x = self.conv_bn3(x)
identity = self.downsample(identity)
x = self.add(x, identity)
x = F.relu(x)
return x
class ResNet(M.Module):
def __init__(
self,
block,
layers,
num_classes=1000,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm=M.BatchNorm2d,
):
super(ResNet, self).__init__()
self.in_channels = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.groups = groups
self.base_width = width_per_group
self.quant = M.QuantStub()
self.dequant = M.DequantStub()
self.conv_bn_relu1 = M.ConvBnRelu2d(
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
)
self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0],
norm=norm,
)
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
norm=norm,
)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2],
norm=norm,
)
self.fc = M.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, M.Conv2d):
M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
M.init.uniform_(m.bias, -bound, bound)
elif isinstance(m, M.BatchNorm2d):
M.init.ones_(m.weight)
M.init.zeros_(m.bias)
elif isinstance(m, M.Linear):
M.init.msra_uniform_(m.weight, a=math.sqrt(5))
if m.bias is not None:
fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
M.init.uniform_(m.bias, -bound, bound)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
M.init.zeros_(m.bn3.weight)
elif isinstance(m, BasicBlock):
M.init.zeros_(m.bn2.weight)
def _make_layer(
self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
):
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
layers = []
layers.append(
block(
self.in_channels,
channels,
stride,
groups=self.groups,
base_width=self.base_width,
dilation=previous_dilation,
norm=norm,
)
)
self.in_channels = channels * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.in_channels,
channels,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm=norm,
)
)
return M.Sequential(*layers)
def extract_features(self, x):
outputs = {}
x = self.conv_bn_relu1(x)
x = self.maxpool(x)
outputs["stem"] = x
x = self.layer1(x)
outputs["res2"] = x
x = self.layer2(x)
outputs["res3"] = x
x = self.layer3(x)
outputs["res4"] = x
x = self.layer4(x)
outputs["res5"] = x
return outputs
def forward(self, x):
# FIXME whenever finding elegant solution
x = self.quant(x)
x = self.extract_features(x)["res5"]
x = F.avg_pool2d(x, 7)
x = F.flatten(x, 1)
x = self.dequant(x)
x = self.fc(x)
return x
def resnet18(**kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
def resnet34(**kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
def resnet50(**kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnet101(**kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
def resnet152(**kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
def resnext50_32x4d(**kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["groups"] = 32
kwargs["width_per_group"] = 4
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnext101_32x8d(**kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
# -*- coding: utf-8 -*-
# MIT License
#
# Copyright (c) 2019 Megvii Technology
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# ------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
from megengine.module import (
BatchNorm2d,
Conv2d,
ConvBn2d,
ConvBnRelu2d,
AvgPool2d,
MaxPool2d,
DequantStub,
Linear,
Module,
QuantStub,
Sequential,
MaxPool2d,
Sequential,
Elemwise,
)
from megengine.quantization import *
class ShuffleV1Block(Module):
def __init__(self, inp, oup, *, group, first_group, mid_channels, ksize, stride):
super(ShuffleV1Block, self).__init__()
self.stride = stride
assert stride in [1, 2]
self.mid_channels = mid_channels
self.ksize = ksize
pad = ksize // 2
self.pad = pad
self.inp = inp
self.group = group
branch_main_1 = [
# pw
ConvBnRelu2d(inp, mid_channels, 1, 1, 0, groups=1 if first_group else group, bias=False),
# dw
ConvBn2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False)
]
branch_main_2 = [
# pw-linear
ConvBn2d(mid_channels, oup, 1, 1, 0, groups=group, bias=False)
]
self.branch_main_1 = Sequential(*branch_main_1)
self.branch_main_2 = Sequential(*branch_main_2)
self.add = Elemwise('FUSE_ADD_RELU')
if stride == 2:
self.branch_proj = ConvBn2d(inp, oup, 1, 2, 0, bias=False)
def forward(self, old_x):
x = old_x
x_proj = old_x
x = self.branch_main_1(x)
if self.group > 1:
x = self.channel_shuffle(x)
x = self.branch_main_2(x)
if self.stride == 1:
return self.add(x, x_proj)
elif self.stride == 2:
return self.add(self.branch_proj(x_proj), x)
def channel_shuffle(self, x):
batchsize, num_channels, height, width = x.shape
# assert num_channels.numpy() % self.group == 0
group_channels = num_channels // self.group
x = x.reshape(batchsize, group_channels, self.group, height, width)
x = x.dimshuffle(0, 2, 1, 3, 4)
x = x.reshape(batchsize, num_channels, height, width)
return x
class ShuffleNetV1(Module):
def __init__(self, num_classes=1000, model_size='2.0x', group=None):
super(ShuffleNetV1, self).__init__()
print('model size is ', model_size)
assert group is not None
self.stage_repeats = [4, 8, 4]
self.model_size = model_size
if group == 3:
if model_size == '0.5x':
self.stage_out_channels = [-1, 12, 120, 240, 480]
elif model_size == '1.0x':
self.stage_out_channels = [-1, 24, 240, 480, 960]
elif model_size == '1.5x':
self.stage_out_channels = [-1, 24, 360, 720, 1440]
elif model_size == '2.0x':
self.stage_out_channels = [-1, 48, 480, 960, 1920]
else:
raise NotImplementedError
elif group == 8:
if model_size == '0.5x':
self.stage_out_channels = [-1, 16, 192, 384, 768]
elif model_size == '1.0x':
self.stage_out_channels = [-1, 24, 384, 768, 1536]
elif model_size == '1.5x':
self.stage_out_channels = [-1, 24, 576, 1152, 2304]
elif model_size == '2.0x':
self.stage_out_channels = [-1, 48, 768, 1536, 3072]
else:
raise NotImplementedError
# building first layer
input_channel = self.stage_out_channels[1]
self.first_conv = Sequential(
ConvBnRelu2d(3, input_channel, 3, 2, 1, bias=False)
)
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
self.features = []
for idxstage in range(len(self.stage_repeats)):
numrepeat = self.stage_repeats[idxstage]
output_channel = self.stage_out_channels[idxstage + 2]
for i in range(numrepeat):
stride = 2 if i == 0 else 1
first_group = idxstage == 0 and i == 0
self.features.append(ShuffleV1Block(input_channel, output_channel,
group=group, first_group=first_group,
mid_channels=output_channel // 4, ksize=3, stride=stride))
input_channel = output_channel
self.features = Sequential(*self.features)
self.quant = QuantStub()
self.dequant = DequantStub()
self.classifier = Sequential(Linear(self.stage_out_channels[-1], num_classes, bias=False))
self._initialize_weights()
def forward(self, x):
x = self.quant(x)
x = self.first_conv(x)
x = self.maxpool(x)
x = self.features(x)
x = F.avg_pool2d(x, 7)
x = F.flatten(x, 1)
x = self.dequant(x)
x = self.classifier(x)
return x
def _initialize_weights(self):
for name, m in self.named_modules():
if isinstance(m, M.Conv2d):
if "first" in name:
M.init.normal_(m.weight, 0, 0.01)
else:
M.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
if m.bias is not None:
M.init.fill_(m.bias, 0)
elif isinstance(m, M.BatchNorm2d):
M.init.fill_(m.weight, 1)
if m.bias is not None:
M.init.fill_(m.bias, 0.0001)
M.init.fill_(m.running_mean, 0)
elif isinstance(m, M.BatchNorm1d):
M.init.fill_(m.weight, 1)
if m.bias is not None:
M.init.fill_(m.bias, 0.0001)
M.init.fill_(m.running_mean, 0)
elif isinstance(m, M.Linear):
M.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
M.init.fill_(m.bias, 0)
def shufflenet_v1_x0_5_g3(num_classes=1000):
net = ShuffleNetV1(num_classes=num_classes, model_size="0.5x", group=3)
return net
def shufflenet_v1_x1_0_g3(num_classes=1000):
net = ShuffleNetV1(num_classes=num_classes, model_size="1.0x", group=3)
return net
def shufflenet_v1_x1_5_g3(num_classes=1000):
net = ShuffleNetV1(num_classes=num_classes, model_size="1.5x", group=3)
return net
def shufflenet_v1_x2_0_g3(num_classes=1000):
net = ShuffleNetV1(num_classes=num_classes, model_size="2.0x", group=3)
return net
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""Test int8 quantizated model on ImageNet.
Note:
* QAT simulate int8 with fp32, gpu only.
* Quantized use real int8, cpu only, a bit slow.
* Results may be slightly different between qat and quantized mode.
"""
import argparse
import multiprocessing as mp
import time
import megengine as mge
import megengine.data as data
import megengine.data.transform as T
import megengine.distributed as dist
import megengine.functional as F
import megengine.jit as jit
import megengine.quantization as Q
import models
logger = mge.get_logger(__name__)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--arch", default="resnet18", type=str)
parser.add_argument("-d", "--data", default=None, type=str)
parser.add_argument("-s", "--save", default="/data/models", type=str)
parser.add_argument("-c", "--checkpoint", default=None, type=str,
help="pretrained model to finetune")
parser.add_argument("-m", "--mode", default="qat", type=str,
choices=["normal", "qat", "quantized"],
help="Quantization Mode\n"
"normal: no quantization, using float32\n"
"qat: quantization aware training, simulate int8\n"
"quantized: convert mode to int8 quantized, inference only")
parser.add_argument("-n", "--ngpus", default=None, type=int)
parser.add_argument("-w", "--workers", default=4, type=int)
parser.add_argument("--report-freq", default=50, type=int)
args = parser.parse_args()
world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus
if args.mode == "quantized":
world_size = 1
args.report_freq = 1 # test is slow on cpu
mge.set_default_device("cpux")
logger.warning("quantized mode use cpu only")
if world_size > 1:
# start distributed training, dispatch sub-processes
mp.set_start_method("spawn")
processes = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank, world_size, args))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
worker(0, 1, args)
def worker(rank, world_size, args):
# pylint: disable=too-many-statements
if world_size > 1:
# Initialize distributed process group
logger.info("init distributed process group {} / {}".format(rank, world_size))
dist.init_process_group(
master_ip="localhost",
master_port=23456,
world_size=world_size,
rank=rank,
dev=rank,
)
model = models.__dict__[args.arch]()
if args.mode != "normal":
Q.quantize_qat(model, Q.ema_fakequant_qconfig)
if args.checkpoint:
logger.info("Load pretrained weights from %s", args.checkpoint)
ckpt = mge.load(args.checkpoint)
ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
model.load_state_dict(ckpt, strict=False)
if args.mode == "quantized":
Q.quantize(model)
# Define valid graph
@jit.trace(symbolic=True)
def valid_func(image, label):
model.eval()
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
return loss, acc1, acc5
# Build valid datasets
logger.info("preparing dataset..")
valid_dataset = data.dataset.ImageNet(args.data, train=False)
valid_sampler = data.SequentialSampler(
valid_dataset, batch_size=100, drop_last=False
)
valid_queue = data.DataLoader(
valid_dataset,
sampler=valid_sampler,
transform=T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.Normalize(mean=128),
T.ToMode("CHW"),
]
),
num_workers=args.workers,
)
_, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
logger.info("TEST %f, %f", valid_acc, valid_acc5)
def infer(model, data_queue, args):
objs = AverageMeter("Loss")
top1 = AverageMeter("Acc@1")
top5 = AverageMeter("Acc@5")
total_time = AverageMeter("Time")
t = time.time()
for step, (image, label) in enumerate(data_queue):
n = image.shape[0]
image = image.astype("float32") # convert np.uint8 to float32
label = label.astype("int32")
loss, acc1, acc5 = model(image, label)
objs.update(loss.numpy()[0], n)
top1.update(100 * acc1.numpy()[0], n)
top5.update(100 * acc5.numpy()[0], n)
total_time.update(time.time() - t)
t = time.time()
if step % args.report_freq == 0 and dist.get_rank() == 0:
logger.info("Step %d, %s %s %s %s",
step, objs, top1, top5, total_time)
return objs.avg, top1.avg, top5.avg
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":.3f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""Train a model in fp32"""
import argparse
import collections
import multiprocessing as mp
import numbers
import os
import bisect
import time
import megengine as mge
import megengine.data as data
import megengine.data.transform as T
import megengine.distributed as dist
import megengine.functional as F
import megengine.jit as jit
import megengine.optimizer as optim
import megengine.quantization as Q
import config
import models
logger = mge.get_logger(__name__)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--arch", default="resnet18", type=str)
parser.add_argument("-d", "--data", default=None, type=str)
parser.add_argument("-s", "--save", default="/data/models", type=str)
parser.add_argument("-m", "--mode", default="normal", type=str,
choices=["normal", "qat", "quantized"],
help="Quantization Mode\n"
"normal: no quantization, using float32\n"
"qat: quantization aware training, simulate int8\n"
"quantized: convert mode to int8 quantized, inference only")
parser.add_argument("-n", "--ngpus", default=None, type=int)
parser.add_argument("-w", "--workers", default=4, type=int)
parser.add_argument("--report-freq", default=50, type=int)
args = parser.parse_args()
world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus
if world_size > 1:
# start distributed training, dispatch sub-processes
mp.set_start_method("spawn")
processes = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank, world_size, args))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
worker(0, 1, args)
def get_parameters(model, cfg):
if isinstance(cfg.WEIGHT_DECAY, numbers.Number):
return {"params": model.parameters(requires_grad=True),
"weight_decay": cfg.WEIGHT_DECAY}
groups = collections.defaultdict(list) # weight_decay -> List[param]
for pname, p in model.named_parameters(requires_grad=True):
wd = cfg.WEIGHT_DECAY(pname, p)
groups[wd].append(p)
groups = [
{"params": params, "weight_decay": wd}
for wd, params in groups.items()
] # List[{param, weight_decay}]
return groups
def worker(rank, world_size, args):
# pylint: disable=too-many-statements
if world_size > 1:
# Initialize distributed process group
logger.info("init distributed process group {} / {}".format(rank, world_size))
dist.init_process_group(
master_ip="localhost",
master_port=23456,
world_size=world_size,
rank=rank,
dev=rank,
)
save_dir = os.path.join(args.save, args.arch + "." + args.mode)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
mge.set_log_file(os.path.join(save_dir, "log.txt"))
model = models.__dict__[args.arch]()
cfg = config.get_config(args.arch)
cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training
total_batch_size = cfg.BATCH_SIZE * world_size
steps_per_epoch = 1280000 // total_batch_size
total_steps = steps_per_epoch * cfg.EPOCHS
if args.mode != "normal":
Q.quantize_qat(model, Q.ema_fakequant_qconfig)
if args.mode == "quantized":
raise ValueError("mode = quantized only used during inference")
Q.quantize(model)
optimizer = optim.SGD(
get_parameters(model, cfg),
lr=cfg.LEARNING_RATE,
momentum=cfg.MOMENTUM,
)
# Define train and valid graph
@jit.trace(symbolic=True)
def train_func(image, label):
model.train()
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
optimizer.backward(loss) # compute gradients
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "train_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "train_acc5") / dist.get_world_size()
return loss, acc1, acc5
@jit.trace(symbolic=True)
def valid_func(image, label):
model.eval()
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
return loss, acc1, acc5
# Build train and valid datasets
logger.info("preparing dataset..")
train_dataset = data.dataset.ImageNet(args.data, train=True)
train_sampler = data.Infinite(data.RandomSampler(
train_dataset, batch_size=cfg.BATCH_SIZE, drop_last=True
))
train_queue = data.DataLoader(
train_dataset,
sampler=train_sampler,
transform=T.Compose(
[
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
cfg.COLOR_JITTOR,
T.Normalize(mean=128),
T.ToMode("CHW"),
]
),
num_workers=args.workers,
)
train_queue = iter(train_queue)
valid_dataset = data.dataset.ImageNet(args.data, train=False)
valid_sampler = data.SequentialSampler(
valid_dataset, batch_size=100, drop_last=False
)
valid_queue = data.DataLoader(
valid_dataset,
sampler=valid_sampler,
transform=T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.Normalize(mean=128),
T.ToMode("CHW"),
]
),
num_workers=args.workers,
)
def adjust_learning_rate(step, epoch):
learning_rate = cfg.LEARNING_RATE
if cfg.SCHEDULER == "Linear":
learning_rate *= 1 - float(step) / total_steps
elif cfg.SCHEDULER == "Multistep":
learning_rate *= cfg.SCHEDULER_GAMMA ** bisect.bisect_right(cfg.SCHEDULER_STEPS, epoch)
else:
raise ValueError(cfg.SCHEDULER)
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
return learning_rate
# Start training
objs = AverageMeter("Loss")
top1 = AverageMeter("Acc@1")
top5 = AverageMeter("Acc@5")
total_time = AverageMeter("Time")
t = time.time()
for step in range(0, total_steps):
# Linear learning rate decay
epoch = step // steps_per_epoch
learning_rate = adjust_learning_rate(step, epoch)
image, label = next(train_queue)
image = image.astype("float32")
label = label.astype("int32")
n = image.shape[0]
optimizer.zero_grad()
loss, acc1, acc5 = train_func(image, label)
optimizer.step()
top1.update(100 * acc1.numpy()[0], n)
top5.update(100 * acc5.numpy()[0], n)
objs.update(loss.numpy()[0], n)
total_time.update(time.time() - t)
t = time.time()
if step % args.report_freq == 0 and rank == 0:
logger.info(
"TRAIN e%d %06d %f %s %s %s %s",
epoch, step, learning_rate,
objs, top1, top5, total_time
)
objs.reset()
top1.reset()
top5.reset()
total_time.reset()
if step % 10000 == 0 and rank == 0:
logger.info("SAVING %06d", step)
mge.save(
{"step": step, "state_dict": model.state_dict()},
os.path.join(save_dir, "checkpoint.pkl"),
)
if step % 10000 == 0 and step != 0:
_, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
mge.save(
{"step": step, "state_dict": model.state_dict()},
os.path.join(save_dir, "checkpoint-final.pkl")
)
_, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
def infer(model, data_queue, args):
objs = AverageMeter("Loss")
top1 = AverageMeter("Acc@1")
top5 = AverageMeter("Acc@5")
total_time = AverageMeter("Time")
t = time.time()
for step, (image, label) in enumerate(data_queue):
n = image.shape[0]
image = image.astype("float32") # convert np.uint8 to float32
label = label.astype("int32")
loss, acc1, acc5 = model(image, label)
objs.update(loss.numpy()[0], n)
top1.update(100 * acc1.numpy()[0], n)
top5.update(100 * acc5.numpy()[0], n)
total_time.update(time.time() - t)
t = time.time()
if step % args.report_freq == 0 and dist.get_rank() == 0:
logger.info("Step %d, %s %s %s %s",
step, objs, top1, top5, total_time)
return objs.avg, top1.avg, top5.avg
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":.3f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册