diff --git a/CMakeLists.txt b/CMakeLists.txt index a77bd684aecc364bf0053e36724fcf0fe880d2f0..435e27f2e0a3ed24964a639236a66de1f7a69f75 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,10 +30,20 @@ function(py_test TARGET_NAME) set(oneValueArgs "") set(multiValueArgs SRCS DEPS ARGS ENVS) cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - add_test(NAME ${TARGET_NAME} - COMMAND python -u ${py_test_SRCS} ${py_test_ARGS} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300) + if (${FILE_NAME} MATCHES ".*abs_test.py") + add_test(NAME ${TARGET_NAME}"_with_abs_path" + COMMAND python -u ${py_test_SRCS} ${py_test_ARGS} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + set_tests_properties(${TARGET_NAME}"_with_abs_path" PROPERTIES TIMEOUT 300) + else() + get_filename_component(WORKING_DIR ${py_test_SRCS} DIRECTORY) + get_filename_component(FILE_NAME ${py_test_SRCS} NAME) + get_filename_component(COMBINED_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${WORKING_DIR} ABSOLUTE) + add_test(NAME ${TARGET_NAME} + COMMAND python -u ${FILE_NAME} ${py_test_ARGS} + WORKING_DIRECTORY ${COMBINED_PATH}) + set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300) + endif() endfunction() function(import_test TARGET_NAME) diff --git a/parl/remote/client.py b/parl/remote/client.py index 379459c5768914a012cf89182724f1233cbf1329..11865f8ff2afe9368768c701a0c00d232a54c4e1 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -50,7 +50,6 @@ class Client(object): distributed_files (list): A list of files to be distributed at all remote instances(e,g. the configuration file for initialization) . - """ self.master_address = master_address self.process_id = process_id @@ -95,34 +94,39 @@ class Client(object): pyfiles['python_files'] = {} pyfiles['other_files'] = {} - code_files = filter(lambda x: x.endswith('.py'), os.listdir('./')) - - try: - for file in code_files: - assert os.path.exists(file) - with open(file, 'rb') as code_file: - code = code_file.read() - pyfiles['python_files'][file] = code - - for file in distributed_files: - assert os.path.exists(file) - assert not os.path.isabs( - file - ), "[XPARL] Please do not distribute a file with absolute path." - with open(file, 'rb') as f: - content = f.read() - pyfiles['other_files'][file] = content - # append entry file to code list - main_file = sys.argv[0] - with open(main_file, 'rb') as code_file: + main_file = sys.argv[0] + main_folder = './' + sep = os.sep + if sep in main_file: + main_folder = sep.join(main_file.split(sep)[:-1]) + code_files = filter(lambda x: x.endswith('.py'), + os.listdir(main_folder)) + + for file_name in code_files: + file_path = os.path.join(main_folder, file_name) + assert os.path.exists(file_path) + with open(file_path, 'rb') as code_file: code = code_file.read() - # parl/remote/remote_decorator.py -> remote_decorator.py - file_name = main_file.split(os.sep)[-1] pyfiles['python_files'][file_name] = code - except AssertionError as e: - raise Exception( - 'Failed to create the client, the file {} does not exist.'. - format(file)) + # append entry file to code list + assert os.path.isfile( + main_file + ), "[xparl] error occurs when distributing files. cannot find the entry file:{} in current working directory: {}".format( + main_file, os.getcwd()) + with open(main_file, 'rb') as code_file: + code = code_file.read() + # parl/remote/remote_decorator.py -> remote_decorator.py + file_name = main_file.split(os.sep)[-1] + pyfiles['python_files'][file_name] = code + + for file_name in distributed_files: + assert os.path.exists(file_name) + assert not os.path.isabs( + file_name + ), "[XPARL] Please do not distribute a file with absolute path." + with open(file_name, 'rb') as f: + content = f.read() + pyfiles['other_files'][file_name] = content return cloudpickle.dumps(pyfiles) def _create_sockets(self, master_address): diff --git a/parl/remote/job.py b/parl/remote/job.py index d835e5389aa447bb69567b61f6f1c60b9cf99d58..a84b852053005430357275951eb284882bed4612 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -311,8 +311,6 @@ class Job(object): try: file_name, class_name, end_of_file = cloudpickle.loads( message[1]) - #/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent - file_name = file_name.split(os.sep)[-1] cls = load_remote_class(file_name, class_name, end_of_file) args, kwargs = cloudpickle.loads(message[2]) logfile_path = os.path.join(self.log_dir, 'stdout.log') diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index a066abc40832fdce00fd00d1784aa75c60925e00..cff791dcfdf80a1a9c27ac0b9e3c706a8c5d1313 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -19,6 +19,7 @@ import time import zmq import numpy as np import inspect +import sys from parl.utils import get_ip_address, logger, to_str, to_byte from parl.utils.communication import loads_argument, loads_return,\ @@ -27,6 +28,7 @@ from parl.remote import remote_constants from parl.remote.exceptions import RemoteError, RemoteAttributeError,\ RemoteDeserializeError, RemoteSerializeError, ResourceError from parl.remote.client import get_global_client +from parl.remote.utils import locate_remote_file def remote_class(*args, **kwargs): @@ -120,13 +122,22 @@ def remote_class(*args, **kwargs): self.job_shutdown = False self.send_file(self.job_socket) - file_name = inspect.getfile(cls)[:-3] + module_path = inspect.getfile(cls) + if module_path.endswith('pyc'): + module_path = module_path[:-4] + elif module_path.endswith('py'): + module_path = module_path[:-3] + else: + raise FileNotFoundError( + "cannot not find the module:{}".format(module_path)) + res = inspect.getfile(cls) + file_path = locate_remote_file(module_path) cls_source = inspect.getsourcelines(cls) end_of_file = cls_source[1] + len(cls_source[0]) class_name = cls.__name__ self.job_socket.send_multipart([ remote_constants.INIT_OBJECT_TAG, - cloudpickle.dumps([file_name, class_name, end_of_file]), + cloudpickle.dumps([file_path, class_name, end_of_file]), cloudpickle.dumps([args, kwargs]), ]) message = self.job_socket.recv_multipart() diff --git a/parl/remote/tests/log_server_test.py b/parl/remote/tests/log_server_test.py index b41fbbd444e2f9602107d2d7b4634c227ad655c8..02815ea1a50387462b6072c47e4e1bc1efb130a8 100644 --- a/parl/remote/tests/log_server_test.py +++ b/parl/remote/tests/log_server_test.py @@ -24,6 +24,7 @@ import time import unittest import requests +requests.adapters.DEFAULT_RETRIES = 5 import parl from parl.remote.client import disconnect, get_global_client @@ -128,7 +129,6 @@ class TestLogServer(unittest.TestCase): monitor_file = __file__.replace('log_server_test.pyc', '../monitor.py') monitor_file = monitor_file.replace('log_server_test.py', '../monitor.py') - command = [ sys.executable, monitor_file, "--monitor_port", str(monitor_port), "--address", "localhost:" + str(master_port) diff --git a/parl/remote/tests/test_import_module/Module2.py b/parl/remote/tests/test_import_module/Module2.py new file mode 100644 index 0000000000000000000000000000000000000000..b04fc66a5ad2006df414480aa32bee364ecba375 --- /dev/null +++ b/parl/remote/tests/test_import_module/Module2.py @@ -0,0 +1,20 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import parl + + +@parl.remote_class +class B(object): + def add_sum(self, a, b): + return a + b diff --git a/parl/remote/tests/test_import_module/main_abs_test.py b/parl/remote/tests/test_import_module/main_abs_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c77dd3a516f136e1031a68d835477dfa3bd40712 --- /dev/null +++ b/parl/remote/tests/test_import_module/main_abs_test.py @@ -0,0 +1,46 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import parl +import time +import threading +from parl.remote.master import Master +from parl.remote.worker import Worker +from parl.remote.client import disconnect + + +class TestImport(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_import_local_module(self): + from Module2 import B + port = 8448 + master = Master(port=port) + th = threading.Thread(target=master.run) + th.start() + time.sleep(1) + worker = Worker('localhost:{}'.format(port), 1) + time.sleep(10) + parl.connect("localhost:8448") + obj = B() + res = obj.add_sum(10, 5) + self.assertEqual(res, 15) + worker.exit() + master.exit() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/tests/test_import_module/main_test.py b/parl/remote/tests/test_import_module/main_test.py new file mode 100644 index 0000000000000000000000000000000000000000..128c47c8ca5a406670c34486dd4fb7bd6d6b63a8 --- /dev/null +++ b/parl/remote/tests/test_import_module/main_test.py @@ -0,0 +1,82 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import parl +import time +import threading +from parl.remote.master import Master +from parl.remote.worker import Worker +from parl.remote.client import disconnect + + +class TestImport(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_import_local_module(self): + from Module2 import B + port = 8442 + master = Master(port=port) + th = threading.Thread(target=master.run) + th.start() + time.sleep(1) + worker = Worker('localhost:{}'.format(port), 1) + time.sleep(10) + parl.connect("localhost:8442") + obj = B() + res = obj.add_sum(10, 5) + self.assertEqual(res, 15) + worker.exit() + master.exit() + + def test_import_subdir_module_0(self): + from subdir import Module + port = 8443 + master = Master(port=port) + th = threading.Thread(target=master.run) + th.start() + time.sleep(1) + worker = Worker('localhost:{}'.format(port), 1) + time.sleep(10) + parl.connect( + "localhost:8443", + distributed_files=['./subdir/Module.py', './subdir/__init__.py']) + obj = Module.A() + res = obj.add_sum(10, 5) + self.assertEqual(res, 15) + worker.exit() + master.exit() + + def test_import_subdir_module_1(self): + from subdir.Module import A + port = 8444 + master = Master(port=port) + th = threading.Thread(target=master.run) + th.start() + time.sleep(1) + worker = Worker('localhost:{}'.format(port), 1) + time.sleep(10) + parl.connect( + "localhost:8444", + distributed_files=['./subdir/Module.py', './subdir/__init__.py']) + obj = A() + res = obj.add_sum(10, 5) + self.assertEqual(res, 15) + worker.exit() + master.exit() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/tests/test_import_module/subdir/Module.py b/parl/remote/tests/test_import_module/subdir/Module.py new file mode 100644 index 0000000000000000000000000000000000000000..c06ba3bfe46d28476ab2d6eb94d0f724cab63851 --- /dev/null +++ b/parl/remote/tests/test_import_module/subdir/Module.py @@ -0,0 +1,20 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import parl + + +@parl.remote_class +class A(object): + def add_sum(self, a, b): + return a + b diff --git a/parl/remote/tests/test_import_module/subdir/__init__.py b/parl/remote/tests/test_import_module/subdir/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..847ddc47ac89114f2012bc6b9990a69abfe39fb3 --- /dev/null +++ b/parl/remote/tests/test_import_module/subdir/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/parl/remote/utils.py b/parl/remote/utils.py index 9a2ece8686ff7de73c8164565f34281e412aa4ee..63c94c1a8022256bec382a348a11466c93d0ecc8 100644 --- a/parl/remote/utils.py +++ b/parl/remote/utils.py @@ -13,8 +13,11 @@ # limitations under the License. import sys from contextlib import contextmanager +import os -__all__ = ['load_remote_class', 'redirect_stdout_to_file'] +__all__ = [ + 'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file' +] def simplify_code(code, end_of_file): @@ -32,7 +35,7 @@ def simplify_code(code, end_of_file): def data_process(): XXXX ------------------> - The last two lines of the above code block will be removed as they are not class related. + The last two lines of the above code block will be removed as they are not class-related. """ to_write_lines = [] for i, line in enumerate(code): @@ -60,12 +63,18 @@ def load_remote_class(file_name, class_name, end_of_file): with open(file_name + '.py') as t_file: code = t_file.readlines() code = simplify_code(code, end_of_file) - module_name = 'xparl_' + file_name - tmp_file_name = 'xparl_' + file_name + '.py' + #folder/xx.py -> folder/xparl_xx.py + file_name = file_name.split(os.sep) + prefix = os.sep.join(file_name[:-1]) + if prefix == "": + prefix = '.' + module_name = prefix + os.sep + 'xparl_' + file_name[-1] + tmp_file_name = module_name + '.py' with open(tmp_file_name, 'w') as t_file: for line in code: t_file.write(line) - mod = __import__(module_name) + module_name = module_name.lstrip('.' + os.sep).replace(os.sep, '.') + mod = __import__(module_name, globals(), locals(), [class_name], 0) cls = getattr(mod, class_name) return cls @@ -74,6 +83,9 @@ def load_remote_class(file_name, class_name, end_of_file): def redirect_stdout_to_file(file_path): """Redirect stdout (e.g., `print`) to specified file. + Args: + file_path: Path of the file to output the stdout. + Example: >>> print('test') test @@ -81,10 +93,6 @@ def redirect_stdout_to_file(file_path): ... print('test') # Output nothing, `test` is printed to `test.log`. >>> print('test') test - - Args: - file_path: Path of the file to output the stdout. - """ tmp = sys.stdout f = open(file_path, 'a') @@ -94,3 +102,33 @@ def redirect_stdout_to_file(file_path): finally: sys.stdout = tmp f.close() + + +def locate_remote_file(module_path): + """xparl has to locate the file that has the class decorated by parl.remote_class. + This function returns the relative path between this file and the entry file. + + Args: + module_path: Absolute path of the module. + + Example: + module_path: /home/user/dir/subdir/my_module + entry_file: /home/user/dir/main.py + --------> relative_path: subdir/my_module + """ + entry_file = sys.argv[0] + entry_file = entry_file.split(os.sep)[-1] + entry_path = None + for path in sys.path: + to_check_path = os.path.join(path, entry_file) + if os.path.isfile(to_check_path): + entry_path = path + break + if entry_path is None or \ + (module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]): + raise FileNotFoundError("cannot locate the remote file") + if module_path.startswith(os.sep): + relative_module_path = '.' + module_path[len(entry_path):] + else: + relative_module_path = module_path + return relative_module_path