提交 cbff80f3 编写于 作者: W Waleed Abdulla

Raise clear error if last training weights are not foundIf using the...

Raise clear error if last training weights are not foundIf using the --weights=last (or --model=last) to resume trainingbut the weights are not found now it raises a clear error message.
上级 a688a66b
......@@ -2062,8 +2062,7 @@ class MaskRCNN():
"""Finds the last checkpoint file of the last trained model in the
model directory.
Returns:
log_dir: The directory where events and weights are saved
checkpoint_path: the path to the last checkpoint file
The path of the last checkpoint file
"""
# Get directory names. Each directory corresponds to a model
dir_names = next(os.walk(self.model_dir))[1]
......@@ -2071,7 +2070,10 @@ class MaskRCNN():
dir_names = filter(lambda f: f.startswith(key), dir_names)
dir_names = sorted(dir_names)
if not dir_names:
return None, None
import errno
raise FileNotFoundError(
errno.ENOENT,
"Could not find model directory under {}".format(self.model_dir))
# Pick last directory
dir_name = os.path.join(self.model_dir, dir_names[-1])
# Find the last checkpoint
......@@ -2079,9 +2081,11 @@ class MaskRCNN():
checkpoints = filter(lambda f: f.startswith("mask_rcnn"), checkpoints)
checkpoints = sorted(checkpoints)
if not checkpoints:
return dir_name, None
import errno
raise FileNotFoundError(
errno.ENOENT, "Could not find weight files in {}".format(dir_name))
checkpoint = os.path.join(dir_name, checkpoints[-1])
return dir_name, checkpoint
return checkpoint
def load_weights(self, filepath, by_name=False, exclude=None):
"""Modified version of the correspoding Keras function with
......
......@@ -336,7 +336,7 @@ if __name__ == '__main__':
utils.download_trained_weights(weights_path)
elif args.weights.lower() == "last":
# Find last trained weights
weights_path = model.find_last()[1]
weights_path = model.find_last()
elif args.weights.lower() == "imagenet":
# Start from ImageNet trained weights
weights_path = model.get_imagenet_weights()
......
......@@ -265,7 +265,7 @@
"# weights_path = \"/path/to/mask_rcnn_balloon.h5\"\n",
"\n",
"# Or, load the last model you trained\n",
"weights_path = model.find_last()[1]\n",
"weights_path = model.find_last()\n",
"\n",
"# Load weights\n",
"print(\"Loading weights \", weights_path)\n",
......@@ -462,7 +462,7 @@ if __name__ == '__main__':
model_path = COCO_MODEL_PATH
elif args.model.lower() == "last":
# Find last trained weights
model_path = model.find_last()[1]
model_path = model.find_last()
elif args.model.lower() == "imagenet":
# Start from ImageNet trained weights
model_path = model.get_imagenet_weights()
......
......@@ -270,7 +270,7 @@
"elif config.NAME == \"coco\":\n",
" weights_path = COCO_MODEL_PATH\n",
"# Or, uncomment to load the last model you trained\n",
"# weights_path = model.find_last()[1]\n",
"# weights_path = model.find_last()\n",
"\n",
"# Load weights\n",
"print(\"Loading weights \", weights_path)\n",
......@@ -150,7 +150,7 @@
"elif config.NAME == \"coco\":\n",
" weights_path = COCO_MODEL_PATH\n",
"# Or, uncomment to load the last model you trained\n",
"# weights_path = model.find_last()[1]\n",
"# weights_path = model.find_last()\n",
"\n",
"# Load weights\n",
"print(\"Loading weights \", weights_path)\n",
......@@ -258,7 +258,7 @@
"# weights_path = \"/path/to/mask_rcnn_nucleus.h5\"\n",
"\n",
"# Or, load the last model you trained\n",
"weights_path = model.find_last()[1]\n",
"weights_path = model.find_last()\n",
"\n",
"# Load weights\n",
"print(\"Loading weights \", weights_path)\n",
......@@ -464,7 +464,7 @@ if __name__ == '__main__':
utils.download_trained_weights(weights_path)
elif args.weights.lower() == "last":
# Find last trained weights
weights_path = model.find_last()[1]
weights_path = model.find_last()
elif args.weights.lower() == "imagenet":
# Start from ImageNet trained weights
weights_path = model.get_imagenet_weights()
......
......@@ -458,7 +458,7 @@
" \"mrcnn_bbox\", \"mrcnn_mask\"])\n",
"elif init_with == \"last\":\n",
" # Load the last model you trained and continue training\n",
" model.load_weights(model.find_last()[1], by_name=True)"
" model.load_weights(model.find_last(), by_name=True)"
]
},
{
......@@ -875,10 +875,9 @@
"# Get path to saved weights\n",
"# Either set a specific path or find last trained weights\n",
"# model_path = os.path.join(ROOT_DIR, \".h5 file name here\")\n",
"model_path = model.find_last()[1]\n",
"model_path = model.find_last()\n",
"\n",
"# Load trained weights (fill in path to trained weights here)\n",
"assert model_path != \"\", \"Provide path to trained weights\"\n",
"# Load trained weights\n",
"print(\"Loading weights from \", model_path)\n",
"model.load_weights(model_path, by_name=True)"
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册