def _Run()

in perfkitbenchmarker/linux_benchmarks/nvidia_mlperf_benchmark.py [0:0]


def _Run(benchmark_spec, additional_params=''):
  """Run mlperf training with additional parameters to override defaults."""
  controller = slurm.GetController(benchmark_spec.vms)
  gpus_per_node = nvidia_driver.QueryNumberOfGpus(controller)
  num_vms = len(benchmark_spec.vms)
  stdout, _ = controller.RemoteCommand(
      f'cd {_GetDir()}; rm slurm*out; '
      'source config_common.sh; source config_fp8.sh; '
      # Debugging Only
      # 'RUN_ONLY_NCCL=1 '
      'NCCL_DEBUG_SUBSYS=INIT,ENV '
      'NCCL_DEBUG=INFO '
      # PKB params
      'DGXSYSTEM=pkb '
      'NEXP=1  '
      'SEED=1 '
      'SLURM_MPI_TYPE=pmi2 '
      'NCCL_LLM_TEST=0 '
      'HANG_MONITOR_TIMEOUT=0 '
      f'DGXNGPU={gpus_per_node} '
      f'DGXNNODES={num_vms} '
      f'WORLD_SIZE={gpus_per_node * num_vms} '
      'TP_COMM_OVERLAP=True '
      'CONT="dockerd://mlperf-nvidia:gpt3" '
      f'LOGDIR={controller.GetScratchDir()}/output '
      f'PREPROC_DATA={controller.GetScratchDir()}/mlperf-llm-public2/c4/preprocessed_c4_spm/ '
      f'SPM={controller.GetScratchDir()}/sentencepiece.model  '
      'NUMBA_CACHE_DIR=/tmp/numba '
      'NPY_INDEX_DIR=/tmp/npy '
      'MPLCONFIGDIR=/tmp/mplconfigdir '
      'TRANSFORMERS_CACHE=/tmp/transformers_cache '
      # Checkpoint flags, set to empty folder
      # Since we are not running original 175B model, not using checkpoints.
      'INIT_GLOBAL_STEP=1 '
      'LOAD_CHECKPOINT= '
      f'LOAD_CHECKPOINTS_PATH={controller.GetScratchDir()}/checkpoint/ '
      # Tuning params: for 5B (2x8 GPUs)
      'TENSOR_MODEL_PARALLEL=2 '
      'PIPELINE_MODEL_PARALLEL=1 '
      'MICRO_BATCH_SIZE=4 '
      'MINIBS=256 '  # Should be dynamic
      f'MAX_STEPS={STEPS.value} '
      # Default Model parameters: 5B
      'NUM_LAYERS=24 '
      'HIDDEN_SIZE=4096 '
      'NUM_ATTENTION_HEADS=32 '
      # Other params
      'INTERLEAVED_PIPELINE=null '
      'SEQ_PARALLEL=False '
      'BUCKET_CAP_MB=200 '
      f'VAL_CHECK_INTERVAL={STEPS.value} '
      'LIMIT_VAL_BATCHES=0.0 '
      f'LIMIT_TRAIN_BATCHES={STEPS.value} '
      'CHECK_COMPLIANCE=0 '
      # TODO(yuyanting) Set timeout based on steps, model parameters.
      # Difficult to estimate how long does it take at runtime, set to 60 mins
      # for now.
      f'{_MLPERF_ENV.value.replace(";", " ")} '
      f'{additional_params.replace(";", " ")} '
      f'sbatch -N {num_vms} -t 60 run.sub'
  )
  job_id = regex_util.ExtractInt(r'Submitted batch job (\d+)', stdout)
  output_file = f'{_GetDir()}/slurm-{job_id}.out'
  results = []
  while True:
    # Check status and backup output every minute.
    time.sleep(60)
    controller.PullFile(vm_util.GetTempDir(), output_file)
    vm_util.IssueCommand([
        'mv',
        os.path.join(vm_util.GetTempDir(), f'slurm-{job_id}.out'),
        os.path.join(vm_util.GetTempDir(), f'slurm-{job_id}.log'),
    ])
    if not slurm.Running(controller):
      break
  metadata = {
      'gpus_per_node': gpus_per_node,
      'num_nodes': num_vms,
      'total_gpus': gpus_per_node * num_vms,
  }
  metadata.update(_GetMetadata(additional_params))
  for metric in _MLPERF_METRICS:
    try:
      lines, _ = controller.RemoteCommand(
          f'cat {output_file} | grep MLLOG | grep {metric}'
      )
      values = [
          float(json.loads(line.split('MLLOG')[-1])['value'][metric])
          for line in lines.strip().splitlines()
      ]
      results.append(
          sample.Sample(
              metric,
              sum(values) / len(values),
              _MLPERF_METRICS[metric],
              metadata,
          )
      )
    except errors.VirtualMachine.RemoteCommandError:
      logging.error(
          'Failed to parse %s, find slurm-%s.log for more info.', metric, job_id
      )
    logging.info(results)
    # Some expected to fail during parameter sweep, certain configurations
    # do not fit in GPU memory.
  return results