未验证 提交 c2c947c3 编写于 作者: Y Yuecheng Liu 提交者: GitHub

check parl and python version (#389)

* check parl and python version

* yapf

* yapf.

* add more comments

* more comments
上级 ec77a4ba
......@@ -19,6 +19,7 @@ import socket
import sys
import threading
import zmq
import parl
from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook
from parl.remote import remote_constants
import time
......@@ -66,6 +67,7 @@ class Client(object):
self.actor_num = 0
self._create_sockets(master_address)
self.check_version()
self.pyfiles = self.read_local_files(distributed_files)
def get_executable_path(self):
......@@ -179,6 +181,24 @@ class Client(object):
"check if master is started and ensure the input "
"address {} is correct.".format(master_address))
def check_version(self):
'''Verify that the parl & python version in 'client' process matches that of the 'master' process'''
self.submit_job_socket.send_multipart(
[remote_constants.CHECK_VERSION_TAG])
message = self.submit_job_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
client_parl_version = parl.__version__
client_python_version = str(sys.version_info.major)
assert client_parl_version == to_str(message[1]) and client_python_version == to_str(message[2]),\
'''Version mismatch: the 'master' is of version 'parl={}, python={}'. However,
'parl={}, python={}'is provided in your environment.'''.format(
to_str(message[1]), to_str(message[2]),
client_parl_version, client_python_version
)
else:
raise NotImplementedError
def _reply_heartbeat(self):
"""Reply heartbeat signals to the master node."""
......
......@@ -18,6 +18,8 @@ import threading
import time
import zmq
from collections import deque, defaultdict
import parl
import sys
from parl.utils import to_str, to_byte, logger, get_ip_address
from parl.remote import remote_constants
from parl.remote.job_center import JobCenter
......@@ -208,6 +210,7 @@ class Master(object):
elif tag == remote_constants.CLIENT_CONNECT_TAG:
# `client_heartbeat_address` is the
# `reply_master_heartbeat_address` of the client
client_heartbeat_address = to_str(message[1])
client_hostname = to_str(message[2])
client_id = to_str(message[3])
......@@ -225,6 +228,13 @@ class Master(object):
[remote_constants.NORMAL_TAG,
to_byte(log_monitor_address)])
elif tag == remote_constants.CHECK_VERSION_TAG:
self.client_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(parl.__version__),
to_byte(str(sys.version_info.major))
])
# a client submits a job to the master
elif tag == remote_constants.CLIENT_SUBMIT_TAG:
# check available CPU resources
......
......@@ -27,6 +27,7 @@ SEND_FILE_TAG = b'[SEND_FILE]'
SUBMIT_JOB_TAG = b'[SUBMIT_JOB]'
NEW_JOB_TAG = b'[NEW_JOB]'
CHECK_VERSION_TAG = b'[CHECK_VERSION]'
INIT_OBJECT_TAG = b'[INIT_OBJECT]'
CALL_TAG = b'[CALL]'
GET_ATTRIBUTE = b'[GET_ATTRIBUTE]'
......
......@@ -26,7 +26,7 @@ import threading
import warnings
import zmq
from datetime import datetime
import parl
from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS, kill_process
from parl.remote import remote_constants
from parl.remote.message import InitializedWorker
......@@ -75,6 +75,7 @@ class Worker(object):
self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets()
self.check_version()
# create log server
self.log_server_proc, self.log_server_address = self._create_log_server(
port=log_server_port)
......@@ -101,6 +102,24 @@ class Worker(object):
else:
self.cpu_num = multiprocessing.cpu_count()
def check_version(self):
'''Verify that the parl & python version in 'worker' process matches that of the 'master' process'''
self.request_master_socket.send_multipart(
[remote_constants.CHECK_VERSION_TAG])
message = self.request_master_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
worker_parl_version = parl.__version__
worker_python_version = str(sys.version_info.major)
assert worker_parl_version == to_str(message[1]) and worker_python_version == to_str(message[2]),\
'''Version mismatch: the "master" is of version "parl={}, python={}". However,
"parl={}, python={}"is provided in your environment.'''.format(
to_str(message[1]), to_str(message[2]),
worker_parl_version, worker_python_version
)
else:
raise NotImplementedError
def _create_sockets(self):
""" Each worker has three sockets at start:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册