def dcheck()

in train/comms/pt/comms_utils.py [0:0]


    def dcheck(self, commsParams, curSize, tensor):
        expRes = self.initVal
        if (
            commsParams.collective
            in ("all_reduce", "reduce_scatter", "reduce_scatter_base")
        ) or (
            self.backendFuncs.get_global_rank() == commsParams.srcOrDst
            and commsParams.collective == "reduce"
        ):
            # NOTE: this is for sum op. and the inital value is "self.initVal"
            expRes = self.collectiveArgs.world_size * self.initVal

        if (
            # Check results for incast only on root
            commsParams.collective == "incast"
            and self.backendFuncs.get_global_rank() != commsParams.srcOrDst
        ) or (
            # Check results of multicast only for dst_ranks
            commsParams.collective in ("multicast", "pt2pt")
            and self.backendFuncs.get_global_rank() not in commsParams.dst_ranks
        ):
            return

        if isinstance(tensor, list):
            # for allgather and incast, it's a list of tensors:
            for (rank, t) in enumerate(tensor):
                if not torch.all(torch.eq(t, expRes)):
                    for (index, val) in enumerate(t):
                        if val != expRes:
                            raise ValueError(
                                f"[{curSize}-bytes {commsParams.collective}] Wrong value at [{rank}][{index}] = {t[index]}, expected {expRes}\n {tensor}"
                            )
        else:
            if not torch.all(torch.eq(tensor, expRes)):
                for (index, val) in enumerate(tensor):
                    if val != expRes:
                        raise ValueError(
                            f"[{curSize}-bytes {commsParams.collective}] Wrong value at [{index}] = {tensor[index]}, expected {expRes}\n {tensor}"
                        )