diff --git a/parl/remote/client.py b/parl/remote/client.py index 7e095ff25f3723e070081c73f80cf80c4cec5a2f..fd17adb75babfa6ed86e82d460fd3025ad538cc5 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -105,6 +105,9 @@ class Client(object): for file in distributed_files: assert os.path.exists(file) + assert not os.path.isabs( + file + ), "[XPARL] Please do not distribute a file with absolute path." with open(file, 'rb') as f: content = f.read() pyfiles['other_files'][file] = content diff --git a/parl/remote/job.py b/parl/remote/job.py index d2f2d54ac231e3a05ceff65371751abfbe37f392..17a280fb4e2cad7e87a01f18c26bc680198b09fb 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -36,6 +36,7 @@ from parl.utils.communication import loads_argument, loads_return,\ from parl.remote import remote_constants from parl.utils.exceptions import SerializeError, DeserializeError from parl.remote.message import InitializedJob +from parl.remote.utils import load_remote_class class Job(object): @@ -268,12 +269,15 @@ class Job(object): # create directory (i.e. ./rom_files/) if '/' in file: try: - os.makedirs(os.path.join(*file.rsplit('/')[:-1])) + sep = os.sep + recursive_dirs = os.path.join(*(file.split(sep)[:-1])) + recursive_dirs = os.path.join(envdir, recursive_dirs) + os.makedirs(recursive_dirs) except OSError as e: pass + file = os.path.join(envdir, file) with open(file, 'wb') as f: f.write(content) - logger.info('[job] reply') reply_socket.send_multipart([remote_constants.NORMAL_TAG]) return envdir else: @@ -301,12 +305,12 @@ class Job(object): if tag == remote_constants.INIT_OBJECT_TAG: try: - file_name, class_name = cloudpickle.loads(message[1]) + file_name, class_name, end_of_file = cloudpickle.loads( + message[1]) #/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent file_name = file_name.split(os.sep)[-1] + cls = load_remote_class(file_name, class_name, end_of_file) args, kwargs = cloudpickle.loads(message[2]) - mod = __import__(file_name) - cls = getattr(mod, class_name)._original obj = cls(*args, **kwargs) except Exception as e: traceback_str = str(traceback.format_exc()) @@ -349,13 +353,14 @@ class Job(object): # receive source code from the actor and append them to the environment variables. envdir = self.wait_for_files(reply_socket, job_address) sys.path.append(envdir) + os.chdir(envdir) obj = self.wait_for_connection(reply_socket) assert obj is not None self.single_task(obj, reply_socket, job_address) except Exception as e: logger.error( - "Error occurs when running a single task. We will reset this job. Reason:{}" + "Error occurs when running a single task. We will reset this job. \nReason:{}" .format(e)) traceback_str = str(traceback.format_exc()) logger.error("traceback:\n{}".format(traceback_str)) diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index 9b31d8e370db42e291ce1ab08e7ef7542b447389..f4a498bf0169d1322f5a76c6b8b0978c8f61546e 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -75,6 +75,12 @@ def remote_class(*args, **kwargs): """ def decorator(cls): + # we are not going to create a remote actor in job.py + if 'XPARL' in os.environ and os.environ['XPARL'] == 'True': + logger.warning( + "Note: this object will be runnning as a local object") + return cls + class RemoteWrapper(object): """ Wrapper for remote class in client side. @@ -115,10 +121,12 @@ def remote_class(*args, **kwargs): self.send_file(self.job_socket) file_name = inspect.getfile(cls)[:-3] + cls_source = inspect.getsourcelines(cls) + end_of_file = cls_source[1] + len(cls_source[0]) class_name = cls.__name__ self.job_socket.send_multipart([ remote_constants.INIT_OBJECT_TAG, - cloudpickle.dumps([file_name, class_name]), + cloudpickle.dumps([file_name, class_name, end_of_file]), cloudpickle.dumps([args, kwargs]), ]) message = self.job_socket.recv_multipart() @@ -130,7 +138,10 @@ def remote_class(*args, **kwargs): def __del__(self): """Delete the remote class object and release remote resources.""" - self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000) + try: + self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000) + except AttributeError: + pass if not self.job_shutdown: try: self.job_socket.send_multipart( diff --git a/parl/remote/tests/local_actor_test.py b/parl/remote/tests/local_actor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0435ed233153ec9efee548e012eb70ead11e2dd5 --- /dev/null +++ b/parl/remote/tests/local_actor_test.py @@ -0,0 +1,38 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +os.environ['XPARL'] = 'True' +import parl +import unittest + + +@parl.remote_class(max_memory=350) +class Actor(object): + def __init__(self, x=10): + self.x = x + self.data = [] + + def add_500mb(self): + self.data.append(os.urandom(500 * 1024**2)) + self.x += 1 + return self.x + + +class TestLocalActor(unittest.TestCase): + def test_create_actors_without_pre_connection(self): + actor = Actor() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/tests/recursive_actor_test.py b/parl/remote/tests/recursive_actor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9613b6be23ae64bf2fba10df793cfa57dbeea1 --- /dev/null +++ b/parl/remote/tests/recursive_actor_test.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from parl.utils import logger +import parl +from parl.remote.client import disconnect +from parl.remote.master import Master +from parl.remote.worker import Worker +import time +import threading + +c = 10 +port = 3002 +if __name__ == '__main__': + master = Master(port=port) + th = threading.Thread(target=master.run) + th.setDaemon(True) + th.start() +time.sleep(5) +cluster_addr = 'localhost:{}'.format(port) +parl.connect(cluster_addr) +worker = Worker(cluster_addr, 1) + + +@parl.remote_class +class Actor(object): + def add(self, a, b): + return a + b + c + + +actor = Actor() + + +class TestRecursive_actor(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_global_running(self): + self.assertEqual(actor.add(1, 2), 13) + master.exit() + worker.exit() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/tests/sync_config_file_test.py b/parl/remote/tests/sync_config_file_test.py index a4d131d5e13111a1c7faaa209aa2acb114e7c7c7..c8be19443e446e1d90819a63c2a64b471fb23e6d 100644 --- a/parl/remote/tests/sync_config_file_test.py +++ b/parl/remote/tests/sync_config_file_test.py @@ -17,12 +17,10 @@ import parl from parl.remote.master import Master from parl.remote.worker import Worker from parl.remote.client import disconnect - +import os import time import threading - import sys - import numpy as np import json @@ -65,7 +63,8 @@ class TestConfigfile(unittest.TestCase): parl.connect('localhost:1335', ['random.npy', 'config.json']) actor = Actor('random.npy', 'config.json') time.sleep(5) - + os.remove('./random.npy') + os.remove('./config.json') remote_sum = actor.random_sum() self.assertEqual(remote_sum, random_sum) time.sleep(10) diff --git a/parl/remote/utils.py b/parl/remote/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd36e5ca61c18c08d2fa03a40f188c066c78918 --- /dev/null +++ b/parl/remote/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['load_remote_class'] + + +def simplify_code(code, end_of_file): + """ + @parl.remote_actor has to use this function to simplify the code. + To create a remote object, PARL has to import the module that contains the decorated class. + It may run some unnecessary code when importing the module, and this is why we use this function + to simplify the code. + + For example. + @parl.remote_actor + class A(object): + def add(self, a, b): + return a + b + def data_process(): + XXXX + ------------------> + The last two lines of the above code block will be removed as they are not class related. + """ + to_write_lines = [] + for i, line in enumerate(code): + if line.startswith('parl.connect'): + continue + if i < end_of_file - 1: + to_write_lines.append(line) + else: + break + return to_write_lines + + +def load_remote_class(file_name, class_name, end_of_file): + """ + load a class given its file_name and class_name. + + Args: + file_name: specify the file to load the class + class_name: specify the class to be loaded + end_of_file: line ID to indicate the last line that defines the class. + + Return: + cls: the class to load + """ + with open(file_name + '.py') as t_file: + code = t_file.readlines() + code = simplify_code(code, end_of_file) + module_name = 'xparl_' + file_name + tmp_file_name = 'xparl_' + file_name + '.py' + with open(tmp_file_name, 'w') as t_file: + for line in code: + t_file.write(line) + mod = __import__(module_name) + cls = getattr(mod, class_name) + return cls