提交 bc6814c2 编写于 作者: A Andrew Gu

Update on "Add overlap with DDP to ZeRO"




[ghstack-poisoned]
......@@ -54,77 +54,115 @@ def hook_then_zero_step(
gradient bucket.
"""
fut = hook(state, bucket)
if zero._use_extra_stream:
fut.wait()
with torch.cuda.stream(zero._optim_stream) if zero._use_extra_stream else contextlib.suppress():
def zero_step(fut: torch.futures.Future) -> torch.Tensor:
r"""
Performs a partial :class:`ZeroRedundancyOptimizer` :meth:`step`
using the gradients in the given :class:`DistributedDataParallel`
overlap_info = zero._overlap_info
# Proceed as normal until the DDP buckets have been rebuilt
if not ddp._has_rebuilt_buckets:
assert overlap_info.status == _OverlapStatus.UNINITIALIZED
return fut
if overlap_info.status == _OverlapStatus.UNINITIALIZED:
overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
# Once DDP buckets have been rebuilt but ZeRO has not been
# properly initialized yet, collect the information needed
if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
bucket_index = bucket.get_index()
rank = zero.global_rank
rank_to_update = zero._ddp_bucket_index_to_rank(bucket_index)
bucket_params = bucket.get_model_params_for_bucket()
assert len(bucket_params) > 0, "Empty bucket"
params_per_rank = overlap_info.params_per_rank
params_per_bucket = overlap_info.params_per_bucket
if rank_to_update == rank:
overlap_info.offsets[bucket_index] = len(params_per_rank[rank_to_update])
params_per_rank[rank_to_update].extend(bucket_params)
params_per_bucket.append(bucket_params)
return fut
def zero_step(fut: torch.futures.Future) -> torch.Tensor:
r"""
Performs a partial :class:`ZeroRedundancyOptimizer` :meth:`step`
using the gradients in the given :class:`DistributedDataParallel`
gradient bucket.
Returns:
A :class:`torch.Tensor` representing the contents of the
gradient bucket.
Returns:
A :class:`torch.Tensor` representing the contents of the
gradient bucket.
"""
# Proceed as normal until the DDP buckets have been rebuilt
if not ddp._has_rebuilt_buckets:
return bucket.get_tensor()
bucket_index = bucket.get_index()
rank = zero.global_rank
overlap_info = zero._overlap_info
if overlap_info.status == _OverlapStatus.UNINITIALIZED:
overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
bucket_params = bucket.get_model_params_for_bucket()
assert len(bucket_params) > 0, "Empty bucket"
rank_to_update = zero._ddp_bucket_index_to_rank(bucket_index)
# Once DDP buckets have been rebuilt but ZeRO has not been
# properly initialized yet, collect the information needed
if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
params_per_rank = overlap_info.params_per_rank
params_per_bucket = overlap_info.params_per_bucket
"""
assert overlap_info.status == _OverlapStatus.INITIALIZED
bucket_index = bucket.get_index()
overlap_info.bucket_indices_seen.append(bucket_index)
rank = zero.global_rank
rank_to_update = zero._ddp_bucket_index_to_rank(bucket_index)
if rank_to_update == rank:
# Construct the `gradients` input for the local optimizer step,
# which expects `None` in a list position to indicate that the
# corresponding parameter should not be updated
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
gradients = [_PARAM_NO_UPDATE for _ in range(num_local_optim_params)]
assert bucket_index in overlap_info.offsets, \
f"Bucket index {bucket_index} was not assigned to rank " \
f"{rank}"
offset = overlap_info.offsets[bucket_index]
bucket_gradients = bucket.get_per_parameter_tensors()
for i, grad in enumerate(bucket_gradients):
gradients[offset + i] = grad
assert bucket_index not in overlap_info.bucket_to_gradients, \
f"Already a gradient list for bucket index {bucket_index}"
# Save the `gradients` input and the all-reduce future
overlap_info.bucket_to_gradients[bucket_index] = gradients
overlap_info.bucket_to_allreduce_future[bucket_index] = fut
# `bucket_index` does not refer to the argument `bucket` 's index
# from this point forward
del bucket_index
num_buckets = len(overlap_info.params_per_bucket)
is_last_bucket = len(overlap_info.bucket_indices_seen) == num_buckets
# Perform partial optimizer step on all buckets
if is_last_bucket:
for bucket_index in range(num_buckets):
rank_to_update = zero._ddp_bucket_index_to_rank(bucket_index)
if rank_to_update == rank:
overlap_info.offsets[bucket_index] = len(params_per_rank[rank_to_update])
params_per_rank[rank_to_update].extend(bucket_params)
params_per_bucket.append(bucket_params)
return bucket.get_tensor()
if rank_to_update == rank:
assert len(zero.optim.param_groups) == 1, \
"Overlapping DDP with ZeRO only supports a single " \
"parameter group"
# Construct the `gradients` input for the local optimizer step,
# which expects `None` in a list position to indicate that the
# corresponding parameter should not be updated
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
gradients: List[Optional[torch.Tensor]] = \
[_PARAM_NO_UPDATE for _ in range(num_local_optim_params)]
assert bucket_index in overlap_info.offsets, \
f"Bucket index {bucket_index} was not assigned to rank " \
f"{rank}"
offset = overlap_info.offsets[bucket_index]
bucket_gradients = bucket.get_per_parameter_tensors()
for i, grad in enumerate(bucket_gradients):
gradients[offset + i] = grad
zero._local_step(gradients)
device = bucket_params[0].device
device_index = zero._device_to_device_index[device]
assert bucket_index in zero._buckets[device_index][rank_to_update]
overlap_info.broadcast_handles.append(
dist.broadcast(
zero._buckets[device_index][rank_to_update][bucket_index],
src=rank_to_update,
async_op=True
assert bucket_index in overlap_info.bucket_to_gradients, \
f"Bucket index {bucket_index} assigned to rank {rank} is not present"
gradients = overlap_info.bucket_to_gradients[bucket_index]
# Ensure that the all-reduce completes before
# performing the parameter update
allreduce_future = overlap_info.bucket_to_allreduce_future[bucket_index]
allreduce_future.wait()
zero._local_step(gradients)
device = overlap_info.params_per_bucket[bucket_index][0].device
device_index = zero._device_to_device_index[device]
assert bucket_index in zero._buckets[device_index][rank_to_update]
overlap_info.broadcast_handles.append(
dist.broadcast(
zero._buckets[device_index][rank_to_update][bucket_index],
src=rank_to_update,
async_op=True
)
)
)
# Zero each parameter's gradient if needed
if zero._zero_grad:
ZeroRedundancyOptimizer._zero_grad(zero._all_params)
# Ensure that all parameter updates are finished before the
# next forward pass
_ = list(map(lambda x: x.wait(), overlap_info.broadcast_handles))
overlap_info.broadcast_handles.clear()
# Reset per-iteration information
overlap_info.bucket_to_gradients.clear()
overlap_info.bucket_to_allreduce_future.clear()
overlap_info.bucket_indices_seen.clear()
return bucket.get_tensor()
return bucket.get_tensor()
return fut.then(zero_step)
......
......@@ -179,14 +179,28 @@ class _OverlapInfo():
assigned to this rank.
broadcast_handles (List[Work]): :class:`list` of async work handles for
the parameter broadcasts.
bucket_to_allreduce_future (Dict[int, torch.futures.Future]):
:class:`dict` mapping bucket index to the corresponding all-reduce
future.
bucket_to_gradients (Dict[int, List[torch.Tensor]]): :class:`dict`
mapping bucket index to the bucket's gradients.
bucket_indices_seen (List[int]): :class:`list` of the bucket indices
seen so far in the iteration.
"""
def __init__(self):
self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
# Modified per bucket reconstruction
self.params_per_bucket: List[List[torch.Tensor]] = []
self.params_per_rank: List[List[torch.Tensor]] = \
[[] for _ in range(dist.get_world_size())]
self.offsets: Dict[int, int] = {}
# Modified per iteration
self.broadcast_handles: List[Any] = []
self.bucket_to_allreduce_future: Dict[int, torch.futures.Future] = {}
self.bucket_to_gradients: Dict[int, List[torch.Tensor]] = {}
self.bucket_indices_seen: List[int] = []
class ZeroRedundancyOptimizer(Optimizer, _Joinable):
......@@ -232,9 +246,9 @@ class ZeroRedundancyOptimizer(Optimizer, _Joinable):
If ``False``, :meth:`step` runs disjointly after the backward pass
(per normal).
(default: ``False``)
use_extra_stream (bool, optional): if ``True``, use a second CUDA
stream for the optimizer computation; if ``False``, only use the
single default stream (default: ``False``).
zero_grad (bool, optional): if ``True``, zeroes gradients after
performing an optimizer step; if ``False``, :meth:`zero_grad()`
should be called by the user (default: ``False``).
**defaults: any trailing arguments, which are forwarded to the local
optimizer.
......@@ -283,7 +297,7 @@ class ZeroRedundancyOptimizer(Optimizer, _Joinable):
process_group: Optional[Any] = None,
parameters_as_bucket_view: bool = False,
overlap_with_ddp: bool = False,
use_extra_stream: bool = False,
zero_grad: bool = False,
**defaults: Any,
):
# Perform type and assumption checks on the input parameters
......@@ -334,14 +348,6 @@ class ZeroRedundancyOptimizer(Optimizer, _Joinable):
)
raise ValueError(error_msg)
self._overlap_with_ddp = overlap_with_ddp
self._use_extra_stream = use_extra_stream
if use_extra_stream:
assert overlap_with_ddp, \
"`use_extra_stream` should only be set to `True` if " \
"`overlap_with_ddp=True`"
if use_extra_stream:
self._bwd_stream = torch.cuda.current_stream(self._default_device)
self._optim_stream = torch.cuda.Stream(self._default_device)
# If `overlap_with_ddp=True`, local optimizer initialization is delayed
# to run time after the necessary information has been collected
......@@ -349,6 +355,7 @@ class ZeroRedundancyOptimizer(Optimizer, _Joinable):
self._init_local_optimizer()
else:
self._overlap_info = _OverlapInfo()
self._zero_grad = zero_grad
# `self._buckets` is used if `parameters_as_bucket_view=True` or
# `overlap_with_ddp=True`, in which case parameter data is flattened
......@@ -848,13 +855,6 @@ class ZeroRedundancyOptimizer(Optimizer, _Joinable):
# initialization of the local optimizer and supporting state
self._init_zero_for_overlap()
# Ensure that all parameter updates are finished before the
# next forward pass
if self._use_extra_stream:
self._bwd_stream.wait_stream(self._optim_stream)
_ = list(map(lambda x: x.wait(), self._overlap_info.broadcast_handles))
self._overlap_info.broadcast_handles.clear()
# `step()` does not actually perform any parameter updates and is
# only used for bookkeeping when `overlap_with_ddp=True`
return None
......@@ -862,6 +862,10 @@ class ZeroRedundancyOptimizer(Optimizer, _Joinable):
# Perform the local optimizer step
loss = self._local_step(closure=closure, **kwargs)
# Zero the gradient before syncing if needed
if self._zero_grad:
self.optim.zero_grad()
# Sync all of the updated parameter shards across the ranks
self._sync_params()
......@@ -1215,6 +1219,23 @@ class ZeroRedundancyOptimizer(Optimizer, _Joinable):
self._sync_param_groups(self.optim.param_groups, self.param_groups)
@staticmethod
def _zero_grad(params: List[torch.Tensor]) -> None:
r"""
Zeroes the gradient of each parameter in ``params``.
Arguments:
params (List[torch.Tensor]): :class:`list` of parameters whose
gradients should be zeroed.
"""
for p in params:
if p.grad is not None:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()
def _init_zero_for_overlap(self) -> None:
r"""
Performs a delayed initialization of the local optimizer and the
......@@ -1227,11 +1248,26 @@ class ZeroRedundancyOptimizer(Optimizer, _Joinable):
self._partition_parameters(self._overlap_info.params_per_rank)
self._build_ddp_param_buckets()
self._init_local_optimizer()
if self._zero_grad:
ZeroRedundancyOptimizer._zero_grad(self._all_params)
def _ddp_bucket_index_to_rank(self, bucket_index: int) -> int:
r"""Assigns a rank to a given bucket index."""
r"""Assigns a rank to a given DDP gradient bucket index."""
return bucket_index % self.world_size
def _get_assigned_ddp_bucket_indices(self) -> List[int]:
r"""
Returns a list of the DDP gradient bucket indices assigned to this rank
to update.
"""
assert self._overlap_info.status == _OverlapStatus.INITIALIZED
num_buckets = len(self._overlap_info.params_per_bucket)
assigned_indices = [
bucket_index for bucket_index in range(num_buckets)
if self._ddp_bucket_index_to_rank(bucket_index) == self.global_rank
]
return assigned_indices
def _check_overlap_initialized(self):
r"""
Checks that the delayed initialization has occurred (see
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册