cli/healthscan.py (133 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CLI for running healthscan on a cluster.
This is part of the larger cluster_diag CLI. To get the full helpstring, run
`cluster_diag healthscan --help`.
"""
import click
import common
import gpu_check
import nccl_check
import neper_check
import status
import straggler_check
import tinymax_check
_SUPPORTED_MACHINE_TYPES = list(common.SUPPORTED_MACHINE_TYPES)
_SUPPORTED_HEALTHCHECKS = [
status.NAME,
nccl_check.NAME,
gpu_check.NAME,
straggler_check.NAME,
neper_check.NAME,
tinymax_check.NAME,
]
def _get_partition_for_machine(machine_type: str) -> str | None:
"""Returns the partition for the given orchestrator."""
match machine_type:
case 'a3-highgpu-8g':
return 'a3'
case 'a3-megagpu-8g':
return 'a3mega'
case 'a3-ultragpu-8g':
return 'a3ultra'
case 'a4-highgpu-8g':
return 'a4'
case _:
raise ValueError(f'Unsupported machine type: {machine_type}')
@click.command(name='healthscan')
@click.argument(
'machine_type',
type=click.Choice(_SUPPORTED_MACHINE_TYPES, case_sensitive=False),
)
@click.option(
'-c',
'--check',
type=click.Choice(_SUPPORTED_HEALTHCHECKS),
default=_SUPPORTED_HEALTHCHECKS[0],
help="""
Check to run. Available checks:
\b
- status: (Default) Checks the current healthscan status of the cluster.
- nccl: Runs a pairwise NCCL bandwidth test on the cluster.
- gpu: Runs a GPU check on the cluster.
- straggler: Instruments a straggler check on the cluster.
- neper: Runs a Network Performand eval on the cluster.
- tinymax: Runs a ml-framework TinyMax test on the cluster.
""",
)
@click.option(
'-n',
'--nodes',
multiple=True,
default=[],
help=(
'Nodes to run checks on. Defaults to running on all nodes. When using'
' slurm, a shortened node format can be used. For example, "node-[0-1]"'
),
)
@click.option(
'--run_only_on_available_nodes',
default=False,
is_flag=True,
help="""
Force running the healthcheck only on available nodes.
Unavailable nodes will be skipped.""",
)
@click.option(
'--dry_run',
default=False,
is_flag=True,
help="""
Run the healthcheck in dry run mode.
This will print the commands that would be run, but not run them.""",
)
@click.pass_context
def cli(
ctx: click.Context,
machine_type: str,
check: str,
nodes: list[str],
run_only_on_available_nodes: bool,
dry_run: bool,
):
"""Run a healthscan on a cluster."""
orchestrator = ctx.obj['orchestrator']
check_runner = None
partition = None
if orchestrator == 'slurm':
partition = _get_partition_for_machine(machine_type)
match check:
case nccl_check.NAME:
check_runner = nccl_check.get_check_for_orchestrator(
orchestrator=orchestrator,
machine_type=machine_type,
partition=partition,
nodes=nodes,
run_only_on_available_nodes=run_only_on_available_nodes,
dry_run=dry_run,
)
case gpu_check.NAME:
check_runner = gpu_check.get_check_for_orchestrator(
orchestrator=orchestrator,
machine_type=machine_type,
partition=partition,
nodes=nodes,
run_only_on_available_nodes=run_only_on_available_nodes,
dry_run=dry_run,
)
case straggler_check.NAME:
check_runner = straggler_check.get_check_for_orchestrator(
orchestrator=orchestrator,
machine_type=machine_type,
nodes=nodes,
run_only_on_available_nodes=run_only_on_available_nodes,
dry_run=dry_run,
)
case neper_check.NAME:
check_runner = neper_check.get_check_for_orchestrator(
orchestrator=orchestrator,
machine_type=machine_type,
nodes=nodes,
run_only_on_available_nodes=run_only_on_available_nodes,
dry_run=dry_run,
)
case tinymax_check.NAME:
check_runner = tinymax_check.get_check_for_orchestrator(
orchestrator=orchestrator,
machine_type=machine_type,
nodes=nodes,
run_only_on_available_nodes=run_only_on_available_nodes,
dry_run=dry_run,
)
case status.NAME:
check_runner = status.get_check_for_orchestrator(
orchestrator=orchestrator,
machine_type=machine_type,
nodes=nodes,
)
if check_runner:
check_runner.set_up()
check_runner.run()
check_runner.clean_up()