diff --git a/parl/remote/job.py b/parl/remote/job.py index aa677ebfb9e18842f60491e24929e626fe2730a6..8e8587afe9c228eecc1f9f20aed59ad3b060eb09 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -16,8 +16,8 @@ import compatible_trick import os -os.environ['CUDA_VISIBLE_DEVICES'] = '' os.environ['XPARL'] = 'True' +os.environ['CUDA_VISIBLE_DEVICES'] = '' import argparse import cloudpickle import pickle @@ -66,21 +66,19 @@ class Job(object): self.log_server_address = log_server_address self.job_ip = get_ip_address() self.pid = os.getpid() - - self.run_job_process = Process( - target=self.run, args=(job_address_sender, job_id_sender)) - self.run_job_process.start() """ NOTE: In Windows, it will raise errors when creating threading.Lock before starting multiprocess.Process. """ self.lock = threading.Lock() - self._create_sockets() + th = threading.Thread(target=self._create_sockets) + th.setDaemon(True) + th.start() process = psutil.Process(self.pid) self.init_memory = float(process.memory_info()[0]) / (1024**2) - self.run_job_process.join() + self.run(job_address_sender, job_id_sender) with self.lock: self.kill_job_socket.send_multipart( diff --git a/parl/remote/tests/cluster_test.py b/parl/remote/tests/cluster_test.py index 8eab666747fcb5b4ebf2e570ae4233fd42e429f2..e00351b8b6141f053db4b14a220e01cc50508d83 100644 --- a/parl/remote/tests/cluster_test.py +++ b/parl/remote/tests/cluster_test.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest - import parl from parl.remote.master import Master from parl.remote.worker import Worker diff --git a/parl/remote/tests/paddle_gpu_test.py b/parl/remote/tests/paddle_gpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d433ab089d93653a152c1a22350922636b7df1 --- /dev/null +++ b/parl/remote/tests/paddle_gpu_test.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import parl +from parl.remote.master import Master +from parl.remote.worker import Worker +import time +import threading +from parl.remote.client import disconnect +from parl.remote import exceptions +import subprocess +from parl.utils import logger +import paddle.fluid as fluid +import os + + +@parl.remote_class +class Actor(object): + def __init__(self, cuda=False): + if cuda: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + +class TestCluster(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_gpu(self): + master = Master(port=8241) + th = threading.Thread(target=master.run) + th.start() + time.sleep(1) + + worker1 = Worker('localhost:8241', 4) + + parl.connect('localhost:8241') + + if parl.utils.is_gpu_available(): + actor = Actor(cuda=True) + else: + actor = Actor(cuda=False) + + del actor + master.exit() + worker1.exit() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/utils.py b/parl/remote/utils.py index 5d6368e21d2f9d19a72f28eea0d214df8a913664..221de528a32bd035f0f5aa92f9de93e45c1e04a8 100644 --- a/parl/remote/utils.py +++ b/parl/remote/utils.py @@ -114,7 +114,7 @@ def locate_remote_file(module_path): module_path: Absolute path of the module. Example: - module_path: /home/user/dir/subdir/my_module + module_path: /home/user/dir/subdir/my_module (or) ./dir/main entry_file: /home/user/dir/main.py --------> relative_path: subdir/my_module """ @@ -129,6 +129,9 @@ def locate_remote_file(module_path): if os.path.isfile(to_check_path): entry_path = path break + # transfer the relative path to the absolute path + if not os.path.isabs(module_path): + module_path = os.path.abspath(module_path) if entry_path is None or \ (module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]): raise FileNotFoundError("cannot locate the remote file")