未验证 提交 c2c947c3 编写于 作者: Y Yuecheng Liu 提交者: GitHub

check parl and python version (#389)

* check parl and python version

* yapf

* yapf.

* add more comments

* more comments
上级 ec77a4ba
...@@ -19,6 +19,7 @@ import socket ...@@ -19,6 +19,7 @@ import socket
import sys import sys
import threading import threading
import zmq import zmq
import parl
from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
from parl.remote import remote_constants from parl.remote import remote_constants
import time import time
...@@ -66,6 +67,7 @@ class Client(object): ...@@ -66,6 +67,7 @@ class Client(object):
self.actor_num = 0 self.actor_num = 0
self._create_sockets(master_address) self._create_sockets(master_address)
self.check_version()
self.pyfiles = self.read_local_files(distributed_files) self.pyfiles = self.read_local_files(distributed_files)
def get_executable_path(self): def get_executable_path(self):
...@@ -179,6 +181,24 @@ class Client(object): ...@@ -179,6 +181,24 @@ class Client(object):
"check if master is started and ensure the input " "check if master is started and ensure the input "
"address {} is correct.".format(master_address)) "address {} is correct.".format(master_address))
def check_version(self):
'''Verify that the parl & python version in 'client' process matches that of the 'master' process'''
self.submit_job_socket.send_multipart(
[remote_constants.CHECK_VERSION_TAG])
message = self.submit_job_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
client_parl_version = parl.__version__
client_python_version = str(sys.version_info.major)
assert client_parl_version == to_str(message[1]) and client_python_version == to_str(message[2]),\
'''Version mismatch: the 'master' is of version 'parl={}, python={}'. However,
'parl={}, python={}'is provided in your environment.'''.format(
to_str(message[1]), to_str(message[2]),
client_parl_version, client_python_version
)
else:
raise NotImplementedError
def _reply_heartbeat(self): def _reply_heartbeat(self):
"""Reply heartbeat signals to the master node.""" """Reply heartbeat signals to the master node."""
......
...@@ -18,6 +18,8 @@ import threading ...@@ -18,6 +18,8 @@ import threading
import time import time
import zmq import zmq
from collections import deque, defaultdict from collections import deque, defaultdict
import parl
import sys
from parl.utils import to_str, to_byte, logger, get_ip_address from parl.utils import to_str, to_byte, logger, get_ip_address
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.remote.job_center import JobCenter from parl.remote.job_center import JobCenter
...@@ -208,6 +210,7 @@ class Master(object): ...@@ -208,6 +210,7 @@ class Master(object):
elif tag == remote_constants.CLIENT_CONNECT_TAG: elif tag == remote_constants.CLIENT_CONNECT_TAG:
# `client_heartbeat_address` is the # `client_heartbeat_address` is the
# `reply_master_heartbeat_address` of the client # `reply_master_heartbeat_address` of the client
client_heartbeat_address = to_str(message[1]) client_heartbeat_address = to_str(message[1])
client_hostname = to_str(message[2]) client_hostname = to_str(message[2])
client_id = to_str(message[3]) client_id = to_str(message[3])
...@@ -225,6 +228,13 @@ class Master(object): ...@@ -225,6 +228,13 @@ class Master(object):
[remote_constants.NORMAL_TAG, [remote_constants.NORMAL_TAG,
to_byte(log_monitor_address)]) to_byte(log_monitor_address)])
elif tag == remote_constants.CHECK_VERSION_TAG:
self.client_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(parl.__version__),
to_byte(str(sys.version_info.major))
])
# a client submits a job to the master # a client submits a job to the master
elif tag == remote_constants.CLIENT_SUBMIT_TAG: elif tag == remote_constants.CLIENT_SUBMIT_TAG:
# check available CPU resources # check available CPU resources
......
...@@ -27,6 +27,7 @@ SEND_FILE_TAG = b'[SEND_FILE]' ...@@ -27,6 +27,7 @@ SEND_FILE_TAG = b'[SEND_FILE]'
SUBMIT_JOB_TAG = b'[SUBMIT_JOB]' SUBMIT_JOB_TAG = b'[SUBMIT_JOB]'
NEW_JOB_TAG = b'[NEW_JOB]' NEW_JOB_TAG = b'[NEW_JOB]'
CHECK_VERSION_TAG = b'[CHECK_VERSION]'
INIT_OBJECT_TAG = b'[INIT_OBJECT]' INIT_OBJECT_TAG = b'[INIT_OBJECT]'
CALL_TAG = b'[CALL]' CALL_TAG = b'[CALL]'
GET_ATTRIBUTE = b'[GET_ATTRIBUTE]' GET_ATTRIBUTE = b'[GET_ATTRIBUTE]'
......
...@@ -26,7 +26,7 @@ import threading ...@@ -26,7 +26,7 @@ import threading
import warnings import warnings
import zmq import zmq
from datetime import datetime from datetime import datetime
import parl
from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS, kill_process from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS, kill_process
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.remote.message import InitializedWorker from parl.remote.message import InitializedWorker
...@@ -75,6 +75,7 @@ class Worker(object): ...@@ -75,6 +75,7 @@ class Worker(object):
self._set_cpu_num(cpu_num) self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num) self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets() self._create_sockets()
self.check_version()
# create log server # create log server
self.log_server_proc, self.log_server_address = self._create_log_server( self.log_server_proc, self.log_server_address = self._create_log_server(
port=log_server_port) port=log_server_port)
...@@ -101,6 +102,24 @@ class Worker(object): ...@@ -101,6 +102,24 @@ class Worker(object):
else: else:
self.cpu_num = multiprocessing.cpu_count() self.cpu_num = multiprocessing.cpu_count()
def check_version(self):
'''Verify that the parl & python version in 'worker' process matches that of the 'master' process'''
self.request_master_socket.send_multipart(
[remote_constants.CHECK_VERSION_TAG])
message = self.request_master_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
worker_parl_version = parl.__version__
worker_python_version = str(sys.version_info.major)
assert worker_parl_version == to_str(message[1]) and worker_python_version == to_str(message[2]),\
'''Version mismatch: the "master" is of version "parl={}, python={}". However,
"parl={}, python={}"is provided in your environment.'''.format(
to_str(message[1]), to_str(message[2]),
worker_parl_version, worker_python_version
)
else:
raise NotImplementedError
def _create_sockets(self): def _create_sockets(self):
""" Each worker has three sockets at start: """ Each worker has three sockets at start:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册