def prepComms()

in train/comms/pt/commsTraceReplay.py [0:0]


    def prepComms(self, curComm, commsParams):
        commOp = paramToCommName(curComm["comms"])
        if commOp in ("wait", "barrier"):
            return ([], [])

        # for all_to_allv, we can shrink the size if running on smaller scale
        # this is for sanity test or debug purpose only since we don't always get to run very large scale
        if self.shrink:
            cur_world_size = self.collectiveArgs.world_size
            real_world_size = cur_world_size

            if "world_size" in curComm.keys():
                real_world_size = curComm["world_size"]
            else:
                # if the trace does not record world size, we may use a2av splits to infer it
                if commOp == "all_to_allv":
                    in_split_len = len(curComm["in_split"])
                    out_split_len = len(curComm["out_split"])
                    if in_split_len > 0:
                        real_world_size = in_split_len
                    elif out_split_len > 0:
                        real_world_size = out_split_len

            newNumElemsIn = (curComm["in_msg_size"] // real_world_size) * cur_world_size
            newNumElemsOut = (
                curComm["out_msg_size"] // real_world_size
            ) * cur_world_size

            if commOp == "all_to_allv":
                curComm["out_split"] = (
                    curComm["out_split"][:cur_world_size]
                    if ("out_split" in curComm.keys())
                    else []
                )
                curComm["in_split"] = (
                    curComm["in_split"][:cur_world_size]
                    if ("in_split" in curComm.keys())
                    else []
                )
                if len(curComm["in_split"]) > 0:
                    newNumElemsIn = sum(curComm["in_split"])
                if len(curComm["out_split"]) > 0:
                    newNumElemsOut = sum(curComm["out_split"])
            elif commOp == "all_gather":
                newNumElemsOut = newNumElemsIn * cur_world_size

            curComm["in_msg_size"] = newNumElemsIn
            curComm["out_msg_size"] = newNumElemsOut

            logger.debug(
                f"shrink message sizes to curInNumElem {curComm['in_msg_size']}, curOutNumElem {curComm['out_msg_size']}"
            )

        commsParams.dtype = self.strToTorchDtype[curComm["dtype"]]
        # allocate and return tensors
        return super().prepComm(curComm, commsParams)