未验证 提交 95768115 编写于 作者: Y Yuang Liu 提交者: GitHub

Multi groups for broadcast of sharding stage 2 (#46894)

上级 a9cc5482
......@@ -184,7 +184,10 @@ class GroupShardedOptimizerStage2(Optimizer):
# Enable gradients' reduces overlap with backward calculation.
self._reduce_overlap = reduce_overlap
def _set_broadcast_overlap(self, broadcast_overlap, layers=None):
def _set_broadcast_overlap(self,
broadcast_overlap,
layers=None,
num_groups=None):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self._broadcast_overlap = broadcast_overlap
if self._broadcast_overlap:
......@@ -202,6 +205,27 @@ class GroupShardedOptimizerStage2(Optimizer):
"overlap broadcast may harm the performance.")
self._broadcast_order_params = self._local_params
if num_groups is None or num_groups > len(self._broadcast_order_params):
warnings.warn(
"The num_groups for broadcast is larger than the number of params to be broadcast. "
"It will set to default value: 1 (use the default sharding group)."
)
num_groups = 1
assert isinstance(
num_groups,
int) and num_groups > 0, "num_groups should be a positive integer"
self._number_of_broadcast_groups = num_groups
self._broadcast_groups = [
None for _ in range(self._number_of_broadcast_groups)
]
self._broadcast_groups[0] = self._group
ranks = self._group.ranks
for i in range(1, self._number_of_broadcast_groups):
self._broadcast_groups[i] = new_group(ranks)
def _generate_master_params(self, trainable_params):
if self.offload:
for param in trainable_params:
......@@ -484,14 +508,17 @@ class GroupShardedOptimizerStage2(Optimizer):
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation.
group_idx = 0
param2task = {}
for x in self._broadcast_order_params:
if x.trainable:
task = broadcast(
tensor=x,
src=self._group.ranks[self._param2rank[x.name]],
group=self._group,
sync_op=False)
group = self._broadcast_groups[group_idx]
group_idx = (group_idx + 1) % self._number_of_broadcast_groups
task = broadcast(tensor=x,
src=group.ranks[self._param2rank[x.name]],
group=group,
sync_op=False)
assert x.name not in param2task
param2task[x.name] = task
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册