提交 6a851ee2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5771 Add deeplabv3 to modelzoo

Merge pull request !5771 from jiangzhenguang/deeplabv3
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""evaluation.""" """eval deeplabv3."""
import argparse
from mindspore import context import os
from mindspore import Model import argparse
from mindspore.train.serialization import load_checkpoint, load_param_into_net import numpy as np
from src.md_dataset import create_dataset import cv2
from src.losses import OhemLoss from mindspore import Tensor
from src.miou_precision import MiouPrecision import mindspore.common.dtype as mstype
from src.deeplabv3 import deeplabv3_resnet50 import mindspore.nn as nn
from src.config import config from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.nets import net_factory
parser = argparse.ArgumentParser(description="Deeplabv3 evaluation") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") device_id=int(os.getenv('DEVICE_ID')))
parser.add_argument('--data_url', required=True, default=None, help='Evaluation data url')
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
def parse_args():
args_opt = parser.parse_args() parser = argparse.ArgumentParser('mindspore deeplabv3 eval')
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
print(args_opt) # val data
parser.add_argument('--data_root', type=str, default='', help='root path of val data')
parser.add_argument('--data_lst', type=str, default='', help='list of val data')
if __name__ == "__main__": parser.add_argument('--batch_size', type=int, default=16, help='batch size')
args_opt.crop_size = config.crop_size parser.add_argument('--crop_size', type=int, default=513, help='crop size')
args_opt.base_size = config.crop_size parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
eval_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="eval") parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], parser.add_argument('--scales', type=float, action='append', help='scales of evaluation')
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, parser.add_argument('--flip', action='store_true', help='perform left-right flip')
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
param_dict = load_checkpoint(args_opt.checkpoint_url)
load_param_into_net(net, param_dict) # model
mIou = MiouPrecision(config.seg_num_classes) parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
metrics = {'mIou': mIou} parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
loss = OhemLoss(config.seg_num_classes, config.ignore_label) parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
model = Model(net, loss, metrics=metrics)
model.eval(eval_dataset) args, _ = parser.parse_known_args()
return args
def cal_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n)
def resize_long(img, long_size=513):
h, w, _ = img.shape
if h > w:
new_h = long_size
new_w = int(1.0 * long_size * w / h)
else:
new_w = long_size
new_h = int(1.0 * long_size * h / w)
imo = cv2.resize(img, (new_w, new_h))
return imo
class BuildEvalNetwork(nn.Cell):
def __init__(self, network):
super(BuildEvalNetwork, self).__init__()
self.network = network
self.softmax = nn.Softmax(axis=1)
def construct(self, input_data):
output = self.network(input_data)
output = self.softmax(output)
return output
def pre_process(args, img_, crop_size=513):
# resize
img_ = resize_long(img_, crop_size)
resize_h, resize_w, _ = img_.shape
# mean, std
image_mean = np.array(args.image_mean)
image_std = np.array(args.image_std)
img_ = (img_ - image_mean) / image_std
# pad to crop_size
pad_h = crop_size - img_.shape[0]
pad_w = crop_size - img_.shape[1]
if pad_h > 0 or pad_w > 0:
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
# hwc to chw
img_ = img_.transpose((2, 0, 1))
return img_, resize_h, resize_w
def eval_batch(args, eval_net, img_lst, crop_size=513, flip=True):
result_lst = []
batch_size = len(img_lst)
batch_img = np.zeros((args.batch_size, 3, crop_size, crop_size), dtype=np.float32)
resize_hw = []
for l in range(batch_size):
img_ = img_lst[l]
img_, resize_h, resize_w = pre_process(args, img_, crop_size)
batch_img[l] = img_
resize_hw.append([resize_h, resize_w])
batch_img = np.ascontiguousarray(batch_img)
net_out = eval_net(Tensor(batch_img, mstype.float32))
net_out = net_out.asnumpy()
if flip:
batch_img = batch_img[:, :, :, ::-1]
net_out_flip = eval_net(Tensor(batch_img, mstype.float32))
net_out += net_out_flip.asnumpy()[:, :, :, ::-1]
for bs in range(batch_size):
probs_ = net_out[bs][:, :resize_hw[bs][0], :resize_hw[bs][1]].transpose((1, 2, 0))
ori_h, ori_w = img_lst[bs].shape[0], img_lst[bs].shape[1]
probs_ = cv2.resize(probs_, (ori_w, ori_h))
result_lst.append(probs_)
return result_lst
def eval_batch_scales(args, eval_net, img_lst, scales,
base_crop_size=513, flip=True):
sizes_ = [int((base_crop_size - 1) * sc) + 1 for sc in scales]
probs_lst = eval_batch(args, eval_net, img_lst, crop_size=sizes_[0], flip=flip)
print(sizes_)
for crop_size_ in sizes_[1:]:
probs_lst_tmp = eval_batch(args, eval_net, img_lst, crop_size=crop_size_, flip=flip)
for pl, _ in enumerate(probs_lst):
probs_lst[pl] += probs_lst_tmp[pl]
result_msk = []
for i in probs_lst:
result_msk.append(i.argmax(axis=2))
return result_msk
def net_eval():
args = parse_args()
# data list
with open(args.data_lst) as f:
img_lst = f.readlines()
# network
if args.model == 'deeplab_v3_s16':
network = net_factory.nets_map[args.model]('eval', args.num_classes, 16, args.freeze_bn)
elif args.model == 'deeplab_v3_s8':
network = net_factory.nets_map[args.model]('eval', args.num_classes, 8, args.freeze_bn)
else:
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
eval_net = BuildEvalNetwork(network)
# load model
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(eval_net, param_dict)
eval_net.set_train(False)
# evaluate
hist = np.zeros((args.num_classes, args.num_classes))
batch_img_lst = []
batch_msk_lst = []
bi = 0
image_num = 0
for i, line in enumerate(img_lst):
img_path, msk_path = line.strip().split(' ')
img_path = os.path.join(args.data_root, img_path)
msk_path = os.path.join(args.data_root, msk_path)
img_ = cv2.imread(img_path)
msk_ = cv2.imread(msk_path, cv2.IMREAD_GRAYSCALE)
batch_img_lst.append(img_)
batch_msk_lst.append(msk_)
bi += 1
if bi == args.batch_size:
batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales,
base_crop_size=args.crop_size, flip=args.flip)
for mi in range(args.batch_size):
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
bi = 0
batch_img_lst = []
batch_msk_lst = []
print('processed {} images'.format(i+1))
image_num = i
if bi > 0:
batch_res = eval_batch_scales(args, eval_net, batch_img_lst, scales=args.scales,
base_crop_size=args.crop_size, flip=args.flip)
for mi in range(bi):
hist += cal_hist(batch_msk_lst[mi].flatten(), batch_res[mi].flatten(), args.num_classes)
print('processed {} images'.format(image_num + 1))
print(hist)
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
print('per-class IoU', iu)
print('mean IoU', np.nanmean(iu))
if __name__ == '__main__':
net_eval()
mindspore
numpy
Pillow
python-opencv
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,25 +13,10 @@ ...@@ -12,25 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""File operation module."""
import os
export DEVICE_ID=7
def _is_obs(url): python /PATH/TO/MODEL_ZOO_CODE/data/build_seg_data.py --data_root=/PATH/TO/DATA_ROOT \
return url.startswith("obs://") or url.startswith("s3://") --data_lst=/PATH/TO/DATA_lst.txt \
--dst_path=/PATH/TO/MINDRECORED_NAME.mindrecord \
--num_shards=8 \
def read(url, binary=False): --shuffle=True
if _is_obs(url): \ No newline at end of file
# TODO read cloud file.
return None
with open(url, "rb" if binary else "r") as f:
return f.read()
def walk(url):
if _is_obs(url):
# TODO read cloud file.
return None
return os.walk(url)
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH"
echo "for example: bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH [PRETRAINED_CKPT_PATH](option)"
echo "It is better to use absolute path."
echo "=============================================================================================================="
DATA_DIR=$2
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export DEVICE_NUM=8
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0;i<DEVICE_NUM;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $i, device $DEVICE_ID"
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
env > env.log
taskset -c $cmdopt python ../train.py \
--distribute="true" \
--device_id=$DEVICE_ID \
--checkpoint_url=$PATH_CHECKPOINT \
--data_url=$DATA_DIR > log.txt 2>&1 &
cd ../
done
\ No newline at end of file
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=300 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.08 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=410 \
--keep_checkpoint_max=200 >log 2>&1 &
done
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=800 \
--batch_size=16 \
--crop_size=513 \
--base_lr=0.02 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--loss_scale=2048 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=820 \
--keep_checkpoint_max=200 >log 2>&1 &
done
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -c unlimited
train_path=/PATH/TO/EXPERIMENTS_DIR
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
export RANK_SIZE=8
export RANK_START_ID=0
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
mkdir ${train_path}/ckpt
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
mkdir ${train_path}/device${DEVICE_ID}
cd ${train_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
--data_file=/PATH/TO/MINDRECORD_NAME \
--train_epochs=300 \
--batch_size=16 \
--crop_size=513 \
--base_lr=0.008 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--loss_scale=2048 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--is_distributed \
--save_steps=110 \
--keep_checkpoint_max=200 >log 2>&1 &
done
#!/bin/bash #!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the License); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# httpwww.apache.orglicensesLICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH"
echo "for example: bash run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH"
echo "=============================================================================================================="
DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
export DEVICE_ID=3
mkdir -p ms_log export SLOG_PRINT_TO_STDOUT=0
CUR_DIR=`pwd` train_code_path=/PATH/TO/MODEL_ZOO_CODE
export GLOG_log_dir=${CUR_DIR}/ms_log eval_path=/PATH/TO/EVAL
export GLOG_logtostderr=0
python eval.py \ if [ -d ${eval_path} ]; then
--device_id=$DEVICE_ID \ rm -rf ${eval_path}
--checkpoint_url=$PATH_CHECKPOINT \ fi
--data_url=$DATA_DIR > eval.log 2>&1 & mkdir -p ${eval_path}
\ No newline at end of file
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=32 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s16 \
--scales=1.0 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the License); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# httpwww.apache.orglicensesLICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Init backbone."""
from .resnet_deeplab import Subsample, DepthwiseConv2dNative, SpaceToBatch, BatchToSpace, ResNetV1, \
RootBlockBeta, resnet50_dl
__all__ = [ export DEVICE_ID=3
"Subsample", "DepthwiseConv2dNative", "SpaceToBatch", "BatchToSpace", "ResNetV1", "RootBlockBeta", "resnet50_dl" export SLOG_PRINT_TO_STDOUT=0
] train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--scales=1.0 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=3
export SLOG_PRINT_TO_STDOUT=0
train_code_path=/PATH/TO/MODEL_ZOO_CODE
eval_path=/PATH/TO/EVAL
if [ -d ${eval_path} ]; then
rm -rf ${eval_path}
fi
mkdir -p ${eval_path}
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
--data_lst=/PATH/TO/DATA_lst.txt \
--batch_size=16 \
--crop_size=513 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s8 \
--scales=0.5 \
--scales=0.75 \
--scales=1.0 \
--scales=1.25 \
--scales=1.75 \
--flip \
--freeze_bn \
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
#!/bin/bash #!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the License); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# httpwww.apache.orglicensesLICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: " export DEVICE_ID=5
echo "bash run_standalone_pretrain.sh DEVICE_ID DATA_PATH" export SLOG_PRINT_TO_STDOUT=0
echo "for example: bash run_standalone_train.sh DEVICE_ID DATA_PATH [PRETRAINED_CKPT_PATH](option)" train_path=/PATH/TO/EXPERIMENTS_DIR
echo "==============================================================================================================" train_code_path=/PATH/TO/MODEL_ZOO_CODE
DEVICE_ID=$1 if [ -d ${train_path} ]; then
DATA_DIR=$2 rm -rf ${train_path}
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi fi
mkdir -p ${train_path}
mkdir -p ms_log mkdir ${train_path}/device${DEVICE_ID}
CUR_DIR=`pwd` mkdir ${train_path}/ckpt
export GLOG_log_dir=${CUR_DIR}/ms_log cd ${train_path}/device${DEVICE_ID} || exit
export GLOG_logtostderr=0
python train.py \ python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
--distribute="false" \ --train_dir=${train_path}/ckpt \
--device_id=$DEVICE_ID \ --train_epochs=200 \
--checkpoint_url=$PATH_CHECKPOINT \ --batch_size=32 \
--data_url=$DATA_DIR > log.txt 2>&1 & --crop_size=513 \
\ No newline at end of file --base_lr=0.015 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--save_steps=1500 \
--keep_checkpoint_max=200 >log 2>&1 &
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Init DeepLabv3."""
from .deeplabv3 import ASPP, DeepLabV3, deeplabv3_resnet50
from .backbone import *
__all__ = [
"ASPP", "DeepLabV3", "deeplabv3_resnet50"
]
__all__.extend(backbone.__all__)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import numpy as np
from mindspore.mindrecord import FileWriter
seg_schema = {"file_name": {"type": "string"}, "label": {"type": "bytes"}, "data": {"type": "bytes"}}
def parse_args():
parser = argparse.ArgumentParser('mindrecord')
parser.add_argument('--data_root', type=str, default='', help='root path of data')
parser.add_argument('--data_lst', type=str, default='', help='list of data')
parser.add_argument('--dst_path', type=str, default='', help='save path of mindrecords')
parser.add_argument('--num_shards', type=int, default=8, help='number of shards')
parser.add_argument('--shuffle', type=bool, default=True, help='shuffle or not')
parser_args, _ = parser.parse_known_args()
return parser_args
if __name__ == '__main__':
args = parse_args()
datas = []
with open(args.data_lst) as f:
lines = f.readlines()
if args.shuffle:
np.random.shuffle(lines)
dst_dir = '/'.join(args.dst_path.split('/')[:-1])
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
print('number of samples:', len(lines))
writer = FileWriter(file_name=args.dst_path, shard_num=args.num_shards)
writer.add_schema(seg_schema, "seg_schema")
cnt = 0
for l in lines:
img_path, label_path = l.strip().split(' ')
sample_ = {"file_name": img_path.split('/')[-1]}
with open(os.path.join(args.data_root, img_path), 'rb') as f:
sample_['data'] = f.read()
with open(os.path.join(args.data_root, label_path), 'rb') as f:
sample_['label'] = f.read()
datas.append(sample_)
cnt += 1
if cnt % 1000 == 0:
writer.write_raw_data(datas)
print('number of samples written:', cnt)
datas = []
if datas:
writer.write_raw_data(datas)
writer.commit()
print('number of samples written:', cnt)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import cv2
import numpy as np
import mindspore.dataset as de
class SegDataset:
def __init__(self,
image_mean,
image_std,
data_file='',
batch_size=32,
crop_size=512,
max_scale=2.0,
min_scale=0.5,
ignore_label=255,
num_classes=21,
num_readers=2,
num_parallel_calls=4,
shard_id=None,
shard_num=None):
self.data_file = data_file
self.batch_size = batch_size
self.crop_size = crop_size
self.image_mean = np.array(image_mean, dtype=np.float32)
self.image_std = np.array(image_std, dtype=np.float32)
self.max_scale = max_scale
self.min_scale = min_scale
self.ignore_label = ignore_label
self.num_classes = num_classes
self.num_readers = num_readers
self.num_parallel_calls = num_parallel_calls
self.shard_id = shard_id
self.shard_num = shard_num
assert max_scale > min_scale
def preprocess_(self, image, label):
# bgr image
image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
sc = np.random.uniform(self.min_scale, self.max_scale)
new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
image_out = (image_out - self.image_mean) / self.image_std
h_, w_ = max(new_h, self.crop_size), max(new_w, self.crop_size)
pad_h, pad_w = h_ - new_h, w_ - new_w
if pad_h > 0 or pad_w > 0:
image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
offset_h = np.random.randint(0, h_ - self.crop_size + 1)
offset_w = np.random.randint(0, w_ - self.crop_size + 1)
image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
if np.random.uniform(0.0, 1.0) > 0.5:
image_out = image_out[:, ::-1, :]
label_out = label_out[:, ::-1]
image_out = image_out.transpose((2, 0, 1))
image_out = image_out.copy()
label_out = label_out.copy()
return image_out, label_out
def get_dataset(self, repeat=1):
data_set = de.MindDataset(dataset_file=self.data_file, columns_list=["data", "label"],
shuffle=True, num_parallel_workers=self.num_readers,
num_shards=self.shard_num, shard_id=self.shard_id)
transforms_list = self.preprocess_
data_set = data_set.map(input_columns=["data", "label"], output_columns=["data", "label"],
operations=transforms_list, num_parallel_workers=self.num_parallel_calls)
data_set = data_set.shuffle(buffer_size=self.batch_size * 10)
data_set = data_set.batch(self.batch_size, drop_remainder=True)
data_set = data_set.repeat(repeat)
return data_set
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Process Dataset."""
import abc
import os
import time
from .utils.adapter import get_raw_samples, read_image
class BaseDataset:
"""
Create dataset.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage):
self.data_url = data_url
self.usage = usage
self.cur_index = 0
self.samples = []
_s_time = time.time()
self._load_samples()
_e_time = time.time()
print(f"load samples success~, time cost = {_e_time - _s_time}")
def __getitem__(self, item):
sample = self.samples[item]
return self._next_data(sample)
def __len__(self):
return len(self.samples)
@staticmethod
def _next_data(sample):
image_path = sample[0]
mask_image_path = sample[1]
image = read_image(image_path)
mask_image = read_image(mask_image_path)
return [image, mask_image]
@abc.abstractmethod
def _load_samples(self):
pass
class HwVocRawDataset(BaseDataset):
"""
Create dataset with raw data.
Args:
data_url (str): The path of data.
usage (str): Whether to use train or eval (default='train').
Returns:
Dataset.
"""
def __init__(self, data_url, usage="train"):
super().__init__(data_url, usage)
def _load_samples(self):
try:
self.samples = get_raw_samples(os.path.join(self.data_url, self.usage))
except Exception as e:
print("load HwVocRawDataset failed!!!")
raise e
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""OhemLoss."""
import mindspore.nn as nn from mindspore import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P import mindspore.nn as nn
from mindspore.ops import functional as F from mindspore.ops import operations as P
class OhemLoss(nn.Cell): class SoftmaxCrossEntropyLoss(nn.Cell):
"""Ohem loss cell.""" def __init__(self, num_cls=21, ignore_label=255):
def __init__(self, num, ignore_label): super(SoftmaxCrossEntropyLoss, self).__init__()
super(OhemLoss, self).__init__() self.one_hot = P.OneHot(axis=-1)
self.mul = P.Mul() self.on_value = Tensor(1.0, mstype.float32)
self.shape = P.Shape() self.off_value = Tensor(0.0, mstype.float32)
self.one_hot = nn.OneHot(-1, num, 1.0, 0.0) self.cast = P.Cast()
self.squeeze = P.Squeeze() self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.num = num self.not_equal = P.NotEqual()
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() self.num_cls = num_cls
self.mean = P.ReduceMean() self.ignore_label = ignore_label
self.select = P.Select() self.mul = P.Mul()
self.reshape = P.Reshape() self.sum = P.ReduceSum(False)
self.cast = P.Cast() self.div = P.RealDiv()
self.not_equal = P.NotEqual() self.transpose = P.Transpose()
self.equal = P.Equal() self.reshape = P.Reshape()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.fill = P.Fill() def construct(self, logits, labels):
self.transpose = P.Transpose() labels_int = self.cast(labels, mstype.int32)
self.ignore_label = ignore_label labels_int = self.reshape(labels_int, (-1,))
self.loss_weight = 1.0 logits_ = self.transpose(logits, (0, 2, 3, 1))
logits_ = self.reshape(logits_, (-1, self.num_cls))
def construct(self, logits, labels): weights = self.not_equal(labels_int, self.ignore_label)
if not self.training: weights = self.cast(weights, mstype.float32)
return 0 one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
logits = self.transpose(logits, (0, 2, 3, 1)) loss = self.ce(logits_, one_hot_labels)
logits = self.reshape(logits, (-1, self.num)) loss = self.mul(weights, loss)
labels = F.cast(labels, mstype.int32) loss = self.div(self.sum(loss), self.sum(weights))
labels = self.reshape(labels, (-1,)) return loss
one_hot_labels = self.one_hot(labels)
losses = self.cross_entropy(logits, one_hot_labels)[0]
weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight
weighted_losses = self.mul(losses, weights)
loss = self.reduce_sum(weighted_losses, (0,))
zeros = self.fill(mstype.float32, self.shape(weights), 0.0)
ones = self.fill(mstype.float32, self.shape(weights), 1.0)
present = self.select(self.equal(weights, zeros), zeros, ones)
present = self.reduce_sum(present, (0,))
zeros = self.fill(mstype.float32, self.shape(present), 0.0)
min_control = self.fill(mstype.float32, self.shape(present), 1.0)
present = self.select(self.equal(present, zeros), min_control, present)
loss = loss / present
return loss
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset module."""
import numpy as np
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C
from .ei_dataset import HwVocRawDataset
from .utils import custom_transforms as tr
class DataTransform:
"""Transform dataset for DeepLabV3."""
def __init__(self, args, usage):
self.args = args
self.usage = usage
def __call__(self, image, label):
if self.usage == "train":
return self._train(image, label)
if self.usage == "eval":
return self._eval(image, label)
return None
def _train(self, image, label):
"""
Process training data.
Args:
image (list): Image data.
label (list): Dataset label.
"""
image = Image.fromarray(image)
label = Image.fromarray(label)
rsc_tr = tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size)
image, label = rsc_tr(image, label)
rhf_tr = tr.RandomHorizontalFlip()
image, label = rhf_tr(image, label)
image = np.array(image).astype(np.float32)
label = np.array(label).astype(np.float32)
return image, label
def _eval(self, image, label):
"""
Process eval data.
Args:
image (list): Image data.
label (list): Dataset label.
"""
image = Image.fromarray(image)
label = Image.fromarray(label)
fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size)
image, label = fsc_tr(image, label)
image = np.array(image).astype(np.float32)
label = np.array(label).astype(np.float32)
return image, label
def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train", shuffle=True):
"""
Create Dataset for DeepLabV3.
Args:
args (dict): Train parameters.
data_url (str): Dataset path.
epoch_num (int): Epoch of dataset (default=1).
batch_size (int): Batch size of dataset (default=1).
usage (str): Whether is use to train or eval (default='train').
Returns:
Dataset.
"""
# create iter dataset
dataset = HwVocRawDataset(data_url, usage=usage)
dataset_len = len(dataset)
# wrapped with GeneratorDataset
dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
dataset.set_dataset_size(dataset_len)
dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage))
channelswap_op = C.HWC2CHW()
dataset = dataset.map(input_columns="image", operations=channelswap_op)
# 1464 samples / batch_size 8 = 183 batches
# epoch_num is num of steps
# 3658 steps / 183 = 20 epochs
if usage == "train" and shuffle:
dataset = dataset.shuffle(1464)
dataset = dataset.batch(batch_size, drop_remainder=(usage == "train"))
dataset = dataset.repeat(count=epoch_num)
dataset.map_model = 4
return dataset
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mIou."""
import numpy as np
from mindspore.nn.metrics.metric import Metric
def confuse_matrix(target, pred, n):
k = (target >= 0) & (target < n)
return np.bincount(n * target[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n)
def iou(hist):
denominator = hist.sum(1) + hist.sum(0) - np.diag(hist)
res = np.diag(hist) / np.where(denominator > 0, denominator, 1)
res = np.sum(res) / np.count_nonzero(denominator)
return res
class MiouPrecision(Metric):
"""Calculate miou precision."""
def __init__(self, num_class=21):
super(MiouPrecision, self).__init__()
if not isinstance(num_class, int):
raise TypeError('num_class should be integer type, but got {}'.format(type(num_class)))
if num_class < 1:
raise ValueError('num_class must be at least 1, but got {}'.format(num_class))
self._num_class = num_class
self._mIoU = []
self.clear()
def clear(self):
self._hist = np.zeros((self._num_class, self._num_class))
self._mIoU = []
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
predict_in = self._convert_data(inputs[0])
label_in = self._convert_data(inputs[1])
pred = predict_in
label = label_in
if len(label.flatten()) != len(pred.flatten()):
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
'classes'.format(self._num_class, predict_in.shape[1]))
self._hist = confuse_matrix(label.flatten(), pred.flatten(), self._num_class)
mIoUs = iou(self._hist)
self._mIoU.append(mIoUs)
def eval(self):
"""
Computes the mIoU categorical accuracy.
"""
mIoU = np.nanmean(self._mIoU)
print('mIoU = {}'.format(mIoU))
return mIoU
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, weight_init='xavier_uniform')
def conv3x3(in_planes, out_planes, stride=1, dilation=1, padding=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, pad_mode='pad', padding=padding,
dilation=dilation, weight_init='xavier_uniform')
class Resnet(nn.Cell):
def __init__(self, block, block_num, output_stride, use_batch_statistics=True):
super(Resnet, self).__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3,
weight_init='xavier_uniform')
self.bn1 = nn.BatchNorm2d(self.inplanes, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(block, 64, block_num[0], use_batch_statistics=use_batch_statistics)
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2, use_batch_statistics=use_batch_statistics)
if output_stride == 16:
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2,
use_batch_statistics=use_batch_statistics)
self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4],
use_batch_statistics=use_batch_statistics)
elif output_stride == 8:
self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2,
use_batch_statistics=use_batch_statistics)
self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4],
use_batch_statistics=use_batch_statistics)
def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None, use_batch_statistics=True):
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.SequentialCell([
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, use_batch_statistics=use_batch_statistics)
])
if grids is None:
grids = [1] * blocks
layers = [
block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0],
use_batch_statistics=use_batch_statistics)
]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(self.inplanes, planes, dilation=base_dilation * grids[i],
use_batch_statistics=use_batch_statistics))
return nn.SequentialCell(layers)
def construct(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out
class Bottleneck(nn.Cell):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_batch_statistics=True):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics)
self.conv2 = conv3x3(planes, planes, stride, dilation, dilation)
self.bn2 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.downsample = downsample
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class ASPP(nn.Cell):
def __init__(self, atrous_rates, phase='train', in_channels=2048, num_classes=21,
use_batch_statistics=True):
super(ASPP, self).__init__()
self.phase = phase
out_channels = 256
self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0], use_batch_statistics=use_batch_statistics)
self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics)
self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics)
self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics)
self.aspp_pooling = ASPPPooling(in_channels, out_channels)
self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1,
weight_init='xavier_uniform')
self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, num_classes, kernel_size=1, weight_init='xavier_uniform', has_bias=True)
self.concat = P.Concat(axis=1)
self.drop = nn.Dropout(0.3)
def construct(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.aspp_pooling(x)
x = self.concat((x1, x2))
x = self.concat((x, x3))
x = self.concat((x, x4))
x = self.concat((x, x5))
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.phase == 'train':
x = self.drop(x)
x = self.conv2(x)
return x
class ASPPPooling(nn.Cell):
def __init__(self, in_channels, out_channels, use_batch_statistics=True):
super(ASPPPooling, self).__init__()
self.conv = nn.SequentialCell([
nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'),
nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics),
nn.ReLU()
])
self.shape = P.Shape()
def construct(self, x):
size = self.shape(x)
out = nn.AvgPool2d(size[2])(x)
out = self.conv(out)
out = P.ResizeNearestNeighbor((size[2], size[3]), True)(out)
return out
class ASPPConv(nn.Cell):
def __init__(self, in_channels, out_channels, atrous_rate=1, use_batch_statistics=True):
super(ASPPConv, self).__init__()
if atrous_rate == 1:
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform')
else:
conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate,
dilation=atrous_rate, weight_init='xavier_uniform')
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
relu = nn.ReLU()
self.aspp_conv = nn.SequentialCell([conv, bn, relu])
def construct(self, x):
out = self.aspp_conv(x)
return out
class DeepLabV3(nn.Cell):
def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False):
super(DeepLabV3, self).__init__()
use_batch_statistics = not freeze_bn
self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride,
use_batch_statistics=use_batch_statistics)
self.aspp = ASPP([1, 6, 12, 18], phase, 2048, num_classes,
use_batch_statistics=use_batch_statistics)
self.shape = P.Shape()
def construct(self, x):
size = self.shape(x)
out = self.resnet(x)
out = self.aspp(out)
out = P.ResizeBilinear((size[2], size[3]), True)(out)
return out
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""
network config setting, will be used in train.py and evaluation.py from src.nets.deeplab_v3 import deeplab_v3
""" nets_map = {'deeplab_v3_s8': deeplab_v3.DeepLabV3,
from easydict import EasyDict as ed 'deeplab_v3_s16': deeplab_v3.DeepLabV3}
config = ed({
"learning_rate": 0.0014,
"weight_decay": 0.00005,
"momentum": 0.97,
"crop_size": 513,
"eval_scales": [0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
"atrous_rates": None,
"image_pyramid": None,
"output_stride": 16,
"fine_tune_batch_norm": False,
"ignore_label": 255,
"decoder_output_stride": None,
"seg_num_classes": 21,
"epoch_size": 6,
"batch_size": 2,
"enable_save_ckpt": True,
"save_checkpoint_steps": 10000,
"save_checkpoint_num": 1
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import sys
def get_multicards_json(server_id):
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
device_ips = {}
for hccn_item in hccn_configs:
hccn_item = hccn_item.strip()
if hccn_item.startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {'board_id': '0x0000', 'chip_info': '910', 'deploy_mode': 'lab', 'group_count': '1', 'group_list': []}
instance_list = []
usable_dev = ''
for instance_id in range(8):
instance = {'devices': []}
device_id = str(instance_id)
device_ip = device_ips[device_id]
usable_dev += str(device_id)
instance['devices'].append({
'device_id': device_id,
'device_ip': device_ip,
})
instance['rank_id'] = str(instance_id)
instance['server_id'] = server_id
instance_list.append(instance)
hccn_table['group_list'].append({
'device_num': '8',
'server_num': '1',
'group_name': '',
'instance_count': '8',
'instance_list': instance_list,
})
hccn_table['para_plane_nic_location'] = 'device'
hccn_table['para_plane_nic_name'] = []
for instance_id in range(8):
hccn_table['para_plane_nic_name'].append('eth{}'.format(instance_id))
hccn_table['para_plane_nic_num'] = '8'
hccn_table['status'] = 'completed'
import json
table_fn = os.path.join(os.getcwd(), 'rank_table_8p.json')
print(table_fn)
with open(table_fn, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
host_server_id = sys.argv[1]
get_multicards_json(host_server_id)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Adapter dataset."""
import fnmatch
import io
import os
import numpy as np
from PIL import Image
from ..utils import file_io
def get_raw_samples(data_url):
"""
Get dataset from raw data.
Args:
data_url (str): Dataset path.
Returns:
list, a file list.
"""
def _list_files(dir_path, pattern):
full_files = []
_, _, files = next(file_io.walk(dir_path))
for f in files:
if fnmatch.fnmatch(f.lower(), pattern.lower()):
full_files.append(os.path.join(dir_path, f))
return full_files
img_files = _list_files(os.path.join(data_url, "Images"), "*.jpg")
seg_files = _list_files(os.path.join(data_url, "SegmentationClassRaw"), "*.png")
files = []
for img_file in img_files:
_, file_name = os.path.split(img_file)
name, _ = os.path.splitext(file_name)
seg_file = os.path.join(data_url, "SegmentationClassRaw", ".".join([name, "png"]))
if seg_file in seg_files:
files.append([img_file, seg_file])
return files
def read_image(img_path):
"""
Read image from file.
Args:
img_path (str): image path.
"""
img = file_io.read(img_path.strip(), binary=True)
data = io.BytesIO(img)
img = Image.open(data)
return np.array(img)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Random process dataset."""
import random
import numpy as np
from PIL import Image, ImageOps, ImageFilter
class Normalize:
"""Normalize a tensor image with mean and standard deviation.
Args:
mean (tuple): means for each channel.
std (tuple): standard deviations for each channel.
"""
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
self.mean = mean
self.std = std
def __call__(self, img, mask):
img = np.array(img).astype(np.float32)
mask = np.array(mask).astype(np.float32)
img = ((img - self.mean) / self.std).astype(np.float32)
return img, mask
class RandomHorizontalFlip:
"""Randomly decide whether to horizontal flip."""
def __call__(self, img, mask):
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return img, mask
class RandomRotate:
"""
Randomly decide whether to rotate.
Args:
degree (float): The degree of rotate.
"""
def __init__(self, degree):
self.degree = degree
def __call__(self, img, mask):
rotate_degree = random.uniform(-1 * self.degree, self.degree)
img = img.rotate(rotate_degree, Image.BILINEAR)
mask = mask.rotate(rotate_degree, Image.NEAREST)
return img, mask
class RandomGaussianBlur:
"""Randomly decide whether to filter image with gaussian blur."""
def __call__(self, img, mask):
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(
radius=random.random()))
return img, mask
class RandomScaleCrop:
"""Randomly decide whether to scale and crop image."""
def __init__(self, base_size, crop_size, fill=0):
self.base_size = base_size
self.crop_size = crop_size
self.fill = fill
def __call__(self, img, mask):
# random scale (short edge)
short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < self.crop_size:
padh = self.crop_size - oh if oh < self.crop_size else 0
padw = self.crop_size - ow if ow < self.crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - self.crop_size)
y1 = random.randint(0, h - self.crop_size)
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return img, mask
class FixScaleCrop:
"""Scale and crop image with fixing size."""
def __init__(self, crop_size):
self.crop_size = crop_size
def __call__(self, img, mask):
w, h = img.size
if w > h:
oh = self.crop_size
ow = int(1.0 * w * oh / h)
else:
ow = self.crop_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - self.crop_size) / 2.))
y1 = int(round((h - self.crop_size) / 2.))
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
return img, mask
class FixedResize:
"""Resize image with fixing size."""
def __init__(self, size):
self.size = (size, size)
def __call__(self, img, mask):
assert img.size == mask.size
img = img.resize(self.size, Image.BILINEAR)
mask = mask.resize(self.size, Image.NEAREST)
return img, mask
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
def cosine_lr(base_lr, decay_steps, total_steps):
for i in range(total_steps):
step_ = min(i, decay_steps)
yield base_lr * 0.5 * (1 + np.cos(np.pi * step_ / decay_steps))
def poly_lr(base_lr, decay_steps, total_steps, end_lr=0.0001, power=0.9):
for i in range(total_steps):
step_ = min(i, decay_steps)
yield (base_lr - end_lr) * ((1.0 - step_ / decay_steps) ** power) + end_lr
def exponential_lr(base_lr, decay_steps, decay_rate, total_steps, staircase=False):
for i in range(total_steps):
if staircase:
power_ = i // decay_steps
else:
power_ = float(i) / decay_steps
yield base_lr * (decay_rate ** power_)
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""train.""" """train deeplabv3."""
import argparse
from mindspore import context import os
from mindspore.communication.management import init import argparse
from mindspore.nn.optim.momentum import Momentum from mindspore import context
from mindspore import Model from mindspore.train.model import ParallelMode, Model
from mindspore.context import ParallelMode import mindspore.nn as nn
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed from mindspore.communication.management import init, get_rank, get_group_size
from src.md_dataset import create_dataset from mindspore.train.callback import LossMonitor, TimeMonitor
from src.losses import OhemLoss from mindspore.train.loss_scale_manager import FixedLossScaleManager
from src.deeplabv3 import deeplabv3_resnet50 from src.data import data_generator
from src.config import config from src.loss import loss
from src.nets import net_factory
set_seed(1) from src.utils import learning_rates
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
parser = argparse.ArgumentParser(description="Deeplabv3 training") device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") class BuildTrainNetwork(nn.Cell):
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
args_opt = parser.parse_args() self.network = network
print(args_opt) self.criterion = criterion
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
class LossCallBack(Callback): def construct(self, input_data, label):
""" output = self.network(input_data)
Monitor the loss in training. net_loss = self.criterion(output, label)
Note: return net_loss
if per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1. def parse_args():
""" parser = argparse.ArgumentParser('mindspore deeplabv3 training')
def __init__(self, per_print_times=1): parser.add_argument('--train_dir', type=str, default='', help='where training log and ckpts saved')
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0: # dataset
raise ValueError("print_step must be int and >= 0") parser.add_argument('--data_file', type=str, default='', help='path and name of one mindrecord file')
self._per_print_times = per_print_times parser.add_argument('--batch_size', type=int, default=32, help='batch size')
def step_end(self, run_context): parser.add_argument('--crop_size', type=int, default=513, help='crop size')
cb_params = run_context.original_args() parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
str(cb_params.net_outputs))) parser.add_argument('--min_scale', type=float, default=0.5, help='minimum scale of data argumentation')
def model_fine_tune(flags, train_net, fix_weight_layer): parser.add_argument('--max_scale', type=float, default=2.0, help='maximum scale of data argumentation')
checkpoint_path = flags.checkpoint_url parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
if checkpoint_path is None: parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
return
param_dict = load_checkpoint(checkpoint_path) # optimizer
load_param_into_net(train_net, param_dict) parser.add_argument('--train_epochs', type=int, default=300, help='epoch')
for para in train_net.trainable_params(): parser.add_argument('--lr_type', type=str, default='cos', help='type of learning rate')
if fix_weight_layer in para.name: parser.add_argument('--base_lr', type=float, default=0.015, help='base learning rate')
para.requires_grad = False parser.add_argument('--lr_decay_step', type=int, default=40000, help='learning rate decay step')
if __name__ == "__main__": parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='learning rate decay rate')
if args_opt.distribute == "true": parser.add_argument('--loss_scale', type=float, default=3072.0, help='loss scale')
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
init() # model
args_opt.base_size = config.crop_size parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
args_opt.crop_size = config.crop_size parser.add_argument('--freeze_bn', action='store_true', help='freeze bn')
train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, usage="train") parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model')
dataset_size = train_dataset.get_dataset_size()
time_cb = TimeMonitor(data_size=dataset_size) # train
callback = [time_cb, LossCallBack()] parser.add_argument('--is_distributed', action='store_true', help='distributed training')
if config.enable_save_ckpt: parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
keep_checkpoint_max=config.save_checkpoint_num) parser.add_argument('--save_steps', type=int, default=3000, help='steps interval for saving')
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) parser.add_argument('--keep_checkpoint_max', type=int, default=int, help='max checkpoint for saving')
callback.append(ckpoint_cb)
net = deeplabv3_resnet50(config.seg_num_classes, [config.batch_size, 3, args_opt.crop_size, args_opt.crop_size], args, _ = parser.parse_known_args()
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, return args
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
net.set_train() def train():
model_fine_tune(args_opt, net, 'layer') args = parse_args()
loss = OhemLoss(config.seg_num_classes, config.ignore_label)
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) # init multicards training
model = Model(net, loss, opt) if args.is_distributed:
model.train(config.epoch_size, train_dataset, callback) init()
args.rank = get_rank()
args.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.group_size)
# dataset
dataset = data_generator.SegDataset(image_mean=args.image_mean,
image_std=args.image_std,
data_file=args.data_file,
batch_size=args.batch_size,
crop_size=args.crop_size,
max_scale=args.max_scale,
min_scale=args.min_scale,
ignore_label=args.ignore_label,
num_classes=args.num_classes,
num_readers=2,
num_parallel_calls=4,
shard_id=args.rank,
shard_num=args.group_size)
dataset = dataset.get_dataset(repeat=1)
# network
if args.model == 'deeplab_v3_s16':
network = net_factory.nets_map[args.model]('train', args.num_classes, 16, args.freeze_bn)
elif args.model == 'deeplab_v3_s8':
network = net_factory.nets_map[args.model]('train', args.num_classes, 8, args.freeze_bn)
else:
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
# loss
loss_ = loss.SoftmaxCrossEntropyLoss(args.num_classes, args.ignore_label)
loss_.add_flags_recursive(fp32=True)
train_net = BuildTrainNetwork(network, loss_)
# load pretrained model
if args.ckpt_pre_trained:
param_dict = load_checkpoint(args.ckpt_pre_trained)
load_param_into_net(train_net, param_dict)
# optimizer
iters_per_epoch = dataset.get_dataset_size()
total_train_steps = iters_per_epoch * args.train_epochs
if args.lr_type == 'cos':
lr_iter = learning_rates.cosine_lr(args.base_lr, total_train_steps, total_train_steps)
elif args.lr_type == 'poly':
lr_iter = learning_rates.poly_lr(args.base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
elif args.lr_type == 'exp':
lr_iter = learning_rates.exponential_lr(args.base_lr, args.lr_decay_step, args.lr_decay_rate,
total_train_steps, staircase=True)
else:
raise ValueError('unknown learning rate type')
opt = nn.Momentum(params=train_net.trainable_params(), learning_rate=lr_iter, momentum=0.9, weight_decay=0.0001,
loss_scale=args.loss_scale)
# loss scale
manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
model = Model(train_net, optimizer=opt, amp_level="O3", loss_scale_manager=manager_loss_scale)
# callback for saving ckpts
time_cb = TimeMonitor(data_size=iters_per_epoch)
loss_cb = LossMonitor()
cbs = [time_cb, loss_cb]
if args.rank == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=args.save_steps,
keep_checkpoint_max=args.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck)
cbs.append(ckpoint_cb)
model.train(args.train_epochs, dataset, callbacks=cbs)
if __name__ == '__main__':
train()
...@@ -88,7 +88,7 @@ rank_start=$((DEVICE_NUM * SERVER_ID)) ...@@ -88,7 +88,7 @@ rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++)) for((i=0; i<${DEVICE_NUM}; i++))
do do
export DEVICE_ID=$i export DEVICE_ID=${i}
export RANK_ID=$((rank_start + i)) export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i rm -rf ./train_parallel$i
mkdir ./train_parallel$i mkdir ./train_parallel$i
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册