提交 8917fb23 编写于 作者: W wuzewu

Add search command

上级 80476dde
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from typing import List
from paddlehub.commands import register
from paddlehub.module.manager import LocalModuleManager
from paddlehub.server.server import module_server
from paddlehub.utils import log, platform
@register(name='hub.search', description='Search PaddleHub pretrained model through model keywords.')
class SearchCommand:
def execute(self, argv: List) -> bool:
argv = '.*' if not argv else argv[0]
widths = [20, 8, 30] if platform.is_windows() else [30, 8, 40]
table = log.Table(widths=widths)
table.append(*['ModuleName', 'Version', 'Summary'], aligns=['^', '^', '^'], colors=["blue", "blue", "blue"])
results = module_server.search_module(name=argv)
for result in results:
table.append(result['name'], result['version'], result['summary'])
print(table)
return True
...@@ -238,28 +238,29 @@ class LocalModuleManager(object): ...@@ -238,28 +238,29 @@ class LocalModuleManager(object):
return self._local_modules[name] return self._local_modules[name]
result = module_server.search_module(name=name, version=version, source=source) result = module_server.search_module(name=name, version=version, source=source)
if not result: for item in result:
module_infos = module_server.get_module_info(name=name, source=source) if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
# The HubModule with the specified name cannot be found if source or 'source' in item:
if not module_infos: return self._install_from_source(result)
raise HubModuleNotFoundError(name=name, version=version, source=source) return self._install_from_url(item['url'])
valid_infos = {} module_infos = module_server.get_module_info(name=name, source=source)
if version: # The HubModule with the specified name cannot be found
for _ver, _info in module_infos.items(): if not module_infos:
if utils.Version(_ver).match(version): raise HubModuleNotFoundError(name=name, version=version, source=source)
valid_infos[_ver] = _info
else: valid_infos = {}
valid_infos = module_infos.copy() if version:
for _ver, _info in module_infos.items():
# Cannot find a HubModule that meets the version if utils.Version(_ver).match(version):
if valid_infos: valid_infos[_ver] = _info
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version) else:
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source) valid_infos = module_infos.copy()
if source or 'source' in result: # Cannot find a HubModule that meets the version
return self._install_from_source(result) if valid_infos:
return self._install_from_url(result['url']) raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source)
def _install_from_source(self, source: str) -> HubModule: def _install_from_source(self, source: str) -> HubModule:
'''Install a HubModule from Git Repo''' '''Install a HubModule from Git Repo'''
......
...@@ -18,6 +18,7 @@ import importlib ...@@ -18,6 +18,7 @@ import importlib
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import List
from urllib.parse import urlparse from urllib.parse import urlparse
import git import git
...@@ -74,7 +75,7 @@ class GitSource(object): ...@@ -74,7 +75,7 @@ class GitSource(object):
log.logger.warning('An error occurred while loading {}'.format(self.path)) log.logger.warning('An error occurred while loading {}'.format(self.path))
sys.path.remove(self.path) sys.path.remove(self.path)
def search_module(self, name: str, version: str = None) -> dict: def search_module(self, name: str, version: str = None) -> List[dict]:
''' '''
Search PaddleHub module Search PaddleHub module
...@@ -84,7 +85,7 @@ class GitSource(object): ...@@ -84,7 +85,7 @@ class GitSource(object):
''' '''
return self.search_resource(type='module', name=name, version=version) return self.search_resource(type='module', name=name, version=version)
def search_resource(self, type: str, name: str, version: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None) -> List[dict]:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -95,13 +96,13 @@ class GitSource(object): ...@@ -95,13 +96,13 @@ class GitSource(object):
''' '''
module = self.hub_modules.get(name, None) module = self.hub_modules.get(name, None)
if module and module.version.match(version): if module and module.version.match(version):
return { return [{
'version': module.version, 'version': module.version,
'name': module.name, 'name': module.name,
'path': self.path, 'path': self.path,
'class': module.__name__, 'class': module.__name__,
'source': self.url 'source': self.url
} }]
return None return None
@classmethod @classmethod
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from typing import List
from paddlehub.server import ServerSource, GitSource from paddlehub.server import ServerSource, GitSource
...@@ -44,7 +45,7 @@ class HubServer(object): ...@@ -44,7 +45,7 @@ class HubServer(object):
'''Remove a module source''' '''Remove a module source'''
self.sources.pop(key) self.sources.pop(key)
def search_module(self, name: str, version: str = None, source: str = None) -> dict: def search_module(self, name: str, version: str = None, source: str = None) -> List[dict]:
''' '''
Search PaddleHub module Search PaddleHub module
...@@ -54,7 +55,7 @@ class HubServer(object): ...@@ -54,7 +55,7 @@ class HubServer(object):
''' '''
return self.search_resource(type='module', name=name, version=version, source=source) return self.search_resource(type='module', name=name, version=version, source=source)
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> List[dict]:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -68,7 +69,7 @@ class HubServer(object): ...@@ -68,7 +69,7 @@ class HubServer(object):
result = source.search_resource(name=name, type=type, version=version) result = source.search_resource(name=name, type=type, version=version)
if result: if result:
return result return result
return {} return []
def get_module_info(self, name: str, source: str = None) -> dict: def get_module_info(self, name: str, source: str = None) -> dict:
''' '''
......
...@@ -43,7 +43,7 @@ class ServerSource(object): ...@@ -43,7 +43,7 @@ class ServerSource(object):
self._url = url self._url = url
self._timeout = timeout self._timeout = timeout
def search_module(self, name: str, version: str = None) -> dict: def search_module(self, name: str, version: str = None) -> List[dict]:
''' '''
Search PaddleHub module Search PaddleHub module
...@@ -53,7 +53,7 @@ class ServerSource(object): ...@@ -53,7 +53,7 @@ class ServerSource(object):
''' '''
return self.search_resource(type='module', name=name, version=version) return self.search_resource(type='module', name=name, version=version)
def search_resource(self, type: str, name: str, version: str = None) -> dict: def search_resource(self, type: str, name: str, version: str = None) -> List[dict]:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -76,9 +76,7 @@ class ServerSource(object): ...@@ -76,9 +76,7 @@ class ServerSource(object):
result = self.request(path='search', params=params) result = self.request(path='search', params=params)
if result['status'] == 0 and len(result['data']) > 0: if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']: return result['data']
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
return None return None
def get_module_info(self, name: str) -> dict: def get_module_info(self, name: str) -> dict:
......
...@@ -41,11 +41,17 @@ def download(name: str, save_path: str, version: str = None): ...@@ -41,11 +41,17 @@ def download(name: str, save_path: str, version: str = None):
if os.path.exists(file): if os.path.exists(file):
return return
resource = module_server.search_resouce(name=name, version=version, type='Model') resources = module_server.search_resouce(name=name, version=version, type='Model')
if not resource: if not resources:
raise ResourceNotFoundError(name, version)
for item in resources:
if item['name'] == name and utils.Version(item['version']).match(version):
url = item['url']
break
else:
raise ResourceNotFoundError(name, version) raise ResourceNotFoundError(name, version)
url = resource['url']
with utils.generate_tempdir() as _dir: with utils.generate_tempdir() as _dir:
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册