From b9918c329d4b07b7c907d9e6f86aa47c974a5c4e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 18 Nov 2020 14:13:40 +0800 Subject: [PATCH] feat(mge/distributed): support distributed key-value store GitOrigin-RevId: b4abe8001459020a1b371a188eba856830ce86df --- .../python/megengine/distributed/server.py | 22 +++++++++++++ .../test/unit/distributed/test_distributed.py | 32 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index e7875851b..c9ab31774 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -35,6 +35,7 @@ class Methods: self.dict_pack_list = defaultdict(partial(Future, False)) self.dict_barrier_counter = defaultdict(int) self.dict_barrier_event = defaultdict(threading.Event) + self.user_dict = defaultdict(partial(Future, False)) def connect(self): """Method for checking connection success.""" @@ -113,6 +114,19 @@ class Methods: event.wait() return True + def user_set(self, key, val): + """Set user defined key-value pairs across processes.""" + with self.lock: + future = self.user_dict[key] + future.set(val) + return True + + def user_get(self, key): + """Get user defined key-value pairs across processes.""" + with self.lock: + future = self.user_dict[key] + return future.get() + class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass @@ -220,3 +234,11 @@ class Client: :param size: group size. """ self.proxy.group_barrier(key, size) + + def user_set(self, key, val): + """Set user defined key-value pairs across processes.""" + self.proxy.user_set(key, val) + + def user_get(self, key): + """Get user defined key-value pairs across processes.""" + return self.proxy.user_get(key) diff --git a/imperative/python/test/unit/distributed/test_distributed.py b/imperative/python/test/unit/distributed/test_distributed.py index 26c732b8d..28c7fb635 100644 --- a/imperative/python/test/unit/distributed/test_distributed.py +++ b/imperative/python/test/unit/distributed/test_distributed.py @@ -195,6 +195,38 @@ def test_synchronized(): assert p.exitcode == 0 +@pytest.mark.skipif( + platform.system() == "Darwin", reason="do not imp GPU mode at macos now" +) +@pytest.mark.skipif( + platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" +) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") +@pytest.mark.isolated_distributed +def test_user_set_get(): + world_size = 2 + port = dist.get_free_ports(1)[0] + server = dist.Server(port) + + def worker(rank): + dist.init_process_group("localhost", port, world_size, rank, rank) + # set in race condition + dist.get_client().user_set("foo", 1) + # get in race condition + ret = dist.get_client().user_get("foo") + assert ret == 1 + + procs = [] + for rank in range(world_size): + p = mp.Process(target=worker, args=(rank,)) + p.start() + procs.append(p) + + for p in procs: + p.join(20) + assert p.exitcode == 0 + + def test_oprmm_hashable(): lhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit()) rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit()) -- GitLab