def prepComm()

in train/comms/pt/comms_utils.py [0:0]


    def prepComm(self, curComm, commsParams):
        """Allocate the tensors for collective"""
        commOp = paramToCommName(
            curComm["comms"] if ("comms" in curComm.keys()) else commsParams.collective,
            supported_comms=self.backendFuncs.collectiveFunc.keys(),
        )

        if commOp in ("wait", "barrier"):
            return ([], [])

        numElementsIn = curComm["in_msg_size"]
        # numElementsOut is only meaningful for out-of-place collectives and pt2pt
        numElementsOut = curComm["out_msg_size"]
        world_size = self.collectiveArgs.world_size
        dtype = commsParams.dtype
        curDevice = commsParams.device
        # scaleFactor = 1 if commsParams.collective == "all_to_all" else numElements * numElements
        scaleFactor = numElementsOut * numElementsOut
        opTensor = []

        if commsParams.dcheck == 1:
            # use predictable values for data validation check
            ipTensor = self.backendFuncs.alloc_ones(
                [numElementsIn], curDevice, dtype, scaleFactor=self.initVal
            )
        else:
            ipTensor = self.backendFuncs.alloc_random(
                [numElementsIn], curDevice, dtype, scaleFactor
            )

        if commOp == "all_to_allv":
            # all_to_all(v) requires two tensors
            opTensor = self.backendFuncs.alloc_random(
                [numElementsOut], curDevice, dtype, scaleFactor
            )
            # all_to_allv requires tensors to specify split
            self.collectiveArgs.opTensor_split = (
                curComm["out_split"] if ("out_split" in curComm.keys()) else []
            )
            self.collectiveArgs.ipTensor_split = (
                curComm["in_split"] if ("in_split" in curComm.keys()) else []
            )
        elif commOp == "all_gather":
            # allgather requires a tensor list, e.g., List[torch.Tensor]
            for _ in range(world_size):
                opTensor.append(
                    self.backendFuncs.alloc_random(
                        [numElementsIn], curDevice, dtype, scaleFactor
                    )
                )
        elif commOp == "all_gather_base":
            # this is a single all gather with flat output tensor
            opTensor = self.backendFuncs.alloc_random(
                numElementsIn * world_size,
                curDevice,
                dtype,
                scaleFactor,
            )
        elif commOp == "incast":
            # incast requires a tensor list with length of src_ranks, e.g., List[torch.Tensor]
            for _ in self.collectiveArgs.src_ranks:
                opTensor.append(
                    self.backendFuncs.alloc_random(
                        [numElementsOut], curDevice, dtype, scaleFactor
                    )
                )
        elif commOp == "reduce_scatter":
            ipTensor = []
            if commsParams.dcheck == 1:
                for _ in range(world_size):
                    ipTensor.append(
                        self.backendFuncs.alloc_ones(
                            [numElementsOut], curDevice, commsParams.dtype, self.initVal
                        )
                    )
            else:
                for _ in range(world_size):
                    ipTensor.append(
                        self.backendFuncs.alloc_random(
                            [numElementsOut], curDevice, commsParams.dtype, scaleFactor
                        )
                    )
            opTensor = self.backendFuncs.alloc_random(
                [numElementsOut], curDevice, dtype, scaleFactor
            )
        elif commOp == "reduce_scatter_base":
            ipTensor = []
            if commsParams.dcheck == 1:
                ipTensor = self.backendFuncs.alloc_ones(
                    numElementsOut * world_size,
                    curDevice,
                    commsParams.dtype,
                    self.initVal,
                )
            else:
                ipTensor = self.backendFuncs.alloc_random(
                    numElementsOut * world_size,
                    curDevice,
                    commsParams.dtype,
                    scaleFactor,
                )
            opTensor = self.backendFuncs.alloc_random(
                [numElementsOut], curDevice, dtype, scaleFactor
            )
        elif commOp in ("all_to_all", "pt2pt"):
            # pt2pt or out-of-place collectives
            opTensor = self.backendFuncs.alloc_random(
                [numElementsOut],
                curDevice,
                dtype,
                scaleFactor,
            )
        else:
            # in-place case for other collectives such as allreduce, reduce, broadcast
            opTensor = ipTensor

        return (ipTensor, opTensor)