提交 821f47d0 编写于 作者: J jiangjiajun

modify thread_num for preprocess

上级 d4b71967
......@@ -45,7 +45,7 @@ predict(image, topk=1)
### batch_predict 接口
```
batch_predict(image_list, topk=1, thread_num=2)
batch_predict(image_list, topk=1)
```
批量图片预测接口。
......@@ -53,4 +53,3 @@ batch_predict(image_list, topk=1, thread_num=2)
>
> > * **image_list** (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径或numpy数组(HWC排列,BGR格式)。
> > * **topk** (int): 图像分类时使用的参数,表示预测前topk个可能的分类。
> > * **thread_num** (int): 并发执行各图像预处理时的线程数。
......@@ -62,7 +62,7 @@ evaluate(self, eval_dataset, batch_size=1, epoch_id=None, return_details=False)
### predict
```python
predict(self, img_file, transforms=None, topk=5)
predict(self, img_file, transforms=None, topk=1)
```
> 分类模型预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`ResNet50.test_transforms`和`ResNet50.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`predict`接口时,用户需要再重新定义test_transforms传入给`predict`接口。
......@@ -81,7 +81,7 @@ predict(self, img_file, transforms=None, topk=5)
### batch_predict
```python
batch_predict(self, img_file_list, transforms=None, topk=5, thread_num=2)
batch_predict(self, img_file_list, transforms=None, topk=1)
```
> 分类模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`ResNet50.test_transforms`和`ResNet50.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义test_transforms传入给`batch_predict`接口。
......@@ -91,7 +91,6 @@ batch_predict(self, img_file_list, transforms=None, topk=5, thread_num=2)
> > - **img_file_list** (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.cls.transforms): 数据预处理操作。
> > - **topk** (int): 预测时前k个最大值。
> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
> **返回值**
>
......
......@@ -108,7 +108,7 @@ predict(self, img_file, transforms=None)
### batch_predict
```python
batch_predict(self, img_file_list, transforms=None, thread_num=2)
batch_predict(self, img_file_list, transforms=None)
```
> PPYOLO模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`YOLOv3.test_transforms`和`YOLOv3.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义`test_transforms`传入给`batch_predict`接口
......@@ -117,7 +117,6 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2)
>
> > - **img_file_list** (str|np.ndarray): 对列表(或元组)中的图像同时进行预测,列表中的元素是预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
>
> **返回值**
>
......@@ -222,7 +221,7 @@ predict(self, img_file, transforms=None)
### batch_predict
```python
batch_predict(self, img_file_list, transforms=None, thread_num=2)
batch_predict(self, img_file_list, transforms=None)
```
> YOLOv3模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`YOLOv3.test_transforms`和`YOLOv3.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义`test_transforms`传入给`batch_predict`接口
......@@ -231,7 +230,6 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2)
>
> > - **img_file_list** (str|np.ndarray): 对列表(或元组)中的图像同时进行预测,列表中的元素是预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
>
> **返回值**
>
......@@ -327,7 +325,7 @@ predict(self, img_file, transforms=None)
### batch_predict
```python
batch_predict(self, img_file_list, transforms=None, thread_num=2)
batch_predict(self, img_file_list, transforms=None)
```
> FasterRCNN模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`FasterRCNN.test_transforms`和`FasterRCNN.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义test_transforms传入给`batch_predict`接口。
......@@ -336,7 +334,6 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2)
>
> > - **img_file_list** (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素是预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
>
> **返回值**
>
......
......@@ -88,7 +88,7 @@ predict(self, img_file, transforms=None)
#### batch_predict
```python
batch_predict(self, img_file_list, transforms=None, thread_num=2)
batch_predict(self, img_file_list, transforms=None)
```
> MaskRCNN模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`FasterRCNN.test_transforms`和`FasterRCNN.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义test_transforms传入给`batch_predict`接口。
......@@ -97,7 +97,6 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2)
>
> > - **img_file_list** (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.det.transforms): 数据预处理操作。
> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
>
> **返回值**
>
......
......@@ -95,7 +95,7 @@ predict(self, img_file, transforms=None):
### batch_predict
```
batch_predict(self, img_file_list, transforms=None, thread_num=2):
batch_predict(self, img_file_list, transforms=None):
```
> DeepLabv3p模型批量预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`DeepLabv3p.test_transforms`和`DeepLabv3p.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`batch_predict`接口时,用户需要再重新定义test_transforms传入给`batch_predict`接口。
......@@ -104,7 +104,6 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2):
> >
> > - **img_file_list** (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **transforms** (paddlex.seg.transforms): 数据预处理操作。
> > - **thread_num** (int): 并发执行各图像预处理时的线程数。
> **返回值**
> >
......
......@@ -70,7 +70,6 @@ cd PaddleX/examples/meter_reader/
| save_dir | 保存可视化结果的路径, 默认值为"output"|
| score_threshold | 检测模型输出结果中,预测得分低于该阈值的框将被滤除,默认值为0.5|
| seg_batch_size | 分割的批量大小,默认为2 |
| seg_thread_num | 分割预测的线程数,默认为cpu处理器个数 |
| use_camera | 是否使用摄像头采集图片,默认为False |
| camera_id | 摄像头设备ID,默认值为0 |
| use_erode | 是否使用图像腐蚀对分割预测图进行细分,默认为False |
......
......@@ -79,7 +79,6 @@ cd PaddleX/examples/meter_reader/
| save_dir | 保存可视化结果的路径, 默认值为"output"|
| score_threshold | 检测模型输出结果中,预测得分低于该阈值的框将被滤除,默认值为0.5|
| seg_batch_size | 分割的批量大小,默认为2 |
| seg_thread_num | 分割预测的线程数,默认为cpu处理器个数 |
| use_camera | 是否使用摄像头采集图片,默认为False |
| camera_id | 摄像头设备ID,默认值为0 |
| use_erode | 是否使用图像腐蚀对分割预测图进行细分,默认为False |
......
......@@ -105,12 +105,6 @@ def parse_args():
help="Segmentation batch size",
type=int,
default=2)
parser.add_argument(
'--seg_thread_num',
dest='seg_thread_num',
help="Thread number of segmentation preprocess",
type=int,
default=2)
return parser.parse_args()
......@@ -143,8 +137,7 @@ class MeterReader:
use_erode=True,
erode_kernel=4,
score_threshold=0.5,
seg_batch_size=2,
seg_thread_num=2):
seg_batch_size=2):
if isinstance(im_file, str):
im = cv2.imread(im_file).astype('float32')
else:
......@@ -190,8 +183,7 @@ class MeterReader:
meter_images.append(resized_meters[j - i])
result = self.segmenter.batch_predict(
transforms=self.seg_transforms,
img_file_list=meter_images,
thread_num=seg_thread_num)
img_file_list=meter_images)
if use_erode:
kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
for i in range(len(result)):
......@@ -334,7 +326,7 @@ def infer(args):
for im_file in image_lists:
meter_reader.predict(im_file, args.save_dir, args.use_erode,
args.erode_kernel, args.score_threshold,
args.seg_batch_size, args.seg_thread_num)
args.seg_batch_size)
elif args.use_camera:
cap_video = cv2.VideoCapture(args.camera_id)
if not cap_video.isOpened():
......@@ -347,7 +339,7 @@ def infer(args):
if ret:
meter_reader.predict(frame, args.save_dir, args.use_erode,
args.erode_kernel, args.score_threshold,
args.seg_batch_size, args.seg_thread_num)
args.seg_batch_size)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
......
......@@ -105,12 +105,6 @@ def parse_args():
help="Segmentation batch size",
type=int,
default=2)
parser.add_argument(
'--seg_thread_num',
dest='seg_thread_num',
help="Thread number of segmentation preprocess",
type=int,
default=2)
return parser.parse_args()
......@@ -143,8 +137,7 @@ class MeterReader:
use_erode=True,
erode_kernel=4,
score_threshold=0.5,
seg_batch_size=2,
seg_thread_num=2):
seg_batch_size=2):
if isinstance(im_file, str):
im = cv2.imread(im_file).astype('float32')
else:
......@@ -190,8 +183,7 @@ class MeterReader:
meter_images.append(resized_meters[j - i])
result = self.segmenter.batch_predict(
transforms=self.seg_transforms,
img_file_list=meter_images,
thread_num=seg_thread_num)
img_file_list=meter_images)
if use_erode:
kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
for i in range(len(result)):
......@@ -334,7 +326,7 @@ def infer(args):
for im_file in image_lists:
meter_reader.predict(im_file, args.save_dir, args.use_erode,
args.erode_kernel, args.score_threshold,
args.seg_batch_size, args.seg_thread_num)
args.seg_batch_size)
elif args.use_camera:
cap_video = cv2.VideoCapture(args.camera_id)
if not cap_video.isOpened():
......@@ -347,7 +339,7 @@ def infer(args):
if ret:
meter_reader.predict(frame, args.save_dir, args.use_erode,
args.erode_kernel, args.score_threshold,
args.seg_batch_size, args.seg_thread_num)
args.seg_batch_size)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
......
......@@ -56,4 +56,4 @@ log_level = 2
from . import interpret
__version__ = '1.1.1'
__version__ = '1.1.4'
......@@ -23,6 +23,7 @@ import yaml
import copy
import json
import functools
import multiprocessing as mp
import paddlex.utils.logging as logging
from paddlex.utils import seconds_to_hms
from paddlex.utils.utils import EarlyStop
......@@ -76,6 +77,16 @@ class BaseAPI:
self.completed_epochs = 0
self.scope = fluid.global_scope()
# 线程池,在模型在预测时用于对输入数据以图片为单位进行并行处理
# 主要用于batch_predict接口
thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
self.thread_pool = mp.pool.ThreadPool(thread_num)
def reset_thread_pool(self, thread_num):
self.thread_pool.close()
self.thread_pool.join()
self.thread_pool = mp.pool.ThreadPool(thread_num)
def _get_single_card_bs(self, batch_size):
if batch_size % len(self.places) == 0:
return int(batch_size // len(self.places))
......@@ -356,23 +367,13 @@ class BaseAPI:
]
test_outputs = list(self.test_outputs.values())
with fluid.scope_guard(self.scope):
if self.__class__.__name__ == 'MaskRCNN':
from paddlex.utils.save import save_mask_inference_model
save_mask_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=self.test_prog)
else:
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=self.test_prog)
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=self.test_prog)
model_info = self.get_model_info()
model_info['status'] = 'Infer'
......
......@@ -279,16 +279,18 @@ class BaseClassifier(BaseAPI):
return metrics
@staticmethod
def _preprocess(images, transforms, model_type, class_name, thread_num=1):
def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
arrange_transforms(
model_type=model_type,
class_name=class_name,
transforms=transforms,
mode='test')
pool = ThreadPool(thread_num)
batch_data = pool.map(transforms, images)
pool.close()
pool.join()
if thread_pool is not None:
batch_data = thread_pool.map(transforms, images)
else:
batch_data = list()
for image in images:
batch_data.append(transforms(image))
padding_batch = generate_minibatch(batch_data)
im = np.array([data[0] for data in padding_batch])
......@@ -344,15 +346,13 @@ class BaseClassifier(BaseAPI):
def batch_predict(self,
img_file_list,
transforms=None,
topk=1,
thread_num=2):
topk=1):
"""预测。
Args:
img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
transforms (paddlex.cls.transforms): 数据预处理操作。
topk (int): 预测时前k个最大值。
thread_num (int): 并发执行各图像预处理时的线程数。
Returns:
list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测列表中,其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
分别对应预测类别id、预测类别标签、预测得分。
......@@ -367,7 +367,7 @@ class BaseClassifier(BaseAPI):
transforms = self.test_transforms
im = BaseClassifier._preprocess(img_file_list, transforms,
self.model_type,
self.__class__.__name__, thread_num)
self.__class__.__name__, self.thread_pool)
with fluid.scope_guard(self.scope):
result = self.exe.run(self.test_prog,
......
......@@ -443,16 +443,18 @@ class DeepLabv3p(BaseAPI):
return metrics
@staticmethod
def _preprocess(images, transforms, model_type, class_name, thread_num=1):
def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
arrange_transforms(
model_type=model_type,
class_name=class_name,
transforms=transforms,
mode='test')
pool = ThreadPool(thread_num)
batch_data = pool.map(transforms, images)
pool.close()
pool.join()
if thread_pool is not None:
batch_data = thread_pool.map(transforms, images)
else:
batch_data = list()
for image in images:
batch_data.append(transforms(image))
padding_batch = generate_minibatch(batch_data)
im = np.array(
[data[0] for data in padding_batch],
......@@ -517,13 +519,12 @@ class DeepLabv3p(BaseAPI):
preds = DeepLabv3p._postprocess(result, im_info)
return preds[0]
def batch_predict(self, img_file_list, transforms=None, thread_num=2):
def batch_predict(self, img_file_list, transforms=None):
"""预测。
Args:
img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
transforms(paddlex.cv.transforms): 数据预处理操作。
thread_num (int): 并发执行各图像预处理时的线程数。
Returns:
list: 每个元素都为列表,表示各图像的预测结果。各图像的预测结果用字典表示,包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
......@@ -538,7 +539,7 @@ class DeepLabv3p(BaseAPI):
transforms = self.test_transforms
im, im_info = DeepLabv3p._preprocess(
img_file_list, transforms, self.model_type,
self.__class__.__name__, thread_num)
self.__class__.__name__, self.thread_pool)
with fluid.scope_guard(self.scope):
result = self.exe.run(self.test_prog,
......
......@@ -376,16 +376,18 @@ class FasterRCNN(BaseAPI):
return metrics
@staticmethod
def _preprocess(images, transforms, model_type, class_name, thread_num=1):
def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
arrange_transforms(
model_type=model_type,
class_name=class_name,
transforms=transforms,
mode='test')
pool = ThreadPool(thread_num)
batch_data = pool.map(transforms, images)
pool.close()
pool.join()
if thread_pool is not None:
batch_data = thread_pool.map(transforms, images)
else:
batch_data = list()
for image in images:
batch_data.append(transforms(image))
padding_batch = generate_minibatch(batch_data)
im = np.array([data[0] for data in padding_batch])
im_resize_info = np.array([data[1] for data in padding_batch])
......@@ -453,14 +455,13 @@ class FasterRCNN(BaseAPI):
return preds[0]
def batch_predict(self, img_file_list, transforms=None, thread_num=2):
def batch_predict(self, img_file_list, transforms=None):
"""预测。
Args:
img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
transforms (paddlex.det.transforms): 数据预处理操作。
thread_num (int): 并发执行各图像预处理时的线程数。
Returns:
list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、
......@@ -477,7 +478,7 @@ class FasterRCNN(BaseAPI):
transforms = self.test_transforms
im, im_resize_info, im_shape = FasterRCNN._preprocess(
img_file_list, transforms, self.model_type,
self.__class__.__name__, thread_num)
self.__class__.__name__, self.thread_pool)
with fluid.scope_guard(self.scope):
result = self.exe.run(self.test_prog,
......
......@@ -408,14 +408,13 @@ class MaskRCNN(FasterRCNN):
return preds[0]
def batch_predict(self, img_file_list, transforms=None, thread_num=2):
def batch_predict(self, img_file_list, transforms=None):
"""预测。
Args:
img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
transforms (paddlex.det.transforms): 数据预处理操作。
thread_num (int): 并发执行各图像预处理时的线程数。
Returns:
dict: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、预测框类别名称、
预测框坐标(坐标格式为[xmin, ymin, w, h])、
......@@ -432,7 +431,7 @@ class MaskRCNN(FasterRCNN):
transforms = self.test_transforms
im, im_resize_info, im_shape = FasterRCNN._preprocess(
img_file_list, transforms, self.model_type,
self.__class__.__name__, thread_num)
self.__class__.__name__, self.thread_pool)
with fluid.scope_guard(self.scope):
result = self.exe.run(self.test_prog,
......
......@@ -447,16 +447,18 @@ class PPYOLO(BaseAPI):
return evaluate_metrics
@staticmethod
def _preprocess(images, transforms, model_type, class_name, thread_num=1):
def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
arrange_transforms(
model_type=model_type,
class_name=class_name,
transforms=transforms,
mode='test')
pool = ThreadPool(thread_num)
batch_data = pool.map(transforms, images)
pool.close()
pool.join()
if thread_pool is not None:
batch_data = thread_pool.map(transforms, images)
else:
batch_data = list()
for image in images:
batch_data.append(transforms(image))
padding_batch = generate_minibatch(batch_data)
im = np.array(
[data[0] for data in padding_batch],
......@@ -520,14 +522,13 @@ class PPYOLO(BaseAPI):
len(images), self.num_classes, self.labels)
return preds[0]
def batch_predict(self, img_file_list, transforms=None, thread_num=2):
def batch_predict(self, img_file_list, transforms=None):
"""预测。
Args:
img_file_list (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
且类型为float32且为BGR格式的数组。
transforms (paddlex.det.transforms): 数据预处理操作。
thread_num (int): 并发执行各图像预处理时的线程数。
Returns:
list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、
预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
......@@ -543,7 +544,7 @@ class PPYOLO(BaseAPI):
transforms = self.test_transforms
im, im_size = PPYOLO._preprocess(img_file_list, transforms,
self.model_type,
self.__class__.__name__, thread_num)
self.__class__.__name__, self.thread_pool)
with fluid.scope_guard(self.scope):
result = self.exe.run(self.test_prog,
......
......@@ -16,6 +16,7 @@ import os.path as osp
import cv2
import numpy as np
import yaml
import multiprocessing as mp
import paddlex
import paddle.fluid as fluid
from paddlex.cv.transforms import build_transforms
......@@ -79,6 +80,15 @@ class Predictor:
self.predictor = self.create_predictor(use_gpu, gpu_id, use_mkl,
mkl_thread_num, use_trt,
use_glog, memory_optimize)
# 线程池,在模型在预测时用于对输入数据以图片为单位进行并行处理
# 主要用于batch_predict接口
thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
self.thread_pool = mp.pool.ThreadPool(thread_num)
def reset_thread_pool(self, thread_num):
self.thread_pool.close()
self.thread_pool.join()
self.thread_pool = mp.pool.ThreadPool(thread_num)
def create_predictor(self,
use_gpu=True,
......@@ -114,7 +124,7 @@ class Predictor:
predictor = fluid.core.create_paddle_predictor(config)
return predictor
def preprocess(self, image, thread_num=1):
def preprocess(self, image, thread_pool=None):
""" 对图像做预处理
Args:
......@@ -128,7 +138,7 @@ class Predictor:
self.transforms,
self.model_type,
self.model_name,
thread_num=thread_num)
thread_pool=thread_pool)
res['image'] = im
elif self.model_type == "detector":
if self.model_name in ["PPYOLO", "YOLOv3"]:
......@@ -137,7 +147,7 @@ class Predictor:
self.transforms,
self.model_type,
self.model_name,
thread_num=thread_num)
thread_pool=thread_pool)
res['image'] = im
res['im_size'] = im_size
if self.model_name.count('RCNN') > 0:
......@@ -146,7 +156,7 @@ class Predictor:
self.transforms,
self.model_type,
self.model_name,
thread_num=thread_num)
thread_pool=thread_pool)
res['image'] = im
res['im_info'] = im_resize_info
res['im_shape'] = im_shape
......@@ -156,7 +166,7 @@ class Predictor:
self.transforms,
self.model_type,
self.model_name,
thread_num=thread_num)
thread_pool=thread_pool)
res['image'] = im
res['im_info'] = im_info
return res
......@@ -253,17 +263,16 @@ class Predictor:
return results[0]
def batch_predict(self, image_list, topk=1, thread_num=2):
def batch_predict(self, image_list, topk=1):
""" 图片预测
Args:
image_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
thread_num (int): 并发执行各图像预处理时的线程数。
topk(int): 分类预测时使用,表示预测前topk的结果
"""
preprocessed_input = self.preprocess(image_list)
preprocessed_input = self.preprocess(image_list, self.thread_pool)
model_pred = self.raw_predict(preprocessed_input)
im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
'im_shape']
......
......@@ -19,7 +19,7 @@ long_description = "PaddlePaddle Entire Process Development Toolkit"
setuptools.setup(
name="paddlex",
version='1.1.1',
version='1.1.4',
author="paddlex",
author_email="paddlex@baidu.com",
description=long_description,
......@@ -30,7 +30,7 @@ setuptools.setup(
setup_requires=['cython', 'numpy'],
install_requires=[
"pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm',
'paddleslim==1.0.1', 'visualdl>=2.0.0b', 'paddlehub>=1.6.2',
'paddleslim==1.0.1', 'visualdl>=2.0.0b', 'paddlehub>=1.8.2',
'shapely>=1.7.0', "opencv-python"
],
classifiers=[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册