job.py 16.7 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
Bo Zhou 已提交
15 16 17
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XPARL'] = 'True'
F
fuyw 已提交
18 19 20
import argparse
import cloudpickle
import pickle
21 22
import psutil
import re
F
fuyw 已提交
23 24 25 26 27 28
import sys
import tempfile
import threading
import time
import traceback
import zmq
29
from multiprocessing import Process, Pipe
F
fuyw 已提交
30 31 32 33 34
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import loads_argument, loads_return,\
    dumps_argument, dumps_return
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
B
Bo Zhou 已提交
35
from parl.remote.message import InitializedJob
F
fuyw 已提交
36 37 38 39 40 41


class Job(object):
    """Base class for the job.

    After establishing connection with the remote object, the job will
42 43
    create a remote class instance locally and enter an infinite loop
    in a separate process, waiting for commands from the remote object.
B
Bo Zhou 已提交
44

F
fuyw 已提交
45 46 47
    """

    def __init__(self, worker_address):
B
Bo Zhou 已提交
48 49 50
        """
        Args:
            worker_address(str): worker_address for sending job information(e.g, pid)
51 52 53

        Attributes:
            pid (int): Job process ID.
F
fuyw 已提交
54
            max_memory (float): Maximum memory (MB) can be used by each remote instance.
B
Bo Zhou 已提交
55
        """
56 57 58 59
        self.max_memory = None

        self.job_address_receiver, job_address_sender = Pipe()

F
fuyw 已提交
60
        self.worker_address = worker_address
61
        self.job_ip = get_ip_address()
62
        self.pid = os.getpid()
B
Bo Zhou 已提交
63
        self.lock = threading.Lock()
64 65 66 67 68

        self.run_job_process = Process(
            target=self.run, args=(job_address_sender, ))
        self.run_job_process.start()

F
fuyw 已提交
69 70
        self._create_sockets()

F
fuyw 已提交
71 72 73
        process = psutil.Process(self.pid)
        self.init_memory = float(process.memory_info()[0]) / (1024**2)

74 75 76 77 78 79 80 81 82 83 84 85
        self.run_job_process.join()

        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
            try:
                _ = self.kill_job_socket.recv_multipart()
            except zmq.error.Again as e:
                pass
            os._exit(1)

F
fuyw 已提交
86
    def _create_sockets(self):
87
        """Create five sockets for each job in main process.
F
fuyw 已提交
88

89 90 91 92 93
        (1) job_socket(functional socket): sends job_address and heartbeat_address to worker.
        (2) ping_heartbeat_socket: replies ping message of client.
        (3) worker_heartbeat_socket: replies heartbeat message of worker.
        (4) client_heartbeat_socket: replies heartbeat message of client.
        (5) kill_job_socket: sends a command to the corresponding worker to kill the job.
F
fuyw 已提交
94 95

        """
96 97
        # wait for another process to create reply socket
        self.job_address = self.job_address_receiver.recv()
F
fuyw 已提交
98 99

        self.ctx = zmq.Context()
B
Bo Zhou 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        # create the job_socket
        self.job_socket = self.ctx.socket(zmq.REQ)
        self.job_socket.connect("tcp://{}".format(self.worker_address))

        # a thread that reply ping signals from the client
        ping_heartbeat_socket, ping_heartbeat_address = self._create_heartbeat_server(
            timeout=False)
        ping_thread = threading.Thread(
            target=self._reply_ping, args=(ping_heartbeat_socket, ))
        ping_thread.setDaemon(True)
        ping_thread.start()

        # a thread that reply heartbeat signals from the worker
        worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
        )
        worker_thread = threading.Thread(
            target=self._reply_worker_heartbeat,
            args=(worker_heartbeat_socket, ))
        worker_thread.setDaemon(True)

        # a thread that reply heartbeat signals from the client
        client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
        )
        self.client_thread = threading.Thread(
            target=self._reply_client_heartbeat,
            args=(client_heartbeat_socket, ))
        self.client_thread.setDaemon(True)

        # sends job information to the worker
        initialized_job = InitializedJob(
            self.job_address, worker_heartbeat_address,
131
            client_heartbeat_address, ping_heartbeat_address, None, self.pid)
B
Bo Zhou 已提交
132 133 134 135
        self.job_socket.send_multipart(
            [remote_constants.NORMAL_TAG,
             cloudpickle.dumps(initialized_job)])
        message = self.job_socket.recv_multipart()
B
Bo Zhou 已提交
136
        worker_thread.start()
B
Bo Zhou 已提交
137

138 139
        tag = message[0]
        assert tag == remote_constants.NORMAL_TAG
B
Bo Zhou 已提交
140 141 142 143 144 145 146
        # create the kill_job_socket
        kill_job_address = to_str(message[1])
        self.kill_job_socket = self.ctx.socket(zmq.REQ)
        self.kill_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.kill_job_socket.connect("tcp://{}".format(kill_job_address))

147 148 149 150 151 152
    def _check_used_memory(self):
        """Check if the memory used by this job exceeds self.max_memory."""
        stop_job = False
        if self.max_memory is not None:
            process = psutil.Process(self.pid)
            used_memory = float(process.memory_info()[0]) / (1024**2)
F
fuyw 已提交
153
            if used_memory > self.max_memory + self.init_memory:
154 155 156
                stop_job = True
        return stop_job

B
Bo Zhou 已提交
157 158 159 160
    def _reply_ping(self, socket):
        """Create a socket server that reply the ping signal from client.
        This signal is used to make sure that the job is still alive.
        """
161 162 163 164 165 166
        message = socket.recv_multipart()
        max_memory = to_str(message[1])
        if max_memory != 'None':
            self.max_memory = float(max_memory)
        socket.send_multipart([remote_constants.HEARTBEAT_TAG])
        self.client_thread.start()
B
Bo Zhou 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
        socket.close(0)

    def _create_heartbeat_server(self, timeout=True):
        """Create a socket server that will raises timeout exception.
        """
        heartbeat_socket = self.ctx.socket(zmq.REP)
        if timeout:
            heartbeat_socket.setsockopt(
                zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_socket.linger = 0
        heartbeat_port = heartbeat_socket.bind_to_random_port(addr="tcp://*")
        heartbeat_address = "{}:{}".format(self.job_ip, heartbeat_port)
        return heartbeat_socket, heartbeat_address

    def _reply_client_heartbeat(self, socket):
        """Create a socket that replies heartbeat signals from the client.
        If the job losts connection with the client, it will exit too.
        """
185
        while True:
B
Bo Zhou 已提交
186 187
            try:
                message = socket.recv_multipart()
188 189 190 191 192 193 194
                stop_job = self._check_used_memory()
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(str(stop_job)),
                    to_byte(self.job_address)
                ])
                if stop_job == True:
F
fuyw 已提交
195 196 197
                    logger.error(
                        "Memory used by this job exceeds {}. This job will exist."
                        .format(self.max_memory))
F
fuyw 已提交
198
                    time.sleep(5)
199 200
                    socket.close(0)
                    os._exit(1)
B
Bo Zhou 已提交
201 202 203 204
            except zmq.error.Again as e:
                logger.warning(
                    "[Job] Cannot connect to the client. This job will exit and inform the worker."
                )
205
                break
B
Bo Zhou 已提交
206 207 208 209 210
        socket.close(0)
        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
B
Bo Zhou 已提交
211 212 213 214
            try:
                _ = self.kill_job_socket.recv_multipart()
            except zmq.error.Again as e:
                pass
B
Bo Zhou 已提交
215 216 217 218 219 220 221
        logger.warning("[Job]lost connection with the client, will exit")
        os._exit(1)

    def _reply_worker_heartbeat(self, socket):
        """create a socket that replies heartbeat signals from the worker.
        If the worker has exited, the job will exit automatically.
        """
222
        while True:
F
fuyw 已提交
223 224 225 226
            try:
                message = socket.recv_multipart()
                socket.send_multipart([remote_constants.HEARTBEAT_TAG])
            except zmq.error.Again as e:
B
Bo Zhou 已提交
227 228
                logger.warning("[Job] Cannot connect to the worker{}. ".format(
                    self.worker_address) + "Job will quit.")
229
                break
B
Bo Zhou 已提交
230 231
        socket.close(0)
        os._exit(1)
F
fuyw 已提交
232

233
    def wait_for_files(self, reply_socket, job_address):
F
fuyw 已提交
234 235 236 237 238 239
        """Wait for python files from remote object.

        When a remote object receives the allocated job address, it will send
        the python files to the job. Later, the job will save these files to a
        temporary directory and add the temporary diretory to Python's working
        directory.
240 241 242 243
        
        Args:
            reply_socket (sockert): main socket to accept commands of remote object.
            job_address (String): address of reply_socket.
F
fuyw 已提交
244 245 246 247 248

        Returns:
            A temporary directory containing the python files.
        """

249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
        message = reply_socket.recv_multipart()
        tag = message[0]
        if tag == remote_constants.SEND_FILE_TAG:
            pyfiles = pickle.loads(message[1])
            # save python files to temporary directory
            envdir = tempfile.mkdtemp()
            for file, code in pyfiles['python_files'].items():
                file = os.path.join(envdir, file)
                with open(file, 'wb') as code_file:
                    code_file.write(code)

            # save other files to current directory
            for file, content in pyfiles['other_files'].items():
                # create directory (i.e. ./rom_files/)
                if '/' in file:
                    try:
                        os.makedirs(os.path.join(*file.rsplit('/')[:-1]))
                    except OSError as e:
                        pass
                with open(file, 'wb') as f:
                    f.write(content)
            logger.info('[job] reply')
            reply_socket.send_multipart([remote_constants.NORMAL_TAG])
            return envdir
        else:
            logger.error("NotImplementedError:{}, received tag:{}".format(
                job_address, ))
            raise NotImplementedError
F
fuyw 已提交
277

278
    def wait_for_connection(self, reply_socket):
F
fuyw 已提交
279 280 281 282 283 284
        """Wait for connection from the remote object.

        The remote object will send its class information and initialization
        arguments to the job, these parameters are then used to create a
        local instance in the job process.

285 286 287
        Args:
            reply_socket (sockert): main socket to accept commands of remote object.

F
fuyw 已提交
288 289 290 291
        Returns:
            A local instance of the remote class object.
        """

292
        message = reply_socket.recv_multipart()
B
Bo Zhou 已提交
293 294 295
        tag = message[0]
        obj = None

296
        if tag == remote_constants.INIT_OBJECT_TAG:
B
Bo Zhou 已提交
297
            try:
298 299
                cls = cloudpickle.loads(message[1])
                args, kwargs = cloudpickle.loads(message[2])
F
fuyw 已提交
300
                obj = cls(*args, **kwargs)
B
Bo Zhou 已提交
301 302 303 304
            except Exception as e:
                traceback_str = str(traceback.format_exc())
                error_str = str(e)
                logger.error("traceback:\n{}".format(traceback_str))
305
                reply_socket.send_multipart([
B
Bo Zhou 已提交
306 307 308 309
                    remote_constants.EXCEPTION_TAG,
                    to_byte(error_str + "\ntraceback:\n" + traceback_str)
                ])
                return None
310
            reply_socket.send_multipart([remote_constants.NORMAL_TAG])
B
Bo Zhou 已提交
311 312
        else:
            logger.error("Message from job {}".format(message))
313
            reply_socket.send_multipart([
B
Bo Zhou 已提交
314 315 316 317 318 319
                remote_constants.EXCEPTION_TAG,
                b"[job]Unkonwn tag when tried to receive the class definition"
            ])
            raise NotImplementedError

        return obj
F
fuyw 已提交
320

321
    def run(self, job_address_sender):
B
Bo Zhou 已提交
322
        """An infinite loop waiting for a new task.
323 324 325

        Args:
            job_address_sender(sending end of multiprocessing.Pipe): send job address of reply_socket to main process.
B
Bo Zhou 已提交
326
        """
327 328 329 330 331 332 333 334 335 336
        ctx = zmq.Context()

        # create the reply_socket
        reply_socket = ctx.socket(zmq.REP)
        job_port = reply_socket.bind_to_random_port(addr="tcp://*")
        reply_socket.linger = 0
        job_ip = get_ip_address()
        job_address = "{}:{}".format(job_ip, job_port)

        job_address_sender.send(job_address)
B
Bo Zhou 已提交
337 338

        try:
339 340 341 342 343
            # receive source code from the actor and append them to the environment variables.
            envdir = self.wait_for_files(reply_socket, job_address)
            sys.path.append(envdir)

            obj = self.wait_for_connection(reply_socket)
B
Bo Zhou 已提交
344
            assert obj is not None
345
            self.single_task(obj, reply_socket, job_address)
B
Bo Zhou 已提交
346 347 348 349 350 351 352
        except Exception as e:
            logger.error(
                "Error occurs when running a single task. We will reset this job. Reason:{}"
                .format(e))
            traceback_str = str(traceback.format_exc())
            logger.error("traceback:\n{}".format(traceback_str))

353
    def single_task(self, obj, reply_socket, job_address):
F
fuyw 已提交
354 355 356 357 358 359 360 361 362
        """An infinite loop waiting for commands from the remote object.

        Each job will receive two kinds of message from the remote object:

        1. When the remote object calls a function, job will run the
           function on the local instance and return the results to the
           remote object.
        2. When the remote object is deleted, the job will quit and release
           related computation resources.
363 364 365 366

        Args:
            reply_socket (sockert): main socket to accept commands of remote object.
            job_address (String): address of reply_socket.
F
fuyw 已提交
367 368
        """

369 370
        while True:
            message = reply_socket.recv_multipart()
B
Bo Zhou 已提交
371

F
fuyw 已提交
372 373 374 375 376 377 378 379 380 381
            tag = message[0]

            if tag == remote_constants.CALL_TAG:
                try:
                    function_name = to_str(message[1])
                    data = message[2]
                    args, kwargs = loads_argument(data)
                    ret = getattr(obj, function_name)(*args, **kwargs)
                    ret = dumps_return(ret)

382
                    reply_socket.send_multipart(
F
fuyw 已提交
383 384 385
                        [remote_constants.NORMAL_TAG, ret])

                except Exception as e:
B
Bo Zhou 已提交
386 387
                    # reset the job

F
fuyw 已提交
388 389 390 391
                    error_str = str(e)
                    logger.error(error_str)

                    if type(e) == AttributeError:
392
                        reply_socket.send_multipart([
F
fuyw 已提交
393 394 395 396 397 398
                            remote_constants.ATTRIBUTE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise AttributeError

                    elif type(e) == SerializeError:
399
                        reply_socket.send_multipart([
F
fuyw 已提交
400 401 402 403 404 405
                            remote_constants.SERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise SerializeError

                    elif type(e) == DeserializeError:
406
                        reply_socket.send_multipart([
F
fuyw 已提交
407 408 409
                            remote_constants.DESERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
B
Bo Zhou 已提交
410
                        raise DeserializeError
F
fuyw 已提交
411 412 413 414

                    else:
                        traceback_str = str(traceback.format_exc())
                        logger.error("traceback:\n{}".format(traceback_str))
415
                        reply_socket.send_multipart([
F
fuyw 已提交
416 417 418 419
                            remote_constants.EXCEPTION_TAG,
                            to_byte(error_str + "\ntraceback:\n" +
                                    traceback_str)
                        ])
B
Bo Zhou 已提交
420
                        break
F
fuyw 已提交
421 422 423

            # receive DELETE_TAG from actor, and stop replying worker heartbeat
            elif tag == remote_constants.KILLJOB_TAG:
424 425 426
                reply_socket.send_multipart([remote_constants.NORMAL_TAG])
                logger.warning("An actor exits and this job {} will exit.".
                               format(job_address))
B
Bo Zhou 已提交
427
                break
F
fuyw 已提交
428
            else:
B
Bo Zhou 已提交
429 430
                logger.error(
                    "The job receives an unknown message: {}".format(message))
F
fuyw 已提交
431 432 433 434 435 436 437 438 439
                raise NotImplementedError


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--worker_address", required=True, type=str, help="worker_address")
    args = parser.parse_args()
    job = Job(args.worker_address)