From cd446dee7348715335f076334dca49d0eb8374f5 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Tue, 11 May 2021 17:24:39 +0800 Subject: [PATCH] Add files via upload --- nets/yolo_training.py | 22 +++++++++++++++++++++- train.py | 3 ++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 5b8fc95..4947212 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -393,7 +393,27 @@ class YOLOLoss(nn.Module): anch_ious_max = anch_ious_max.view(pred_boxes[i].size()[:3]) noobj_mask[i][anch_ious_max>self.ignore_threshold] = 0 return noobj_mask - + +def weights_init(net, init_type='normal', init_gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and classname.find('Conv') != -1: + if init_type == 'normal': + torch.nn.init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + print('initialize network with %s type' % init_type) + net.apply(init_func) + class LossHistory(): def __init__(self, log_dir): import datetime diff --git a/train.py b/train.py index 0d5b3b7..8beda93 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm from nets.yolo3 import YoloBody -from nets.yolo_training import YOLOLoss, LossHistory +from nets.yolo_training import YOLOLoss, LossHistory, weights_init from utils.dataloader import YoloDataset, yolo_dataset_collate @@ -151,6 +151,7 @@ if __name__ == "__main__": # 训练前一定要修改Config里面的classes参数 #------------------------------------------------------# model = YoloBody(anchors, num_classes) + weights_init(model) #------------------------------------------------------# # 权值文件请看README,百度网盘下载 -- GitLab