提交 8733b5a1 编写于 作者: W wuzewu

Add module v1

上级 ce342854
......@@ -15,4 +15,6 @@
__version__ = '2.0.0a0'
from .module import Module
from paddlehub.module import Module
from paddlehub.compat.module.processor import BaseProcessor
此差异已折叠。
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
from typing import Tuple, List
import paddle
from paddlehub.compat import paddle_utils
from paddlehub.compat.module import module_v1_utils
from paddlehub.utils import utils, log
class ModuleV1(object):
'''
'''
def __init__(self, name: str = None, directory: str = None, version: str = None):
if not directory:
return
self.directory = directory
desc_file = os.path.join(directory, 'module_desc.pb')
self.desc = module_v1_utils.convert_module_desc(desc_file)
self._load_model()
self._load_parameters()
self._load_processor()
self._load_assets()
self._load_extra_info()
self._load_signatures()
def _load_processor(self):
python_path = os.path.join(self.directory, 'python')
processor_name = self.desc.processor_info
self.processor = utils.load_py_module(python_path, processor_name)
def _load_assets(self):
assets_path = os.path.join(self.directory, 'assets')
self.assets = []
for file in os.listdir(assets_path):
filepath = os.path.join(assets_path, file)
self.assets.append(filepath)
def _load_parameters(self):
global_block = self.program.global_block()
for param, attrs in self.desc.param_attrs.items():
name = self.desc.name_prefix + param
if not name in global_block.vars:
continue
var = global_block.vars[name]
global_block.create_parameter(name=name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
error_clip=var.error_clip,
stop_gradient=var.stop_gradient,
is_data=var.is_data,
**attrs)
def _load_extra_info(self):
for key, value in self.desc.extra_info.items():
self.__dict__['get_{}'.format(key)] = value
def _load_signatures(self):
for signature in self.desc.signatures:
self.__dict__[signature] = functools.partial(self.__call__, signature=signature)
def _load_model(self):
model_path = os.path.join(self.directory, 'model')
exe = paddle.static.Executor(paddle.CPUPlace())
self.program, _, _ = paddle.io.load_inference_model(model_path, executor=exe)
# Clear the callstack since it may leak the privacy of the creator.
for block in self.program.blocks:
for op in block.ops:
if not 'op_callstack' in op.all_attrs():
continue
op._set_attr('op_callstack', [''])
def context(self, for_test: bool = False, trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]:
'''
'''
program = self.program.clone(for_test=for_test)
paddle_utils.remove_feed_fetch_op(program)
# generate feed vars and fetch vars from signatures
feed_dict = {}
fetch_dict = {}
for info in self.desc.signatures.values():
for feed_var in info.feed_vars:
paddle_var = program.global_block().vars[feed_var.name]
feed_dict[feed_var.alias] = paddle_var
for fetch_var in info.fetch_vars:
paddle_var = program.global_block().vars[fetch_var.name]
fetch_dict[fetch_var.alias] = paddle_var
# record num parameters loaded by PaddleHub
num_param_loaded = 0
for param in program.all_parameters():
num_param_loaded += 1
param.trainable = trainable
log.logger.info('{} pretrained paramaters loaded by PaddleHub'.format(num_param_loaded))
return feed_dict, fetch_dict, program
def __call__(self, signature, data, use_gpu: bool = False, batch_size: int = 1, **kwargs):
'''
'''
...
@classmethod
def get_py_requirements(cls) -> List[str]:
return []
@classmethod
def load(cls, desc_file):
desc = module_v1_utils.convert_module_desc(desc_file)
cls.author = desc.module_info.author
cls.author_email = desc.module_info.author_email
cls.summary = desc.module_info.summary
cls.type = desc.module_info.type
cls.name = desc.module_info.name
cls.version = utils.Version(desc.module_info.version)
return cls
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from easydict import EasyDict
from paddlehub.compat.module import module_desc_pb2
def convert_module_desc(desc_file):
desc = module_desc_pb2.ModuleDesc()
with open(desc_file, 'rb') as file:
desc.ParseFromString(file.read())
result = convert_attr(desc.attr)
result.signatures = convert_signatures(desc.sign2var)
return result
def convert_signatures(signmaps):
_dict = EasyDict()
for sign, var in signmaps.items():
_dict[sign] = EasyDict()
for fetch_var in var.fetch_desc:
_dict[sign].fetch_vars = list()
_dict[sign].fetch_vars.append(EasyDict(name=fetch_var.var_name, alias=fetch_var.alias))
for feed_var in var.feed_desc:
_dict[sign].feed_vars = list()
_dict[sign].feed_vars.append(EasyDict(name=feed_var.var_name, alias=feed_var.alias))
return _dict
def convert_attr(module_attr):
if module_attr.type == 1:
return module_attr.i
elif module_attr.type == 2:
return module_attr.f
elif module_attr.type == 3:
return module_attr.s
elif module_attr.type == 4:
return module_attr.b
_dict = EasyDict()
for key, val in module_attr.map.data.items():
_dict[key] = convert_attr(val)
return _dict
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
class BaseProcessor(object):
'''
'''
def __init__(self, module):
...
def configs(self) -> List:
return []
def preprocess(self, signature: str, data: dict):
'''
'''
raise NotImplementedError('BaseProcessor\' preprocess should not be called!')
def postprocess(self, signature: str, data_out: dict, data_info: dict, **kwargs):
'''
'''
raise NotImplementedError('BaseProcessor\' postprocess should not be called!')
def data_format(self, signature: str):
'''
'''
raise NotImplementedError('BaseProcessor\' data_format should not be called!')
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def remove_feed_fetch_op(program: paddle.static.Program):
'''Remove feed and fetch operator and variable for fine-tuning.'''
block = program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(block.ops):
if op.type == 'feed' or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
block._remove_op(index)
need_to_remove_var = []
for var in block.vars:
if var.endswith("feed"):
need_to_remove_var.append(var)
if var.endswith('fetch'):
need_to_remove_var.append(var)
for var in need_to_remove_var:
block._remove_var(var)
program.desc.flush()
......@@ -203,7 +203,7 @@ class LocalModuleManager(object):
shutil.copytree(directory, os.path.join(self.home, hub_module_cls.name))
self._local_modules[hub_module_cls.name] = hub_module_cls
for py_req in hub_module_cls.get_py_requirments():
for py_req in hub_module_cls.get_py_requirements():
log.logger.info('Installing dependent packages: {}'.format(py_req))
result = pypi.install(py_req)
if result:
......@@ -221,5 +221,5 @@ class LocalModuleManager(object):
for path, ds, ts in xarfile.unarchive_with_progress(archive, _tdir):
bar.update(float(ds) / ts)
path = os.path.split(path)[0]
path = path.split(os.sep)[0]
return self._install_from_directory(os.path.join(_tdir, path))
......@@ -20,6 +20,7 @@ import sys
from typing import Callable, List, Optional, Generic
from paddlehub.utils import utils
from paddlehub.compat.module.module_v1 import ModuleV1
class InvalidHubModule(Exception):
......@@ -55,6 +56,8 @@ def serving(func: Callable) -> Callable:
class Module(object):
'''
'''
def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
if cls.__name__ == 'Module':
# This branch come from hub.Module(name='xxx')
......@@ -65,7 +68,6 @@ class Module(object):
else:
module = object.__new__(cls)
module.directory = directory
return module
@classmethod
......@@ -73,6 +75,11 @@ class Module(object):
if directory.endswith(os.sep):
directory = directory[:-1]
# if module description file existed, try to load as ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file):
return ModuleV1.load(desc_file)
basename = os.path.split(directory)[-1]
dirname = os.path.join(*list(os.path.split(directory)[:-1]))
......@@ -99,7 +106,8 @@ class Module(object):
if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name, version)
return user_module_cls(**kwargs)
directory = manager._get_normalized_path(name)
return user_module_cls(directory=directory, **kwargs)
@classmethod
def init_with_directory(cls, directory: str, **kwargs):
......@@ -107,7 +115,7 @@ class Module(object):
return user_module_cls(**kwargs)
@classmethod
def get_py_requirments(cls):
def get_py_requirements(cls):
req_file = os.path.join(cls.directory, 'requirements.txt')
if not os.path.exists(req_file):
return []
......
......@@ -16,12 +16,14 @@
import base64
import contextlib
import cv2
import importlib
import math
import os
import requests
import sys
import time
import tempfile
import types
import numpy as np
from typing import Generator
from urllib.parse import urlparse
......@@ -33,7 +35,6 @@ import paddlehub.env as hubenv
class Version(packaging.version.Version):
'''Extended implementation of packaging.version.Version'''
def match(self, condition: str) -> bool:
'''
Determine whether the given condition are met
......@@ -78,7 +79,6 @@ class Version(packaging.version.Version):
class Timer(object):
'''Calculate runing speed and estimated time of arrival(ETA)'''
def __init__(self, total_step: int):
self.total_step = total_step
self.last_start_step = 0
......@@ -202,3 +202,18 @@ def download_with_progress(url: str, path: str = None) -> Generator[str, int, in
_file.write(data)
download_size += len(data)
yield savename, download_size, total_size
def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
'''
Load the specified python module.
Args:
python_path(str) : The directory where the python module is located
py_module_name(str) : Module name to be loaded
'''
sys.path.insert(0, python_path)
py_module = importlib.import_module(py_module_name)
sys.path.pop(0)
return py_module
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册