def main()

in benchmarks/dlrm/ubench/dlrm_ubench_comms_driver.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="comms.py driver")
    parser.add_argument(
        "--size",
        type=str,
        default="small",
    )
    parser.add_argument(
        "--backend",
        type=str,
        default=("nccl"),
        choices=["nccl", "gloo", "mpi", "ucc", "xla"],
    )
    parser.add_argument(
        "--collective",
        type=str,
        default=("all_to_all"),
        choices=["all_to_all", "all_reduce"],
    )
    parser.add_argument("--fb5logger", type=str, default=None)
    args = parser.parse_args()

    if args.size not in ["small", "medium", "large"] and not (
        args.size.isdigit() and int(args.size) > 0
    ):
        sys.exit("The --size argument provided is not a valid positive integer.")

    lookup = {
        "small": 2200 if args.collective == "all_reduce" else 134000000,
        "medium": 9944 if args.collective == "all_reduce" else 244000000,
        "large": 22372 if args.collective == "all_reduce" else 544000000,
        str(2200): "small" if args.collective == "all_reduce" else 2200,
        str(9944): "medium" if args.collective == "all_reduce" else 9944,
        str(22372): "large" if args.collective == "all_reduce" else 22372,
        str(134000000): "small" if args.collective == "all_to_all" else 134000000,
        str(244000000): "medium" if args.collective == "all_to_all" else 244000000,
        str(544000000): "large" if args.collective == "all_to_all" else 544000000,
    }
    (x, y) = (args.size, lookup.get(args.size, args.size))
    (size, name) = (x, y) if args.size.isdigit() else (y, x)

    master_ip = "localhost"
    num_compute_per_collective = 100
    mm_dim = 1000
    num_iter = 100

    cmd = f"""
        --f 2
        --n {num_iter}
        --master-ip {master_ip}
        --master-port 22565
        --collective {args.collective}
        --b {size}
        --e {size}
        --num-compute {num_compute_per_collective}
        --mm-dim {mm_dim}
        --backend {args.backend}
    """
    sys.argv = cmd.replace("\n", " ").replace("  ", "").split()

    fb5logger = FB5Logger(args.fb5logger)
    fb5logger.header(
        "DLRM",
        "UBENCH",
        "train",
        "comms_" + args.collective.replace("_", "") + "_" + name,
        score_metric=loggerconstants.GBPS,
    )

    mpi_env_params = comms_utils.read_comms_env_vars()
    print("This process's MPI global rank: ", mpi_env_params["global_rank"])
    comms_stdout = io.StringIO()
    with contextlib.redirect_stdout(comms_stdout):
        if mpi_env_params["global_rank"] == 0:
            fb5logger.run_start()
        comms_main()

    if mpi_env_params["global_rank"] == 0:
        print(comms_stdout.getvalue())        
        output = comms_stdout.getvalue().split("\n")[-3:]
        output = [" ".join(line.split()).split() for line in output]
        output[0].pop(2)
        output[1].insert(3, "")
        output[0][3] = "Latency(us):"
        output[0].insert(4, "p50")
        extra_metadata = {}
        for a, b in zip(output[0], output[1]):
            extra_metadata[a.lstrip()] = b.lstrip()
        fb5logger.run_stop(
            num_batches=num_iter, batch_size=None, extra_metadata=extra_metadata
        )
        print("-- Pretty Format --")
        for a, b in zip(output[0], output[1]):
            print("{:<18s}{:>4s}".format(a.lstrip(), b.lstrip()))