提交 84b08346 编写于 作者: C chengxianbin

yolov3 network directory rectification

上级 5d7b9d95
......@@ -19,10 +19,10 @@ import argparse
import time
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithEval
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
from config import ConfigYOLOV3ResNet18
from util import metrics
from src.yolov3 import yolov3_resnet18, YoloWithEval
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
from src.config import ConfigYOLOV3ResNet18
from src.utils import metrics
def yolo_eval(dataset_path, ckpt_path):
"""Yolov3 evaluation."""
......
......@@ -45,6 +45,9 @@ echo "After running the scipt, the network runs in the background. The log will
export MINDSPORE_HCCL_CONFIG_PATH=$6
export RANK_SIZE=$1
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
......@@ -56,6 +59,7 @@ do
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cp -r ./src ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
......@@ -63,7 +67,7 @@ do
if [ $# == 6 ]
then
taskset -c $cmdopt python ../train.py \
taskset -c $cmdopt python train.py \
--distribute=1 \
--lr=0.005 \
--device_num=$RANK_SIZE \
......@@ -76,7 +80,7 @@ do
if [ $# == 8 ]
then
taskset -c $cmdopt python ../train.py \
taskset -c $cmdopt python train.py \
--distribute=1 \
--lr=0.005 \
--device_num=$RANK_SIZE \
......
......@@ -20,4 +20,7 @@ echo "sh run_eval.sh DEVICE_ID CKPT_PATH MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
echo "for example: sh run_eval.sh 0 yolo.ckpt ./Mindrecord_eval ./dataset ./dataset/eval.txt"
echo "=============================================================================================================="
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
python eval.py --device_id=$1 --ckpt_path=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5
......@@ -27,6 +27,9 @@ then
exit 1
fi
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
if [ $# == 5 ]
then
python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5
......
......@@ -25,7 +25,7 @@ class ConfigYOLOV3ResNet18:
"""
img_shape = [352, 640]
feature_shape = [32, 3, 352, 640]
num_classes = 80
num_classes = 2
nms_max_num = 50
backbone_input_shape = [64, 64, 128, 256]
......
......@@ -23,7 +23,7 @@ from PIL import Image
import mindspore.dataset as de
from mindspore.mindrecord import FileWriter
import mindspore.dataset.transforms.vision.c_transforms as C
from config import ConfigYOLOV3ResNet18
from src.config import ConfigYOLOV3ResNet18
iter_cnt = 0
_NUM_BOXES = 50
......
......@@ -15,7 +15,7 @@
"""metrics utils"""
import numpy as np
from config import ConfigYOLOV3ResNet18
from src.config import ConfigYOLOV3ResNet18
def calc_iou(bbox_pred, bbox_ground):
......
......@@ -33,9 +33,9 @@ from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
from config import ConfigYOLOV3ResNet18
from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
from src.config import ConfigYOLOV3ResNet18
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册