提交 64aebb6d 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

make job run task in a separate process (#170)

* make job run task in a separate process

* fix typo

* add more debug info in xparl client

* refine control flow of different processes in xparl job

* refine control flow of different processes in xparl job

* remove tsinghua source

* remove tsinghua source

* remove unnecessary logic

* fix typo

* refine comments and some logic

* fix bug, `decay=0` means totally synchronize weights of source model to target model
上级 ee36f15b
...@@ -173,7 +173,7 @@ function main() { ...@@ -173,7 +173,7 @@ function main() {
run_test_with_gpu run_test_with_gpu
# #
/root/miniconda3/envs/empty_env/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple . /root/miniconda3/envs/empty_env/bin/pip install .
run_import_test run_import_test
run_docs_test run_docs_test
;; ;;
......
...@@ -4,5 +4,5 @@ source ~/.bashrc ...@@ -4,5 +4,5 @@ source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH" export PATH="/root/miniconda3/bin:$PATH"
source deactivate source deactivate
source activate docs source activate docs
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple /work/ pip install /work/
make html make html
...@@ -40,7 +40,7 @@ For final submission, we test our model in 500 CPUs, running 10 episodes per CPU ...@@ -40,7 +40,7 @@ For final submission, we test our model in 500 CPUs, running 10 episodes per CPU
2. Download the model file from online stroage service, [Baidu Pan](https://pan.baidu.com/s/1NN1auY2eDblGzUiqR8Bfqw) or [Google Drive](https://drive.google.com/open?id=1DQHrwtXzgFbl9dE7jGOe9ZbY0G9-qfq3) 2. Download the model file from online stroage service, [Baidu Pan](https://pan.baidu.com/s/1NN1auY2eDblGzUiqR8Bfqw) or [Google Drive](https://drive.google.com/open?id=1DQHrwtXzgFbl9dE7jGOe9ZbY0G9-qfq3)
3. Unpack the file by using: 3. Unpack the file by using:
`tar zxvf saved_model.tar.gz` `tar zxvf saved_model.tar.gz`
4. Launch test scription: 4. Launch the test script:
`python test.py` `python test.py`
## Part2: Curriculum learning ## Part2: Curriculum learning
......
...@@ -59,7 +59,7 @@ class OpenSimAgent(parl.Agent): ...@@ -59,7 +59,7 @@ class OpenSimAgent(parl.Agent):
# Attention: In the beginning, sync target model totally. # Attention: In the beginning, sync target model totally.
self.alg.sync_target( self.alg.sync_target(
model_id=i, model_id=i,
decay=1.0, decay=0,
share_vars_parallel_executor=self.learn_pe[i]) share_vars_parallel_executor=self.learn_pe[i])
# Do cache, will create ParallelExecutor of sync params in advance # Do cache, will create ParallelExecutor of sync params in advance
# If not, there are some issues when ensemble_num > 1 # If not, there are some issues when ensemble_num > 1
......
...@@ -14,5 +14,5 @@ ...@@ -14,5 +14,5 @@
2. Download the model file from online stroage service: [Baidu Pan](https://pan.baidu.com/s/12LIPspckCT8-Q5U1QX69Fg) (password: `b5ck`) or [Google Drive](https://drive.google.com/file/d/1jJtOcOVJ6auz3s-TyWgUJvofPXI94yxy/view?usp=sharing) 2. Download the model file from online stroage service: [Baidu Pan](https://pan.baidu.com/s/12LIPspckCT8-Q5U1QX69Fg) (password: `b5ck`) or [Google Drive](https://drive.google.com/file/d/1jJtOcOVJ6auz3s-TyWgUJvofPXI94yxy/view?usp=sharing)
3. Unpack the file: 3. Unpack the file:
`tar zxvf saved_models.tar.gz` `tar zxvf saved_models.tar.gz`
4. Launch test scription: 4. Launch the test script:
`python test.py` `python test.py`
...@@ -91,20 +91,27 @@ class Client(object): ...@@ -91,20 +91,27 @@ class Client(object):
working directory. working directory.
""" """
pyfiles = dict() pyfiles = dict()
pyfiles['python_files'] = {}
pyfiles['other_files'] = {}
code_files = filter(lambda x: x.endswith('.py'), os.listdir('./')) code_files = filter(lambda x: x.endswith('.py'), os.listdir('./'))
to_distributed_files = list(code_files) + distributed_files
for file in to_distributed_files: try:
try: for file in code_files:
assert os.path.exists(file) assert os.path.exists(file)
with open(file, 'rb') as code_file: with open(file, 'rb') as code_file:
code = code_file.read() code = code_file.read()
pyfiles[file] = code pyfiles['python_files'][file] = code
except AssertionError as e:
raise Exception( for file in distributed_files:
'Failed to create the client, the file {} does not exist.'. assert os.path.exists(file)
format(file)) with open(file, 'rb') as f:
content = f.read()
pyfiles['other_files'][file] = content
except AssertionError as e:
raise Exception(
'Failed to create the client, the file {} does not exist.'.
format(file))
return cloudpickle.dumps(pyfiles) return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address): def _create_sockets(self, master_address):
...@@ -173,7 +180,7 @@ class Client(object): ...@@ -173,7 +180,7 @@ class Client(object):
logger.warning("Client exit replying heartbeat for master.") logger.warning("Client exit replying heartbeat for master.")
def _check_and_monitor_job(self, job_heartbeat_address, def _check_and_monitor_job(self, job_heartbeat_address,
ping_heartbeat_address): ping_heartbeat_address, max_memory):
""" Sometimes the client may receive a job that is dead, thus """ Sometimes the client may receive a job that is dead, thus
we have to check if this job is still alive before sending it to the actor. we have to check if this job is still alive before sending it to the actor.
""" """
...@@ -184,7 +191,8 @@ class Client(object): ...@@ -184,7 +191,8 @@ class Client(object):
job_heartbeat_socket.connect("tcp://" + ping_heartbeat_address) job_heartbeat_socket.connect("tcp://" + ping_heartbeat_address)
try: try:
job_heartbeat_socket.send_multipart( job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG]) [remote_constants.HEARTBEAT_TAG,
to_byte(str(max_memory))])
job_heartbeat_socket.recv_multipart() job_heartbeat_socket.recv_multipart()
except zmq.error.Again: except zmq.error.Again:
job_heartbeat_socket.close(0) job_heartbeat_socket.close(0)
...@@ -231,6 +239,9 @@ class Client(object): ...@@ -231,6 +239,9 @@ class Client(object):
job_is_alive = False job_is_alive = False
self.lock.acquire() self.lock.acquire()
self.actor_num -= 1 self.actor_num -= 1
logger.error(
'[xparl] lost connection with a job, current actor num: {}'
.format(self.actor_num))
self.lock.release() self.lock.release()
except zmq.error.ZMQError as e: except zmq.error.ZMQError as e:
...@@ -238,13 +249,18 @@ class Client(object): ...@@ -238,13 +249,18 @@ class Client(object):
job_heartbeat_socket.close(0) job_heartbeat_socket.close(0)
def submit_job(self): def submit_job(self, max_memory):
"""Send a job to the Master node. """Send a job to the Master node.
When a `@parl.remote_class` object is created, the global client When a `@parl.remote_class` object is created, the global client
sends a job to the master node. Then the master node will allocate sends a job to the master node. Then the master node will allocate
a vacant job from its job pool to the remote object. a vacant job from its job pool to the remote object.
Args:
max_memory (float): Maximum memory (MB) can be used by each remote
instance, the unit is in MB and default value is
none(unlimited).
Returns: Returns:
job_address(str): IP address of the job. None if there is no available CPU in the cluster. job_address(str): IP address of the job. None if there is no available CPU in the cluster.
""" """
...@@ -268,7 +284,8 @@ class Client(object): ...@@ -268,7 +284,8 @@ class Client(object):
ping_heartbeat_address = to_str(message[3]) ping_heartbeat_address = to_str(message[3])
check_result = self._check_and_monitor_job( check_result = self._check_and_monitor_job(
job_heartbeat_address, ping_heartbeat_address) job_heartbeat_address, ping_heartbeat_address,
max_memory)
if check_result: if check_result:
self.lock.acquire() self.lock.acquire()
self.actor_num += 1 self.actor_num += 1
......
...@@ -26,6 +26,7 @@ import threading ...@@ -26,6 +26,7 @@ import threading
import time import time
import traceback import traceback
import zmq import zmq
from multiprocessing import Process, Pipe
from parl.utils import to_str, to_byte, get_ip_address, logger from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import loads_argument, loads_return,\ from parl.utils.communication import loads_argument, loads_return,\
dumps_argument, dumps_return dumps_argument, dumps_return
...@@ -38,8 +39,8 @@ class Job(object): ...@@ -38,8 +39,8 @@ class Job(object):
"""Base class for the job. """Base class for the job.
After establishing connection with the remote object, the job will After establishing connection with the remote object, the job will
create a remote class instance locally and enter an infinite loop, create a remote class instance locally and enter an infinite loop
waiting for commands from the remote object. in a separate process, waiting for commands from the remote object.
""" """
...@@ -52,36 +53,50 @@ class Job(object): ...@@ -52,36 +53,50 @@ class Job(object):
pid (int): Job process ID. pid (int): Job process ID.
max_memory (float): Maximum memory (MB) can be used by each remote instance. max_memory (float): Maximum memory (MB) can be used by each remote instance.
""" """
self.job_is_alive = True self.max_memory = None
self.job_address_receiver, job_address_sender = Pipe()
self.worker_address = worker_address self.worker_address = worker_address
self.job_ip = get_ip_address()
self.pid = os.getpid() self.pid = os.getpid()
self.max_memory = None
self.lock = threading.Lock() self.lock = threading.Lock()
self.run_job_process = Process(
target=self.run, args=(job_address_sender, ))
self.run_job_process.start()
self._create_sockets() self._create_sockets()
process = psutil.Process(self.pid) process = psutil.Process(self.pid)
self.init_memory = float(process.memory_info()[0]) / (1024**2) self.init_memory = float(process.memory_info()[0]) / (1024**2)
self.run_job_process.join()
with self.lock:
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
try:
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
os._exit(1)
def _create_sockets(self): def _create_sockets(self):
"""Create three sockets for each job. """Create five sockets for each job in main process.
(1) reply_socket(main socket): receives the command(i.e, the function name and args) (1) job_socket(functional socket): sends job_address and heartbeat_address to worker.
from the actual class instance, completes the computation, and returns the result of (2) ping_heartbeat_socket: replies ping message of client.
the function. (3) worker_heartbeat_socket: replies heartbeat message of worker.
(2) job_socket(functional socket): sends job_address and heartbeat_address to worker. (4) client_heartbeat_socket: replies heartbeat message of client.
(3) kill_job_socket: sends a command to the corresponding worker to kill the job. (5) kill_job_socket: sends a command to the corresponding worker to kill the job.
""" """
# wait for another process to create reply socket
self.job_address = self.job_address_receiver.recv()
self.ctx = zmq.Context() self.ctx = zmq.Context()
# create the reply_socket
self.reply_socket = self.ctx.socket(zmq.REP)
job_port = self.reply_socket.bind_to_random_port(addr="tcp://*")
self.reply_socket.linger = 0
self.job_ip = get_ip_address()
self.job_address = "{}:{}".format(self.job_ip, job_port)
# create the job_socket # create the job_socket
self.job_socket = self.ctx.socket(zmq.REQ) self.job_socket = self.ctx.socket(zmq.REQ)
self.job_socket.connect("tcp://{}".format(self.worker_address)) self.job_socket.connect("tcp://{}".format(self.worker_address))
...@@ -93,7 +108,6 @@ class Job(object): ...@@ -93,7 +108,6 @@ class Job(object):
target=self._reply_ping, args=(ping_heartbeat_socket, )) target=self._reply_ping, args=(ping_heartbeat_socket, ))
ping_thread.setDaemon(True) ping_thread.setDaemon(True)
ping_thread.start() ping_thread.start()
self.ping_heartbeat_address = ping_heartbeat_address
# a thread that reply heartbeat signals from the worker # a thread that reply heartbeat signals from the worker
worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server( worker_heartbeat_socket, worker_heartbeat_address = self._create_heartbeat_server(
...@@ -114,8 +128,7 @@ class Job(object): ...@@ -114,8 +128,7 @@ class Job(object):
# sends job information to the worker # sends job information to the worker
initialized_job = InitializedJob( initialized_job = InitializedJob(
self.job_address, worker_heartbeat_address, self.job_address, worker_heartbeat_address,
client_heartbeat_address, self.ping_heartbeat_address, None, client_heartbeat_address, ping_heartbeat_address, None, self.pid)
self.pid)
self.job_socket.send_multipart( self.job_socket.send_multipart(
[remote_constants.NORMAL_TAG, [remote_constants.NORMAL_TAG,
cloudpickle.dumps(initialized_job)]) cloudpickle.dumps(initialized_job)])
...@@ -145,9 +158,12 @@ class Job(object): ...@@ -145,9 +158,12 @@ class Job(object):
"""Create a socket server that reply the ping signal from client. """Create a socket server that reply the ping signal from client.
This signal is used to make sure that the job is still alive. This signal is used to make sure that the job is still alive.
""" """
while self.job_is_alive: message = socket.recv_multipart()
message = socket.recv_multipart() max_memory = to_str(message[1])
socket.send_multipart([remote_constants.HEARTBEAT_TAG]) if max_memory != 'None':
self.max_memory = float(max_memory)
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
self.client_thread.start()
socket.close(0) socket.close(0)
def _create_heartbeat_server(self, timeout=True): def _create_heartbeat_server(self, timeout=True):
...@@ -166,8 +182,7 @@ class Job(object): ...@@ -166,8 +182,7 @@ class Job(object):
"""Create a socket that replies heartbeat signals from the client. """Create a socket that replies heartbeat signals from the client.
If the job losts connection with the client, it will exit too. If the job losts connection with the client, it will exit too.
""" """
self.client_is_alive = True while True:
while self.client_is_alive and self.job_is_alive:
try: try:
message = socket.recv_multipart() message = socket.recv_multipart()
stop_job = self._check_used_memory() stop_job = self._check_used_memory()
...@@ -187,7 +202,7 @@ class Job(object): ...@@ -187,7 +202,7 @@ class Job(object):
logger.warning( logger.warning(
"[Job] Cannot connect to the client. This job will exit and inform the worker." "[Job] Cannot connect to the client. This job will exit and inform the worker."
) )
self.client_is_alive = False break
socket.close(0) socket.close(0)
with self.lock: with self.lock:
self.kill_job_socket.send_multipart( self.kill_job_socket.send_multipart(
...@@ -204,73 +219,77 @@ class Job(object): ...@@ -204,73 +219,77 @@ class Job(object):
"""create a socket that replies heartbeat signals from the worker. """create a socket that replies heartbeat signals from the worker.
If the worker has exited, the job will exit automatically. If the worker has exited, the job will exit automatically.
""" """
while True:
self.worker_is_alive = True
# a flag to decide when to exit heartbeat loop
while self.worker_is_alive and self.job_is_alive:
try: try:
message = socket.recv_multipart() message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG]) socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e: except zmq.error.Again as e:
logger.warning("[Job] Cannot connect to the worker{}. ".format( logger.warning("[Job] Cannot connect to the worker{}. ".format(
self.worker_address) + "Job will quit.") self.worker_address) + "Job will quit.")
self.worker_is_alive = False break
self.job_is_alive = False
socket.close(0) socket.close(0)
os._exit(1) os._exit(1)
def wait_for_files(self): def wait_for_files(self, reply_socket, job_address):
"""Wait for python files from remote object. """Wait for python files from remote object.
When a remote object receives the allocated job address, it will send When a remote object receives the allocated job address, it will send
the python files to the job. Later, the job will save these files to a the python files to the job. Later, the job will save these files to a
temporary directory and add the temporary diretory to Python's working temporary directory and add the temporary diretory to Python's working
directory. directory.
Args:
reply_socket (sockert): main socket to accept commands of remote object.
job_address (String): address of reply_socket.
Returns: Returns:
A temporary directory containing the python files. A temporary directory containing the python files.
""" """
while True: message = reply_socket.recv_multipart()
message = self.reply_socket.recv_multipart() tag = message[0]
tag = message[0] if tag == remote_constants.SEND_FILE_TAG:
if tag == remote_constants.SEND_FILE_TAG: pyfiles = pickle.loads(message[1])
pyfiles = pickle.loads(message[1]) # save python files to temporary directory
envdir = tempfile.mkdtemp() envdir = tempfile.mkdtemp()
for file in pyfiles: for file, code in pyfiles['python_files'].items():
code = pyfiles[file] file = os.path.join(envdir, file)
with open(file, 'wb') as code_file:
# create directory (i.e. ./rom_files/) code_file.write(code)
if '/' in file:
try: # save other files to current directory
os.makedirs( for file, content in pyfiles['other_files'].items():
os.path.join(envdir, # create directory (i.e. ./rom_files/)
*file.rsplit('/')[:-1])) if '/' in file:
except OSError as e: try:
pass os.makedirs(os.path.join(*file.rsplit('/')[:-1]))
except OSError as e:
file = os.path.join(envdir, file) pass
with open(file, 'wb') as code_file: with open(file, 'wb') as f:
code_file.write(code) f.write(content)
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) logger.info('[job] reply')
return envdir reply_socket.send_multipart([remote_constants.NORMAL_TAG])
else: return envdir
logger.error("NotImplementedError:{}, received tag:{}".format( else:
self.job_address, )) logger.error("NotImplementedError:{}, received tag:{}".format(
raise NotImplementedError job_address, ))
raise NotImplementedError
def wait_for_connection(self): def wait_for_connection(self, reply_socket):
"""Wait for connection from the remote object. """Wait for connection from the remote object.
The remote object will send its class information and initialization The remote object will send its class information and initialization
arguments to the job, these parameters are then used to create a arguments to the job, these parameters are then used to create a
local instance in the job process. local instance in the job process.
Args:
reply_socket (sockert): main socket to accept commands of remote object.
Returns: Returns:
A local instance of the remote class object. A local instance of the remote class object.
""" """
message = self.reply_socket.recv_multipart() message = reply_socket.recv_multipart()
tag = message[0] tag = message[0]
obj = None obj = None
...@@ -278,24 +297,20 @@ class Job(object): ...@@ -278,24 +297,20 @@ class Job(object):
try: try:
cls = cloudpickle.loads(message[1]) cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2]) args, kwargs = cloudpickle.loads(message[2])
max_memory = to_str(message[3])
if max_memory != 'None':
self.max_memory = float(max_memory)
obj = cls(*args, **kwargs) obj = cls(*args, **kwargs)
except Exception as e: except Exception as e:
traceback_str = str(traceback.format_exc()) traceback_str = str(traceback.format_exc())
error_str = str(e) error_str = str(e)
logger.error("traceback:\n{}".format(traceback_str)) logger.error("traceback:\n{}".format(traceback_str))
self.reply_socket.send_multipart([ reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG, remote_constants.EXCEPTION_TAG,
to_byte(error_str + "\ntraceback:\n" + traceback_str) to_byte(error_str + "\ntraceback:\n" + traceback_str)
]) ])
self.client_is_alive = False
return None return None
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) reply_socket.send_multipart([remote_constants.NORMAL_TAG])
else: else:
logger.error("Message from job {}".format(message)) logger.error("Message from job {}".format(message))
self.reply_socket.send_multipart([ reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG, remote_constants.EXCEPTION_TAG,
b"[job]Unkonwn tag when tried to receive the class definition" b"[job]Unkonwn tag when tried to receive the class definition"
]) ])
...@@ -303,36 +318,39 @@ class Job(object): ...@@ -303,36 +318,39 @@ class Job(object):
return obj return obj
def run(self): def run(self, job_address_sender):
"""An infinite loop waiting for a new task. """An infinite loop waiting for a new task.
Args:
job_address_sender(sending end of multiprocessing.Pipe): send job address of reply_socket to main process.
""" """
# receive source code from the actor and append them to the environment variables. ctx = zmq.Context()
envdir = self.wait_for_files()
sys.path.append(envdir) # create the reply_socket
self.client_is_alive = True reply_socket = ctx.socket(zmq.REP)
self.client_thread.start() job_port = reply_socket.bind_to_random_port(addr="tcp://*")
reply_socket.linger = 0
job_ip = get_ip_address()
job_address = "{}:{}".format(job_ip, job_port)
job_address_sender.send(job_address)
try: try:
obj = self.wait_for_connection() # receive source code from the actor and append them to the environment variables.
envdir = self.wait_for_files(reply_socket, job_address)
sys.path.append(envdir)
obj = self.wait_for_connection(reply_socket)
assert obj is not None assert obj is not None
self.single_task(obj) self.single_task(obj, reply_socket, job_address)
except Exception as e: except Exception as e:
logger.error( logger.error(
"Error occurs when running a single task. We will reset this job. Reason:{}" "Error occurs when running a single task. We will reset this job. Reason:{}"
.format(e)) .format(e))
traceback_str = str(traceback.format_exc()) traceback_str = str(traceback.format_exc())
logger.error("traceback:\n{}".format(traceback_str)) logger.error("traceback:\n{}".format(traceback_str))
with self.lock:
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
try:
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
os._exit(1)
def single_task(self, obj): def single_task(self, obj, reply_socket, job_address):
"""An infinite loop waiting for commands from the remote object. """An infinite loop waiting for commands from the remote object.
Each job will receive two kinds of message from the remote object: Each job will receive two kinds of message from the remote object:
...@@ -342,10 +360,14 @@ class Job(object): ...@@ -342,10 +360,14 @@ class Job(object):
remote object. remote object.
2. When the remote object is deleted, the job will quit and release 2. When the remote object is deleted, the job will quit and release
related computation resources. related computation resources.
Args:
reply_socket (sockert): main socket to accept commands of remote object.
job_address (String): address of reply_socket.
""" """
while self.job_is_alive and self.client_is_alive: while True:
message = self.reply_socket.recv_multipart() message = reply_socket.recv_multipart()
tag = message[0] tag = message[0]
...@@ -357,32 +379,31 @@ class Job(object): ...@@ -357,32 +379,31 @@ class Job(object):
ret = getattr(obj, function_name)(*args, **kwargs) ret = getattr(obj, function_name)(*args, **kwargs)
ret = dumps_return(ret) ret = dumps_return(ret)
self.reply_socket.send_multipart( reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret]) [remote_constants.NORMAL_TAG, ret])
except Exception as e: except Exception as e:
# reset the job # reset the job
self.client_is_alive = False
error_str = str(e) error_str = str(e)
logger.error(error_str) logger.error(error_str)
if type(e) == AttributeError: if type(e) == AttributeError:
self.reply_socket.send_multipart([ reply_socket.send_multipart([
remote_constants.ATTRIBUTE_EXCEPTION_TAG, remote_constants.ATTRIBUTE_EXCEPTION_TAG,
to_byte(error_str) to_byte(error_str)
]) ])
raise AttributeError raise AttributeError
elif type(e) == SerializeError: elif type(e) == SerializeError:
self.reply_socket.send_multipart([ reply_socket.send_multipart([
remote_constants.SERIALIZE_EXCEPTION_TAG, remote_constants.SERIALIZE_EXCEPTION_TAG,
to_byte(error_str) to_byte(error_str)
]) ])
raise SerializeError raise SerializeError
elif type(e) == DeserializeError: elif type(e) == DeserializeError:
self.reply_socket.send_multipart([ reply_socket.send_multipart([
remote_constants.DESERIALIZE_EXCEPTION_TAG, remote_constants.DESERIALIZE_EXCEPTION_TAG,
to_byte(error_str) to_byte(error_str)
]) ])
...@@ -391,7 +412,7 @@ class Job(object): ...@@ -391,7 +412,7 @@ class Job(object):
else: else:
traceback_str = str(traceback.format_exc()) traceback_str = str(traceback.format_exc())
logger.error("traceback:\n{}".format(traceback_str)) logger.error("traceback:\n{}".format(traceback_str))
self.reply_socket.send_multipart([ reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG, remote_constants.EXCEPTION_TAG,
to_byte(error_str + "\ntraceback:\n" + to_byte(error_str + "\ntraceback:\n" +
traceback_str) traceback_str)
...@@ -400,11 +421,9 @@ class Job(object): ...@@ -400,11 +421,9 @@ class Job(object):
# receive DELETE_TAG from actor, and stop replying worker heartbeat # receive DELETE_TAG from actor, and stop replying worker heartbeat
elif tag == remote_constants.KILLJOB_TAG: elif tag == remote_constants.KILLJOB_TAG:
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) reply_socket.send_multipart([remote_constants.NORMAL_TAG])
self.client_is_alive = False logger.warning("An actor exits and this job {} will exit.".
logger.warning( format(job_address))
"An actor exits and this job {} will exit.".format(
self.job_address))
break break
else: else:
logger.error( logger.error(
...@@ -418,4 +437,3 @@ if __name__ == "__main__": ...@@ -418,4 +437,3 @@ if __name__ == "__main__":
"--worker_address", required=True, type=str, help="worker_address") "--worker_address", required=True, type=str, help="worker_address")
args = parser.parse_args() args = parser.parse_args()
job = Job(args.worker_address) job = Job(args.worker_address)
job.run()
...@@ -92,7 +92,8 @@ def remote_class(*args, **kwargs): ...@@ -92,7 +92,8 @@ def remote_class(*args, **kwargs):
# GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
# finds the master is dead. # finds the master is dead.
if self.GLOBAL_CLIENT.master_is_alive: if self.GLOBAL_CLIENT.master_is_alive:
job_address = self.request_cpu_resource(self.GLOBAL_CLIENT) job_address = self.request_cpu_resource(
self.GLOBAL_CLIENT, max_memory)
else: else:
raise Exception("Can not submit job to the master. " raise Exception("Can not submit job to the master. "
"Please check if master is still alive.") "Please check if master is still alive.")
...@@ -117,7 +118,6 @@ def remote_class(*args, **kwargs): ...@@ -117,7 +118,6 @@ def remote_class(*args, **kwargs):
remote_constants.INIT_OBJECT_TAG, remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps(cls), cloudpickle.dumps(cls),
cloudpickle.dumps([args, kwargs]), cloudpickle.dumps([args, kwargs]),
to_byte(str(max_memory))
]) ])
message = self.job_socket.recv_multipart() message = self.job_socket.recv_multipart()
tag = message[0] tag = message[0]
...@@ -149,11 +149,11 @@ def remote_class(*args, **kwargs): ...@@ -149,11 +149,11 @@ def remote_class(*args, **kwargs):
except zmq.error.Again as e: except zmq.error.Again as e:
logger.error("Send python files failed.") logger.error("Send python files failed.")
def request_cpu_resource(self, global_client): def request_cpu_resource(self, global_client, max_memory):
"""Try to request cpu resource for 1 second/time for 300 times.""" """Try to request cpu resource for 1 second/time for 300 times."""
cnt = 300 cnt = 300
while cnt > 0: while cnt > 0:
job_address = global_client.submit_job() job_address = global_client.submit_job(max_memory)
if job_address is not None: if job_address is not None:
return job_address return job_address
if cnt % 30 == 0: if cnt % 30 == 0:
......
...@@ -86,8 +86,8 @@ def cli(): ...@@ -86,8 +86,8 @@ def cli():
@click.option("--port", help="The port to bind to.", type=str, required=True) @click.option("--port", help="The port to bind to.", type=str, required=True)
@click.option( @click.option(
"--debug", "--debug",
help="Start parl in debug mode to show all logs.", help="Start parl in the debugging mode to print all running log.",
default=False) is_flag=True)
@click.option( @click.option(
"--cpu_num", "--cpu_num",
type=int, type=int,
......
...@@ -56,10 +56,11 @@ class Worker(object): ...@@ -56,10 +56,11 @@ class Worker(object):
reply_job_socket (zmq.Context.socket): A socket which receives reply_job_socket (zmq.Context.socket): A socket which receives
job_address from the job. job_address from the job.
kill_job_socket (zmq.Context.socket): A socket that receives commands to kill the job from jobs. kill_job_socket (zmq.Context.socket): A socket that receives commands to kill the job from jobs.
job_buffer (str): A buffer that stores initialized jobs for providing new jobs in a short time.
Args: Args:
master_address (str): IP address of the master node. master_address (str): IP address of the master node.
cpu_num (int): Number of cpu to be used on the worker. cpu_num (int): Number of cpu to be used on the worker.
job_buffer (str): A buffer that stores initialized jobs for providing new jobs in a short time.
""" """
def __init__(self, master_address, cpu_num=None): def __init__(self, master_address, cpu_num=None):
...@@ -170,9 +171,13 @@ class Worker(object): ...@@ -170,9 +171,13 @@ class Worker(object):
"""An endless loop that adds initialized job into the job buffer""" """An endless loop that adds initialized job into the job buffer"""
while self.worker_is_alive: while self.worker_is_alive:
if self.job_buffer.full() is False: if self.job_buffer.full() is False:
initialized_jobs = self._init_jobs(job_num=self.cpu_num) job_num = self.cpu_num - self.job_buffer.qsize()
for job in initialized_jobs: if job_num > 0:
self.job_buffer.put(job) initialized_jobs = self._init_jobs(job_num=job_num)
for job in initialized_jobs:
self.job_buffer.put(job)
time.sleep(0.02)
# release jobs if the worker is not alive # release jobs if the worker is not alive
for job in initialized_jobs: for job in initialized_jobs:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册