From 0e8b317f22c89ae138f341a471ee03473e51facc Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Wed, 15 Apr 2020 02:36:47 +0000 Subject: [PATCH] auto download pretrain --- examples/yolov3/README.md | 11 ++++------- examples/yolov3/main.py | 6 +++++- examples/yolov3/pretrain_weights/download.sh | 5 ----- hapi/download.py | 12 +++++++++++- 4 files changed, 20 insertions(+), 14 deletions(-) delete mode 100644 examples/yolov3/pretrain_weights/download.sh diff --git a/examples/yolov3/README.md b/examples/yolov3/README.md index a8784ac..9a0d2cc 100644 --- a/examples/yolov3/README.md +++ b/examples/yolov3/README.md @@ -99,20 +99,17 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层 | ... ``` -### 预训练权重下载 - -YOLOv3模型训练需下载骨干网络DarkNet53的预训练权重,可通过如下方式下载。 - ```bash sh pretrain_weights/download.sh ``` ### 模型训练 -数据和预训练权重下载完成后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`yolo_checkpoint`目录下。 +数据准备完成后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`yolo_checkpoint`目录下。 YOLOv3模型训练总batch_size为64训练,以下以使用4卡Tesla P40每卡batch_size为16训练介绍训练方式。对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数。 +YOLOv3模型训练须加载骨干网络[DarkNet53]()的预训练权重,可在训练时通过`--pretrain_weights`指定,若指定为URL,将自动下载权重至`~/.cache/paddle/weights`目录并加载。 `main.py`脚本参数可通过如下命令查询 @@ -125,7 +122,7 @@ python main.py --help 使用如下方式进行多卡训练: ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data= --batch_size=16 --pretrain_weights=./pretrain_weights/darknet53_pretrained +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data= --batch_size=16 --pretrain_weights=https://paddlemodels.bj.bcebos.com/hapi/darknet53_pretrained.pdparams ``` #### 动态图训练 @@ -135,7 +132,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data= 使用如下方式进行多卡训练: ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m paddle.distributed.launch --data= --batch_size=16 -d --pretrain_weights=./pretrain_weights/darknet53_pretrained +CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m paddle.distributed.launch --data= --batch_size=16 -d --pretrain_weights=https://paddlemodels.bj.bcebos.com/hapi/darknet53_pretrained.pdparams ``` diff --git a/examples/yolov3/main.py b/examples/yolov3/main.py index 6e05955..7203329 100644 --- a/examples/yolov3/main.py +++ b/examples/yolov3/main.py @@ -27,6 +27,7 @@ from paddle.io import DataLoader from hapi.model import Model, Input, set_device from hapi.distributed import DistributedBatchSampler +from hapi.download import is_url, get_weights_path from hapi.datasets import COCODataset from hapi.vision.transforms import * from hapi.vision.models import yolov3_darknet53, YoloLoss @@ -125,7 +126,10 @@ def main(): pretrained=pretrained) if FLAGS.pretrain_weights and not FLAGS.eval_only: - model.load(FLAGS.pretrain_weights, skip_mismatch=True, reset_optimizer=True) + pretrain_weights = FLAGS.pretrain_weights + if is_url(pretrain_weights): + pretrain_weights = get_weights_path(pretrain_weights) + model.load(pretrain_weights, skip_mismatch=True, reset_optimizer=True) optim = make_optimizer(len(batch_sampler), parameter_list=model.parameters()) diff --git a/examples/yolov3/pretrain_weights/download.sh b/examples/yolov3/pretrain_weights/download.sh deleted file mode 100644 index 6982f44..0000000 --- a/examples/yolov3/pretrain_weights/download.sh +++ /dev/null @@ -1,5 +0,0 @@ -DIR="$( cd "$(dirname "$0")" ; pwd -P )" -cd "$DIR" - -echo "Downloading https://paddlemodels.bj.bcebos.com/hapi/darknet53_pretrained.pdparams" -wget https://paddlemodels.bj.bcebos.com/hapi/darknet53_pretrained.pdparams diff --git a/hapi/download.py b/hapi/download.py index 10d3fba..e9a89ba 100644 --- a/hapi/download.py +++ b/hapi/download.py @@ -29,13 +29,22 @@ from paddle.fluid.dygraph.parallel import ParallelEnv import logging logger = logging.getLogger(__name__) -__all__ = ['get_weights_path'] +__all__ = ['get_weights_path', 'is_url'] WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights") DOWNLOAD_RETRY_LIMIT = 3 +def is_url(path): + """ + Whether path is URL. + Args: + path (string): URL string or not. + """ + return path.startswith('http://') or path.startswith('https://') + + def get_weights_path(url, md5sum=None): """Get weights path from WEIGHT_HOME, if not exists, download it from url. @@ -62,6 +71,7 @@ def get_path(url, root_dir, md5sum=None, check_exist=True): WEIGHTS_HOME or DATASET_HOME md5sum (str): md5 sum of download package """ + assert is_url(url), "downloading from {} not a url".format(url) # parse path after download to decompress under root_dir fullpath = map_path(url, root_dir) -- GitLab