in fairscale/nn/data_parallel/sharded_ddp.py [0:0]
def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable:
"""
Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
or contribute to a bucket and reduce when the bucket is full.
Either way a delayed action is necessary and is passed as a callback.
"""
if not self._use_buckets or not self._should_bucket_grad[index]:
# Direct reduction
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set:
Variable._execution_engine.queue_callback(self._flush_reduce_calls)
self._bucket_flush_callback_set = True
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
param.grad.mul_(self._world_size_scaling)
if self._reduce_fp16:
param.grad.data = param.grad.data.half()
# Future work includes clearing up the buffer if possible
def cleanup() -> None:
if dst_rank != self._global_rank:
param.grad = None
else:
assert param.grad is not None
param.grad.data = param.grad.data.to(dtype=param.dtype)
# Async reduce for this buffer, log the future
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data,
dst=self._local_to_global_rank[dst_rank],
group=self._process_group,
async_op=True,
),
callback=cleanup,
)
)
# Opportunistically try to empty the queue, free memory
self._try_consume_work_handle()
else:
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set:
Variable._execution_engine.queue_callback(self._flush_reduce_calls)
self._bucket_flush_callback_set = True
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
bucket = self._buckets[param.device][dst_rank]
bucket.params_checked_in += 1
if bucket.all_checked_in:
assert bucket.buffer is not None
# Normalize the bucket in one go
bucket.buffer.mul_(self._world_size_scaling)
# Reduce the bucket
bucket.sent = True
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer,
dst=bucket.destination,
group=self._process_group,
async_op=True,
),
callback=None,
)
)
# Opportunistically try to empty the queue
self._try_consume_work_handle()
return reduce