提交 240a61fd 编写于 作者: C chenyuntc

tiny modification for optimizer

上级 c96fa0a2
......@@ -18,7 +18,7 @@ class Config:
roi_sigma = 1.
# param for optimizer
weight_decay = 0.0001 # NOTE:it's modified
weight_decay = 0.0005 # 0.0005 in origin paper but 0.0001 in tf-faster-rcnn
lr_decay = 0.1 # 1e-3 -> 1e-4
# lr = 1e-3
lr1 = 1e-3 # extractor
......@@ -40,7 +40,7 @@ class Config:
# change lr
milestone = [0, 1, 5, 10]
use_adam = False
# debug
debug_file = '/tmp/debugf'
......
......@@ -82,7 +82,6 @@ class Dataset():
img, bbox, label, scale = self.tsf((ori_img, bbox, label))
# TODO: check whose stride is negative to fix this instead copy all
# some of the strides of a given numpy array are negative.
# This is currently not supported, but will be added in future releases.
return img.copy(), bbox.copy(), label.copy(), scale, ori_img
def __len__(self):
......@@ -90,34 +89,15 @@ class Dataset():
class TestDataset():
def __init__(self, opt):
def __init__(self, opt,split='test',use_difficult=True):
self.opt = opt
self.db = testset = VOCBboxDataset(opt.voc_data_dir, split='test', use_difficult=True)
self.db = testset = VOCBboxDataset(opt.voc_data_dir, split=split, use_difficult=use_difficult)
def __getitem__(self, idx):
ori_img, bbox, label, difficult = self.db.get_example(idx)
img = preprocess(ori_img)
return (img), ori_img.shape[1:], bbox, label, difficult
# TODO: check whose stride is negative to fix this instead copy all
# some of the strides of a given numpy array are negative.
# This is currently not supported, but will be added in future releases.
def __len__(self):
return len(self.db)
class TestDataset2():
def __init__(self, opt):
self.opt = opt
self.db = testset = VOCBboxDataset(opt.voc_data_dir, split='trainval', use_difficult=True)
def __getitem__(self, idx):
ori_img, bbox, label, difficult = self.db.get_example(idx)
img = preprocess(ori_img)
return (img), ori_img.shape[1:], bbox, label, difficult
# TODO: check whose stride is negative to fix this instead copy all
# some of the strides of a given numpy array are negative.
# This is currently not supported, but will be added in future releases.
def __len__(self):
return len(self.db)
......@@ -66,16 +66,14 @@ class FasterRCNN(nn.Module):
Region Proposal Networks. NIPS 2015.
Args:
extractor (callable Chain): A callable that takes a BCHW image
extractor (nn.Module): A module that takes a BCHW image
array and returns feature maps.
rpn (callable Chain): A callable that has the same interface as
rpn (nn.Module): A module that has the same interface as
:class:`~chainercv.links.model.faster_rcnn.RegionProposalNetwork`.
Please refer to the documentation found there.
head (callable Chain): A callable that takes
head (nn.Module): A callable that takes
a BCHW array, RoIs and batch indices for RoIs. This returns class
dependent localization paramters and class scores.
mean (numpy.ndarray): A value to be subtracted from an image
in :meth:`prepare`.
loc_normalize_mean (tuple of four floats): Mean values of
localization estimates.
loc_normalize_std (tupler of four floats): Standard deviation
......@@ -399,3 +397,16 @@ class FasterRCNN(nn.Module):
for param_group in self.optimizer.param_groups:
param_group['lr'] *= decay
return self.optimizer
def get_optimizer_adam(self):
lr = opt.lr1 *0.1
self.lr1 = lr
params = []
for key, value in dict(self.named_parameters()).items():
if value.requires_grad:
if 'bias' in key:
params += [{'params': [value], 'lr': lr * 2, 'weight_decay': 0}]
else:
params += [{'params': [value], 'lr': lr, 'weight_decay': opt.weight_decay}]
self.optimizer = t.optim.Adam(params)
return self.optimizer
......@@ -99,7 +99,9 @@ def train(**kwargs):
trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm')
# roi confusion matrix
trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float())
if best_map>0.6:
opt.test_num=10000
best_map = 0
eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num)
if eval_result['map'] > best_map:
......
......@@ -55,6 +55,8 @@ class FasterRCNNTrainer(nn.Module):
self.loc_normalize_std = faster_rcnn.loc_normalize_std
self.optimizer = self.faster_rcnn.get_optimizer()
if opt.use_adam:
self.optimizer = self.faster_rcnn.get_optimizer_adam()
# visdom wrapper
self.vis = Visualizer(env=opt.env)
......@@ -198,7 +200,7 @@ class FasterRCNNTrainer(nn.Module):
save_dict['optimizer'] = self.optimizer.state_dict()
if save_path is None:
timestr = time.strftime('%m%d-%H%M')
timestr = time.strftime('%m%d%H%M')
save_path = 'checkpoints/fasterrcnn_%s' % timestr
for k_,v_ in kwargs.items():
save_path += '_%s' %v_
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册