job.py 15.5 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
Bo Zhou 已提交
15 16 17
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XPARL'] = 'True'
F
fuyw 已提交
18 19 20
import argparse
import cloudpickle
import pickle
21 22
import psutil
import re
F
fuyw 已提交
23 24 25 26 27 28 29 30 31 32 33
import sys
import tempfile
import threading
import time
import traceback
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import loads_argument, loads_return,\
    dumps_argument, dumps_return
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
B
Bo Zhou 已提交
34
from parl.remote.message import InitializedJob
F
fuyw 已提交
35 36 37 38 39 40 41 42


class Job(object):
    """Base class for the job.

    After establishing connection with the remote object, the job will
    create a remote class instance locally and enter an infinite loop,
    waiting for commands from the remote object.
B
Bo Zhou 已提交
43

F
fuyw 已提交
44 45 46
    """

    def __init__(self, worker_address):
B
Bo Zhou 已提交
47 48 49
        """
        Args:
            worker_address(str): worker_address for sending job information(e.g, pid)
50 51 52 53

        Attributes:
            pid (int): Job process ID.
            max_memory (float): Maximum memory (MB) can be used by each remote instance. 
B
Bo Zhou 已提交
54
        """
F
fuyw 已提交
55 56
        self.job_is_alive = True
        self.worker_address = worker_address
57 58
        self.pid = os.getpid()
        self.max_memory = None
B
Bo Zhou 已提交
59
        self.lock = threading.Lock()
F
fuyw 已提交
60 61 62
        self._create_sockets()

    def _create_sockets(self):
B
Bo Zhou 已提交
63
        """Create three sockets for each job.
F
fuyw 已提交
64

B
Bo Zhou 已提交
65 66
        (1) reply_socket(main socket): receives the command(i.e, the function name and args) 
            from the actual class instance, completes the computation, and returns the result of
F
fuyw 已提交
67
            the function.
B
Bo Zhou 已提交
68 69
        (2) job_socket(functional socket): sends job_address and heartbeat_address to worker.
        (3) kill_job_socket: sends a command to the corresponding worker to kill the job.
F
fuyw 已提交
70 71 72 73 74

        """

        self.ctx = zmq.Context()

B
Bo Zhou 已提交
75
        # create the reply_socket
F
fuyw 已提交
76 77
        self.reply_socket = self.ctx.socket(zmq.REP)
        job_port = self.reply_socket.bind_to_random_port(addr="tcp://*")
B
Bo Zhou 已提交
78
        self.reply_socket.linger = 0
F
fuyw 已提交
79 80 81
        self.job_ip = get_ip_address()
        self.job_address = "{}:{}".format(self.job_ip, job_port)

B
Bo Zhou 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        # create the job_socket
        self.job_socket = self.ctx.socket(zmq.REQ)
        self.job_socket.connect("tcp://{}".format(self.worker_address))

        # a thread that reply ping signals from the client
        ping_heartbeat_socket, ping_heartbeat_address = self._create_heartbeat_server(
            timeout=False)
        ping_thread = threading.Thread(
            target=self._reply_ping, args=(ping_heartbeat_socket, ))
        ping_thread.setDaemon(True)
        ping_thread.start()
        self.ping_heartbeat_address = ping_heartbeat_address

        # a thread that reply heartbeat signals from the worker
        worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
        )
        worker_thread = threading.Thread(
            target=self._reply_worker_heartbeat,
            args=(worker_heartbeat_socket, ))
        worker_thread.setDaemon(True)

        # a thread that reply heartbeat signals from the client
        client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
        )
        self.client_thread = threading.Thread(
            target=self._reply_client_heartbeat,
            args=(client_heartbeat_socket, ))
        self.client_thread.setDaemon(True)

        # sends job information to the worker
        initialized_job = InitializedJob(
            self.job_address, worker_heartbeat_address,
            client_heartbeat_address, self.ping_heartbeat_address, None,
115
            self.pid)
B
Bo Zhou 已提交
116 117 118 119
        self.job_socket.send_multipart(
            [remote_constants.NORMAL_TAG,
             cloudpickle.dumps(initialized_job)])
        message = self.job_socket.recv_multipart()
B
Bo Zhou 已提交
120
        worker_thread.start()
B
Bo Zhou 已提交
121

122 123
        tag = message[0]
        assert tag == remote_constants.NORMAL_TAG
B
Bo Zhou 已提交
124 125 126 127 128 129 130
        # create the kill_job_socket
        kill_job_address = to_str(message[1])
        self.kill_job_socket = self.ctx.socket(zmq.REQ)
        self.kill_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.kill_job_socket.connect("tcp://{}".format(kill_job_address))

131 132 133 134 135 136 137 138 139 140
    def _check_used_memory(self):
        """Check if the memory used by this job exceeds self.max_memory."""
        stop_job = False
        if self.max_memory is not None:
            process = psutil.Process(self.pid)
            used_memory = float(process.memory_info()[0]) / (1024**2)
            if used_memory > self.max_memory:
                stop_job = True
        return stop_job

B
Bo Zhou 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
    def _reply_ping(self, socket):
        """Create a socket server that reply the ping signal from client.
        This signal is used to make sure that the job is still alive.
        """
        while self.job_is_alive:
            message = socket.recv_multipart()
            socket.send_multipart([remote_constants.HEARTBEAT_TAG])
        socket.close(0)

    def _create_heartbeat_server(self, timeout=True):
        """Create a socket server that will raises timeout exception.
        """
        heartbeat_socket = self.ctx.socket(zmq.REP)
        if timeout:
            heartbeat_socket.setsockopt(
                zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_socket.linger = 0
        heartbeat_port = heartbeat_socket.bind_to_random_port(addr="tcp://*")
        heartbeat_address = "{}:{}".format(self.job_ip, heartbeat_port)
        return heartbeat_socket, heartbeat_address

    def _reply_client_heartbeat(self, socket):
        """Create a socket that replies heartbeat signals from the client.
        If the job losts connection with the client, it will exit too.
        """
        self.client_is_alive = True
167
        while self.client_is_alive and self.job_is_alive:
B
Bo Zhou 已提交
168 169
            try:
                message = socket.recv_multipart()
170 171 172 173 174 175 176 177 178
                stop_job = self._check_used_memory()
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(str(stop_job)),
                    to_byte(self.job_address)
                ])
                if stop_job == True:
                    socket.close(0)
                    os._exit(1)
B
Bo Zhou 已提交
179 180 181 182 183 184 185 186 187 188
            except zmq.error.Again as e:
                logger.warning(
                    "[Job] Cannot connect to the client. This job will exit and inform the worker."
                )
                self.client_is_alive = False
        socket.close(0)
        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
B
Bo Zhou 已提交
189 190 191 192
            try:
                _ = self.kill_job_socket.recv_multipart()
            except zmq.error.Again as e:
                pass
B
Bo Zhou 已提交
193 194 195 196 197 198 199
        logger.warning("[Job]lost connection with the client, will exit")
        os._exit(1)

    def _reply_worker_heartbeat(self, socket):
        """create a socket that replies heartbeat signals from the worker.
        If the worker has exited, the job will exit automatically.
        """
B
Bo Zhou 已提交
200

F
fuyw 已提交
201
        self.worker_is_alive = True
B
Bo Zhou 已提交
202
        # a flag to decide when to exit heartbeat loop
F
fuyw 已提交
203 204 205 206 207
        while self.worker_is_alive and self.job_is_alive:
            try:
                message = socket.recv_multipart()
                socket.send_multipart([remote_constants.HEARTBEAT_TAG])
            except zmq.error.Again as e:
B
Bo Zhou 已提交
208 209
                logger.warning("[Job] Cannot connect to the worker{}. ".format(
                    self.worker_address) + "Job will quit.")
F
fuyw 已提交
210 211
                self.worker_is_alive = False
                self.job_is_alive = False
B
Bo Zhou 已提交
212 213
        socket.close(0)
        os._exit(1)
F
fuyw 已提交
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240

    def wait_for_files(self):
        """Wait for python files from remote object.

        When a remote object receives the allocated job address, it will send
        the python files to the job. Later, the job will save these files to a
        temporary directory and add the temporary diretory to Python's working
        directory.

        Returns:
            A temporary directory containing the python files.
        """

        while True:
            message = self.reply_socket.recv_multipart()
            tag = message[0]
            if tag == remote_constants.SEND_FILE_TAG:
                pyfiles = pickle.loads(message[1])
                envdir = tempfile.mkdtemp()
                for file in pyfiles:
                    code = pyfiles[file]
                    file = os.path.join(envdir, file)
                    with open(file, 'wb') as code_file:
                        code_file.write(code)
                self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
                return envdir
            else:
B
Bo Zhou 已提交
241 242
                logger.error("NotImplementedError:{}, received tag:{}".format(
                    self.job_address, ))
F
fuyw 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255
                raise NotImplementedError

    def wait_for_connection(self):
        """Wait for connection from the remote object.

        The remote object will send its class information and initialization
        arguments to the job, these parameters are then used to create a
        local instance in the job process.

        Returns:
            A local instance of the remote class object.
        """

B
Bo Zhou 已提交
256 257 258 259 260 261
        message = self.reply_socket.recv_multipart()
        tag = message[0]
        obj = None
        if tag == remote_constants.INIT_OBJECT_TAG:
            cls = cloudpickle.loads(message[1])
            args, kwargs = cloudpickle.loads(message[2])
262 263 264
            max_memory = to_str(message[3])
            if max_memory != 'None':
                self.max_memory = float(max_memory)
B
Bo Zhou 已提交
265 266

            try:
F
fuyw 已提交
267
                obj = cls(*args, **kwargs)
B
Bo Zhou 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
            except Exception as e:
                traceback_str = str(traceback.format_exc())
                error_str = str(e)
                logger.error("traceback:\n{}".format(traceback_str))
                self.reply_socket.send_multipart([
                    remote_constants.EXCEPTION_TAG,
                    to_byte(error_str + "\ntraceback:\n" + traceback_str)
                ])
                self.client_is_alive = False
                return None

            self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
        else:
            logger.error("Message from job {}".format(message))
            self.reply_socket.send_multipart([
                remote_constants.EXCEPTION_TAG,
                b"[job]Unkonwn tag when tried to receive the class definition"
            ])
            raise NotImplementedError

        return obj
F
fuyw 已提交
289 290

    def run(self):
B
Bo Zhou 已提交
291 292 293 294 295
        """An infinite loop waiting for a new task.
        """
        # receive source code from the actor and append them to the environment variables.
        envdir = self.wait_for_files()
        sys.path.append(envdir)
B
Bo Zhou 已提交
296
        self.client_is_alive = True
B
Bo Zhou 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
        self.client_thread.start()

        try:
            obj = self.wait_for_connection()
            assert obj is not None
            self.single_task(obj)
        except Exception as e:
            logger.error(
                "Error occurs when running a single task. We will reset this job. Reason:{}"
                .format(e))
            traceback_str = str(traceback.format_exc())
            logger.error("traceback:\n{}".format(traceback_str))
        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
B
Bo Zhou 已提交
313 314 315 316 317
            try:
                _ = self.kill_job_socket.recv_multipart()
            except zmq.error.Again as e:
                pass
            os._exit(1)
B
Bo Zhou 已提交
318 319

    def single_task(self, obj):
F
fuyw 已提交
320 321 322 323 324 325 326 327 328 329 330
        """An infinite loop waiting for commands from the remote object.

        Each job will receive two kinds of message from the remote object:

        1. When the remote object calls a function, job will run the
           function on the local instance and return the results to the
           remote object.
        2. When the remote object is deleted, the job will quit and release
           related computation resources.
        """

B
Bo Zhou 已提交
331
        while self.job_is_alive and self.client_is_alive:
F
fuyw 已提交
332
            message = self.reply_socket.recv_multipart()
B
Bo Zhou 已提交
333

F
fuyw 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347
            tag = message[0]

            if tag == remote_constants.CALL_TAG:
                try:
                    function_name = to_str(message[1])
                    data = message[2]
                    args, kwargs = loads_argument(data)
                    ret = getattr(obj, function_name)(*args, **kwargs)
                    ret = dumps_return(ret)

                    self.reply_socket.send_multipart(
                        [remote_constants.NORMAL_TAG, ret])

                except Exception as e:
B
Bo Zhou 已提交
348 349 350
                    # reset the job
                    self.client_is_alive = False

F
fuyw 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
                    error_str = str(e)
                    logger.error(error_str)

                    if type(e) == AttributeError:
                        self.reply_socket.send_multipart([
                            remote_constants.ATTRIBUTE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise AttributeError

                    elif type(e) == SerializeError:
                        self.reply_socket.send_multipart([
                            remote_constants.SERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise SerializeError

                    elif type(e) == DeserializeError:
                        self.reply_socket.send_multipart([
                            remote_constants.DESERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
B
Bo Zhou 已提交
373
                        raise DeserializeError
F
fuyw 已提交
374 375 376 377 378 379 380 381 382

                    else:
                        traceback_str = str(traceback.format_exc())
                        logger.error("traceback:\n{}".format(traceback_str))
                        self.reply_socket.send_multipart([
                            remote_constants.EXCEPTION_TAG,
                            to_byte(error_str + "\ntraceback:\n" +
                                    traceback_str)
                        ])
B
Bo Zhou 已提交
383
                        break
F
fuyw 已提交
384 385 386 387

            # receive DELETE_TAG from actor, and stop replying worker heartbeat
            elif tag == remote_constants.KILLJOB_TAG:
                self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
B
Bo Zhou 已提交
388 389 390 391 392
                self.client_is_alive = False
                logger.warning(
                    "An actor exits and this job {} will exit.".format(
                        self.job_address))
                break
F
fuyw 已提交
393
            else:
B
Bo Zhou 已提交
394 395
                logger.error(
                    "The job receives an unknown message: {}".format(message))
F
fuyw 已提交
396 397 398 399 400 401 402 403 404 405
                raise NotImplementedError


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--worker_address", required=True, type=str, help="worker_address")
    args = parser.parse_args()
    job = Job(args.worker_address)
    job.run()