in train/comms/pt/comms_utils.py [0:0]
def dcheck(self, commsParams, curSize, tensor):
expRes = self.initVal
if (
commsParams.collective
in ("all_reduce", "reduce_scatter", "reduce_scatter_base")
) or (
self.backendFuncs.get_global_rank() == commsParams.srcOrDst
and commsParams.collective == "reduce"
):
# NOTE: this is for sum op. and the inital value is "self.initVal"
expRes = self.collectiveArgs.world_size * self.initVal
if (
# Check results for incast only on root
commsParams.collective == "incast"
and self.backendFuncs.get_global_rank() != commsParams.srcOrDst
) or (
# Check results of multicast only for dst_ranks
commsParams.collective in ("multicast", "pt2pt")
and self.backendFuncs.get_global_rank() not in commsParams.dst_ranks
):
return
if isinstance(tensor, list):
# for allgather and incast, it's a list of tensors:
for (rank, t) in enumerate(tensor):
if not torch.all(torch.eq(t, expRes)):
for (index, val) in enumerate(t):
if val != expRes:
raise ValueError(
f"[{curSize}-bytes {commsParams.collective}] Wrong value at [{rank}][{index}] = {t[index]}, expected {expRes}\n {tensor}"
)
else:
if not torch.all(torch.eq(tensor, expRes)):
for (index, val) in enumerate(tensor):
if val != expRes:
raise ValueError(
f"[{curSize}-bytes {commsParams.collective}] Wrong value at [{index}] = {tensor[index]}, expected {expRes}\n {tensor}"
)