提交 9a1eac7b 编写于 作者: W wuzewu

Add module install tips

上级 613375e5
...@@ -28,20 +28,83 @@ from paddlehub.utils import xarfile, log, utils, pypi ...@@ -28,20 +28,83 @@ from paddlehub.utils import xarfile, log, utils, pypi
class HubModuleNotFoundError(Exception): class HubModuleNotFoundError(Exception):
def __init__(self, name, version=None, source=None): def __init__(self, name: str, info: dict = None, version: str = None, source: str = None):
self.name = name self.name = name
self.version = version self.version = version
self.info = info
self.source = source self.source = source
def __str__(self): def __str__(self):
msg = '{}'.format(self.name) msg = '{}'.format(self.name)
if self.version: if self.version:
msg += '-{}'.format(self.version) msg += '-{}'.format(self.version)
if self.source: if self.source:
msg += ' from {}'.format(self.source) msg += ' from {}'.format(self.source)
tips = 'No HubModule named {} was found'.format(msg) tips = 'No HubModule named {} was found'.format(log.FormattedText(text=msg, color='red'))
if self.info:
sort_infos = sorted(self.info.items(), key=lambda x: utils.Version(x[0]))
table = log.Table()
table.append(
*['Name', 'Version', 'PaddlePaddle Version Required', 'PaddleHub Version Required'],
widths=[15, 10, 35, 35],
aligns=['^', '^', '^', '^'],
colors=['cyan', 'cyan', 'cyan', 'cyan'])
for _ver, info in sort_infos:
paddle_version = 'Any' if not info['paddle_version'] else ', '.join(info['paddle_version'])
hub_version = 'Any' if not info['hub_version'] else ', '.join(info['hub_version'])
table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^'])
tips += ', \n{}'.format(table)
return tips
class EnvironmentMismatchError(Exception):
def __init__(self, name: str, info: dict, version: str = None):
self.name = name
self.version = version
self.info = info
def __str__(self):
msg = '{}'.format(self.name)
if self.version:
msg += '-{}'.format(self.version)
tips = '{} cannot be installed because some conditions are not met'.format(
log.FormattedText(text=msg, color='red'))
if self.info:
sort_infos = sorted(self.info.items(), key=lambda x: utils.Version(x[0]))
table = log.Table()
table.append(
*['Name', 'Version', 'PaddlePaddle Version Required', 'PaddleHub Version Required'],
widths=[15, 10, 35, 35],
aligns=['^', '^', '^', '^'],
colors=['cyan', 'cyan', 'cyan', 'cyan'])
import paddle
import paddlehub
for _ver, info in sort_infos:
paddle_version = 'Any' if not info['paddle_version'] else ', '.join(info['paddle_version'])
for version in info['paddle_version']:
if not utils.Version(paddle.__version__).match(version):
paddle_version = '{}(Mismatch)'.format(paddle_version)
break
hub_version = 'Any' if not info['hub_version'] else ', '.join(info['hub_version'])
for version in info['hub_version']:
if not utils.Version(paddlehub.__version__).match(version):
hub_version = '{}(Mismatch)'.format(hub_version)
break
table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^'])
tips += ', \n{}'.format(table)
return tips return tips
...@@ -176,7 +239,23 @@ class LocalModuleManager(object): ...@@ -176,7 +239,23 @@ class LocalModuleManager(object):
result = module_server.search_module(name=name, version=version, source=source) result = module_server.search_module(name=name, version=version, source=source)
if not result: if not result:
raise HubModuleNotFoundError(name, version, source) module_infos = module_server.get_module_info(name=name, source=source)
# The HubModule with the specified name cannot be found
if not module_infos:
raise HubModuleNotFoundError(name=name, version=version, source=source)
valid_infos = {}
if version:
for _ver, _info in module_infos.items():
if utils.Version(_ver).match(version):
valid_infos[_ver] = _info
else:
valid_infos = list(module_infos.keys())
# Cannot find a HubModule that meets the version
if valid_infos:
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source)
if source or 'source' in result: if source or 'source' in result:
return self._install_from_source(result) return self._install_from_source(result)
......
...@@ -82,9 +82,9 @@ class GitSource(object): ...@@ -82,9 +82,9 @@ class GitSource(object):
name(str) : PaddleHub module name name(str) : PaddleHub module name
version(str) : PaddleHub module version version(str) : PaddleHub module version
''' '''
return self.search_resouce(type='module', name=name, version=version) return self.search_resource(type='module', name=name, version=version)
def search_resouce(self, type: str, name: str, version: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None) -> dict:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
......
...@@ -22,6 +22,7 @@ PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub' ...@@ -22,6 +22,7 @@ PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub'
class HubServer(object): class HubServer(object):
'''PaddleHub server''' '''PaddleHub server'''
def __init__(self): def __init__(self):
self.sources = OrderedDict() self.sources = OrderedDict()
...@@ -51,9 +52,9 @@ class HubServer(object): ...@@ -51,9 +52,9 @@ class HubServer(object):
name(str) : PaddleHub module name name(str) : PaddleHub module name
version(str) : PaddleHub module version version(str) : PaddleHub module version
''' '''
return self.search_resouce(type='module', name=name, version=version, source=source) return self.search_resource(type='module', name=name, version=version, source=source)
def search_resouce(self, type: str, name: str, version: str = None, source: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> dict:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -64,10 +65,20 @@ class HubServer(object): ...@@ -64,10 +65,20 @@ class HubServer(object):
''' '''
sources = self.sources.values() if not source else [self._generate_source(source)] sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources: for source in sources:
result = source.search_resouce(name=name, type=type, version=version) result = source.search_resource(name=name, type=type, version=version)
if result:
return result
return {}
def get_module_info(self, name: str, source: str = None) -> dict:
'''
'''
sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources:
result = source.get_module_info(name=name)
if result: if result:
return result return result
return None return {}
module_server = HubServer() module_server = HubServer()
......
...@@ -14,12 +14,11 @@ ...@@ -14,12 +14,11 @@
# limitations under the License. # limitations under the License.
import json import json
import platform
import requests import requests
import sys from typing import List
import paddlehub import paddlehub
from paddlehub.utils import utils from paddlehub.utils import utils, platform
class ServerConnectionError(Exception): class ServerConnectionError(Exception):
...@@ -52,9 +51,9 @@ class ServerSource(object): ...@@ -52,9 +51,9 @@ class ServerSource(object):
name(str) : PaddleHub module name name(str) : PaddleHub module name
version(str) : PaddleHub module version version(str) : PaddleHub module version
''' '''
return self.search_resouce(type='module', name=name, version=version) return self.search_resource(type='module', name=name, version=version)
def search_resouce(self, type: str, name: str, version: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None) -> dict:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -63,36 +62,64 @@ class ServerSource(object): ...@@ -63,36 +62,64 @@ class ServerSource(object):
name(str) : Resource name name(str) : Resource name
version(str) : Resource version version(str) : Resource version
''' '''
payload = {'environments': {}} params = {'environments': platform.get_platform_info()}
payload['word'] = name params['word'] = name
payload['type'] = type params['type'] = type
if version: if version:
payload['version'] = version params['version'] = version
# Delay module loading to improve command line speed # Delay module loading to improve command line speed
import paddle import paddle
payload['environments']['hub_version'] = paddlehub.__version__ params['hub_version'] = paddlehub.__version__
payload['environments']['paddle_version'] = paddle.__version__ params['paddle_version'] = paddle.__version__
payload['environments']['python_version'] = '.'.join(map(str, sys.version_info[0:3]))
payload['environments']['platform_version'] = platform.version()
payload['environments']['platform_system'] = platform.system()
payload['environments']['platform_architecture'] = platform.architecture()
payload['environments']['platform_type'] = platform.platform()
api = '{}/search'.format(self._url) result = self.request(path='search', params=params)
if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
return None
def get_module_info(self, name: str) -> dict:
'''
'''
def _convert_version(version: str) -> List:
result = []
# from [1.5.4, 2.0.0] -> 1.5.4,2.0.0
version = version.replace(' ', '')[1:-1]
version = version.split(',')
if version[0] != '-1.0.0':
result.append('>={}'.format(version[0]))
if len(version) > 1:
if version[1] != '99.0.0':
result.append('<={}'.format(version[1]))
return result
params = {'name': name}
result = self.request(path='info', params=params)
if result['status'] == 0 and len(result['data']) > 0:
infos = {}
for _info in result['data']['info']:
infos[_info['version']] = {
'url': _info['url'],
'paddle_version': _convert_version(_info['paddle_version']),
'hub_version': _convert_version(_info['hub_version'])
}
return infos
return {}
def request(self, path: str, params: dict) -> dict:
'''
'''
api = '{}/{}'.format(self._url, path)
try: try:
result = requests.get(api, payload, timeout=self._timeout) result = requests.get(api, params, timeout=self._timeout)
result = result.json() return result.json()
if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
else:
print(result)
return None
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url) raise ServerConnectionError(self._url)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册