未验证 提交 a6ceff1c 编写于 作者: S Steffy-zxf 提交者: GitHub

Update ocr (#906)

上级 90ef8d10
## 概述
chinese_ocr_db_crnn_mobile Module用于识别图片当中的汉字。其基于[chinese_text_detection_db_mobile Module](https://www.paddlepaddle.org.cn/hubdetail?name=chinese_text_detection_db_mobile&en_category=TextRecognition)检测得到的文本框,继续识别文本框中的中文文字。识别文字算法采用CRNN(Convolutional Recurrent Neural Network)即卷积递归神经网络。其是DCNN和RNN的组合,专门用于识别图像中的序列式对象。与CTC loss配合使用,进行文字识别,可以直接从文本词级或行级的标注中学习,不需要详细的字符级的标注。该Module是一个超轻量级中文OCR模型,支持直接预测。
chinese_ocr_db_crnn_mobile Module用于识别图片当中的汉字。其基于[chinese_text_detection_db_mobile Module](https://www.paddlepaddle.org.cn/hubdetail?name=chinese_text_detection_db_mobile&en_category=TextRecognition)检测得到的文本框,继续识别文本框中的中文文字。之后对检测文本框进行角度分类。最终识别文字算法采用CRNN(Convolutional Recurrent Neural Network)即卷积递归神经网络。其是DCNN和RNN的组合,专门用于识别图像中的序列式对象。与CTC loss配合使用,进行文字识别,可以直接从文本词级或行级的标注中学习,不需要详细的字符级的标注。该Module是一个超轻量级中文OCR模型,支持直接预测。
<p align="center">
......@@ -142,3 +142,11 @@ pyclipper
* 1.0.1
修复使用在线服务调用模型失败问题
* 1.0.2
支持mkldnn加速CPU计算
* 1.1.0
使用超轻量级的三阶段模型(文本框检测-角度分类-文字识别)识别图片文字。
......@@ -22,17 +22,23 @@ class CharacterOps(object):
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config:
add_space = config['use_space_char']
self.character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n")
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
if add_space:
self.character_str += " "
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
......@@ -46,6 +52,8 @@ class CharacterOps(object):
self.end_str = "eos"
if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character
elif self.loss_type == "srn":
dict_character = dict_character + [self.beg_str, self.end_str]
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
......@@ -90,7 +98,7 @@ class CharacterOps(object):
if is_remove_duplicate:
if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue
char_list.append(self.character[text_index[idx]])
char_list.append(self.character[int(text_index[idx])])
text = ''.join(char_list)
return text
......@@ -139,6 +147,42 @@ def cal_predicts_accuracy(char_ops,
return acc, acc_num, img_num
def cal_predicts_accuracy_srn(char_ops,
preds,
labels,
max_text_len,
is_debug=False):
acc_num = 0
img_num = 0
char_num = char_ops.get_char_num()
total_len = preds.shape[0]
img_num = int(total_len / max_text_len)
for i in range(img_num):
cur_label = []
cur_pred = []
for j in range(max_text_len):
if labels[j + i * max_text_len] != int(char_num - 1): #0
cur_label.append(labels[j + i * max_text_len][0])
else:
break
for j in range(max_text_len + 1):
if j < len(cur_label) and preds[
j + i * max_text_len][0] != cur_label[j]:
break
elif j == len(cur_label) and j == max_text_len:
acc_num += 1
break
elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(
char_num - 1):
acc_num += 1
break
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def convert_rec_attention_infer_res(preds):
img_num = preds.shape[0]
target_lod = [0]
......
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import copy
......@@ -25,9 +21,10 @@ from chinese_ocr_db_crnn_mobile.utils import base64_to_cv2, draw_ocr, get_image_
@moduleinfo(
name="chinese_ocr_db_crnn_mobile",
version="1.0.4",
version="1.1.0",
summary=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ",
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/text_recognition")
......@@ -41,23 +38,31 @@ class ChineseOCRDBCRNN(hub.Module):
char_ops_params = {
'character_type': 'ch',
'character_dict_path': self.character_dict_path,
'loss_type': 'ctc'
'loss_type': 'ctc',
'max_text_length': 25,
'use_space_char': True
}
self.char_ops = CharacterOps(char_ops_params)
self.rec_image_shape = [3, 32, 320]
self._text_detector_module = text_detector_module
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.pretrained_model_path = os.path.join(self.directory,
'inference_model')
self.enable_mkldnn = enable_mkldnn
self._set_config()
def _set_config(self):
self.rec_pretrained_model_path = os.path.join(
self.directory, 'inference_model', 'character_rec')
self.cls_pretrained_model_path = os.path.join(
self.directory, 'inference_model', 'angle_cls')
self.rec_predictor, self.rec_input_tensor, self.rec_output_tensors = self._set_config(
self.rec_pretrained_model_path)
self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config(
self.cls_pretrained_model_path)
def _set_config(self, pretrained_model_path):
"""
predictor config setting
predictor config path
"""
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
model_file_path = os.path.join(pretrained_model_path, 'model')
params_file_path = os.path.join(pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path)
try:
......@@ -72,21 +77,25 @@ class ChineseOCRDBCRNN(hub.Module):
else:
config.disable_gpu()
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config)
input_names = self.predictor.get_input_names()
self.input_tensor = self.predictor.get_input_tensor(input_names[0])
output_names = self.predictor.get_output_names()
self.output_tensors = []
predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0])
output_names = predictor.get_output_names()
output_tensors = []
for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name)
self.output_tensors.append(output_tensor)
output_tensor = predictor.get_output_tensor(output_name)
output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors
@property
def text_detector_module(self):
......@@ -97,7 +106,7 @@ class ChineseOCRDBCRNN(hub.Module):
self._text_detector_module = hub.Module(
name='chinese_text_detection_db_mobile',
enable_mkldnn=self.enable_mkldnn,
version='1.0.2')
version='1.0.3')
return self._text_detector_module
def read_images(self, paths=[]):
......@@ -113,6 +122,7 @@ class ChineseOCRDBCRNN(hub.Module):
return images
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
......@@ -121,23 +131,51 @@ class ChineseOCRDBCRNN(hub.Module):
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0],\
[img_crop_width, img_crop_height], [0, img_crop_height]])
'''
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def resize_norm_img(self, img, max_wh_ratio):
def resize_norm_img_rec(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
imgW = int(32 * max_wh_ratio)
assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_cls(self, img):
cls_image_shape = [3, 48, 192]
imgC, imgH, imgW = cls_image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
......@@ -147,7 +185,11 @@ class ChineseOCRDBCRNN(hub.Module):
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
if cls_image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
......@@ -198,6 +240,7 @@ class ChineseOCRDBCRNN(hub.Module):
detection_results = self.text_detector_module.detect_text(
images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh)
boxes = [
np.array(item['data']).astype(np.float32)
for item in detection_results
......@@ -206,7 +249,7 @@ class ChineseOCRDBCRNN(hub.Module):
for index, img_boxes in enumerate(boxes):
original_image = predicted_data[index].copy()
result = {'save_path': ''}
if img_boxes is None:
if img_boxes.size == 0:
result['data'] = []
else:
img_crop_list = []
......@@ -216,8 +259,9 @@ class ChineseOCRDBCRNN(hub.Module):
img_crop = self.get_rotate_crop_image(
original_image, tmp_box)
img_crop_list.append(img_crop)
img_crop_list, angle_list = self._classify_text(img_crop_list)
rec_results = self._recognize_text(img_crop_list)
# if the recognized text confidence score is lower than text_thresh, then drop it
rec_res_final = []
for index, res in enumerate(rec_results):
......@@ -276,32 +320,86 @@ class ChineseOCRDBCRNN(hub.Module):
cv2.imwrite(save_file_path, draw_img[:, :, ::-1])
return save_file_path
def _recognize_text(self, image_list):
img_num = len(image_list)
def _classify_text(self, image_list):
img_list = copy.deepcopy(image_list)
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
cls_res = [['', 0.0]] * img_num
batch_num = 30
rec_res = []
predict_time = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = image_list[ino].shape[0:2]
wh_ratio = w / h
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(image_list[ino], max_wh_ratio)
norm_img = self.resize_norm_img_cls(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0]
predict_batch = self.output_tensors[1].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0]
self.cls_input_tensor.copy_from_cpu(norm_img_batch)
self.cls_predictor.zero_copy_run()
prob_out = self.cls_output_tensors[0].copy_to_cpu()
label_out = self.cls_output_tensors[1].copy_to_cpu()
if len(label_out.shape) != 1:
prob_out, label_out = label_out, prob_out
label_list = ['0', '180']
for rno in range(len(label_out)):
label_idx = label_out[rno]
score = prob_out[rno][label_idx]
label = label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > 0.9999:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res
def _recognize_text(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = 30
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img_rec(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy()
self.rec_input_tensor.copy_from_cpu(norm_img_batch)
self.rec_predictor.zero_copy_run()
rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu()
rec_idx_lod = self.rec_output_tensors[0].lod()[0]
predict_batch = self.rec_output_tensors[1].copy_to_cpu()
predict_lod = self.rec_output_tensors[1].lod()[0]
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
......@@ -316,9 +414,10 @@ class ChineseOCRDBCRNN(hub.Module):
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
# rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
return rec_res
return rec_res
def save_inference_model(self,
dirname,
......@@ -326,9 +425,12 @@ class ChineseOCRDBCRNN(hub.Module):
params_filename=None,
combined=True):
detector_dir = os.path.join(dirname, 'text_detector')
classifier_dir = os.path.join(dirname, 'angle_classifier')
recognizer_dir = os.path.join(dirname, 'text_recognizer')
self._save_detector_model(detector_dir, model_filename, params_filename,
combined)
self._save_classifier_model(classifier_dir, model_filename,
params_filename, combined)
self._save_recognizer_model(recognizer_dir, model_filename,
params_filename, combined)
logger.info("The inference model has been saved in the path {}".format(
......@@ -353,10 +455,40 @@ class ChineseOCRDBCRNN(hub.Module):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
model_file_path = os.path.join(self.rec_pretrained_model_path, 'model')
params_file_path = os.path.join(self.rec_pretrained_model_path,
'params')
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.rec_pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
def _save_classifier_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
model_file_path = os.path.join(self.cls_pretrained_model_path, 'model')
params_file_path = os.path.join(self.cls_pretrained_model_path,
'params')
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.pretrained_model_path,
dirname=self.cls_pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
......@@ -430,7 +562,7 @@ class ChineseOCRDBCRNN(hub.Module):
if __name__ == '__main__':
ocr = ChineseOCRDBCRNN()
image_path = [
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/2.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/test_image.jpg'
]
......
......@@ -175,8 +175,8 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
......
......@@ -137,3 +137,7 @@ pyclipper
* 1.0.0
初始发布
* 1.0.1
支持mkldnn加速CPU计算
......@@ -62,7 +62,7 @@ def detect_text(paths=[],
import paddlehub as hub
import cv2
text_detector = hub.Module(name="chinese_text_detection_db_mobile", enable_mk)
text_detector = hub.Module(name="chinese_text_detection_db_mobile", enable_mkldnn=True)
result = text_detector.detect_text(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
......
......@@ -29,7 +29,7 @@ def base64_to_cv2(b64str):
@moduleinfo(
name="chinese_text_detection_db_mobile",
version="1.0.2",
version="1.0.3",
summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev",
......@@ -73,7 +73,10 @@ class ChineseTextDetectionDB(hub.Module):
config.enable_use_gpu(8000, 0)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
......@@ -102,19 +105,18 @@ class ChineseTextDetectionDB(hub.Module):
images.append(img)
return images
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
left = int(np.min(box[:, 0]))
right = int(np.max(box[:, 0]))
top = int(np.min(box[:, 1]))
bottom = int(np.max(box[:, 1]))
bbox_height = bottom - top
bbox_width = right - left
diffh = math.fabs(box[0, 1] - box[1, 1])
diffw = math.fabs(box[0, 0] - box[3, 0])
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 10 or rect_height <= 10:
......@@ -168,7 +170,7 @@ class ChineseTextDetectionDB(hub.Module):
"""
self.check_requirements()
from chinese_text_detection_db_mobile.processor import DBPreProcess, DBPostProcess, draw_boxes, get_image_ext
from chinese_text_detection_db_mobile.processor import DBProcessTest, DBPostProcess, draw_boxes, get_image_ext
if use_gpu:
try:
......@@ -188,13 +190,20 @@ class ChineseTextDetectionDB(hub.Module):
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
preprocessor = DBPreProcess()
postprocessor = DBPostProcess(box_thresh)
preprocessor = DBProcessTest(params={'max_side_len': 960})
postprocessor = DBPostProcess(
params={
'thresh': 0.3,
'box_thresh': 0.5,
'max_candidates': 1000,
'unclip_ratio': 2.0
})
all_imgs = []
all_ratios = []
all_results = []
for original_image in predicted_data:
ori_im = original_image.copy()
im, ratio_list = preprocessor(original_image)
res = {'save_path': ''}
if im is None:
......@@ -202,11 +211,20 @@ class ChineseTextDetectionDB(hub.Module):
else:
im = im.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(im)
self.predictor.zero_copy_run()
data_out = self.output_tensors[0].copy_to_cpu()
dt_boxes_list = postprocessor(data_out, [ratio_list])
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
outs_dict = {}
outs_dict['maps'] = outputs[0]
# data_out = self.output_tensors[0].copy_to_cpu()
dt_boxes_list = postprocessor(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0]
boxes = self.filter_tag_det_res(dt_boxes_list[0],
original_image.shape)
res['data'] = boxes.astype(np.int).tolist()
......@@ -328,7 +346,7 @@ class ChineseTextDetectionDB(hub.Module):
if __name__ == '__main__':
db = ChineseTextDetectionDB()
image_path = [
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/2.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/test_image.jpg'
]
......
......@@ -12,25 +12,43 @@ import numpy as np
import pyclipper
class DBPreProcess(object):
def __init__(self, max_side_len=960):
self.max_side_len = max_side_len
class DBProcessTest(object):
"""
DB pre-process for Test mode
"""
def __init__(self, params):
super(DBProcessTest, self).__init__()
self.resize_type = 0
if 'test_image_shape' in params:
self.image_shape = params['test_image_shape']
# print(self.image_shape)
self.resize_type = 1
if 'max_side_len' in params:
self.max_side_len = params['max_side_len']
else:
self.max_side_len = 2400
def resize_image_type(self, im):
def resize_image_type0(self, im):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
max_side_len = self.max_side_len
h, w, _ = im.shape
resize_w = w
resize_h = h
# limit the max side
if max(resize_h, resize_w) > self.max_side_len:
if max(resize_h, resize_w) > max_side_len:
if resize_h > resize_w:
ratio = float(self.max_side_len) / resize_h
ratio = float(max_side_len) / resize_h
else:
ratio = float(self.max_side_len) / resize_w
ratio = float(max_side_len) / resize_w
else:
ratio = 1.
resize_h = int(resize_h * ratio)
......@@ -58,19 +76,34 @@ class DBPreProcess(object):
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_type1(self, im):
resize_h, resize_w = self.image_shape
ori_h, ori_w = im.shape[:2] # (h, w, c)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
return im, (ratio_h, ratio_w)
def normalize(self, im):
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
im = im.astype(np.float32, copy=False)
im = im / 255
im -= img_mean
im /= img_std
im[:, :, 0] -= img_mean[0]
im[:, :, 1] -= img_mean[1]
im[:, :, 2] -= img_mean[2]
im[:, :, 0] /= img_std[0]
im[:, :, 1] /= img_std[1]
im[:, :, 2] /= img_std[2]
channel_swap = (2, 0, 1)
im = im.transpose(channel_swap)
return im
def __call__(self, im):
im, (ratio_h, ratio_w) = self.resize_image_type(im)
if self.resize_type == 0:
im, (ratio_h, ratio_w) = self.resize_image_type0(im)
else:
im, (ratio_h, ratio_w) = self.resize_image_type1(im)
im = self.normalize(im)
im = im[np.newaxis, :]
return [im, (ratio_h, ratio_w)]
......@@ -81,10 +114,11 @@ class DBPostProcess(object):
The post process for Differentiable Binarization (DB).
"""
def __init__(self, thresh=0.3, box_thresh=0.5, max_candidates=1000):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
def __init__(self, params):
self.thresh = params['thresh']
self.box_thresh = params['box_thresh']
self.max_candidates = params['max_candidates']
self.unclip_ratio = params['unclip_ratio']
self.min_size = 3
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
......@@ -134,7 +168,8 @@ class DBPostProcess(object):
scores[index] = score
return boxes, scores
def unclip(self, box, unclip_ratio=2.0):
def unclip(self, box):
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
......@@ -179,8 +214,10 @@ class DBPostProcess(object):
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, predictions, ratio_list):
pred = predictions[:, 0, :, :]
def __call__(self, outs_dict, ratio_list):
pred = outs_dict['maps']
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
......
......@@ -125,3 +125,7 @@ pyclipper
* 1.0.0
初始发布
* 1.0.2
支持mkldnn加速CPU计算
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册