From 80476dde1cf186e69fdf990ee76e1a8aec354482 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Wed, 23 Sep 2020 21:04:38 +0800 Subject: [PATCH] Add run command --- paddlehub/commands/install.py | 5 +- paddlehub/commands/run.py | 84 ++++++++++++++++++++++++++++ paddlehub/compat/module/module_v1.py | 5 ++ paddlehub/module/manager.py | 6 +- paddlehub/module/module.py | 8 +++ 5 files changed, 104 insertions(+), 4 deletions(-) diff --git a/paddlehub/commands/install.py b/paddlehub/commands/install.py index 1aeb9c51..592e5253 100644 --- a/paddlehub/commands/install.py +++ b/paddlehub/commands/install.py @@ -36,5 +36,8 @@ class InstallCommand: elif os.path.exists(_arg) and xarfile.is_xarfile(_arg): manager.install(archive=_arg) else: - manager.install(name=_arg) + _arg = _arg.split('==') + name = _arg[0] + version = None if len(_arg) == 1 else _arg[1] + manager.install(name=name, version=version) return True diff --git a/paddlehub/commands/run.py b/paddlehub/commands/run.py index e69de29b..c3ce409b 100644 --- a/paddlehub/commands/run.py +++ b/paddlehub/commands/run.py @@ -0,0 +1,84 @@ +# coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import ast +import os +from typing import Any, List + +from paddlehub.compat.module.module_v1 import ModuleV1 +from paddlehub.commands import register +from paddlehub.module.manager import LocalModuleManager +from paddlehub.module.module import Module, InvalidHubModule + + +@register(name='hub.run', description='Run the specific module.') +class RunCommand: + def execute(self, argv: List) -> bool: + if not argv: + print('ERROR: You must give one module to run.') + return False + module_name = argv[0] + + if os.path.exists(module_name) and os.path.isdir(module_name): + try: + module = Module.load(module_name) + except InvalidHubModule: + print('{} is not a valid HubModule'.format(module_name)) + return False + except: + print('Some exception occurred while loading the {}'.format(module_name)) + return False + else: + module = Module(name=module_name) + + if not module.is_runnable: + print('ERROR! Module {} is not executable.'.format(module_name)) + return False + + if isinstance(module, ModuleV1): + result = self.run_module_v1(module, argv[1:]) + else: + result = module._run_func(argv[1:]) + + print(result) + return True + + def run_module_v1(self, module, argv: List) -> Any: + parser = argparse.ArgumentParser(prog='hub run {}'.format(module.name), add_help=False) + + arg_input_group = parser.add_argument_group(title='Input options', description='Data feed into the module.') + arg_config_group = parser.add_argument_group( + title='Config options', description='Run configuration for controlling module behavior, optional.') + + arg_config_group.add_argument( + '--use_gpu', type=ast.literal_eval, default=False, help='whether use GPU for prediction') + arg_config_group.add_argument('--batch_size', type=int, default=1, help='batch size for prediction') + + module_type = module.type.lower() + if module_type.startswith('cv'): + arg_input_group.add_argument( + '--input_path', type=str, default=None, help='path of image/video to predict', required=True) + else: + arg_input_group.add_argument('--input_text', type=str, default=None, help='text to predict', required=True) + + args = parser.parse_args(argv) + + except_data_format = module.processor.data_format(module.default_signature) + key = list(except_data_format.keys())[0] + input_data = {key: [args.input_path] if module_type.startswith('cv') else [args.input_text]} + + return module( + sign_name=module.default_signature, data=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size) diff --git a/paddlehub/compat/module/module_v1.py b/paddlehub/compat/module/module_v1.py index cbde15cd..ea2f555a 100644 --- a/paddlehub/compat/module/module_v1.py +++ b/paddlehub/compat/module/module_v1.py @@ -37,6 +37,7 @@ class ModuleV1(object): self.desc = module_v1_utils.convert_module_desc(desc_file) self.helper = self self.signatures = self.desc.signatures + self.default_signature = self.desc.default_signature self.directory = directory self._load_model() @@ -196,3 +197,7 @@ class ModuleV1(object): def assets_path(self): return os.path.join(self.directory, 'assets') + + @property + def is_runnable(self): + return self.default_signature != None diff --git a/paddlehub/module/manager.py b/paddlehub/module/manager.py index 83bb3bc5..b1130440 100644 --- a/paddlehub/module/manager.py +++ b/paddlehub/module/manager.py @@ -58,7 +58,7 @@ class HubModuleNotFoundError(Exception): hub_version = 'Any' if not info['hub_version'] else ', '.join(info['hub_version']) table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^']) - tips += ', \n{}'.format(table) + tips += ':\n{}'.format(table) return tips @@ -104,7 +104,7 @@ class EnvironmentMismatchError(Exception): table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^']) - tips += ', \n{}'.format(table) + tips += ':\n{}'.format(table) return tips @@ -250,7 +250,7 @@ class LocalModuleManager(object): if utils.Version(_ver).match(version): valid_infos[_ver] = _info else: - valid_infos = list(module_infos.keys()) + valid_infos = module_infos.copy() # Cannot find a HubModule that meets the version if valid_infos: diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index c4176006..f52c5179 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -148,6 +148,10 @@ class Module(object): user_module = user_module_cls(directory=directory) user_module._initialize(**kwargs) return user_module + + if user_module_cls == ModuleV1: + return user_module_cls(directory=directory, **kwargs) + user_module_cls.directory = directory return user_module_cls(**kwargs) @@ -166,6 +170,10 @@ class Module(object): user_module = user_module_cls(directory=directory) user_module._initialize(**kwargs) return user_module + + if user_module_cls == ModuleV1: + return user_module_cls(directory=directory, **kwargs) + user_module_cls.directory = directory return user_module_cls(**kwargs) -- GitLab