diff --git a/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh b/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh index 92359a438633287c7aaffdce3abe94afcdefa9fb..8d621ec2c608657eaa556bef13486ac1663a14fd 100644 --- a/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh @@ -14,9 +14,9 @@ # limitations under the License. # ============================================================================ -if [ $# != 2 ] +if [ $# != 2 ] && [ $# != 3 ] then - echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]" + echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]" exit 1 fi @@ -32,6 +32,19 @@ then exit 1 fi + +dataset_type='cifar10' +if [ $# == 3 ] +then + if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ] + then + echo "error: the selected dataset is neither cifar10 nor imagenet2012" + exit 1 + fi + dataset_type=$3 +fi + + export DEVICE_NUM=8 export RANK_SIZE=8 export RANK_TABLE_FILE=$1 @@ -45,8 +58,8 @@ do cp *.py ./train_parallel$i cp -r src ./train_parallel$i cd ./train_parallel$i || exit - echo "start training for rank $RANK_ID, device $DEVICE_ID" + echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" env > env.log - python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log & + python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log & cd .. -done \ No newline at end of file +done diff --git a/model_zoo/official/cv/vgg16/src/vgg.py b/model_zoo/official/cv/vgg16/src/vgg.py index 3e87acf1aff34fa96f7b77591c835b3ebc778e2f..19ca7e0dcfd5aaa870b8caefac35452dfe0dfedc 100644 --- a/model_zoo/official/cv/vgg16/src/vgg.py +++ b/model_zoo/official/cv/vgg16/src/vgg.py @@ -139,5 +139,8 @@ def vgg16(num_classes=1000, args=None, phase="train"): >>> vgg16(num_classes=1000, args=args) """ + if args is None: + from .config import cifar_cfg + args = cifar_cfg net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) return net