提交 8f66e956 编写于 作者: T timgaripov

initial commit

上级
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer images
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
# Sphinx documentation
docs/_build/
# PyBuilder
target/
#Ipython Notebook
.ipynb_checkpoints
#Archives
*.gz
#Idea
.idea*
BSD 2-Clause License
Copyright (c) 2018, Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Stochastic Weight Averaging (SWA)
This repository contains a PyTorch implementation of the Stochastic Weight Averaging (SWA) training method for DNNs from the paper
[Averaging Weights Leads to Wider Optima and Better Generalization](https://arxiv.org/abs/1803.05407)
by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson.
# Introduction
SWA is a simple DNN training method that can be used as a drop-in replacement for SGD with improved generalization and essentially no overhead. The key idea of SWA is to average multiple samples produced by SGD with a modified learning rate schedule. We use constant or cyclical learning rate schedule that force SGD to explore the set of points in the weight space corresponding to high-performing networks. We observe that SWA converges more quickly than SGD, and to wider optima that provide higher test accuracy.
In this repo we only implement the constant learning rate schedule that we found to be most practical on CIFAR datasets.
<p align="center">
<img src="https://user-images.githubusercontent.com/14368801/37633888-89fdc05a-2bca-11e8-88aa-dd3661a44c3f.png" width=250>
<img src="https://user-images.githubusercontent.com/14368801/37633885-89d809a0-2bca-11e8-8d57-3bd78734cea3.png" width=250>
<img src="https://user-images.githubusercontent.com/14368801/37633887-89e93784-2bca-11e8-9d71-a385ea72ff7c.png" width=250>
</p>
Please cite our work if you find this approach useful in your research:
```latex
@article{izmailov2018averaging,
title={Averaging Weights Leads to Wider Optima and Better Generalization},
author={Izmailov, Pavel and Podoprikhin, Dmitrii and Garipov, Timur and Vetrov, Dmitry and Wilson, Andrew Gordon},
journal={arXiv preprint arXiv:1803.05407},
year={2018}
}
```
# Dependencies
* [PyTorch](http://pytorch.org/)
* [torchvision](https://github.com/pytorch/vision/)
* [tabulate](https://pypi.python.org/pypi/tabulate/)
# Usage
The code in this repository implements both SWA and conventional SGD training, with examples on fthe CIFAR-10 and CIFAR-100 datasets.
To run SWA use the following command:
```bash
python3 train.py --dir=<DIR> \
--dataset=<DATASET> \
--data_path=<PATH> \
--model=<MODEL> \
--epochs=<EPOCHS> \
--lr_init=<LR_INIT> \
--wd=<WD> \
--swa \
--swa_start=<SWA_START> \
--swa_lr=<SWA_LR>
```
Parameters:
* ```DIR``` &mdash; path to training directory where checkpoints will be stored
* ```DATASET``` &mdash; dataset name [CIFAR10/CIFAR100] (default: CIFAR10)
* ```PATH``` &mdash; path to the data directory
* ```MODEL``` &mdash; DNN model name:
- VGG16/VGG16BN/VGG19/VGG19BN
- PreResNet110
- WideResNet28x10
* ```EPOCHS``` &mdash; number of training epochs (default: 200)
* ```LR_INIT``` &mdash; initial learning rate (default: 0.1)
* ```WD``` &mdash; weight decay (default: 1e-4)
* ```SWA_START``` &mdash; the number of epoch after which SWA will start to average models (default: 161)
* ```SWA_LR``` &mdash; SWA learning rate (default: 0.05)
To run conventional SGD training use the following command:
```bash
python3 train.py --dir=<DIR> \
--dataset=<DATASET> \
--data_path=<PATH> \
--model=<MODEL> \
--epochs=<EPOCHS> \
--lr_init=<LR_INIT> \
--wd=<WD>
```
## Examples
To reproduce the results from the paper run (we use same parameters for both CIFAR-10 and CIFAR-100):
```bash
#VGG16
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=VGG16 --epochs=200 --lr_init=0.05 --wd=5e-4 # SGD
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=VGG16 --epochs=300 --lr_init=0.05 --wd=5e-4 --swa --swa_start=161 --swa_lr=0.01 # SWA 1.5 Budgets
#PreResNet110
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=PreResNet110 --epochs=150 --lr_init=0.1 --wd=3e-4 # SGD
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=PreResNet110 --epochs=225 --lr_init=0.1 --wd=3e-4 --swa --swa_start=126 --swa_lr=0.05 # SWA 1.5 Budgets
#WideResNet28x10
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=WideResNet28x10 --epochs=200 --lr_init=0.1 --wd=5e-4 # SGD
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=WideResNet28x10 --epochs=300 --lr_init=0.1 --wd=5e-4 --swa --swa_start=161 --swa_lr=0.05 # SWA 1.5 Budgets
```
# Results
## CIFAR-100
Test accuracy (%) of SGD and SWA on CIFAR-100 for different training budgets. For each model the _Budget_ is defined as the number of epochs required to train the model with the conventional SGD procedure.
| DNN (Budget) | SGD | SWA 1 Budget | SWA 1.25 Budgets | SWA 1.5 Budgets |
| ------------------------- |:------------:|:------------:|:----------------:|:---------------:|
| VGG16 (200) | 72.55 ± 0.10 | 73.91 ± 0.12 | 74.17 ± 0.15 | 74.27 ± 0.25 |
| PreResNet110 (150) | 78.49 ± 0.36 | 79.77 ± 0.17 | 80.18 ± 0.23 | 80.35 ± 0.16 |
| WideResNet28x10 (200) | 80.82 ± 0.23 | 81.46 ± 0.23 | 81.91 ± 0.27 | 82.15 ± 0.27 |
Below we show the convergence plot for SWA and SGD with PreResNet110 on CIFAR-100 and the corresponding learning rates. The dashed line illustrates the accuracy of individual models averaged by SWA.
<p align="center">
<img src="https://user-images.githubusercontent.com/14368801/37633527-226bb2d6-2bc9-11e8-9be6-097c0dfe64ab.png" width=500>
</p>
## CIFAR-10
Test accuracy (%) of SGD and SWA on CIFAR-10 for different training budgets.
| DNN (Budget) | SGD | SWA 1 Budget | SWA 1.25 Budgets | SWA 1.5 Budgets |
| ------------------------- |:------------:|:------------:|:----------------:|:---------------:|
| VGG16 (200) | 93.25 ± 0.16 | 93.59 ± 0.16 | 93.70 ± 0.22 | 93.64 ± 0.18 |
| PreResNet110 (150) | 95.28 ± 0.10 | 95.56 ± 0.11 | 95.77 ± 0.04 | 95.83 ± 0.03 |
| WideResNet28x10 (200) | 96.18 ± 0.11 | 96.45 ± 0.11 | 96.64 ± 0.08 | 96.79 ± 0.05 |
# References
Provided model implementations were adapted from
* VGG: [github.com/pytorch/vision/](https://github.com/pytorch/vision/)
* PreResNet: [github.com/bearpaw/pytorch-classification](https://github.com/bearpaw/pytorch-classification)
* WideResNet: [github.com/meliketoy/wide-resnet.pytorch](https://github.com/meliketoy/wide-resnet.pytorch)
from .preresnet import *
from .vgg import *
from .wide_resnet import *
"""
PreResNet model definition
ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py
"""
import torch.nn as nn
import torchvision.transforms as transforms
import math
__all__ = ['PreResNet110']
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu = nn.ReLU(inplace=True)
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
class PreResNet(nn.Module):
def __init__(self, num_classes=10, depth=110):
super(PreResNet, self).__init__()
assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
n = (depth - 2) // 6
block = Bottleneck if depth >= 44 else BasicBlock
self.inplanes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
bias=False)
self.layer1 = self._make_layer(block, 16, n)
self.layer2 = self._make_layer(block, 32, n, stride=2)
self.layer3 = self._make_layer(block, 64, n, stride=2)
self.bn = nn.BatchNorm2d(64 * block.expansion)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
)
layers = list()
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.layer1(x) # 32x32
x = self.layer2(x) # 16x16
x = self.layer3(x) # 8x8
x = self.bn(x)
x = self.relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class PreResNet110:
base = PreResNet
args = list()
kwargs = {'depth': 110}
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
"""
VGG model definition
ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
"""
import math
import torch.nn as nn
import torchvision.transforms as transforms
__all__ = ['VGG16', 'VGG16BN', 'VGG19', 'VGG19BN']
def make_layers(cfg, batch_norm=False):
layers = list()
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfg = {
16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
512, 512, 512, 512, 'M'],
}
class VGG(nn.Module):
def __init__(self, num_classes=10, depth=16, batch_norm=False):
super(VGG, self).__init__()
self.features = make_layers(cfg[depth], batch_norm)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Linear(512, num_classes),
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
m.bias.data.zero_()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class Base:
base = VGG
args = list()
kwargs = dict()
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
class VGG16(Base):
pass
class VGG16BN(Base):
kwargs = {'batch_norm': True}
class VGG19(Base):
kwargs = {'depth': 19}
class VGG19BN(Base):
kwargs = {'depth': 19, 'batch_norm': True}
"""
WideResNet model definition
ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py
"""
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import math
__all__ = ['WideResNet28x10']
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
def conv_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_uniform(m.weight, gain=math.sqrt(2))
init.constant(m.bias, 0)
elif classname.find('BatchNorm') != -1:
init.constant(m.weight, 1)
init.constant(m.bias, 0)
class WideBasic(nn.Module):
def __init__(self, in_planes, planes, dropout_rate, stride=1):
super(WideBasic, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
)
def forward(self, x):
out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv2(F.relu(self.bn2(out)))
out += self.shortcut(x)
return out
class WideResNet(nn.Module):
def __init__(self, num_classes=10, depth=28, widen_factor=10, dropout_rate=0.):
super(WideResNet, self).__init__()
self.in_planes = 16
assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
n = (depth - 4) / 6
k = widen_factor
nstages = [16, 16 * k, 32 * k, 64 * k]
self.conv1 = conv3x3(3, nstages[0])
self.layer1 = self._wide_layer(WideBasic, nstages[1], n, dropout_rate, stride=1)
self.layer2 = self._wide_layer(WideBasic, nstages[2], n, dropout_rate, stride=2)
self.layer3 = self._wide_layer(WideBasic, nstages[3], n, dropout_rate, stride=2)
self.bn1 = nn.BatchNorm2d(nstages[3], momentum=0.9)
self.linear = nn.Linear(nstages[3], num_classes)
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
strides = [stride] + [1] * int(num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, dropout_rate, stride))
self.in_planes = planes
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
class WideResNet28x10:
base = WideResNet
args = list()
kwargs = {'depth': 28, 'widen_factor': 10}
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
import argparse
import os
import sys
import time
import torch
import torch.nn.functional as F
import torchvision
import models
import utils
import tabulate
parser = argparse.ArgumentParser(description='SGD/SWA training')
parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)')
parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH',
help='path to datasets location (default: None)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)')
parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)')
parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL',
help='model name (default: None)')
parser.add_argument('--resume', type=str, default=None, metavar='CKPT',
help='checkpoint to resume training from (default: None)')
parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)')
parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)')
parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)')
parser.add_argument('--lr_init', type=float, default=0.1, metavar='LR', help='initial learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)')
parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--swa_start', type=float, default=161, metavar='N', help='SWA start epoch number (default: 161)')
parser.add_argument('--swa_lr', type=float, default=0.05, metavar='LR', help='SWA LR (default: 0.05)')
parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N',
help='SWA model collection frequency/cycle length in epochs (default: 1)')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args()
print('Preparing directory %s' % args.dir)
os.makedirs(args.dir, exist_ok=True)
with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
f.write(' '.join(sys.argv))
f.write('\n')
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)
print('Loading dataset %s from %s' % (args.dataset, args.data_path))
ds = getattr(torchvision.datasets, args.dataset)
path = os.path.join(args.data_path, args.dataset.lower())
train_set = ds(path, train=True, download=True, transform=model_cfg.transform_train)
test_set = ds(path, train=False, download=True, transform=model_cfg.transform_test)
loaders = {
'train': torch.utils.data.DataLoader(
train_set,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True
),
'test': torch.utils.data.DataLoader(
test_set,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True
)
}
num_classes = max(train_set.train_labels) + 1
print('Preparing model')
model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model.cuda()
if args.swa:
print('SWA training')
swa_model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
swa_model.cuda()
swa_n = 0
else:
print('SGD training')
def schedule(epoch):
t = (epoch) / (args.swa_start if args.swa else args.epochs)
lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
if t <= 0.5:
factor = 1.0
elif t <= 0.9:
factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
else:
factor = lr_ratio
return args.lr_init * factor
criterion = F.cross_entropy
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr_init,
momentum=args.momentum,
weight_decay=args.wd
)
start_epoch = 0
if args.resume is not None:
print('Resume training from %s' % args.resume)
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
if args.swa:
swa_state_dict = checkpoint['swa_state_dict']
if swa_state_dict is not None:
swa_model.load_state_dict(swa_state_dict)
swa_n_ckpt = checkpoint['swa_n']
if swa_n_ckpt is not None:
swa_n = swa_n_ckpt
columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']
if args.swa:
columns = columns[:-1] + ['swa_te_loss', 'swa_te_acc'] + columns[-1:]
swa_res = {'loss': None, 'accuracy': None}
utils.save_checkpoint(
args.dir,
start_epoch,
state_dict=model.state_dict(),
swa_state_dict=swa_model.state_dict() if args.swa else None,
swa_n=swa_n if args.swa else None,
optimizer=optimizer.state_dict()
)
for epoch in range(start_epoch, args.epochs):
time_ep = time.time()
lr = schedule(epoch)
utils.adjust_learning_rate(optimizer, lr)
train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer)
if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
test_res = utils.eval(loaders['test'], model, criterion)
else:
test_res = {'loss': None, 'accuracy': None}
if args.swa and (epoch + 1) >= args.swa_start and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0:
utils.moving_average(swa_model, model, 1.0 / (swa_n + 1))
swa_n += 1
if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
utils.bn_update(loaders['train'], swa_model)
swa_res = utils.eval(loaders['test'], swa_model, criterion)
else:
swa_res = {'loss': None, 'accuracy': None}
if (epoch + 1) % args.save_freq == 0:
utils.save_checkpoint(
args.dir,
epoch + 1,
state_dict=model.state_dict(),
swa_state_dict=swa_model.state_dict() if args.swa else None,
swa_n=swa_n if args.swa else None,
optimizer=optimizer.state_dict()
)
time_ep = time.time() - time_ep
values = [epoch + 1, lr, train_res['loss'], train_res['accuracy'], test_res['loss'], test_res['accuracy'], time_ep]
if args.swa:
values = values[:-1] + [swa_res['loss'], swa_res['accuracy']] + values[-1:]
table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
if epoch % 40 == 0:
table = table.split('\n')
table = '\n'.join([table[1]] + table)
else:
table = table.split('\n')[2]
print(table)
if args.epochs % args.save_freq != 0:
utils.save_checkpoint(
args.dir,
args.epochs,
state_dict=model.state_dict(),
swa_state_dict=swa_model.state_dict() if args.swa else None,
swa_n=swa_n if args.swa else None,
optimizer=optimizer.state_dict()
)
import os
import torch
def adjust_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def save_checkpoint(dir, epoch, **kwargs):
state = {
'epoch': epoch,
}
state.update(kwargs)
filepath = os.path.join(dir, 'checkpoint-%d.pt' % epoch)
torch.save(state, filepath)
def train_epoch(loader, model, criterion, optimizer):
loss_sum = 0.0
correct = 0.0
model.train()
for i, (input, target) in enumerate(loader):
input = input.cuda(async=True)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
output = model(input_var)
loss = criterion(output, target_var)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.data[0] * input.size(0)
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target_var.data.view_as(pred)).sum()
return {
'loss': loss_sum / len(loader.dataset),
'accuracy': correct / len(loader.dataset) * 100.0,
}
def eval(loader, model, criterion):
loss_sum = 0.0
correct = 0.0
model.eval()
for i, (input, target) in enumerate(loader):
input = input.cuda(async=True)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
output = model(input_var)
loss = criterion(output, target_var)
loss_sum += loss.data[0] * input.size(0)
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target_var.data.view_as(pred)).sum()
return {
'loss': loss_sum / len(loader.dataset),
'accuracy': correct / len(loader.dataset) * 100.0,
}
def moving_average(net1, net2, alpha=1):
for param1, param2 in zip(net1.parameters(), net2.parameters()):
param1.data *= (1.0 - alpha)
param1.data += param2.data * alpha
def _check_bn(module, flag):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
flag[0] = True
def check_bn(model):
flag = [False]
model.apply(lambda module: _check_bn(module, flag))
return flag[0]
def reset_bn(module):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
def _get_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
momenta[module] = module.momentum
def _set_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.momentum = momenta[module]
def bn_update(loader, model):
"""
BatchNorm buffers update (if any).
Performs 1 epochs to estimate buffers average using train dataset.
:param loader: train dataset loader for buffers average estimation.
:param model: model being update
:return: None
"""
if not check_bn(model):
return
model.train()
momenta = {}
model.apply(reset_bn)
model.apply(lambda module: _get_momenta(module, momenta))
n = 0
for input, _ in loader:
input = input.cuda(async=True)
input_var = torch.autograd.Variable(input)
b = input_var.data.size(0)
momentum = b / (n + b)
for module in momenta.keys():
module.momentum = momentum
model(input_var)
n += b
model.apply(lambda module: _set_momenta(module, momenta))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册