diff --git a/train.py b/train.py index cc09307cf136948fcac65c7ef4d6308ef4a94ab3..6441d1db12534b980266eb1a90f31020aff86766 100644 --- a/train.py +++ b/train.py @@ -106,16 +106,17 @@ if __name__ == "__main__": #------------------------------------------------------# model = YoloBody(anchors_mask, num_classes) weights_init(model) - #------------------------------------------------------# - # 权值文件请看README,百度网盘下载 - #------------------------------------------------------# - print('Load weights {}.'.format(model_path)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model_dict = model.state_dict() - pretrained_dict = torch.load(model_path, map_location = device) - pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) + if model_path != '': + #------------------------------------------------------# + # 权值文件请看README,百度网盘下载 + #------------------------------------------------------# + print('Load weights {}.'.format(model_path)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model_dict = model.state_dict() + pretrained_dict = torch.load(model_path, map_location = device) + pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) model_train = model.train() if Cuda: