From 47138c06cf473e8e108aa01ff1bf421c37eeff55 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 26 May 2021 13:37:05 +0800 Subject: [PATCH] perf(dist): add fastpath for bcast params GitOrigin-RevId: aa40b3cd72e7665402737605ddf6f59ff8de0c0f --- imperative/python/megengine/distributed/functional.py | 7 +++++++ imperative/python/megengine/distributed/helper.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 453d3ba55..d832ae86c 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -196,6 +196,13 @@ def broadcast( return out +def _bcast_param( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + mode = CollectiveComm.Mode.BROADCAST + return collective_comm(inp, mode, group, device) + + def all_gather( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 0a67f2dd7..c0743958a 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -22,7 +22,7 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit from ..functional.tensor import copy from ..tensor import Tensor from ..utils.future import Future -from .functional import all_reduce_sum, broadcast +from .functional import _bcast_param, all_reduce_sum, broadcast from .group import WORLD, Group, group_barrier, is_distributed @@ -186,7 +186,7 @@ def bcast_list_(inps: list, group: Group = WORLD): :param group: communication group. """ for inp in inps: - inp._reset(broadcast(inp, group)) + inp._reset(_bcast_param(inp, group)) class AllreduceCallback: -- GitLab