def _get_reduce_fn()

in fairscale/nn/data_parallel/sharded_ddp.py [0:0]


    def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int) -> Callable:
        """
        Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
        or contribute to a bucket and reduce when the bucket is full.

        Either way a delayed action is necessary and is passed as a callback.
        """

        if not self._use_buckets or not self._should_bucket_grad[index]:
            # Direct reduction
            @torch.no_grad()
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
                if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                    if not self._bucket_flush_callback_set:
                        Variable._execution_engine.queue_callback(self._flush_reduce_calls)
                        self._bucket_flush_callback_set = True

                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    param.grad.mul_(self._world_size_scaling)

                    if self._reduce_fp16:
                        param.grad.data = param.grad.data.half()

                    # Future work includes clearing up the buffer if possible
                    def cleanup() -> None:
                        if dst_rank != self._global_rank:
                            param.grad = None
                        else:
                            assert param.grad is not None
                            param.grad.data = param.grad.data.to(dtype=param.dtype)

                    # Async reduce for this buffer, log the future
                    self._work_handles.append(
                        Workhandle(
                            handle=dist.reduce(
                                tensor=param.grad.data,
                                dst=self._local_to_global_rank[dst_rank],
                                group=self._process_group,
                                async_op=True,
                            ),
                            callback=cleanup,
                        )
                    )

                    # Opportunistically try to empty the queue, free memory
                    self._try_consume_work_handle()

        else:

            @torch.no_grad()
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags

                if not self._should_accumulate_grads and self._grad_to_be_reduced[index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                    if not self._bucket_flush_callback_set:
                        Variable._execution_engine.queue_callback(self._flush_reduce_calls)
                        self._bucket_flush_callback_set = True

                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    bucket = self._buckets[param.device][dst_rank]
                    bucket.params_checked_in += 1

                    if bucket.all_checked_in:
                        assert bucket.buffer is not None

                        # Normalize the bucket in one go
                        bucket.buffer.mul_(self._world_size_scaling)

                        # Reduce the bucket
                        bucket.sent = True
                        self._work_handles.append(
                            Workhandle(
                                handle=dist.reduce(
                                    tensor=bucket.buffer,
                                    dst=bucket.destination,
                                    group=self._process_group,
                                    async_op=True,
                                ),
                                callback=None,
                            )
                        )

                    # Opportunistically try to empty the queue
                    self._try_consume_work_handle()

        return reduce