def all_to_allv()

in train/comms/pt/pytorch_dist_backend.py [0:0]


    def all_to_allv(self, collectiveArgs, retFlag=False, pair=False):
        # pair=True mode does not support quantization
        if (
            collectiveArgs.all2all_qcomm
            and collectiveArgs.ipTensor.dtype == torch.float32
            and (
                collectiveArgs.opTensor.nelement() >= collectiveArgs.quant_threshold
                or collectiveArgs.ipTensor.nelement() >= collectiveArgs.quant_threshold
            )
            and not pair
        ):
            work = all_to_allv_internal(collectiveArgs)
        else:
            work = dist.all_to_all_single(
                collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair,
                collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair,
                collectiveArgs.opTensor_split
                if not pair
                else collectiveArgs.opTensor_split_pair,
                collectiveArgs.ipTensor_split
                if not pair
                else collectiveArgs.ipTensor_split_pair,
                group=collectiveArgs.group,
                async_op=collectiveArgs.asyncOp,
            )

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(work)

        if retFlag:
            return work