提交 41e6ceaa 编写于 作者: C CaoJian

vgg16 support imagenet dataset on Ascend

上级 75045e3e
......@@ -14,9 +14,9 @@
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]"
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]"
exit 1
fi
......@@ -32,6 +32,19 @@ then
exit 1
fi
dataset_type='cifar10'
if [ $# == 3 ]
then
if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
dataset_type=$3
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$1
......@@ -45,8 +58,8 @@ do
cp *.py ./train_parallel$i
cp -r src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
env > env.log
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log &
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log &
cd ..
done
\ No newline at end of file
done
......@@ -139,5 +139,8 @@ def vgg16(num_classes=1000, args=None, phase="train"):
>>> vgg16(num_classes=1000, args=args)
"""
if args is None:
from .config import cifar_cfg
args = cifar_cfg
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
return net
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册