def smp_lossgather()

in distributed_training/src_dir/dis_util.py [0:0]


def smp_lossgather(loss, args):
    if args.use_horovod or args.use_ddp:
        # Rubik: If using data parallelism, gather all losses across different model
        # replicas and check if losses match.

        losses = smp.allgather(loss, smp.DP_GROUP)
        for l in losses:
            assert math.isclose(l, losses[0])

        assert loss < 0.14
    else:
        assert loss < 0.08