remote_decorator.py 12.5 KB
Newer Older
H
Hongsheng Zeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

F
fuyw 已提交
15 16
import cloudpickle
import os
H
Hongsheng Zeng 已提交
17 18 19
import threading
import time
import zmq
F
fuyw 已提交
20
import numpy as np
21
import inspect
B
Bo Zhou 已提交
22
import sys
F
fuyw 已提交
23

H
Hongsheng Zeng 已提交
24
from parl.utils import get_ip_address, logger, to_str, to_byte
F
fuyw 已提交
25 26 27 28 29 30
from parl.utils.communication import loads_argument, loads_return,\
    dumps_argument, dumps_return
from parl.remote import remote_constants
from parl.remote.exceptions import RemoteError, RemoteAttributeError,\
    RemoteDeserializeError, RemoteSerializeError, ResourceError
from parl.remote.client import get_global_client
B
Bo Zhou 已提交
31
from parl.remote.utils import locate_remote_file
H
Hongsheng Zeng 已提交
32 33


34
def remote_class(*args, **kwargs):
F
fuyw 已提交
35 36
    """A Python decorator that enables a class to run all its functions
    remotely.
H
Hongsheng Zeng 已提交
37

F
fuyw 已提交
38 39 40 41 42 43
    Each instance of the remote class can be seemed as a task submitted
    to the cluster by the global client, which is created automatically
    when we call parl.connect(master_address). After global client
    submits the task, the master node will send an available job address
    to this remote instance. Then the remote object will send local python
    files, class definition and initialization arguments to the related job.
H
Hongsheng Zeng 已提交
44

F
fuyw 已提交
45
    In this way, we can run distributed applications easily and efficiently.
H
Hongsheng Zeng 已提交
46

F
fuyw 已提交
47
    .. code-block:: python
H
Hongsheng Zeng 已提交
48

49
        @parl.remote_class
F
fuyw 已提交
50 51 52
        class Actor(object):
            def __init__(self, x):
                self.x = x
H
Hongsheng Zeng 已提交
53

F
fuyw 已提交
54 55 56
            def step(self):
                self.x += 1
                return self.x
H
Hongsheng Zeng 已提交
57

F
fuyw 已提交
58 59
        actor = Actor()
        actor.step()
H
Hongsheng Zeng 已提交
60

61
        # Set maximum memory usage to 300 MB for each object.
62 63 64 65 66 67 68 69 70
        @parl.remote_class(max_memory=300)
        class LimitedActor(object):
           ...

    Args:
        max_memory (float): Maximum memory (MB) can be used by each remote
                            instance, the unit is in MB and default value is
                            none(unlimited).

F
fuyw 已提交
71 72
    Returns:
        A remote wrapper for the remote class.
H
Hongsheng Zeng 已提交
73

F
fuyw 已提交
74 75 76 77
    Raises:
        Exception: An exception is raised if the client is not created
                   by `parl.connect(master_address)` beforehand.
    """
H
Hongsheng Zeng 已提交
78

79
    def decorator(cls):
80 81 82 83 84 85
        # we are not going to create a remote actor in job.py
        if 'XPARL' in os.environ and os.environ['XPARL'] == 'True':
            logger.warning(
                "Note: this object will be runnning as a local object")
            return cls

86
        class RemoteWrapper(object):
F
fuyw 已提交
87
            """
88
            Wrapper for remote class in client side.
H
Hongsheng Zeng 已提交
89 90
            """

91 92 93 94 95 96 97 98 99 100 101 102 103
            def __init__(self, *args, **kwargs):
                """
                Args:
                    args, kwargs: arguments for the initialization of the unwrapped
                    class.
                """
                self.GLOBAL_CLIENT = get_global_client()

                self.ctx = self.GLOBAL_CLIENT.ctx

                # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
                # finds the master is dead.
                if self.GLOBAL_CLIENT.master_is_alive:
104 105
                    job_address = self.request_cpu_resource(
                        self.GLOBAL_CLIENT, max_memory)
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
                else:
                    raise Exception("Can not submit job to the master. "
                                    "Please check if master is still alive.")

                if job_address is None:
                    raise ResourceError("Cannot submit the job to the master. "
                                        "Please add more CPU resources to the "
                                        "master or try again later.")

                self.internal_lock = threading.Lock()

                # Send actor commands like `init` and `call` to the job.
                self.job_socket = self.ctx.socket(zmq.REQ)
                self.job_socket.linger = 0
                self.job_socket.connect("tcp://{}".format(job_address))
                self.job_address = job_address
                self.job_shutdown = False

                self.send_file(self.job_socket)
B
Bo Zhou 已提交
125 126 127 128 129 130 131 132 133 134
                module_path = inspect.getfile(cls)
                if module_path.endswith('pyc'):
                    module_path = module_path[:-4]
                elif module_path.endswith('py'):
                    module_path = module_path[:-3]
                else:
                    raise FileNotFoundError(
                        "cannot not find the module:{}".format(module_path))
                res = inspect.getfile(cls)
                file_path = locate_remote_file(module_path)
135 136
                cls_source = inspect.getsourcelines(cls)
                end_of_file = cls_source[1] + len(cls_source[0])
137
                class_name = cls.__name__
138 139
                self.job_socket.send_multipart([
                    remote_constants.INIT_OBJECT_TAG,
B
Bo Zhou 已提交
140
                    cloudpickle.dumps([file_path, class_name, end_of_file]),
141 142
                    cloudpickle.dumps([args, kwargs]),
                ])
F
fuyw 已提交
143 144
                message = self.job_socket.recv_multipart()
                tag = message[0]
145 146 147 148 149 150 151
                if tag == remote_constants.EXCEPTION_TAG:
                    traceback_str = to_str(message[1])
                    self.job_shutdown = True
                    raise RemoteError('__init__', traceback_str)

            def __del__(self):
                """Delete the remote class object and release remote resources."""
152 153 154 155
                try:
                    self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000)
                except AttributeError:
                    pass
156 157 158 159 160 161 162 163 164 165
                if not self.job_shutdown:
                    try:
                        self.job_socket.send_multipart(
                            [remote_constants.KILLJOB_TAG])
                        _ = self.job_socket.recv_multipart()
                        self.job_socket.close(0)
                    except AttributeError:
                        pass
                    except zmq.error.ZMQError:
                        pass
166 167
                    except TypeError:
                        pass
168 169 170 171 172 173 174 175 176 177 178

            def send_file(self, socket):
                try:
                    socket.send_multipart([
                        remote_constants.SEND_FILE_TAG,
                        self.GLOBAL_CLIENT.pyfiles
                    ])
                    _ = socket.recv_multipart()
                except zmq.error.Again as e:
                    logger.error("Send python files failed.")

179
            def request_cpu_resource(self, global_client, max_memory):
180 181 182
                """Try to request cpu resource for 1 second/time for 300 times."""
                cnt = 300
                while cnt > 0:
183
                    job_address = global_client.submit_job(max_memory)
184 185 186 187 188 189 190 191 192
                    if job_address is not None:
                        return job_address
                    if cnt % 30 == 0:
                        logger.warning(
                            "No vacant cpu resources at the moment, "
                            "will try {} times later.".format(cnt))
                    cnt -= 1
                return None

193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
            def check_attribute(self, attr):
                '''checkout if attr is a attribute or a function'''
                self.internal_lock.acquire()
                self.job_socket.send_multipart(
                    [remote_constants.CHECK_ATTRIBUTE,
                     to_byte(attr)])
                message = self.job_socket.recv_multipart()
                self.internal_lock.release()
                tag = message[0]
                if tag == remote_constants.NORMAL_TAG:
                    return loads_return(message[1])
                else:
                    self.job_shutdown = True
                    raise NotImplementedError()

            def set_remote_attr(self, attr, value):
                self.internal_lock.acquire()
                self.job_socket.send_multipart([
                    remote_constants.SET_ATTRIBUTE,
                    to_byte(attr),
                    dumps_return(value)
                ])
                message = self.job_socket.recv_multipart()
                tag = message[0]
                self.internal_lock.release()
                if tag == remote_constants.NORMAL_TAG:
                    pass
                else:
                    self.job_shutdown = True
                    raise NotImplementedError()
                return

            def get_remote_attr(self, attr):
226
                """Call the function of the unwrapped class."""
227 228
                #check if attr is a attribute or a function
                is_attribute = self.check_attribute(attr)
229 230 231

                def wrapper(*args, **kwargs):
                    self.internal_lock.acquire()
232 233 234 235 236 237 238 239 240 241 242 243 244
                    if is_attribute:
                        self.job_socket.send_multipart(
                            [remote_constants.GET_ATTRIBUTE,
                             to_byte(attr)])
                    else:
                        if self.job_shutdown:
                            raise RemoteError(
                                attr,
                                "This actor losts connection with the job.")
                        data = dumps_argument(*args, **kwargs)
                        self.job_socket.send_multipart(
                            [remote_constants.CALL_TAG,
                             to_byte(attr), data])
H
Hongsheng Zeng 已提交
245

246 247
                    message = self.job_socket.recv_multipart()
                    tag = message[0]
H
Hongsheng Zeng 已提交
248

249 250
                    if tag == remote_constants.NORMAL_TAG:
                        ret = loads_return(message[1])
251 252
                        self.internal_lock.release()
                        return ret
H
Hongsheng Zeng 已提交
253

254 255 256 257
                    elif tag == remote_constants.EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteError(attr, error_str)
H
Hongsheng Zeng 已提交
258

259 260 261 262
                    elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteAttributeError(attr, error_str)
H
Hongsheng Zeng 已提交
263

264 265 266 267 268 269 270 271 272 273 274 275 276 277
                    elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteSerializeError(attr, error_str)

                    elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
                        error_str = to_str(message[1])
                        self.job_shutdown = True
                        raise RemoteDeserializeError(attr, error_str)

                    else:
                        self.job_shutdown = True
                        raise NotImplementedError()

278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
                return wrapper() if is_attribute else wrapper

        def proxy_wrapper_func(remote_wrapper):
            '''
            The 'proxy_wrapper_func' is defined on the top of class 'RemoteWrapper'
            in order to set and get attributes of 'remoted_wrapper' and the corresponding 
            remote models individually. 

            With 'proxy_wrapper_func', it is allowed to define a attribute (or method) of
            the same name in 'RemoteWrapper' and remote models.
            '''

            class ProxyWrapper(object):
                def __init__(self, *args, **kwargs):
                    self.xparl_remote_wrapper_obj = remote_wrapper(
                        *args, **kwargs)

                def __getattr__(self, attr):
                    return self.xparl_remote_wrapper_obj.get_remote_attr(attr)

                def __setattr__(self, attr, value):
                    if attr == 'xparl_remote_wrapper_obj':
                        super(ProxyWrapper, self).__setattr__(attr, value)
                    else:
                        self.xparl_remote_wrapper_obj.set_remote_attr(
                            attr, value)
H
Hongsheng Zeng 已提交
304

305
            return ProxyWrapper
H
Hongsheng Zeng 已提交
306

307
        RemoteWrapper._original = cls
308 309
        proxy_wrapper = proxy_wrapper_func(RemoteWrapper)
        return proxy_wrapper
H
Hongsheng Zeng 已提交
310

311 312 313 314
    max_memory = kwargs.get('max_memory')
    if len(args) == 1 and callable(args[0]):
        return decorator(args[0])
    return decorator