in distributed_training/src_dir/dis_util.py [0:0]
def smp_lossgather(loss, args):
if args.use_horovod or args.use_ddp:
# Rubik: If using data parallelism, gather all losses across different model
# replicas and check if losses match.
losses = smp.allgather(loss, smp.DP_GROUP)
for l in losses:
assert math.isclose(l, losses[0])
assert loss < 0.14
else:
assert loss < 0.08