From d6e82f0145b6a0e926280cd6aa2ae2cb20eb9b9d Mon Sep 17 00:00:00 2001 From: liuyuecheng-github <52879090+liuyuecheng-github@users.noreply.github.com> Date: Fri, 14 Aug 2020 11:39:58 +0800 Subject: [PATCH] add the function to get and set attributes of remote models (#381) * add support of RegExp input of distributed_files[] * add support of RegExp input of distributed_files[] * add the function to get and set attributes of remote models * get_set_attributes, not finished * get_set_atributes, not finieshed yet * add function to get and set attributes of remote model * add function to get and set attributes of remote models * add more unnitest cases, together with several comments * make codes reusable --- parl/core/fluid/model.py | 2 +- parl/remote/job.py | 67 ++++++-- parl/remote/remote_constants.py | 3 + parl/remote/remote_decorator.py | 92 +++++++++-- parl/remote/tests/get_set_attribute_test.py | 162 ++++++++++++++++++++ parl/remote/worker.py | 1 - 6 files changed, 296 insertions(+), 31 deletions(-) create mode 100644 parl/remote/tests/get_set_attribute_test.py diff --git a/parl/core/fluid/model.py b/parl/core/fluid/model.py index bf7069a..80f7486 100644 --- a/parl/core/fluid/model.py +++ b/parl/core/fluid/model.py @@ -53,7 +53,7 @@ class Model(ModelBase): copied_policy = copy.deepcopy(model) Attributes: - model_id(str): each model instance has its uniqe model_id. + model_id(str): each model instance has its unique model_id. Public Functions: - ``sync_weights_to``: synchronize parameters of the current model to another model. diff --git a/parl/remote/job.py b/parl/remote/job.py index a84b852..b13c868 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -395,24 +395,59 @@ class Job(object): while True: message = reply_socket.recv_multipart() - tag = message[0] - - if tag == remote_constants.CALL_TAG: + if tag in [ + remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE, + remote_constants.SET_ATTRIBUTE, + remote_constants.CHECK_ATTRIBUTE + ]: try: - function_name = to_str(message[1]) - data = message[2] - args, kwargs = loads_argument(data) - - # Redirect stdout to stdout.log temporarily - logfile_path = os.path.join(self.log_dir, 'stdout.log') - with redirect_stdout_to_file(logfile_path): - ret = getattr(obj, function_name)(*args, **kwargs) - - ret = dumps_return(ret) - - reply_socket.send_multipart( - [remote_constants.NORMAL_TAG, ret]) + if tag == remote_constants.CHECK_ATTRIBUTE: + attr = to_str(message[1]) + if attr in obj.__dict__: + reply_socket.send_multipart([ + remote_constants.NORMAL_TAG, + dumps_return(True) + ]) + else: + reply_socket.send_multipart([ + remote_constants.NORMAL_TAG, + dumps_return(False) + ]) + + elif tag == remote_constants.CALL_TAG: + function_name = to_str(message[1]) + data = message[2] + args, kwargs = loads_argument(data) + + # Redirect stdout to stdout.log temporarily + logfile_path = os.path.join(self.log_dir, 'stdout.log') + with redirect_stdout_to_file(logfile_path): + ret = getattr(obj, function_name)(*args, **kwargs) + + ret = dumps_return(ret) + + reply_socket.send_multipart( + [remote_constants.NORMAL_TAG, ret]) + + elif tag == remote_constants.GET_ATTRIBUTE: + attribute_name = to_str(message[1]) + logfile_path = os.path.join(self.log_dir, 'stdout.log') + with redirect_stdout_to_file(logfile_path): + ret = getattr(obj, attribute_name) + ret = dumps_return(ret) + reply_socket.send_multipart( + [remote_constants.NORMAL_TAG, ret]) + elif tag == remote_constants.SET_ATTRIBUTE: + attribute_name = to_str(message[1]) + attribute_value = loads_return(message[2]) + logfile_path = os.path.join(self.log_dir, 'stdout.log') + with redirect_stdout_to_file(logfile_path): + setattr(obj, attribute_name, attribute_value) + reply_socket.send_multipart( + [remote_constants.NORMAL_TAG]) + else: + pass except Exception as e: # reset the job diff --git a/parl/remote/remote_constants.py b/parl/remote/remote_constants.py index 8f49da5..db7a86e 100644 --- a/parl/remote/remote_constants.py +++ b/parl/remote/remote_constants.py @@ -29,6 +29,9 @@ NEW_JOB_TAG = b'[NEW_JOB]' INIT_OBJECT_TAG = b'[INIT_OBJECT]' CALL_TAG = b'[CALL]' +GET_ATTRIBUTE = b'[GET_ATTRIBUTE]' +SET_ATTRIBUTE = b'[SET_ATTRIBUTE]' +CHECK_ATTRIBUTE = b'[CHECK_ATTRIBUTE]' EXCEPTION_TAG = b'[EXCEPTION]' ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]' diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index cff791d..47909b1 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -190,25 +190,66 @@ def remote_class(*args, **kwargs): cnt -= 1 return None - def __getattr__(self, attr): + def check_attribute(self, attr): + '''checkout if attr is a attribute or a function''' + self.internal_lock.acquire() + self.job_socket.send_multipart( + [remote_constants.CHECK_ATTRIBUTE, + to_byte(attr)]) + message = self.job_socket.recv_multipart() + self.internal_lock.release() + tag = message[0] + if tag == remote_constants.NORMAL_TAG: + return loads_return(message[1]) + else: + self.job_shutdown = True + raise NotImplementedError() + + def set_remote_attr(self, attr, value): + self.internal_lock.acquire() + self.job_socket.send_multipart([ + remote_constants.SET_ATTRIBUTE, + to_byte(attr), + dumps_return(value) + ]) + message = self.job_socket.recv_multipart() + tag = message[0] + self.internal_lock.release() + if tag == remote_constants.NORMAL_TAG: + pass + else: + self.job_shutdown = True + raise NotImplementedError() + return + + def get_remote_attr(self, attr): """Call the function of the unwrapped class.""" + #check if attr is a attribute or a function + is_attribute = self.check_attribute(attr) def wrapper(*args, **kwargs): - if self.job_shutdown: - raise RemoteError( - attr, "This actor losts connection with the job.") self.internal_lock.acquire() - data = dumps_argument(*args, **kwargs) - - self.job_socket.send_multipart( - [remote_constants.CALL_TAG, - to_byte(attr), data]) + if is_attribute: + self.job_socket.send_multipart( + [remote_constants.GET_ATTRIBUTE, + to_byte(attr)]) + else: + if self.job_shutdown: + raise RemoteError( + attr, + "This actor losts connection with the job.") + data = dumps_argument(*args, **kwargs) + self.job_socket.send_multipart( + [remote_constants.CALL_TAG, + to_byte(attr), data]) message = self.job_socket.recv_multipart() tag = message[0] if tag == remote_constants.NORMAL_TAG: ret = loads_return(message[1]) + self.internal_lock.release() + return ret elif tag == remote_constants.EXCEPTION_TAG: error_str = to_str(message[1]) @@ -234,13 +275,38 @@ def remote_class(*args, **kwargs): self.job_shutdown = True raise NotImplementedError() - self.internal_lock.release() - return ret + return wrapper() if is_attribute else wrapper + + def proxy_wrapper_func(remote_wrapper): + ''' + The 'proxy_wrapper_func' is defined on the top of class 'RemoteWrapper' + in order to set and get attributes of 'remoted_wrapper' and the corresponding + remote models individually. + + With 'proxy_wrapper_func', it is allowed to define a attribute (or method) of + the same name in 'RemoteWrapper' and remote models. + ''' + + class ProxyWrapper(object): + def __init__(self, *args, **kwargs): + self.xparl_remote_wrapper_obj = remote_wrapper( + *args, **kwargs) + + def __getattr__(self, attr): + return self.xparl_remote_wrapper_obj.get_remote_attr(attr) + + def __setattr__(self, attr, value): + if attr == 'xparl_remote_wrapper_obj': + super(ProxyWrapper, self).__setattr__(attr, value) + else: + self.xparl_remote_wrapper_obj.set_remote_attr( + attr, value) - return wrapper + return ProxyWrapper RemoteWrapper._original = cls - return RemoteWrapper + proxy_wrapper = proxy_wrapper_func(RemoteWrapper) + return proxy_wrapper max_memory = kwargs.get('max_memory') if len(args) == 1 and callable(args[0]): diff --git a/parl/remote/tests/get_set_attribute_test.py b/parl/remote/tests/get_set_attribute_test.py new file mode 100644 index 0000000..a233822 --- /dev/null +++ b/parl/remote/tests/get_set_attribute_test.py @@ -0,0 +1,162 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import parl +import numpy as np +from parl.remote.client import disconnect +from parl.utils import logger +from parl.remote.master import Master +from parl.remote.worker import Worker +import time +import threading +import random + + +@parl.remote_class +class Actor(object): + def __init__(self, arg1, arg2, arg3, arg4): + self.arg1 = arg1 + self.arg2 = arg2 + self.arg3 = arg3 + self.GLOBAL_CLIENT = arg4 + + def arg1(self, x, y): + time.sleep(0.2) + return x + y + + def arg5(self): + return 100 + + +class Test_get_and_set_attribute(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_get_attribute(self): + port1 = random.randint(6100, 6200) + logger.info("running:test_get_attirbute") + master = Master(port=port1) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + worker1 = Worker('localhost:{}'.format(port1), 1) + arg1 = np.random.randint(100) + arg2 = np.random.randn() + arg3 = np.random.randn(3, 3) + arg4 = 100 + parl.connect('localhost:{}'.format(port1)) + actor = Actor(arg1, arg2, arg3, arg4) + self.assertTrue(arg1 == actor.arg1) + self.assertTrue(arg2 == actor.arg2) + self.assertTrue((arg3 == actor.arg3).all()) + self.assertTrue(arg4 == actor.GLOBAL_CLIENT) + master.exit() + worker1.exit() + + def test_set_attribute(self): + port2 = random.randint(6200, 6300) + logger.info("running:test_set_attirbute") + master = Master(port=port2) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + worker1 = Worker('localhost:{}'.format(port2), 1) + arg1 = 3 + arg2 = 3.5 + arg3 = np.random.randn(3, 3) + arg4 = 100 + parl.connect('localhost:{}'.format(port2)) + actor = Actor(arg1, arg2, arg3, arg4) + actor.arg1 = arg1 + actor.arg2 = arg2 + actor.arg3 = arg3 + actor.GLOBAL_CLIENT = arg4 + self.assertTrue(arg1 == actor.arg1) + self.assertTrue(arg2 == actor.arg2) + self.assertTrue((arg3 == actor.arg3).all()) + self.assertTrue(arg4 == actor.GLOBAL_CLIENT) + master.exit() + worker1.exit() + + def test_create_new_attribute_same_with_wrapper(self): + port3 = random.randint(6400, 6500) + logger.info("running:test_create_new_attribute_same_with_wrapper") + master = Master(port=port3) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + worker1 = Worker('localhost:{}'.format(port3), 1) + arg1 = np.random.randint(100) + arg2 = np.random.randn() + arg3 = np.random.randn(3, 3) + arg4 = 100 + parl.connect('localhost:{}'.format(port3)) + actor = Actor(arg1, arg2, arg3, arg4) + + actor.internal_lock = 50 + self.assertTrue(actor.internal_lock == 50) + master.exit() + worker1.exit() + + def test_same_name_of_attribute_and_method(self): + port4 = random.randint(6500, 6600) + logger.info("running:test_same_name_of_attribute_and_method") + master = Master(port=port4) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + worker1 = Worker('localhost:{}'.format(port4), 1) + arg1 = np.random.randint(100) + arg2 = np.random.randn() + arg3 = np.random.randn(3, 3) + arg4 = 100 + parl.connect('localhost:{}'.format(port4)) + actor = Actor(arg1, arg2, arg3, arg4) + self.assertEqual(arg1, actor.arg1) + + def call_method(): + return actor.arg1(1, 2) + + self.assertRaises(TypeError, call_method) + master.exit() + worker1.exit() + + def test_non_existing_attribute_same_with_existing_method(self): + port5 = random.randint(6600, 6700) + logger.info( + "running:test_non_existing_attribute_same_with_existing_method") + master = Master(port=port5) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + worker1 = Worker('localhost:{}'.format(port5), 1) + arg1 = np.random.randint(100) + arg2 = np.random.randn() + arg3 = np.random.randn(3, 3) + arg4 = 100 + parl.connect('localhost:{}'.format(port5)) + actor = Actor(arg1, arg2, arg3, arg4) + self.assertTrue(callable(actor.arg5)) + + def call_non_existing_method(): + return actor.arg2(10) + + self.assertRaises(TypeError, call_non_existing_method) + master.exit() + worker1.exit() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/worker.py b/parl/remote/worker.py index eec5598..d1a0333 100644 --- a/parl/remote/worker.py +++ b/parl/remote/worker.py @@ -72,7 +72,6 @@ class Worker(object): self.master_is_alive = True self.worker_is_alive = True self.worker_status = None # initialized at `self._create_jobs` - self.lock = threading.Lock() self._set_cpu_num(cpu_num) self.job_buffer = queue.Queue(maxsize=self.cpu_num) self._create_sockets() -- GitLab