提交 47138c06 编写于 作者: M Megvii Engine Team

perf(dist): add fastpath for bcast params

GitOrigin-RevId: aa40b3cd72e7665402737605ddf6f59ff8de0c0f
上级 1a21dfdf
......@@ -196,6 +196,13 @@ def broadcast(
return out
def _bcast_param(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device)
def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
......
......@@ -22,7 +22,7 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..functional.tensor import copy
from ..tensor import Tensor
from ..utils.future import Future
from .functional import all_reduce_sum, broadcast
from .functional import _bcast_param, all_reduce_sum, broadcast
from .group import WORLD, Group, group_barrier, is_distributed
......@@ -186,7 +186,7 @@ def bcast_list_(inps: list, group: Group = WORLD):
:param group: communication group.
"""
for inp in inps:
inp._reset(broadcast(inp, group))
inp._reset(_bcast_param(inp, group))
class AllreduceCallback:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册