未验证 提交 f5dde36c 编写于 作者: H Haodong Lyu 提交者: GitHub

Fix RuntimeError when using ZeRO Stage3 with mpu: #3564 (#3565)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 3b299997
......@@ -1499,7 +1499,12 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
grad_norms.append(g.to(get_accelerator().device_name(), non_blocking=True).double().norm(2))
# Sum across all model parallel GPUs.
total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2))
if len(grad_norms) == 0:
# FIX https://github.com/microsoft/DeepSpeed/issues/3564
total_norm_cuda = torch.tensor(0,
dtype=gradients[0].dtype).to(get_accelerator().device_name()).double()
else:
total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2))
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册