未验证 提交 cd446dee 编写于 作者: B Bubbliiiing 提交者: GitHub

Add files via upload

上级 3e4ce1ac
......@@ -393,7 +393,27 @@ class YOLOLoss(nn.Module):
anch_ious_max = anch_ious_max.view(pred_boxes[i].size()[:3])
noobj_mask[i][anch_ious_max>self.ignore_threshold] = 0
return noobj_mask
def weights_init(net, init_type='normal', init_gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and classname.find('Conv') != -1:
if init_type == 'normal':
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
print('initialize network with %s type' % init_type)
net.apply(init_func)
class LossHistory():
def __init__(self, log_dir):
import datetime
......
......@@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from nets.yolo3 import YoloBody
from nets.yolo_training import YOLOLoss, LossHistory
from nets.yolo_training import YOLOLoss, LossHistory, weights_init
from utils.dataloader import YoloDataset, yolo_dataset_collate
......@@ -151,6 +151,7 @@ if __name__ == "__main__":
# 训练前一定要修改Config里面的classes参数
#------------------------------------------------------#
model = YoloBody(anchors, num_classes)
weights_init(model)
#------------------------------------------------------#
# 权值文件请看README,百度网盘下载
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册