From c2c947c377bd3f495f0bda2bf85305ec38ba2acb Mon Sep 17 00:00:00 2001 From: Yuecheng Liu <52879090+liuyuecheng-github@users.noreply.github.com> Date: Tue, 18 Aug 2020 10:35:34 +0800 Subject: [PATCH] check parl and python version (#389) * check parl and python version * yapf * yapf. * add more comments * more comments --- parl/remote/client.py | 20 ++++++++++++++++++++ parl/remote/master.py | 10 ++++++++++ parl/remote/remote_constants.py | 1 + parl/remote/worker.py | 21 ++++++++++++++++++++- 4 files changed, 51 insertions(+), 1 deletion(-) diff --git a/parl/remote/client.py b/parl/remote/client.py index 9d8f30b..344debe 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -19,6 +19,7 @@ import socket import sys import threading import zmq +import parl from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook from parl.remote import remote_constants import time @@ -66,6 +67,7 @@ class Client(object): self.actor_num = 0 self._create_sockets(master_address) + self.check_version() self.pyfiles = self.read_local_files(distributed_files) def get_executable_path(self): @@ -179,6 +181,24 @@ class Client(object): "check if master is started and ensure the input " "address {} is correct.".format(master_address)) + def check_version(self): + '''Verify that the parl & python version in 'client' process matches that of the 'master' process''' + self.submit_job_socket.send_multipart( + [remote_constants.CHECK_VERSION_TAG]) + message = self.submit_job_socket.recv_multipart() + tag = message[0] + if tag == remote_constants.NORMAL_TAG: + client_parl_version = parl.__version__ + client_python_version = str(sys.version_info.major) + assert client_parl_version == to_str(message[1]) and client_python_version == to_str(message[2]),\ + '''Version mismatch: the 'master' is of version 'parl={}, python={}'. However, + 'parl={}, python={}'is provided in your environment.'''.format( + to_str(message[1]), to_str(message[2]), + client_parl_version, client_python_version + ) + else: + raise NotImplementedError + def _reply_heartbeat(self): """Reply heartbeat signals to the master node.""" diff --git a/parl/remote/master.py b/parl/remote/master.py index 8cca029..7964c56 100644 --- a/parl/remote/master.py +++ b/parl/remote/master.py @@ -18,6 +18,8 @@ import threading import time import zmq from collections import deque, defaultdict +import parl +import sys from parl.utils import to_str, to_byte, logger, get_ip_address from parl.remote import remote_constants from parl.remote.job_center import JobCenter @@ -208,6 +210,7 @@ class Master(object): elif tag == remote_constants.CLIENT_CONNECT_TAG: # `client_heartbeat_address` is the # `reply_master_heartbeat_address` of the client + client_heartbeat_address = to_str(message[1]) client_hostname = to_str(message[2]) client_id = to_str(message[3]) @@ -225,6 +228,13 @@ class Master(object): [remote_constants.NORMAL_TAG, to_byte(log_monitor_address)]) + elif tag == remote_constants.CHECK_VERSION_TAG: + self.client_socket.send_multipart([ + remote_constants.NORMAL_TAG, + to_byte(parl.__version__), + to_byte(str(sys.version_info.major)) + ]) + # a client submits a job to the master elif tag == remote_constants.CLIENT_SUBMIT_TAG: # check available CPU resources diff --git a/parl/remote/remote_constants.py b/parl/remote/remote_constants.py index db7a86e..7ce2ae1 100644 --- a/parl/remote/remote_constants.py +++ b/parl/remote/remote_constants.py @@ -27,6 +27,7 @@ SEND_FILE_TAG = b'[SEND_FILE]' SUBMIT_JOB_TAG = b'[SUBMIT_JOB]' NEW_JOB_TAG = b'[NEW_JOB]' +CHECK_VERSION_TAG = b'[CHECK_VERSION]' INIT_OBJECT_TAG = b'[INIT_OBJECT]' CALL_TAG = b'[CALL]' GET_ATTRIBUTE = b'[GET_ATTRIBUTE]' diff --git a/parl/remote/worker.py b/parl/remote/worker.py index d1a0333..9b534ee 100644 --- a/parl/remote/worker.py +++ b/parl/remote/worker.py @@ -26,7 +26,7 @@ import threading import warnings import zmq from datetime import datetime - +import parl from parl.utils import get_ip_address, to_byte, to_str, logger, _IS_WINDOWS, kill_process from parl.remote import remote_constants from parl.remote.message import InitializedWorker @@ -75,6 +75,7 @@ class Worker(object): self._set_cpu_num(cpu_num) self.job_buffer = queue.Queue(maxsize=self.cpu_num) self._create_sockets() + self.check_version() # create log server self.log_server_proc, self.log_server_address = self._create_log_server( port=log_server_port) @@ -101,6 +102,24 @@ class Worker(object): else: self.cpu_num = multiprocessing.cpu_count() + def check_version(self): + '''Verify that the parl & python version in 'worker' process matches that of the 'master' process''' + self.request_master_socket.send_multipart( + [remote_constants.CHECK_VERSION_TAG]) + message = self.request_master_socket.recv_multipart() + tag = message[0] + if tag == remote_constants.NORMAL_TAG: + worker_parl_version = parl.__version__ + worker_python_version = str(sys.version_info.major) + assert worker_parl_version == to_str(message[1]) and worker_python_version == to_str(message[2]),\ + '''Version mismatch: the "master" is of version "parl={}, python={}". However, + "parl={}, python={}"is provided in your environment.'''.format( + to_str(message[1]), to_str(message[2]), + worker_parl_version, worker_python_version + ) + else: + raise NotImplementedError + def _create_sockets(self): """ Each worker has three sockets at start: -- GitLab