提交 9874f502 编写于 作者: T tink2123

support tensorrt

上级 aa48cda3
......@@ -42,6 +42,7 @@ class TextRecognizer(object):
self.rec_algorithm = args.rec_algorithm
self.text_len = args.max_text_length
self.use_zero_copy_run = args.use_zero_copy_run
self.benchmark = args.enable_benchmark
char_ops_params = {
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
......@@ -62,8 +63,8 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
#if self.character_type == "ch":
#imgW = int((32 * max_wh_ratio))
if self.character_type == "ch" and not self.benchmark:
imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
......@@ -313,13 +314,17 @@ def main(args):
continue
valid_image_file_list.append(image_file)
img_list.append(img)
rec_res, predict_time = text_recognizer(img_list)
"""
try:
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:
print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
"""
for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" %
......
......@@ -154,11 +154,7 @@ def main(args):
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr(
image,
boxes,
txts,
scores,
drop_score=drop_score)
image, boxes, txts, scores, drop_score=drop_score)
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
......@@ -171,20 +167,20 @@ def main(args):
test_num = 10
test_time = 0.0
for i in range(0, test_num + 10):
#inputs = np.random.rand(640, 640, 3).astype(np.float32)
#print(image_file_list)
image_file = image_file_list[0]
inputs = cv2.imread(image_file)
inputs = cv2.resize(inputs, (int(640), int(640)))
start_time = time.time()
dt_boxes,rec_res = text_sys(inputs)
dt_boxes, rec_res = text_sys(inputs)
if i >= 10:
test_time += time.time() - start_time
time.sleep(0.01)
fp_message = "FP16" if args.use_fp16 else "FP32"
trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
print("model\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(
trt_msg, fp_message, args.max_batch_size, 1000 *
test_time / test_num))
print("Benchmark\t{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(
trt_msg, fp_message, args.max_batch_size, 1000 * test_time /
test_num))
if __name__ == "__main__":
main(utility.parse_args())
......@@ -37,8 +37,8 @@ def parse_args():
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--enable_benchmark", type=str2bool, default=True)
parser.add_argument("--max_batch_size", type=int, default=1)
parser.add_argument("--enable_benchmark", type=str2bool, default=False)
#params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册