未验证 提交 e5aaf662 编写于 作者: B Bubbliiiing 提交者: GitHub

Update train.py

上级 da37b95e
...@@ -106,16 +106,17 @@ if __name__ == "__main__": ...@@ -106,16 +106,17 @@ if __name__ == "__main__":
#------------------------------------------------------# #------------------------------------------------------#
model = YoloBody(anchors_mask, num_classes) model = YoloBody(anchors_mask, num_classes)
weights_init(model) weights_init(model)
#------------------------------------------------------# if model_path != '':
# 权值文件请看README,百度网盘下载 #------------------------------------------------------#
#------------------------------------------------------# # 权值文件请看README,百度网盘下载
print('Load weights {}.'.format(model_path)) #------------------------------------------------------#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Load weights {}.'.format(model_path))
model_dict = model.state_dict() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_dict = torch.load(model_path, map_location = device) model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} pretrained_dict = torch.load(model_path, map_location = device)
model_dict.update(pretrained_dict) pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model.load_state_dict(model_dict) model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model_train = model.train() model_train = model.train()
if Cuda: if Cuda:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册