提交 6d367454 编写于 作者: M Megvii Engine Team

feat(mge/opr-mm): add param local_grad for collective_comm opr

GitOrigin-RevId: cc120cfb55d67a48dc126d1fd8773fa08a860d32
上级 0ccb965c
......@@ -11,10 +11,13 @@ from .functional import (
all_reduce_max,
all_reduce_min,
all_reduce_sum,
all_to_all,
bcast_param,
broadcast,
gather,
reduce_scatter_sum,
reduce_sum,
scatter,
)
from .util import (
get_backend,
......
......@@ -9,7 +9,7 @@
from typing import Optional, Union
import megengine._internal as mgb
from megengine._internal.opr_param_defs import CollectiveComm as CollParam
from megengine._internal.opr_param_defs import CollectiveComm as Param
from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
from ..functional import add_update
......@@ -22,9 +22,16 @@ def _collective_comm(*args, **kargs):
return collective_comm_symvar(*args, **kargs)
def _group_check(*args):
"""Return True when arguments are all None or all not None
"""
l = [val is None for val in args]
return len(set(l)) <= 1
def reduce_sum(
tensor: Tensor,
key: str,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
) -> Tensor:
......@@ -35,14 +42,17 @@ def reduce_sum(
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this is a root node
"""
assert _group_check(
key, nr_ranks, is_root
), "key, nr_ranks, is_root should be set at the same time"
return _collective_comm(
tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
)
def gather(
tensor: Tensor,
key: str,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
rank: Optional[int] = None,
......@@ -55,20 +65,17 @@ def gather(
:param is_root: whether this is a root node
:param rank: rank of this node
"""
assert _group_check(
key, nr_ranks, is_root, rank
), "key, nr_ranks, is_root, rank should be set at the same time"
return _collective_comm(
tensor,
key,
CollParam.Mode.GATHER,
nr_ranks,
is_root,
rank,
device=tensor.device,
tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device,
)
def broadcast(
tensor: Tensor,
key: str,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
) -> Tensor:
......@@ -79,11 +86,12 @@ def broadcast(
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this is a root node
"""
if key is None:
key = tensor._symvar.name
assert _group_check(
key, nr_ranks, is_root
), "key, nr_ranks, is_root should be set at the same time"
if is_root is None:
is_root = get_rank() == 0
if is_root:
inp = tensor
else:
......@@ -92,7 +100,7 @@ def broadcast(
return _collective_comm(
inp,
key,
CollParam.Mode.BROADCAST,
Param.Mode.BROADCAST,
nr_ranks,
is_root,
dtype=tensor.dtype,
......@@ -102,7 +110,7 @@ def broadcast(
def scatter(
tensor: Tensor,
key: str,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
rank: Optional[int] = None,
......@@ -115,6 +123,9 @@ def scatter(
:param is_root: whether this is a root node
:param rank: rank of this node
"""
assert _group_check(
key, nr_ranks, is_root, rank
), "key, nr_ranks, is_root, rank should be set at the same time"
if key is None:
key = tensor._symvar.name
if is_root is None:
......@@ -128,7 +139,7 @@ def scatter(
return _collective_comm(
inp,
key,
CollParam.Mode.SCATTER,
Param.Mode.SCATTER,
nr_ranks,
is_root,
rank,
......@@ -138,7 +149,11 @@ def scatter(
def all_to_all(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create all_to_all operator for collective communication
......@@ -146,12 +161,22 @@ def all_to_all(
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node
:param local_grad: whether use local grad
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_TO_ALL, nr_ranks, rank=rank)
assert _group_check(
key, nr_ranks, rank
), "key, nr_ranks, rank should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad,
)
def all_gather(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create all_gather operator for collective communication
......@@ -159,12 +184,22 @@ def all_gather(
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node
:param local_grad: whether use local grad
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank)
assert _group_check(
key, nr_ranks, rank
), "key, nr_ranks, rank should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad
)
def reduce_scatter_sum(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create reduce_scatter_sum operator for collective communication
......@@ -172,45 +207,81 @@ def reduce_scatter_sum(
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node
:param local_grad: whether use local grad
"""
assert _group_check(
key, nr_ranks, rank
), "key, nr_ranks, rank should be set at the same time"
return _collective_comm(
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank,
tensor,
key,
Param.Mode.REDUCE_SCATTER_SUM,
nr_ranks,
rank=rank,
local_grad=local_grad,
)
def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
def all_reduce_sum(
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create all_reduce_sum operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks)
assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad
)
def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
def all_reduce_max(
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create all_reduce_max operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks)
assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad
)
def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
def all_reduce_min(
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create all_reduce_min operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks)
assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad
)
def bcast_param(
inp: Union[Buffer, Parameter],
key: str,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
) -> None:
......@@ -223,6 +294,9 @@ def bcast_param(
"""
if not is_distributed():
return
assert _group_check(
key, nr_ranks, is_root
), "key, nr_ranks, is_root should be set at the same time"
assert isinstance(inp, (Buffer, Parameter))
bcast_res = broadcast(inp, key, nr_ranks, is_root)
add_update(inp, bcast_res, alpha=0)
......@@ -11,16 +11,24 @@ from typing import Optional, Union
import megengine._internal as mgb
from megengine._internal.opr_param_defs import CollectiveComm as CollParam
from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size
from .util import (
get_backend,
get_group_id,
get_master_ip,
get_master_port,
get_rank,
get_world_size,
)
def collective_comm_symvar(
inp: Union[mgb.SymbolVar, mgb.CompGraph],
key: str,
op: CollParam.Mode,
key: Optional[str] = None,
op: CollParam.Mode = None,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
rank: Optional[int] = None,
local_grad: Optional[bool] = False,
dtype: Optional[type] = None,
device: Optional[mgb.CompNode] = None,
comp_graph: Optional[mgb.CompGraph] = None,
......@@ -32,16 +40,19 @@ def collective_comm_symvar(
:param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this node is root node
:param rank: rank of this node
:param local_grad: whether use local grad
:param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default
"""
return mgb.opr.collective_comm(
inp,
key=str(key),
key=key if key is not None else ("collective_comm_" + str(get_group_id())),
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(),
is_root=is_root if is_root is not None else (get_rank() == 0),
rank=rank if rank is not None else -1,
rank=rank if rank is not None else get_rank(),
local_grad=local_grad,
server_addr=get_master_ip(),
port=get_master_port(),
param=CollParam(mode=op),
......
......@@ -182,7 +182,9 @@ class Optimizer(metaclass=ABCMeta):
with opr_priority_scope(cg, -(2 ** 30)):
# always run all_reduce_mean first except add_update
grad = (
all_reduce_sum(grad, "grad_" + str(get_group_id()))
all_reduce_sum(
grad, "grad_" + str(get_group_id()), get_world_size()
)
/ get_world_size()
)
with opr_priority_scope(cg, -(2 ** 31)):
......@@ -229,7 +231,7 @@ class Optimizer(metaclass=ABCMeta):
for group in self.param_groups:
for param in group["params"]:
bcast_param(
param, "bcast_param_" + str(key), is_root=(get_rank() == 0),
param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0,
)
key += 1
......
......@@ -94,9 +94,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
SymbolVar _Opr::collective_comm_with_input(
SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const bool is_root, const int rank, const bool local_grad,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
SymbolVarArray inputs(1, inpvar);
ComputingGraph* graph = inpvar.node()->owner_graph();
......@@ -111,15 +111,15 @@ SymbolVar _Opr::collective_comm_with_input(
_dtype = npy::dtype_np2mgb(dtype);
}
return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank,
group_mgr, dev_buffer_arr, param, _dtype,
backend, config, disable.get_val())[0];
local_grad, group_mgr, dev_buffer_arr, param,
_dtype, backend, config, disable.get_val())[0];
}
SymbolVar _Opr::collective_comm_without_input(
CompGraph& cg, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const bool is_root, const int rank, const bool local_grad,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
SymbolVarArray inputs;
auto& graph = cg.get();
......@@ -134,8 +134,8 @@ SymbolVar _Opr::collective_comm_without_input(
_dtype = npy::dtype_np2mgb(dtype);
}
return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank,
group_mgr, dev_buffer_arr, param, _dtype,
backend, config, disable.get_val())[0];
local_grad, group_mgr, dev_buffer_arr, param,
_dtype, backend, config, disable.get_val())[0];
}
#else
......@@ -171,8 +171,8 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
}
SymbolVar _Opr::collective_comm_with_input(
SymbolVar inpvar, const std::string& key,
const size_t nr_devices, const bool is_root, const int rank,
SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const bool local_grad,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
......@@ -180,8 +180,8 @@ SymbolVar _Opr::collective_comm_with_input(
}
SymbolVar _Opr::collective_comm_without_input(
CompGraph& cg, const std::string& key,
const size_t nr_devices, const bool is_root, const int rank,
CompGraph& cg, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const bool local_grad,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
......
......@@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port,
static SymbolVar collective_comm_with_input(
SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype, const std::string& backend,
SharedND* output_buf, const OperatorNodeConfig& config,
const SharedScalar& disable);
const bool is_root, const int rank, const bool local_grad,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable);
static SymbolVar collective_comm_without_input(
CompGraph& graph, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype, const std::string& backend,
SharedND* output_buf, const OperatorNodeConfig& config,
const SharedScalar& disable);
const bool is_root, const int rank, const bool local_grad,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable);
// misc
static SymbolVarArray extern_c_opr_placeholder(
......
......@@ -34,7 +34,7 @@ def test_reduce_sum():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.reduce_sum(inp, "x")
output = dist.functional.reduce_sum(inp)
if rank == 0:
assert np.allclose(output.numpy(), expect)
else:
......@@ -70,7 +70,7 @@ def test_gather():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.gather(inp, "x", is_root=(rank == 0), rank=rank)
output = dist.functional.gather(inp)
if rank == 0:
assert np.allclose(output.numpy(), expect)
else:
......@@ -106,7 +106,7 @@ def test_broadcast():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.broadcast(inp, "x")
output = dist.functional.broadcast(inp)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -138,7 +138,7 @@ def test_scatter():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.scatter(inp, "x", is_root=(rank == 0), rank=rank)
output = dist.functional.scatter(inp)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -174,7 +174,7 @@ def test_all_to_all():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.all_to_all(inp, "x", rank=rank)
output = dist.functional.all_to_all(inp)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -208,7 +208,7 @@ def test_all_gather():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.all_gather(inp, "x", rank=rank)
output = dist.functional.all_gather(inp)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -241,7 +241,7 @@ def test_reduce_scatter_sum():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank)
output = dist.functional.reduce_scatter_sum(inp)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -278,7 +278,7 @@ def test_all_reduce_sum():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.all_reduce_sum(inp, "x")
output = dist.functional.all_reduce_sum(inp)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -311,7 +311,7 @@ def test_all_reduce_max():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.all_reduce_max(inp, "x")
output = dist.functional.all_reduce_max(inp)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -344,7 +344,7 @@ def test_all_reduce_min():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.all_reduce_min(inp, "x")
output = dist.functional.all_reduce_min(inp)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -377,7 +377,7 @@ def test_bcast_param():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = Parameter(data)
dist.functional.bcast_param(inp, "x")
dist.functional.bcast_param(inp)
assert np.allclose(inp.numpy(), expect)
def check(shape, backend):
......
......@@ -688,6 +688,7 @@ bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) {
if (!opr->same_type<opr::CollectiveComm>()) return false;
auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false;
if (comm.local_grad()) return false;
if (comm.input().size() != 1) return false;
auto grad = comm.input(0)->owner_opr();
......@@ -839,7 +840,7 @@ void PackAllReduceReplacePass::insert_packed_oprs(
std::string key = ssprintf("grad_pack_%zu", pack_id);
auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph,
key, info->nr_devices, info->is_root, info->rank,
key, info->nr_devices, info->is_root, info->rank, false,
info->group_client, param, info->dtype, info->backend)[0];
// split according to recorded partition
......
......@@ -438,14 +438,14 @@ TEST_PASS(PackAllReduceScanPass, Basic) {
auto grad3 = opr::VirtualGrad::make(y1, x1);
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(),
"grad0", 2, 0, 0, client, mode)[0];
auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(),
"grad1", 2, 0, 0, client, mode)[0];
auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(),
"grad2", 2, 0, 0, client, mode)[0];
auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(),
"grad3", 2, 0, 0, client, mode)[0];
auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), "grad0", 2,
false, 0, false, client, mode)[0];
auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), "grad1", 2,
false, 0, false, client, mode)[0];
auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), "grad2", 2,
false, 0, false, client, mode)[0];
auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), "grad3", 2,
false, 0, false, client, mode)[0];
gopt::GraphOptimizer()
.add_pass<gopt::PackAllReduceScanPass>()
......@@ -488,10 +488,12 @@ TEST_PASS(PackAllReduceReplacePass, CollectGroups) {
auto grad = opr::VirtualGrad::make(target, wrt);
auto comm = opr::CollectiveComm::make(
{grad}, graph.get(), "key", 2, 0, 0, client,
opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0]
.node()->owner_opr();
auto comm =
opr::CollectiveComm::make(
{grad}, graph.get(), "key", 2, false, 0, false, client,
opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0]
.node()
->owner_opr();
comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash);
......@@ -543,8 +545,8 @@ TEST_PASS(PackAllReduceReplacePass, DividePacks) {
auto insert_opr = [&] (size_t size) {
auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)});
auto sd = opr::SharedDeviceTensor::make(*graph, dev);
auto symvar = opr::CollectiveComm::make({sd}, graph.get(),
"key", 2, 0, 0, client, mode)[0];
auto symvar = opr::CollectiveComm::make(
{sd}, graph.get(), "key", 2, false, 0, false, client, mode)[0];
auto opr = symvar.node()->owner_opr();
auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
comm.set_pack_hash(1);
......@@ -596,7 +598,6 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
size_t nr_devices = 2;
uint32_t rank = 0;
uint32_t root = 0;
using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo;
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
......@@ -605,8 +606,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
auto insert_opr = [&] (const TensorShape& shape) {
auto dev = std::make_shared<DeviceTensorND>(cn, shape);
auto sd = opr::SharedDeviceTensor::make(*graph, dev);
auto symvar = opr::CollectiveComm::make({sd}, graph.get(),
"key", nr_devices, rank, root, client, mode)[0];
auto symvar =
opr::CollectiveComm::make({sd}, graph.get(), "key", nr_devices,
false, rank, false, client, mode)[0];
auto opr = symvar.node()->owner_opr();
auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
comm.set_pack_hash(1);
......@@ -634,8 +636,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0);
std::string key = ssprintf("grad_pack_%zu", pack_id);
auto allreduce = opr::CollectiveComm::make({concat}, graph.get(),
key, nr_devices, rank, root, client, mode)[0];
auto allreduce =
opr::CollectiveComm::make({concat}, graph.get(), key, nr_devices,
false, rank, false, client, mode)[0];
std::vector<size_t> partition;
partition.push_back(shape_x.total_nr_elems());
......@@ -683,10 +686,14 @@ TEST_PASS(PackAllReduceReplacePass, Equivalence) {
using Mode = opr::CollectiveComm::Param::Mode;
bool is_root = (rank == 0);
auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(),
"x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2;
auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(),
"y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2;
auto reduced_x = opr::CollectiveComm::make(
{grad_x}, graph.get(), "x", 2, is_root, rank,
false, client, Mode::ALL_REDUCE_SUM)[0] /
2;
auto reduced_y = opr::CollectiveComm::make(
{grad_y}, graph.get(), "y", 2, is_root, rank,
false, client, Mode::ALL_REDUCE_SUM)[0] /
2;
graph->options().allreduce_pack_max_size = 5000;
graph->options().allreduce_pack_ignore_first = 0;
......
......@@ -14,6 +14,8 @@
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/event.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/megray_helper.h"
#include "megbrain/opr/group_manager.h"
......@@ -77,6 +79,8 @@ cudaStream_t get_stream(VarNode* var) {
}
} // anonymous namespace
/* ================= ModeTrait ================= */
class CollectiveComm::ModeTrait {
class BROADCAST;
class REDUCE_SUM;
......@@ -132,6 +136,42 @@ public:
return None;
}
VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const {
auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode();
SymbolVarArray og_syms;
og_syms.push_back(out_grad);
auto&& cn = opr->output(0)->comp_node();
auto gvar = CollectiveComm::make(
og_syms, opr->owner_graph(), opr->key() + ":grad",
opr->nr_devices(), opr->is_root(), opr->rank(), false,
opr->group_client(), mode, opr->dtype(), opr->backend(), {cn});
return gvar[0].node();
}
virtual VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const {
mgb_throw(MegBrainError,
"only all_reduce all_to_all all_gather reduce_scatter "
"support local_grad");
}
virtual VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const {
if (opr->local_grad()){
return local_grad(out_grad, opr);
} else {
return full_grad(out_grad, opr);
}
}
VarNode* zeros(mgb::cg::ComputingGraph &graph, CompNode node, const SymbolVar& shape,
DType dtype) const {
auto zero = SymbolVar::make_scalar(0, graph, node);
auto zero_tensor = opr::TypeCvt::make(zero, dtype).broadcast(shape);
return zero_tensor.node();
}
virtual void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) = 0;
......@@ -174,6 +214,17 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
}
Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; }
VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
auto nr_devices = opr->nr_devices();
auto rank = opr->rank();
opr::Subtensor::IndexDesc axis;
auto shape0 = opr::GetVarShape::make(out_grad, 0);
axis.push_back({0, shape0 * rank / (int)nr_devices,
shape0 * (rank + 1) / (int)nr_devices});
auto grad = opr::Subtensor::make(out_grad, axis);
return grad.node();
}
};
class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
......@@ -211,9 +262,23 @@ class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
}
Mode grad_mode() override { return Mode::ALL_GATHER; }
};
/* ================= ModeTrait impls ================= */
VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNodeArray grads;
auto zeros_tensor =
zeros(*out_grad->owner_graph(), out_grad->comp_node(),
opr::GetVarShape::make(out_grad), out_grad->dtype());
for (size_t i = 0;i < opr->nr_devices();i++) {
if (i == opr->rank()) {
grads.push_back(out_grad);
} else {
grads.push_back(zeros_tensor);
}
}
auto grad = opr::Concat::make(grads, 0);
return grad.node();
}
};
class CollectiveComm::ModeTrait::ReducedBasedTrait {
protected:
......@@ -250,6 +315,12 @@ class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait,
}
Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; }
public:
VarNode* local_grad(VarNode* out_grad,
const CollectiveComm* opr) const override {
return out_grad;
}
};
class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase {
......@@ -258,10 +329,38 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase {
class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase {
MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; }
VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNode* grad;
if (opr->local_grad()) {
grad = local_grad(out_grad, opr);
} else {
grad = full_grad(out_grad, opr);
}
grad = opr::Elemwise::make({opr->output(0), opr->input(0), grad},
Elemwise::Mode::COND_LEQ_MOV)
.node();
return grad;
}
};
class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase {
MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; }
VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNode* grad;
if (opr->local_grad()) {
grad = local_grad(out_grad, opr);
} else {
grad = full_grad(out_grad, opr);
}
grad = opr::Elemwise::make({opr->input(0), opr->output(0), grad},
Elemwise::Mode::COND_LEQ_MOV)
.node();
return grad;
}
};
class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
......@@ -448,6 +547,24 @@ class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait {
}
Mode grad_mode() override { return Mode::ALL_TO_ALL; }
VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNodeArray grads;
auto grad_shape = opr::GetVarShape::make(out_grad);
auto zeros_tensor =
zeros(*out_grad->owner_graph(), out_grad->comp_node(),
grad_shape, out_grad->dtype());
auto nr_devices = opr->nr_devices();
auto rank = opr->rank();
opr::Subtensor::IndexDesc axis;
auto shape0 = opr::GetVarShape::make(out_grad, 0);
axis.push_back({0, shape0 * rank / (int)nr_devices,
shape0 * (rank + 1) / (int)nr_devices});
auto sub_grad = opr::Subtensor::make(out_grad, axis);
return opr::SetSubtensor::make(zeros_tensor, sub_grad, axis).node();
}
};
CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
......@@ -469,8 +586,9 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
CollectiveComm::CollectiveComm(
VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const int rank, const bool local_grad,
std::shared_ptr<GroupClient> group_client, const Param& param,
const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable)
......@@ -482,6 +600,7 @@ CollectiveComm::CollectiveComm(
m_nr_devices(nr_devices),
m_is_root(is_root),
m_rank(rank),
m_local_grad(local_grad),
m_key(key),
m_dev_buffers(dev_buffer_arr),
m_disable{disable} {
......@@ -523,28 +642,31 @@ CollectiveComm::CollectiveComm(
SymbolVarArray CollectiveComm::make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const int rank, const bool local_grad,
std::shared_ptr<GroupClient> group_client, const Param& param,
const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) {
SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices,
nullptr);
return make(inputs, graph, key, nr_devices, is_root, rank, group_client,
dev_buffer_arr, param, dtype, backend, config);
return make(inputs, graph, key, nr_devices, is_root, rank, local_grad,
group_client, dev_buffer_arr, param, dtype, backend, config);
}
SymbolVarArray CollectiveComm::make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const int rank, const bool local_grad,
std::shared_ptr<GroupClient> group_client,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const Param& param, const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) {
auto inpvars = cg::to_var_node_array(inputs);
auto opr = graph->insert_opr(std::make_unique<CollectiveComm>(
inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client),
param, dtype, backend, dev_buffer_arr, config, disable));
inpvars, graph, key, nr_devices, is_root, rank, local_grad,
std::move(group_client), param, dtype, backend, dev_buffer_arr,
config, disable));
mgb_assert(!opr->output().empty());
return cg::to_symbol_var_array(opr->output());
}
......@@ -647,93 +769,12 @@ void CollectiveComm::do_execute(ExecEnv& env) {
owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn);
trait.exec(this);
owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn);
#if CUDART_VERSION < 9000
#pragma message "legacy CUDA; use sync to avoid blocking"
// nccl hangs occasionally without this sync()
cn.sync();
#endif
};
env.dispatch_on_comp_node(cn, runner);
}
void CollectiveComm::on_output_comp_node_stream_changed() {}
VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const {
auto mode = ModeTrait::from_mode(m_param.mode).grad_mode();
SymbolVarArray og_syms;
if (m_param.mode == Param::Mode::REDUCE_SUM) {
for (size_t i = 0; i < output().size(); i++) {
if (out_grads[i])
og_syms.push_back(out_grads[i]);
}
mgb_assert(og_syms.size() == 1);
} else {
for (size_t i = 0; i < output().size(); i++) {
if (!out_grads[i]) {
mgb_assert(m_param.mode != Param::Mode::REDUCE_SCATTER_SUM,
"null out grad in CollctiveCommMM currently "
"unsupported when the forward mode is "
"Reduce_Scatter_Sum.");
DTypeScalar dval{output(i)->dtype()};
dval.set_retain_dtype(0);
auto zeros =
SymbolVar::make_scalar(dval, *output(i)->owner_graph(),
output(i)->comp_node())
.broadcast(SymbolVar(output(i)).symshape());
og_syms.push_back(zeros);
} else {
og_syms.push_back(out_grads[i]);
}
}
}
OperatorNodeConfig::CompNodeArray cn_arr;
if (m_param.mode == Param::Mode::REDUCE_SUM) {
for (auto i : input()) {
cn_arr.push_back(i->comp_node());
}
} else if (m_param.mode == Param::Mode::BROADCAST) {
if (!input().empty()) {
cn_arr.push_back(input(0)->comp_node());
}
}
auto gvar = CollectiveComm::make(
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root,
m_rank, m_group_client, mode, m_dtype, m_backend,
OperatorNodeConfig{}.comp_node_arr(cn_arr));
if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) {
for (size_t i = 0; i < input().size(); ++i) {
gvar[i] = Elemwise::make({output(i), input(i), gvar[i]},
Elemwise::Mode::COND_LEQ_MOV);
}
} else if (m_param.mode == Param::Mode::ALL_REDUCE_MIN) {
for (size_t i = 0; i < input().size(); ++i) {
gvar[i] = Elemwise::make({input(i), output(i), gvar[i]},
Elemwise::Mode::COND_LEQ_MOV);
}
} else if (m_param.mode == Param::Mode::BROADCAST) {
if (!input().empty()) {
CompNode&& master_out_cn = input(0)->comp_node();
SymbolVarArray rst;
for (auto i : gvar) {
if (i.node()->comp_node() == master_out_cn) {
mgb_assert(rst.empty());
rst.push_back(i);
}
}
gvar = rst;
}
}
return cg::to_var_node_array(gvar);
}
MGB_IMPL_OPR_GRAD(CollectiveComm) {
return opr.grad(out_grad);
}
void CollectiveComm::init_output_dtype() {
if (m_dtype.valid()) {
for (size_t i = 0; i < input().size(); ++i) {
......@@ -797,6 +838,15 @@ void CollectiveComm::init_output_static_infer_desc() {
}
}
VarNode* CollectiveComm::grad(VarNode* out_grad) const {
return ModeTrait::from_mode(m_param.mode).grad(out_grad, this);
}
MGB_IMPL_OPR_GRAD(CollectiveComm) {
mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad");
return opr.grad(out_grad[0]);
}
/* ===================== shallow copy ===================== */
namespace mgb {
......@@ -807,13 +857,14 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>();
auto new_opr = CollectiveComm::make(
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs),
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(),
opr.group_client(), opr.dev_buffers(), opr.param(),
opr.dtype(), opr.backend(), config)[0]
.node()
->owner_opr();
auto new_opr =
CollectiveComm::make(
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs),
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(),
opr.local_grad(), opr.group_client(), opr.dev_buffers(),
opr.param(), opr.dtype(), opr.backend(), config)[0]
.node()
->owner_opr();
new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash());
return new_opr;
}
......
......@@ -8,6 +8,7 @@ decl_raw_opr(
'operation to which this operator belongs.', 'int'),
Doc('is_root', 'whether this node is root node', 'bool'),
Doc('rank', 'rank of this node, if is -1, generate one', 'int'),
Doc('local_grad', 'whether use local grad', 'bool'),
Doc('server_addr', 'rpc server ip address'),
Doc('port', 'server rpc listening port'),
Doc('param', 'The only component of *param* is *mode*, which refers to '
......@@ -28,12 +29,12 @@ decl_raw_opr(
body = [
'if isinstance(input, _mgb.SymbolVar):',
(' output = _mgb._Opr.collective_comm_with_input(input, key, '
'nr_devices, is_root, rank, server_addr, port, '
'nr_devices, is_root, rank, local_grad, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)'),
'else:',
' assert isinstance(input, _mgb.CompGraph)',
(' output = _mgb._Opr.collective_comm_without_input(input, key, '
'nr_devices, is_root, rank, server_addr, port, '
'nr_devices, is_root, rank, local_grad, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)')
],
desc = ('collective communication between multiple CompNodes on multiple '
......
......@@ -29,8 +29,9 @@ public:
CollectiveComm(
VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const int rank, const bool local_grad,
std::shared_ptr<GroupClient> group_client, const Param& param,
const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable);
......@@ -38,7 +39,8 @@ public:
static SymbolVarArray make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const int rank, const bool local_grad,
std::shared_ptr<GroupClient> group_client,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const Param& param, const DType& dtype = {},
const std::string& backend = "nccl",
......@@ -50,6 +52,7 @@ public:
ComputingGraph* const graph,
const std::string& key, const size_t nr_devices,
const bool is_root, const int rank,
const bool local_grad,
std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype = {},
const std::string& backend = "nccl",
......@@ -72,6 +75,7 @@ public:
int rank() const { return m_rank; }
int root() const { return m_root; }
bool is_root() const { return m_is_root; }
bool local_grad() const { return m_local_grad; }
//! The key that identifies an NCCL clique.
//! Operators with same keys belong to the same clique.
......@@ -89,7 +93,7 @@ public:
return m_megray_ctx;
}
VarNodeArray grad(const VarNodeArray& out_grad) const;
VarNode* grad(VarNode* out_grad) const;
private:
Barrier m_exec_barrier;
......@@ -116,6 +120,7 @@ private:
size_t m_nr_devices = 0;
bool m_is_root;
int m_rank;
bool m_local_grad;
std::string m_key;
//! XXHash generated from m_key
size_t m_hash;
......
此差异已折叠。
......@@ -46,7 +46,7 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('CollectiveComm', 'collective communication between multiple computing '
'nodes on localhost')
.add_enum('Mode',
.add_enum(Doc('Mode', 'mode of collective communication'),
Doc('REDUCE_SUM', 'reduce by sum to output computing node'),
Doc('BROADCAST', 'copy input value to each output computing node'),
Doc('ALL_GATHER', 'each output comp node gets the concatenated '
......@@ -59,7 +59,8 @@ pdef('PersistentOutputStorage').add_fields(
Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'),
Doc('GATHER', 'concat inputs to one node'),
Doc('SCATTER', 'scatter input to each output computing node'),
Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node')))
Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'),
name_field='mode'))
(pdef('FakeSerializedDType',
'HACK: The tag of this param def is actually used for another '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册