def all_reduce()

in train/comms/pt/pytorch_dist_backend.py [0:0]


    def all_reduce(self, collectiveArgs, retFlag=False, pair=False):
        # pair=True mode does not support quantization
        if (
            collectiveArgs.allreduce_qcomm != 32
            and collectiveArgs.allreduce_qcomm > 4
            and collectiveArgs.ipTensor.dtype == torch.float32
            and not pair
        ):
            # note: note that quantized is a new tensor
            # that is not collectiveArgs.ipTensor.
            # this means when all_reduce/reduce finished
            # quantized will hold the result instead of collectiveArgs.ipTensor
            # this is intended because we don't want to allocate new buffers
            # every time we call all_reduce (because if we don't, it will be float16 instead of float32).
            # That also means we can't use the output of  quantized all_reduce's for anything other than
            # benchmarking purpose.
            with paramProfile(
                timer=collectiveArgs.quant_time,
                description="# PARAM: Allreduce quantization #",
            ):
                quantized = _downcast(
                    collectiveArgs.ipTensor, collectiveArgs.allreduce_qcomm
                )
        else:
            quantized = (
                collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair
            )
        retObj = dist.all_reduce(
            quantized,
            op=collectiveArgs.op,
            group=collectiveArgs.group,
            async_op=collectiveArgs.asyncOp,
        )  # synchronicity is maintained in runColl
        if (id(quantized) != id(collectiveArgs.ipTensor)) and not pair:
            if collectiveArgs.asyncOp:
                retObj = retObj.get_future().then(_dequantize)
            else:
                with paramProfile(
                    timer=collectiveArgs.dequant_time,
                    description="# PARAM: Allreduce de-quantization #",
                ):
                    retObj = _dequantize(quantized)

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(retObj)

        if retFlag:
            return retObj