未验证 提交 bd586f4a 编写于 作者: B Bin Lu 提交者: GitHub

Fix batch predict of cls and rec (#1089)

* fixbug_bs=1 of predict_cls\rec
上级 01ea6f99
......@@ -27,7 +27,6 @@ from utils.get_image_list import get_image_list
from python.preprocess import create_operators
from python.postprocess import build_postprocess
class ClsPredictor(Predictor):
def __init__(self, config):
super().__init__(config["Global"])
......@@ -59,6 +58,8 @@ class ClsPredictor(Predictor):
input_tensor.copy_from_cpu(image)
self.paddle_predictor.run()
batch_output = output_tensor.copy_to_cpu()
if self.postprocess is not None:
batch_output = self.postprocess(batch_output)
return batch_output
......@@ -66,14 +67,38 @@ def main(config):
cls_predictor = ClsPredictor(config)
image_list = get_image_list(config["Global"]["infer_imgs"])
assert config["Global"]["batch_size"] == 1
for idx, image_file in enumerate(image_list):
img = cv2.imread(image_file)[:, :, ::-1]
output = cls_predictor.predict(img)
output = cls_predictor.postprocess(output, [image_file])
print(output)
return
batch_imgs = []
batch_names = []
cnt = 0
for idx, img_path in enumerate(image_list):
img = cv2.imread(img_path)
if img is None:
logger.warning(
"Image file failed to read and has been skipped. The path: {}".
format(img_path))
else:
img = img[:, :, ::-1]
batch_imgs.append(img)
img_name = os.path.basename(img_path)
batch_names.append(img_name)
cnt += 1
if cnt % config["Global"]["batch_size"] == 0 or (idx + 1) == len(image_list):
if len(batch_imgs) == 0:
continue
batch_results = cls_predictor.predict(batch_imgs)
for number, result_dict in enumerate(batch_results):
filename = batch_names[number]
clas_ids = result_dict["class_ids"]
scores_str = "[{}]".format(", ".join("{:.2f}".format(
r) for r in result_dict["scores"]))
label_names = result_dict["label_names"]
print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}".
format(filename, clas_ids, scores_str, label_names))
batch_imgs = []
batch_names = []
return
if __name__ == "__main__":
args = config.parse_args()
......
......@@ -54,12 +54,14 @@ class RecPredictor(Predictor):
input_tensor.copy_from_cpu(image)
self.paddle_predictor.run()
batch_output = output_tensor.copy_to_cpu()
if feature_normalize:
feas_norm = np.sqrt(
np.sum(np.square(batch_output), axis=1, keepdims=True))
batch_output = np.divide(batch_output, feas_norm)
if self.postprocess is not None:
batch_output = self.postprocess(batch_output)
return batch_output
......@@ -67,14 +69,33 @@ def main(config):
rec_predictor = RecPredictor(config)
image_list = get_image_list(config["Global"]["infer_imgs"])
assert config["Global"]["batch_size"] == 1
for idx, image_file in enumerate(image_list):
batch_input = []
img = cv2.imread(image_file)[:, :, ::-1]
output = rec_predictor.predict(img)
if rec_predictor.postprocess is not None:
output = rec_predictor.postprocess(output)
print(output)
batch_imgs = []
batch_names = []
cnt = 0
for idx, img_path in enumerate(image_list):
img = cv2.imread(img_path)
if img is None:
logger.warning(
"Image file failed to read and has been skipped. The path: {}".
format(img_path))
else:
img = img[:, :, ::-1]
batch_imgs.append(img)
img_name = os.path.basename(img_path)
batch_names.append(img_name)
cnt += 1
if cnt % config["Global"]["batch_size"] == 0 or (idx + 1) == len(image_list):
if len(batch_imgs) == 0:
continue
batch_results = rec_predictor.predict(batch_imgs)
for number, result_dict in enumerate(batch_results):
filename = batch_names[number]
print("{}:\t{}".format(filename, result_dict))
batch_imgs = []
batch_names = []
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册