未验证 提交 c322e6fb 编写于 作者: B Bo Zhou 提交者: GitHub

importting submodule in xparl (#373)

* fix #370

* add unit tests for importting submodule in xparl

* add comment

* revert scripts.py

* yapf

* unit test

* fix relative path bug

* read files using relative path

* remove print

* update CMake

* add try times

* fix the bug in log_server_test.py

* fix relative path bug&add unit test

* update utils.py

* update Cmake

* fix pyc bug

* Update utils.py
上级 28ea49fc
...@@ -30,10 +30,20 @@ function(py_test TARGET_NAME) ...@@ -30,10 +30,20 @@ function(py_test TARGET_NAME)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS ENVS) set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME} if (${FILE_NAME} MATCHES ".*abs_test.py")
COMMAND python -u ${py_test_SRCS} ${py_test_ARGS} add_test(NAME ${TARGET_NAME}"_with_abs_path"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
set_tests_properties(${TARGET_NAME}"_with_abs_path" PROPERTIES TIMEOUT 300)
else()
get_filename_component(WORKING_DIR ${py_test_SRCS} DIRECTORY)
get_filename_component(FILE_NAME ${py_test_SRCS} NAME)
get_filename_component(COMBINED_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${WORKING_DIR} ABSOLUTE)
add_test(NAME ${TARGET_NAME}
COMMAND python -u ${FILE_NAME} ${py_test_ARGS}
WORKING_DIRECTORY ${COMBINED_PATH})
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300)
endif()
endfunction() endfunction()
function(import_test TARGET_NAME) function(import_test TARGET_NAME)
......
...@@ -50,7 +50,6 @@ class Client(object): ...@@ -50,7 +50,6 @@ class Client(object):
distributed_files (list): A list of files to be distributed at all distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration remote instances(e,g. the configuration
file for initialization) . file for initialization) .
""" """
self.master_address = master_address self.master_address = master_address
self.process_id = process_id self.process_id = process_id
...@@ -95,34 +94,39 @@ class Client(object): ...@@ -95,34 +94,39 @@ class Client(object):
pyfiles['python_files'] = {} pyfiles['python_files'] = {}
pyfiles['other_files'] = {} pyfiles['other_files'] = {}
code_files = filter(lambda x: x.endswith('.py'), os.listdir('./')) main_file = sys.argv[0]
main_folder = './'
try: sep = os.sep
for file in code_files: if sep in main_file:
assert os.path.exists(file) main_folder = sep.join(main_file.split(sep)[:-1])
with open(file, 'rb') as code_file: code_files = filter(lambda x: x.endswith('.py'),
code = code_file.read() os.listdir(main_folder))
pyfiles['python_files'][file] = code
for file_name in code_files:
for file in distributed_files: file_path = os.path.join(main_folder, file_name)
assert os.path.exists(file) assert os.path.exists(file_path)
assert not os.path.isabs( with open(file_path, 'rb') as code_file:
file
), "[XPARL] Please do not distribute a file with absolute path."
with open(file, 'rb') as f:
content = f.read()
pyfiles['other_files'][file] = content
# append entry file to code list
main_file = sys.argv[0]
with open(main_file, 'rb') as code_file:
code = code_file.read() code = code_file.read()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code pyfiles['python_files'][file_name] = code
except AssertionError as e: # append entry file to code list
raise Exception( assert os.path.isfile(
'Failed to create the client, the file {} does not exist.'. main_file
format(file)) ), "[xparl] error occurs when distributing files. cannot find the entry file:{} in current working directory: {}".format(
main_file, os.getcwd())
with open(main_file, 'rb') as code_file:
code = code_file.read()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code
for file_name in distributed_files:
assert os.path.exists(file_name)
assert not os.path.isabs(
file_name
), "[XPARL] Please do not distribute a file with absolute path."
with open(file_name, 'rb') as f:
content = f.read()
pyfiles['other_files'][file_name] = content
return cloudpickle.dumps(pyfiles) return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address): def _create_sockets(self, master_address):
......
...@@ -311,8 +311,6 @@ class Job(object): ...@@ -311,8 +311,6 @@ class Job(object):
try: try:
file_name, class_name, end_of_file = cloudpickle.loads( file_name, class_name, end_of_file = cloudpickle.loads(
message[1]) message[1])
#/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent
file_name = file_name.split(os.sep)[-1]
cls = load_remote_class(file_name, class_name, end_of_file) cls = load_remote_class(file_name, class_name, end_of_file)
args, kwargs = cloudpickle.loads(message[2]) args, kwargs = cloudpickle.loads(message[2])
logfile_path = os.path.join(self.log_dir, 'stdout.log') logfile_path = os.path.join(self.log_dir, 'stdout.log')
......
...@@ -19,6 +19,7 @@ import time ...@@ -19,6 +19,7 @@ import time
import zmq import zmq
import numpy as np import numpy as np
import inspect import inspect
import sys
from parl.utils import get_ip_address, logger, to_str, to_byte from parl.utils import get_ip_address, logger, to_str, to_byte
from parl.utils.communication import loads_argument, loads_return,\ from parl.utils.communication import loads_argument, loads_return,\
...@@ -27,6 +28,7 @@ from parl.remote import remote_constants ...@@ -27,6 +28,7 @@ from parl.remote import remote_constants
from parl.remote.exceptions import RemoteError, RemoteAttributeError,\ from parl.remote.exceptions import RemoteError, RemoteAttributeError,\
RemoteDeserializeError, RemoteSerializeError, ResourceError RemoteDeserializeError, RemoteSerializeError, ResourceError
from parl.remote.client import get_global_client from parl.remote.client import get_global_client
from parl.remote.utils import locate_remote_file
def remote_class(*args, **kwargs): def remote_class(*args, **kwargs):
...@@ -120,13 +122,22 @@ def remote_class(*args, **kwargs): ...@@ -120,13 +122,22 @@ def remote_class(*args, **kwargs):
self.job_shutdown = False self.job_shutdown = False
self.send_file(self.job_socket) self.send_file(self.job_socket)
file_name = inspect.getfile(cls)[:-3] module_path = inspect.getfile(cls)
if module_path.endswith('pyc'):
module_path = module_path[:-4]
elif module_path.endswith('py'):
module_path = module_path[:-3]
else:
raise FileNotFoundError(
"cannot not find the module:{}".format(module_path))
res = inspect.getfile(cls)
file_path = locate_remote_file(module_path)
cls_source = inspect.getsourcelines(cls) cls_source = inspect.getsourcelines(cls)
end_of_file = cls_source[1] + len(cls_source[0]) end_of_file = cls_source[1] + len(cls_source[0])
class_name = cls.__name__ class_name = cls.__name__
self.job_socket.send_multipart([ self.job_socket.send_multipart([
remote_constants.INIT_OBJECT_TAG, remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps([file_name, class_name, end_of_file]), cloudpickle.dumps([file_path, class_name, end_of_file]),
cloudpickle.dumps([args, kwargs]), cloudpickle.dumps([args, kwargs]),
]) ])
message = self.job_socket.recv_multipart() message = self.job_socket.recv_multipart()
......
...@@ -24,6 +24,7 @@ import time ...@@ -24,6 +24,7 @@ import time
import unittest import unittest
import requests import requests
requests.adapters.DEFAULT_RETRIES = 5
import parl import parl
from parl.remote.client import disconnect, get_global_client from parl.remote.client import disconnect, get_global_client
...@@ -128,7 +129,6 @@ class TestLogServer(unittest.TestCase): ...@@ -128,7 +129,6 @@ class TestLogServer(unittest.TestCase):
monitor_file = __file__.replace('log_server_test.pyc', '../monitor.py') monitor_file = __file__.replace('log_server_test.pyc', '../monitor.py')
monitor_file = monitor_file.replace('log_server_test.py', monitor_file = monitor_file.replace('log_server_test.py',
'../monitor.py') '../monitor.py')
command = [ command = [
sys.executable, monitor_file, "--monitor_port", sys.executable, monitor_file, "--monitor_port",
str(monitor_port), "--address", "localhost:" + str(master_port) str(monitor_port), "--address", "localhost:" + str(master_port)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
@parl.remote_class
class B(object):
def add_sum(self, a, b):
return a + b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import parl
import time
import threading
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
class TestImport(unittest.TestCase):
def tearDown(self):
disconnect()
def test_import_local_module(self):
from Module2 import B
port = 8448
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect("localhost:8448")
obj = B()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import parl
import time
import threading
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
class TestImport(unittest.TestCase):
def tearDown(self):
disconnect()
def test_import_local_module(self):
from Module2 import B
port = 8442
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect("localhost:8442")
obj = B()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
def test_import_subdir_module_0(self):
from subdir import Module
port = 8443
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect(
"localhost:8443",
distributed_files=['./subdir/Module.py', './subdir/__init__.py'])
obj = Module.A()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
def test_import_subdir_module_1(self):
from subdir.Module import A
port = 8444
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect(
"localhost:8444",
distributed_files=['./subdir/Module.py', './subdir/__init__.py'])
obj = A()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
@parl.remote_class
class A(object):
def add_sum(self, a, b):
return a + b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -13,8 +13,11 @@ ...@@ -13,8 +13,11 @@
# limitations under the License. # limitations under the License.
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import os
__all__ = ['load_remote_class', 'redirect_stdout_to_file'] __all__ = [
'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file'
]
def simplify_code(code, end_of_file): def simplify_code(code, end_of_file):
...@@ -32,7 +35,7 @@ def simplify_code(code, end_of_file): ...@@ -32,7 +35,7 @@ def simplify_code(code, end_of_file):
def data_process(): def data_process():
XXXX XXXX
------------------> ------------------>
The last two lines of the above code block will be removed as they are not class related. The last two lines of the above code block will be removed as they are not class-related.
""" """
to_write_lines = [] to_write_lines = []
for i, line in enumerate(code): for i, line in enumerate(code):
...@@ -60,12 +63,18 @@ def load_remote_class(file_name, class_name, end_of_file): ...@@ -60,12 +63,18 @@ def load_remote_class(file_name, class_name, end_of_file):
with open(file_name + '.py') as t_file: with open(file_name + '.py') as t_file:
code = t_file.readlines() code = t_file.readlines()
code = simplify_code(code, end_of_file) code = simplify_code(code, end_of_file)
module_name = 'xparl_' + file_name #folder/xx.py -> folder/xparl_xx.py
tmp_file_name = 'xparl_' + file_name + '.py' file_name = file_name.split(os.sep)
prefix = os.sep.join(file_name[:-1])
if prefix == "":
prefix = '.'
module_name = prefix + os.sep + 'xparl_' + file_name[-1]
tmp_file_name = module_name + '.py'
with open(tmp_file_name, 'w') as t_file: with open(tmp_file_name, 'w') as t_file:
for line in code: for line in code:
t_file.write(line) t_file.write(line)
mod = __import__(module_name) module_name = module_name.lstrip('.' + os.sep).replace(os.sep, '.')
mod = __import__(module_name, globals(), locals(), [class_name], 0)
cls = getattr(mod, class_name) cls = getattr(mod, class_name)
return cls return cls
...@@ -74,6 +83,9 @@ def load_remote_class(file_name, class_name, end_of_file): ...@@ -74,6 +83,9 @@ def load_remote_class(file_name, class_name, end_of_file):
def redirect_stdout_to_file(file_path): def redirect_stdout_to_file(file_path):
"""Redirect stdout (e.g., `print`) to specified file. """Redirect stdout (e.g., `print`) to specified file.
Args:
file_path: Path of the file to output the stdout.
Example: Example:
>>> print('test') >>> print('test')
test test
...@@ -81,10 +93,6 @@ def redirect_stdout_to_file(file_path): ...@@ -81,10 +93,6 @@ def redirect_stdout_to_file(file_path):
... print('test') # Output nothing, `test` is printed to `test.log`. ... print('test') # Output nothing, `test` is printed to `test.log`.
>>> print('test') >>> print('test')
test test
Args:
file_path: Path of the file to output the stdout.
""" """
tmp = sys.stdout tmp = sys.stdout
f = open(file_path, 'a') f = open(file_path, 'a')
...@@ -94,3 +102,33 @@ def redirect_stdout_to_file(file_path): ...@@ -94,3 +102,33 @@ def redirect_stdout_to_file(file_path):
finally: finally:
sys.stdout = tmp sys.stdout = tmp
f.close() f.close()
def locate_remote_file(module_path):
"""xparl has to locate the file that has the class decorated by parl.remote_class.
This function returns the relative path between this file and the entry file.
Args:
module_path: Absolute path of the module.
Example:
module_path: /home/user/dir/subdir/my_module
entry_file: /home/user/dir/main.py
--------> relative_path: subdir/my_module
"""
entry_file = sys.argv[0]
entry_file = entry_file.split(os.sep)[-1]
entry_path = None
for path in sys.path:
to_check_path = os.path.join(path, entry_file)
if os.path.isfile(to_check_path):
entry_path = path
break
if entry_path is None or \
(module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]):
raise FileNotFoundError("cannot locate the remote file")
if module_path.startswith(os.sep):
relative_module_path = '.' + module_path[len(entry_path):]
else:
relative_module_path = module_path
return relative_module_path
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册