提交 6acb2dd4 编写于 作者: W wuzewu

Fix module compat bug

上级 b33e4d14
......@@ -13,9 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
__version__ = '2.0.0a0'
from paddlehub.utils import log, parser, utils
from paddlehub.module import Module
# In order to maintain the compatibility of the old version, we put the relevant
# compatible code in the paddlehub/compat package, and mapped some modules referenced
# in the old version
from paddlehub.compat import paddle_utils
from paddlehub.compat.module.processor import BaseProcessor
from paddlehub.compat.module.nlp_module import NLPPredictionModule, TransformerModule
from paddlehub.compat.type import DataType
sys.modules['paddlehub.io.parser'] = parser
sys.modules['paddlehub.common.logger'] = log
sys.modules['paddlehub.common.paddle_helper'] = paddle_utils
sys.modules['paddlehub.common.utils'] = utils
// Copyright 2018 The Paddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax = "proto3";
option optimize_for = LITE_RUNTIME;
package paddlehub.module.desc;
enum DataType {
NONE = 0;
INT = 1;
FLOAT = 2;
STRING = 3;
BOOLEAN = 4;
LIST = 5;
MAP = 6;
SET = 7;
OBJECT = 8;
}
message KVData {
map<string, DataType> key_type = 1;
map<string, ModuleAttr> data = 2;
}
message ModuleAttr {
// Basic type
DataType type = 1;
int64 i = 2;
double f = 3;
bool b = 4;
string s = 5;
KVData map = 6;
KVData list = 7;
KVData set = 8;
KVData object = 9;
//
string name = 10;
string info = 11;
}
// Feed Variable Description
message FeedDesc {
string var_name = 1;
string alias = 2;
};
// Fetch Variable Description
message FetchDesc {
string var_name = 1;
string alias = 2;
};
// Module Variable
message ModuleVar {
repeated FetchDesc fetch_desc = 1;
repeated FeedDesc feed_desc = 2;
}
// A Hub Module is stored in a directory with a file 'module_desc.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
message ModuleDesc {
// signature to module variable
map<string, ModuleVar> sign2var = 2;
ModuleAttr attr = 3;
};
......@@ -47,6 +47,10 @@ class ModuleV1(object):
self._generate_func()
def _load_processor(self):
# Some module does not have a processor(e.g. ernie)
if not 'processor_info' in self.desc:
return
python_path = os.path.join(self.directory, 'python')
processor_name = self.desc.processor_info
self.processor = utils.load_py_module(python_path, processor_name)
......
......@@ -13,8 +13,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Callable, List
import paddle
from paddlehub.utils.utils import Version
dtype_map = {
paddle.device.core.VarDesc.VarType.FP32: "float32",
paddle.device.core.VarDesc.VarType.FP64: "float64",
paddle.device.core.VarDesc.VarType.FP16: "float16",
paddle.device.core.VarDesc.VarType.INT32: "int32",
paddle.device.core.VarDesc.VarType.INT16: "int16",
paddle.device.core.VarDesc.VarType.INT64: "int64",
paddle.device.core.VarDesc.VarType.BOOL: "bool",
paddle.device.core.VarDesc.VarType.INT16: "int16",
paddle.device.core.VarDesc.VarType.UINT8: "uint8",
paddle.device.core.VarDesc.VarType.INT8: "int8",
}
def convert_dtype_to_string(dtype: str) -> paddle.device.core.VarDesc.VarType:
if dtype in dtype_map:
return dtype_map[dtype]
raise TypeError("dtype shoule in %s" % list(dtype_map.keys()))
def get_variable_info(var: paddle.Variable) -> dict:
if not isinstance(var, paddle.Variable):
raise TypeError("var shoule be an instance of paddle.Variable")
var_info = {
'name': var.name,
'stop_gradient': var.stop_gradient,
'is_data': var.is_data,
'error_clip': var.error_clip,
'type': var.type
}
try:
var_info['dtype'] = convert_dtype_to_string(var.dtype)
var_info['lod_level'] = var.lod_level
var_info['shape'] = var.shape
except:
pass
if isinstance(var, paddle.device.framework.Parameter):
var_info['trainable'] = var.trainable
var_info['optimize_attr'] = var.optimize_attr
var_info['regularizer'] = var.regularizer
if Version(paddle.__version__) < '1.8':
var_info['gradient_clip_attr'] = var.gradient_clip_attr
var_info['do_model_average'] = var.do_model_average
else:
var_info['persistable'] = var.persistable
return var_info
def remove_feed_fetch_op(program: paddle.static.Program):
'''Remove feed and fetch operator and variable for fine-tuning.'''
......@@ -39,3 +95,103 @@ def remove_feed_fetch_op(program: paddle.static.Program):
block._remove_var(var)
program.desc.flush()
def rename_var(block: paddle.device.framework.Block, old_name: str, new_name: str):
'''
'''
for op in block.ops:
for input_name in op.input_arg_names:
if input_name == old_name:
op._rename_input(old_name, new_name)
for output_name in op.output_arg_names:
if output_name == old_name:
op._rename_output(old_name, new_name)
block._rename_var(old_name, new_name)
def add_vars_prefix(program: paddle.static.Program,
prefix: str,
vars: List[paddle.Variable] = None,
excludes: Callable = None):
'''
'''
block = program.global_block()
vars = list(vars) if vars else list(block.vars.keys())
vars = [var for var in vars if var not in excludes] if excludes else vars
for var in vars:
rename_var(block, var, prefix + var)
def remove_vars_prefix(program: paddle.static.Program,
prefix: str,
vars: List[paddle.Variable] = None,
excludes: Callable = None):
'''
'''
block = program.global_block()
vars = [var for var in vars
if var.startswith(prefix)] if vars else [var for var in block.vars.keys() if var.startswith(prefix)]
vars = [var for var in vars if var not in excludes] if excludes else vars
for var in vars:
rename_var(block, var, var.replace(prefix, '', 1))
def clone_program(origin_program: paddle.static.Program, for_test: bool = False) -> paddle.static.Program:
dest_program = paddle.static.Program()
_copy_vars_and_ops_in_blocks(origin_program.global_block(), dest_program.global_block())
dest_program = dest_program.clone(for_test=for_test)
if not for_test:
for name, var in origin_program.global_block().vars.items():
dest_program.global_block().vars[name].stop_gradient = var.stop_gradient
return dest_program
def _copy_vars_and_ops_in_blocks(from_block: paddle.device.framework.Block, to_block: paddle.device.framework.Block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, paddle.device.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
all_attrs = op.all_attrs()
if 'sub_block' in all_attrs:
_sub_block = to_block.program._create_block()
_copy_vars_and_ops_in_blocks(all_attrs['sub_block'], _sub_block)
to_block.program._rollback()
new_attrs = {'sub_block': _sub_block}
for key, value in all_attrs.items():
if key == 'sub_block':
continue
new_attrs[key] = copy.deepcopy(value)
else:
new_attrs = copy.deepcopy(all_attrs)
op_info = {
'type': op.type,
'inputs':
{input: [to_block._find_var_recursive(var) for var in op.input(input)]
for input in op.input_names},
'outputs':
{output: [to_block._find_var_recursive(var) for var in op.output(output)]
for output in op.output_names},
'attrs': new_attrs
}
to_block.append_op(**op_info)
def set_op_attr(program: paddle.static.Program, is_test: bool = False):
for block in program.blocks:
for op in block.ops:
if not op.has_attr('is_test'):
continue
op._set_attr('is_test', is_test)
......@@ -43,14 +43,14 @@ class HubModuleNotFoundError(Exception):
class LocalModuleManager(object):
"""
'''
LocalModuleManager is used to manage PaddleHub's local Module, which supports the installation, uninstallation,
and search of HubModule. LocalModuleManager is a singleton object related to the path, in other words, when the
LocalModuleManager object of the same home directory is generated multiple times, the same object is returned.
Args:
home (str): The directory where PaddleHub modules are stored, the default is ~/.paddlehub/modules
"""
'''
_instance_map = {}
def __new__(cls, home: str = MODULE_HOME):
......
......@@ -17,9 +17,9 @@ import inspect
import importlib
import os
import sys
from typing import Callable, List, Optional, Generic
from typing import Callable, Generic, List, Optional
from paddlehub.utils import utils
from paddlehub.utils import log, utils
from paddlehub.compat.module.module_v1 import ModuleV1
......@@ -58,9 +58,10 @@ def serving(func: Callable) -> Callable:
class Module(object):
'''
'''
def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
if cls.__name__ == 'Module':
# This branch come from hub.Module(name='xxx')
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name:
module = cls.init_with_name(name=name, version=version, **kwargs)
elif directory:
......@@ -72,19 +73,19 @@ class Module(object):
@classmethod
def load(cls, directory: str) -> Generic:
'''
'''
if directory.endswith(os.sep):
directory = directory[:-1]
# if module description file existed, try to load as ModuleV1
# If module description file existed, try to load as ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file):
return ModuleV1.load(desc_file)
basename = os.path.split(directory)[-1]
dirname = os.path.join(*list(os.path.split(directory)[:-1]))
sys.path.insert(0, dirname)
py_module = importlib.import_module('{}.module'.format(basename))
py_module = utils.load_py_module(dirname, '{}.module'.format(basename))
for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
_item = py_module.__dict__[_item]
......@@ -93,13 +94,14 @@ class Module(object):
break
else:
raise InvalidHubModule(directory)
sys.path.pop(0)
user_module_cls.directory = directory
return user_module_cls
@classmethod
def init_with_name(cls, name: str, version: str = None, **kwargs):
'''
'''
from paddlehub.module.manager import LocalModuleManager
manager = LocalModuleManager()
user_module_cls = manager.search(name)
......@@ -107,15 +109,39 @@ class Module(object):
user_module_cls = manager.install(name, version)
directory = manager._get_normalized_path(name)
# The HubModule in the old version will use the _initialize method to initialize,
# this function will be obsolete in a future version
if hasattr(user_module_cls, '_initialize'):
log.logger.warning(
'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
)
user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs)
return user_module
return user_module_cls(directory=directory, **kwargs)
@classmethod
def init_with_directory(cls, directory: str, **kwargs):
'''
'''
user_module_cls = cls.load(directory)
return user_module_cls(**kwargs)
# The HubModule in the old version will use the _initialize method to initialize,
# this function will be obsolete in a future version
if hasattr(user_module_cls, '_initialize'):
log.logger.warning(
'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
)
user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs)
return user_module
return user_module_cls(directory=directory, **kwargs)
@classmethod
def get_py_requirements(cls):
'''
'''
req_file = os.path.join(cls.directory, 'requirements.txt')
if not os.path.exists(req_file):
return []
......@@ -125,6 +151,9 @@ class Module(object):
class RunModule(object):
'''
'''
def __init__(self, *args, **kwargs):
# Avoid module being initialized multiple times
if '_is_initialize' in self.__dict__ and self._is_initialize:
......@@ -149,6 +178,8 @@ class RunModule(object):
@classmethod
def get_py_requirements(cls) -> List[str]:
'''
'''
py_module = sys.modules[cls.__module__]
directory = os.path.dirname(py_module.__file__)
req_file = os.path.join(directory, 'requirements.txt')
......@@ -172,6 +203,9 @@ def moduleinfo(name: str,
summary: str = None,
type: str = None,
meta=None) -> Callable:
'''
'''
def _wrapper(cls: Generic) -> Generic:
wrap_cls = cls
_meta = RunModule if not meta else meta
......
......@@ -170,8 +170,8 @@ class FormattedText(object):
self.width = width
def __repr__(self) -> str:
form = ':{}{}'.format(self.align, self.width)
text = ('{' + form + '}').format(self.text)
form = '{{:{}{}}}'.format(self.align, self.width)
text = form.format(self.text)
if not self.color:
return text
return self.color + text + Fore.RESET
......
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import sys
from typing import List
import yaml
from paddlehub.utils.utils import sys_stdin_encoding
class CSVFileParser(object):
def parse(self, csv_file: str) -> dict:
with codecs.open(csv_file, 'r', sys_stdin_encoding()) as file:
content = file.read()
content = content.split('\n')
self.title = content[0].split(',')
self.content = {}
for key in self.title:
self.content[key] = []
for text in content[1:]:
if (text == ''):
continue
for index, item in enumerate(text.split(',')):
title = self.title[index]
self.content[title].append(item)
return self.content
class YAMLFileParser(object):
def parse(self, yaml_file: str) -> dict:
with codecs.open(yaml_file, 'r', sys_stdin_encoding()) as file:
content = file.read()
return yaml.load(content, Loader=yaml.BaseLoader)
class TextFileParser(object):
def parse(self, txt_file: str, use_strip: bool = True) -> List:
contents = []
try:
with codecs.open(txt_file, 'r', encoding='utf8') as file:
for line in file:
if use_strip:
line = line.strip()
if line:
contents.append(line)
except:
with codecs.open(txt_file, 'r', encoding='gbk') as file:
for line in file:
if use_strip:
line = line.strip()
if line:
contents.append(line)
return contents
csv_parser = CSVFileParser()
yaml_parser = YAMLFileParser()
txt_parser = TextFileParser()
......@@ -31,10 +31,12 @@ from urllib.parse import urlparse
import packaging.version
import paddlehub.env as hubenv
import paddlehub.utils as utils
class Version(packaging.version.Version):
'''Extended implementation of packaging.version.Version'''
def match(self, condition: str) -> bool:
'''
Determine whether the given condition are met
......@@ -76,9 +78,35 @@ class Version(packaging.version.Version):
return _comp(Version(version))
def __lt__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__lt__(other)
def __le__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__le__(other)
def __gt__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__gt__(other)
def __ge__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__ge__(other)
def __eq__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__eq__(other)
class Timer(object):
'''Calculate runing speed and estimated time of arrival(ETA)'''
def __init__(self, total_step: int):
self.total_step = total_step
self.last_start_step = 0
......@@ -217,3 +245,35 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
sys.path.pop(0)
return py_module
def get_platform_default_encoding() -> str:
'''
'''
if utils.platform.is_windows():
return 'gbk'
return 'utf8'
def sys_stdin_encoding() -> str:
'''
'''
encoding = sys.stdin.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
if encoding is None:
encoding = get_platform_default_encoding()
return encoding
def sys_stdout_encoding() -> str:
'''
'''
encoding = sys.stdout.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
if encoding is None:
encoding = get_platform_default_encoding()
return encoding
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册