diff --git a/models/shufflenetv2.py b/models/shufflenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..5164161504e31d4b823ce62977a5a7032a4f08cb --- /dev/null +++ b/models/shufflenetv2.py @@ -0,0 +1,157 @@ +"""shufflenetv2 in pytorch + + + +[1] Ningning Ma, Xiangyu Zhang, Hai-Tao Zheng, Jian Sun + + ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + https://arxiv.org/abs/1807.11164 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def channel_split(x, split): + """split a tensor into two pieces along channel dimension + Args: + x: input tensor + split:(int) channel size for each pieces + """ + assert x.size(1) == split * 2 + return torch.split(x, split, dim=1) + +def channel_shuffle(x, groups): + """channel shuffle operation + Args: + x: input tensor + groups: input branch number + """ + + batch_size, channels, height, width = x.size() + channels_per_group = int(channels // groups) + + x = x.view(batch_size, groups, channels_per_group, height, width) + x = x.transpose(1, 2).contiguous() + x = x.view(batch_size, -1, height, width) + + return x + +class ShuffleUnit(nn.Module): + + def __init__(self, in_channels, out_channels, stride): + super().__init__() + + self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + + if stride != 1 or in_channels != out_channels: + self.residual = nn.Sequential( + nn.Conv2d(in_channels, in_channels, 1), + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), + nn.BatchNorm2d(in_channels), + nn.Conv2d(in_channels, int(out_channels / 2), 1), + nn.BatchNorm2d(int(out_channels / 2)), + nn.ReLU(inplace=True) + ) + + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), + nn.BatchNorm2d(in_channels), + nn.Conv2d(in_channels, int(out_channels / 2), 1), + nn.BatchNorm2d(int(out_channels / 2)), + nn.ReLU(inplace=True) + ) + else: + self.shortcut = nn.Sequential() + + in_channels = int(in_channels / 2) + self.residual = nn.Sequential( + nn.Conv2d(in_channels, in_channels, 1), + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), + nn.BatchNorm2d(in_channels), + nn.Conv2d(in_channels, in_channels, 1), + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True) + ) + + + def forward(self, x): + + if self.stride == 1 and self.out_channels == self.in_channels: + shortcut, residual = channel_split(x, int(self.in_channels / 2)) + else: + shortcut = x + residual = x + + shortcut = self.shortcut(shortcut) + residual = self.residual(residual) + x = torch.cat([shortcut, residual], dim=1) + x = channel_shuffle(x, 2) + + return x + +class ShuffleNetV2(nn.Module): + + def __init__(self, ratio=1., class_num=100, dropout_factor = 1.0): + super().__init__() + if ratio == 0.5: + out_channels = [48, 96, 192, 1024] + elif ratio == 1: + out_channels = [116, 232, 464, 1024] + elif ratio == 1.5: + out_channels = [176, 352, 704, 1024] + elif ratio == 2: + out_channels = [244, 488, 976, 2048] + else: + ValueError('unsupported ratio number') + + self.pre = nn.Sequential( + nn.Conv2d(3, 24, 3, padding=1), + nn.BatchNorm2d(24) + ) + + self.stage2 = self._make_stage(24, out_channels[0], 3) + self.stage3 = self._make_stage(out_channels[0], out_channels[1], 7) + self.stage4 = self._make_stage(out_channels[1], out_channels[2], 3) + self.conv5 = nn.Sequential( + nn.Conv2d(out_channels[2], out_channels[3], 1), + nn.BatchNorm2d(out_channels[3]), + nn.ReLU(inplace=True) + ) + + self.fc = nn.Linear(out_channels[3], class_num) + + self.dropout = nn.Dropout(dropout_factor) + + def forward(self, x): + x = self.pre(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = self.conv5(x) + x = F.adaptive_avg_pool2d(x, 1) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.fc(x) + + return x + + def _make_stage(self, in_channels, out_channels, repeat): + layers = [] + layers.append(ShuffleUnit(in_channels, out_channels, 2)) + + while repeat: + layers.append(ShuffleUnit(out_channels, out_channels, 1)) + repeat -= 1 + + return nn.Sequential(*layers) + +def shufflenetv2(): + return ShuffleNetV2() diff --git a/train.py b/train.py index 148a8a18399abcd788ecde140f9d2694462ece48..eb8cfc6a22ff1bcb64215da2aad337dc73250f42 100644 --- a/train.py +++ b/train.py @@ -16,6 +16,7 @@ from hand_data_iter.datasets import * from models.resnet import resnet50,resnet101 from models.squeezenet import squeezenet1_1,squeezenet1_0 +from models.shufflenetv2 import ShuffleNetV2 from loss.loss import * import cv2 import time @@ -42,6 +43,8 @@ def trainer(ops,f_log): model_ = squeezenet1_0(pretrained=True, num_classes=ops.num_classes,dropout_factor=ops.dropout) elif ops.model == "squeezenet1_1": model_ = squeezenet1_1(pretrained=True, num_classes=ops.num_classes,dropout_factor=ops.dropout) + elif ops.model == "shufflenetv2": + model_ = ShuffleNetV2(ratio=1., class_num=ops.num_classes, dropout_factor=ops.dropout) else: print(" no support the model") @@ -153,8 +156,8 @@ if __name__ == "__main__": help = 'seed') # 设置随机种子 parser.add_argument('--model_exp', type=str, default = './model_exp', help = 'model_exp') # 模型输出文件夹 - parser.add_argument('--model', type=str, default = 'squeezenet1_1', - help = 'model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1') # 模型类型 + parser.add_argument('--model', type=str, default = 'shufflenetv2', + help = 'model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2') # 模型类型 parser.add_argument('--num_classes', type=int , default = 42, help = 'num_classes') # landmarks 个数*2 parser.add_argument('--GPUS', type=str, default = '0', @@ -178,7 +181,7 @@ if __name__ == "__main__": help = 'weight_decay') # 优化器正则损失权重 parser.add_argument('--momentum', type=float, default = 0.9, help = 'momentum') # 优化器动量 - parser.add_argument('--batch_size', type=int, default = 128, + parser.add_argument('--batch_size', type=int, default = 16, help = 'batch_size') # 训练每批次图像数量 parser.add_argument('--dropout', type=float, default = 0.5, help = 'dropout') # dropout