def main()

in src/straggler_healthcheck/pp_benchmark_runner.py [0:0]


def main(_) -> None:
  message_sizes_mb = [
      int(msg_size_mb_str) for msg_size_mb_str in (_MESSAGE_SIZES_MB.value)
  ]
  device_id = _OMPI_COMM_WORLD_LOCAL_RANK.value
  os.environ["CUDA_VISIBLE_DEVICES"] = device_id

  init_method_hostname = _MAIN_ADDRESS.value
  init_method = f"tcp://{init_method_hostname}:2379"

  n_nodes = _N_NODES.value
  n_gpus_per_node = _N_GPUS_PER_NODE.value
  world_size = n_nodes * n_gpus_per_node
  world_rank = _OMPI_COMM_WORLD_RANK.value

  print("initializing process group")
  dist.init_process_group(
      backend="nccl",
      rank=world_rank,
      world_size=world_size,
      init_method=init_method,
  )

  rank = dist.get_rank()

  machine = socket.getaddrinfo(socket.gethostname(), None)
  if _HOSTNAME.value:
    hostname = _HOSTNAME.value
  else:
    hostname = socket.gethostname()
  print(
      f"rank: {rank}, local_rank: {device_id} server: {machine}, hostname:"
      f" {hostname}"
  )

  for message_size_mb in message_sizes_mb:
    pp_benchmark.run_pp_benchmark(
        hostname=hostname,
        output_dir=_OUTPUT_DIR.value,
        message_size_mb=message_size_mb,
        n_gpus_per_node=n_gpus_per_node,
        n_nodes=n_nodes,
        n_batch=_N_BATCH.value,
        n_microbatch=_N_MICROBATCH.value,
        bidirectional=_BIDIRECTIONAL.value,
    )