job.py 15.8 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
Bo Zhou 已提交
15 16 17
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XPARL'] = 'True'
F
fuyw 已提交
18 19 20
import argparse
import cloudpickle
import pickle
21 22
import psutil
import re
F
fuyw 已提交
23 24 25 26 27 28 29 30 31 32 33
import sys
import tempfile
import threading
import time
import traceback
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import loads_argument, loads_return,\
    dumps_argument, dumps_return
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
B
Bo Zhou 已提交
34
from parl.remote.message import InitializedJob
F
fuyw 已提交
35 36 37 38 39 40 41 42


class Job(object):
    """Base class for the job.

    After establishing connection with the remote object, the job will
    create a remote class instance locally and enter an infinite loop,
    waiting for commands from the remote object.
B
Bo Zhou 已提交
43

F
fuyw 已提交
44 45 46
    """

    def __init__(self, worker_address):
B
Bo Zhou 已提交
47 48 49
        """
        Args:
            worker_address(str): worker_address for sending job information(e.g, pid)
50 51 52

        Attributes:
            pid (int): Job process ID.
F
fuyw 已提交
53
            max_memory (float): Maximum memory (MB) can be used by each remote instance.
B
Bo Zhou 已提交
54
        """
F
fuyw 已提交
55 56
        self.job_is_alive = True
        self.worker_address = worker_address
57 58
        self.pid = os.getpid()
        self.max_memory = None
B
Bo Zhou 已提交
59
        self.lock = threading.Lock()
F
fuyw 已提交
60 61
        self._create_sockets()

F
fuyw 已提交
62 63 64
        process = psutil.Process(self.pid)
        self.init_memory = float(process.memory_info()[0]) / (1024**2)

F
fuyw 已提交
65
    def _create_sockets(self):
B
Bo Zhou 已提交
66
        """Create three sockets for each job.
F
fuyw 已提交
67

F
fuyw 已提交
68
        (1) reply_socket(main socket): receives the command(i.e, the function name and args)
B
Bo Zhou 已提交
69
            from the actual class instance, completes the computation, and returns the result of
F
fuyw 已提交
70
            the function.
B
Bo Zhou 已提交
71 72
        (2) job_socket(functional socket): sends job_address and heartbeat_address to worker.
        (3) kill_job_socket: sends a command to the corresponding worker to kill the job.
F
fuyw 已提交
73 74 75 76 77

        """

        self.ctx = zmq.Context()

B
Bo Zhou 已提交
78
        # create the reply_socket
F
fuyw 已提交
79 80
        self.reply_socket = self.ctx.socket(zmq.REP)
        job_port = self.reply_socket.bind_to_random_port(addr="tcp://*")
B
Bo Zhou 已提交
81
        self.reply_socket.linger = 0
F
fuyw 已提交
82 83 84
        self.job_ip = get_ip_address()
        self.job_address = "{}:{}".format(self.job_ip, job_port)

B
Bo Zhou 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        # create the job_socket
        self.job_socket = self.ctx.socket(zmq.REQ)
        self.job_socket.connect("tcp://{}".format(self.worker_address))

        # a thread that reply ping signals from the client
        ping_heartbeat_socket, ping_heartbeat_address = self._create_heartbeat_server(
            timeout=False)
        ping_thread = threading.Thread(
            target=self._reply_ping, args=(ping_heartbeat_socket, ))
        ping_thread.setDaemon(True)
        ping_thread.start()
        self.ping_heartbeat_address = ping_heartbeat_address

        # a thread that reply heartbeat signals from the worker
        worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
        )
        worker_thread = threading.Thread(
            target=self._reply_worker_heartbeat,
            args=(worker_heartbeat_socket, ))
        worker_thread.setDaemon(True)

        # a thread that reply heartbeat signals from the client
        client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
        )
        self.client_thread = threading.Thread(
            target=self._reply_client_heartbeat,
            args=(client_heartbeat_socket, ))
        self.client_thread.setDaemon(True)

        # sends job information to the worker
        initialized_job = InitializedJob(
            self.job_address, worker_heartbeat_address,
            client_heartbeat_address, self.ping_heartbeat_address, None,
118
            self.pid)
B
Bo Zhou 已提交
119 120 121 122
        self.job_socket.send_multipart(
            [remote_constants.NORMAL_TAG,
             cloudpickle.dumps(initialized_job)])
        message = self.job_socket.recv_multipart()
B
Bo Zhou 已提交
123
        worker_thread.start()
B
Bo Zhou 已提交
124

125 126
        tag = message[0]
        assert tag == remote_constants.NORMAL_TAG
B
Bo Zhou 已提交
127 128 129 130 131 132 133
        # create the kill_job_socket
        kill_job_address = to_str(message[1])
        self.kill_job_socket = self.ctx.socket(zmq.REQ)
        self.kill_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.kill_job_socket.connect("tcp://{}".format(kill_job_address))

134 135 136 137 138 139
    def _check_used_memory(self):
        """Check if the memory used by this job exceeds self.max_memory."""
        stop_job = False
        if self.max_memory is not None:
            process = psutil.Process(self.pid)
            used_memory = float(process.memory_info()[0]) / (1024**2)
F
fuyw 已提交
140
            if used_memory > self.max_memory + self.init_memory:
141 142 143
                stop_job = True
        return stop_job

B
Bo Zhou 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    def _reply_ping(self, socket):
        """Create a socket server that reply the ping signal from client.
        This signal is used to make sure that the job is still alive.
        """
        while self.job_is_alive:
            message = socket.recv_multipart()
            socket.send_multipart([remote_constants.HEARTBEAT_TAG])
        socket.close(0)

    def _create_heartbeat_server(self, timeout=True):
        """Create a socket server that will raises timeout exception.
        """
        heartbeat_socket = self.ctx.socket(zmq.REP)
        if timeout:
            heartbeat_socket.setsockopt(
                zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_socket.linger = 0
        heartbeat_port = heartbeat_socket.bind_to_random_port(addr="tcp://*")
        heartbeat_address = "{}:{}".format(self.job_ip, heartbeat_port)
        return heartbeat_socket, heartbeat_address

    def _reply_client_heartbeat(self, socket):
        """Create a socket that replies heartbeat signals from the client.
        If the job losts connection with the client, it will exit too.
        """
        self.client_is_alive = True
170
        while self.client_is_alive and self.job_is_alive:
B
Bo Zhou 已提交
171 172
            try:
                message = socket.recv_multipart()
173 174 175 176 177 178 179
                stop_job = self._check_used_memory()
                socket.send_multipart([
                    remote_constants.HEARTBEAT_TAG,
                    to_byte(str(stop_job)),
                    to_byte(self.job_address)
                ])
                if stop_job == True:
F
fuyw 已提交
180 181 182
                    logger.error(
                        "Memory used by this job exceeds {}. This job will exist."
                        .format(self.max_memory))
F
fuyw 已提交
183
                    time.sleep(5)
184 185
                    socket.close(0)
                    os._exit(1)
B
Bo Zhou 已提交
186 187 188 189 190 191 192 193 194 195
            except zmq.error.Again as e:
                logger.warning(
                    "[Job] Cannot connect to the client. This job will exit and inform the worker."
                )
                self.client_is_alive = False
        socket.close(0)
        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
B
Bo Zhou 已提交
196 197 198 199
            try:
                _ = self.kill_job_socket.recv_multipart()
            except zmq.error.Again as e:
                pass
B
Bo Zhou 已提交
200 201 202 203 204 205 206
        logger.warning("[Job]lost connection with the client, will exit")
        os._exit(1)

    def _reply_worker_heartbeat(self, socket):
        """create a socket that replies heartbeat signals from the worker.
        If the worker has exited, the job will exit automatically.
        """
B
Bo Zhou 已提交
207

F
fuyw 已提交
208
        self.worker_is_alive = True
B
Bo Zhou 已提交
209
        # a flag to decide when to exit heartbeat loop
F
fuyw 已提交
210 211 212 213 214
        while self.worker_is_alive and self.job_is_alive:
            try:
                message = socket.recv_multipart()
                socket.send_multipart([remote_constants.HEARTBEAT_TAG])
            except zmq.error.Again as e:
B
Bo Zhou 已提交
215 216
                logger.warning("[Job] Cannot connect to the worker{}. ".format(
                    self.worker_address) + "Job will quit.")
F
fuyw 已提交
217 218
                self.worker_is_alive = False
                self.job_is_alive = False
B
Bo Zhou 已提交
219 220
        socket.close(0)
        os._exit(1)
F
fuyw 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247

    def wait_for_files(self):
        """Wait for python files from remote object.

        When a remote object receives the allocated job address, it will send
        the python files to the job. Later, the job will save these files to a
        temporary directory and add the temporary diretory to Python's working
        directory.

        Returns:
            A temporary directory containing the python files.
        """

        while True:
            message = self.reply_socket.recv_multipart()
            tag = message[0]
            if tag == remote_constants.SEND_FILE_TAG:
                pyfiles = pickle.loads(message[1])
                envdir = tempfile.mkdtemp()
                for file in pyfiles:
                    code = pyfiles[file]
                    file = os.path.join(envdir, file)
                    with open(file, 'wb') as code_file:
                        code_file.write(code)
                self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
                return envdir
            else:
B
Bo Zhou 已提交
248 249
                logger.error("NotImplementedError:{}, received tag:{}".format(
                    self.job_address, ))
F
fuyw 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262
                raise NotImplementedError

    def wait_for_connection(self):
        """Wait for connection from the remote object.

        The remote object will send its class information and initialization
        arguments to the job, these parameters are then used to create a
        local instance in the job process.

        Returns:
            A local instance of the remote class object.
        """

B
Bo Zhou 已提交
263 264 265 266 267 268
        message = self.reply_socket.recv_multipart()
        tag = message[0]
        obj = None
        if tag == remote_constants.INIT_OBJECT_TAG:
            cls = cloudpickle.loads(message[1])
            args, kwargs = cloudpickle.loads(message[2])
269 270 271
            max_memory = to_str(message[3])
            if max_memory != 'None':
                self.max_memory = float(max_memory)
B
Bo Zhou 已提交
272 273

            try:
F
fuyw 已提交
274
                obj = cls(*args, **kwargs)
B
Bo Zhou 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
            except Exception as e:
                traceback_str = str(traceback.format_exc())
                error_str = str(e)
                logger.error("traceback:\n{}".format(traceback_str))
                self.reply_socket.send_multipart([
                    remote_constants.EXCEPTION_TAG,
                    to_byte(error_str + "\ntraceback:\n" + traceback_str)
                ])
                self.client_is_alive = False
                return None

            self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
        else:
            logger.error("Message from job {}".format(message))
            self.reply_socket.send_multipart([
                remote_constants.EXCEPTION_TAG,
                b"[job]Unkonwn tag when tried to receive the class definition"
            ])
            raise NotImplementedError

        return obj
F
fuyw 已提交
296 297

    def run(self):
B
Bo Zhou 已提交
298 299 300 301 302
        """An infinite loop waiting for a new task.
        """
        # receive source code from the actor and append them to the environment variables.
        envdir = self.wait_for_files()
        sys.path.append(envdir)
B
Bo Zhou 已提交
303
        self.client_is_alive = True
B
Bo Zhou 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        self.client_thread.start()

        try:
            obj = self.wait_for_connection()
            assert obj is not None
            self.single_task(obj)
        except Exception as e:
            logger.error(
                "Error occurs when running a single task. We will reset this job. Reason:{}"
                .format(e))
            traceback_str = str(traceback.format_exc())
            logger.error("traceback:\n{}".format(traceback_str))
        with self.lock:
            self.kill_job_socket.send_multipart(
                [remote_constants.KILLJOB_TAG,
                 to_byte(self.job_address)])
B
Bo Zhou 已提交
320 321 322 323 324
            try:
                _ = self.kill_job_socket.recv_multipart()
            except zmq.error.Again as e:
                pass
            os._exit(1)
B
Bo Zhou 已提交
325 326

    def single_task(self, obj):
F
fuyw 已提交
327 328 329 330 331 332 333 334 335 336 337
        """An infinite loop waiting for commands from the remote object.

        Each job will receive two kinds of message from the remote object:

        1. When the remote object calls a function, job will run the
           function on the local instance and return the results to the
           remote object.
        2. When the remote object is deleted, the job will quit and release
           related computation resources.
        """

B
Bo Zhou 已提交
338
        while self.job_is_alive and self.client_is_alive:
F
fuyw 已提交
339
            message = self.reply_socket.recv_multipart()
B
Bo Zhou 已提交
340

F
fuyw 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354
            tag = message[0]

            if tag == remote_constants.CALL_TAG:
                try:
                    function_name = to_str(message[1])
                    data = message[2]
                    args, kwargs = loads_argument(data)
                    ret = getattr(obj, function_name)(*args, **kwargs)
                    ret = dumps_return(ret)

                    self.reply_socket.send_multipart(
                        [remote_constants.NORMAL_TAG, ret])

                except Exception as e:
B
Bo Zhou 已提交
355 356 357
                    # reset the job
                    self.client_is_alive = False

F
fuyw 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
                    error_str = str(e)
                    logger.error(error_str)

                    if type(e) == AttributeError:
                        self.reply_socket.send_multipart([
                            remote_constants.ATTRIBUTE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise AttributeError

                    elif type(e) == SerializeError:
                        self.reply_socket.send_multipart([
                            remote_constants.SERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
                        raise SerializeError

                    elif type(e) == DeserializeError:
                        self.reply_socket.send_multipart([
                            remote_constants.DESERIALIZE_EXCEPTION_TAG,
                            to_byte(error_str)
                        ])
B
Bo Zhou 已提交
380
                        raise DeserializeError
F
fuyw 已提交
381 382 383 384 385 386 387 388 389

                    else:
                        traceback_str = str(traceback.format_exc())
                        logger.error("traceback:\n{}".format(traceback_str))
                        self.reply_socket.send_multipart([
                            remote_constants.EXCEPTION_TAG,
                            to_byte(error_str + "\ntraceback:\n" +
                                    traceback_str)
                        ])
B
Bo Zhou 已提交
390
                        break
F
fuyw 已提交
391 392 393 394

            # receive DELETE_TAG from actor, and stop replying worker heartbeat
            elif tag == remote_constants.KILLJOB_TAG:
                self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
B
Bo Zhou 已提交
395 396 397 398 399
                self.client_is_alive = False
                logger.warning(
                    "An actor exits and this job {} will exit.".format(
                        self.job_address))
                break
F
fuyw 已提交
400
            else:
B
Bo Zhou 已提交
401 402
                logger.error(
                    "The job receives an unknown message: {}".format(message))
F
fuyw 已提交
403 404 405 406 407 408 409 410 411 412
                raise NotImplementedError


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--worker_address", required=True, type=str, help="worker_address")
    args = parser.parse_args()
    job = Job(args.worker_address)
    job.run()