提交 b9918c32 编写于 作者: M Megvii Engine Team

feat(mge/distributed): support distributed key-value store

GitOrigin-RevId: b4abe8001459020a1b371a188eba856830ce86df
上级 ab82c8da
......@@ -35,6 +35,7 @@ class Methods:
self.dict_pack_list = defaultdict(partial(Future, False))
self.dict_barrier_counter = defaultdict(int)
self.dict_barrier_event = defaultdict(threading.Event)
self.user_dict = defaultdict(partial(Future, False))
def connect(self):
"""Method for checking connection success."""
......@@ -113,6 +114,19 @@ class Methods:
event.wait()
return True
def user_set(self, key, val):
"""Set user defined key-value pairs across processes."""
with self.lock:
future = self.user_dict[key]
future.set(val)
return True
def user_get(self, key):
"""Get user defined key-value pairs across processes."""
with self.lock:
future = self.user_dict[key]
return future.get()
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass
......@@ -220,3 +234,11 @@ class Client:
:param size: group size.
"""
self.proxy.group_barrier(key, size)
def user_set(self, key, val):
"""Set user defined key-value pairs across processes."""
self.proxy.user_set(key, val)
def user_get(self, key):
"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key)
......@@ -195,6 +195,38 @@ def test_synchronized():
assert p.exitcode == 0
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_user_set_get():
world_size = 2
port = dist.get_free_ports(1)[0]
server = dist.Server(port)
def worker(rank):
dist.init_process_group("localhost", port, world_size, rank, rank)
# set in race condition
dist.get_client().user_set("foo", 1)
# get in race condition
ret = dist.get_client().user_get("foo")
assert ret == 1
procs = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank,))
p.start()
procs.append(p)
for p in procs:
p.join(20)
assert p.exitcode == 0
def test_oprmm_hashable():
lhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册