From e5aaf662d194f1a15574ab75f4e1b31cbc3d19f5 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Mon, 27 Sep 2021 10:57:49 +0800 Subject: [PATCH] Update train.py --- train.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index cc09307..6441d1d 100644 --- a/train.py +++ b/train.py @@ -106,16 +106,17 @@ if __name__ == "__main__": #------------------------------------------------------# model = YoloBody(anchors_mask, num_classes) weights_init(model) - #------------------------------------------------------# - # 权值文件请看README,百度网盘下载 - #------------------------------------------------------# - print('Load weights {}.'.format(model_path)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model_dict = model.state_dict() - pretrained_dict = torch.load(model_path, map_location = device) - pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) + if model_path != '': + #------------------------------------------------------# + # 权值文件请看README,百度网盘下载 + #------------------------------------------------------# + print('Load weights {}.'.format(model_path)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model_dict = model.state_dict() + pretrained_dict = torch.load(model_path, map_location = device) + pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) model_train = model.train() if Cuda: -- GitLab