perfkitbenchmarker/cluster.py (212 lines of code) (raw):
# Copyright 2025 PerfKitBenchmarker Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class to represent a Cluster object."""
import typing
from typing import Callable, List, Tuple
from absl import flags
from absl import logging
from perfkitbenchmarker import disk
from perfkitbenchmarker import errors
from perfkitbenchmarker import linux_virtual_machine
from perfkitbenchmarker import resource
from perfkitbenchmarker import static_virtual_machine
from perfkitbenchmarker import virtual_machine
from perfkitbenchmarker import vm_util
from perfkitbenchmarker.configs import option_decoders
from perfkitbenchmarker.configs import spec
from perfkitbenchmarker.configs import vm_group_decoders
FLAGS = flags.FLAGS
TEMPLATE_FILE = flags.DEFINE_string(
'cluster_template_file',
None,
'The template file to be used to create the cluster. None by default, '
'each provider has a default template file.',
)
UNMANAGED = flags.DEFINE_boolean(
'cluster_unmanaged_provision',
False,
'Instead of creating with cluster toolset, relying on cloud provider CLI.'
' e.g. gcloud for gcp, awscli for aws.'
)
TYPE = flags.DEFINE_string(
'cluster_type',
'default',
'Type of cluster to use. Chances are clusters vary quite differently and '
'may as well use its own template.'
)
class BaseClusterSpec(spec.BaseSpec):
"""Storing various data about HPC/ML cluster.
Attributes:
zone: The region / zone the in which to launch the cluster.
machine_type: The provider-specific instance type (e.g. n1-standard-8).
image: The disk image to boot from.
"""
SPEC_TYPE = 'BaseClusterSpec'
SPEC_ATTRS = ['CLOUD']
CLOUD = None
@classmethod
def _ApplyFlags(cls, config_values, flag_values):
"""Overrides config values with flag values.
Can be overridden by derived classes to add support for specific flags.
Args:
config_values: dict mapping config option names to provided values. Is
modified by this function.
flag_values: flags.FlagValues. Runtime flags that may override the
provided config values.
Returns:
dict mapping config option names to values derived from the config
values or flag values.
"""
super()._ApplyFlags(config_values, flag_values)
if flag_values['cloud'].present:
config_values['cloud'] = flag_values.cloud
if flag_values['cluster_template_file'].present:
config_values['template'] = flag_values.cluster_template_file
if flag_values['cluster_unmanaged_provision'].present:
config_values['unmanaged'] = flag_values.cluster_unmanaged_provision
cloud = config_values['cloud']
# only apply to workers
if flag_values['num_vms'].present:
config_values['workers']['vm_count'] = flag_values['num_vms'].value
# flags should be applied to workers and headnode
if flag_values['zone'].present:
config_values['workers']['vm_spec'][cloud]['zone'] = flag_values[
'zone'
].value[0]
config_values['headnode']['vm_spec'][cloud]['zone'] = flag_values[
'zone'
].value[0]
for flag_name in ('os_type', 'cloud'):
if flag_values[flag_name].present:
config_values['workers'][flag_name] = flag_values[flag_name].value
config_values['headnode'][flag_name] = flag_values[flag_name].value
@classmethod
def _GetOptionDecoderConstructions(cls):
"""Gets decoder classes and constructor args for each configurable option.
Can be overridden by derived classes to add options or impose additional
requirements on existing options.
Returns:
dict. Maps option name string to a (ConfigOptionDecoder class, dict) pair.
The pair specifies a decoder class and its __init__() keyword
arguments to construct in order to decode the named option.
"""
result = super()._GetOptionDecoderConstructions()
result.update({
'workers': (vm_group_decoders.VmGroupSpecDecoder, {}),
'headnode': (vm_group_decoders.VmGroupSpecDecoder, {}),
'cloud': (option_decoders.StringDecoder, {'default': None}),
'template': (option_decoders.StringDecoder, {'default': None}),
'unmanaged': (option_decoders.BooleanDecoder, {'default': False}),
})
return result
def GetClusterSpecClass(cloud: str):
"""Returns the cluster spec class corresponding to the given service."""
return spec.GetSpecClass(BaseClusterSpec, CLOUD=cloud)
def GetClusterClass(cloud: str):
"""Returns the cluster spec class corresponding to the given service."""
if UNMANAGED.value:
return BaseCluster
return resource.GetResourceClass(BaseCluster, CLOUD=cloud, TYPE=TYPE.value)
class BaseCluster(resource.BaseResource):
"""Base class for cluster resources.
This class holds cluster-level methods and attributes.
Attributes:
image: The disk image used to boot.
machine_type: The provider-specific instance type for worker VMs.
zone: The region / zone the VM was launched in.
headnode_vm: The headnode VM.
worker_vms: Internal IP address.
"""
RESOURCE_TYPE = 'BaseCluster'
TYPE = 'default'
REQUIRED_ATTRS = ['CLOUD', 'TYPE']
DEFAULT_TEMPLATE = ''
def __init__(self, cluster_spec: BaseClusterSpec):
"""Initialize BaseCluster class.
Args:
cluster_spec: cluster.BaseBaseClusterSpec object.
"""
super().__init__()
self.zone: str = cluster_spec.workers.vm_spec.zone
self.machine_type: str = cluster_spec.workers.vm_spec.machine_type
self.template: str = cluster_spec.template or self.DEFAULT_TEMPLATE
self.unmanaged: bool = cluster_spec.unmanaged
self.spec: BaseClusterSpec = cluster_spec
self.worker_machine_type: str = self.machine_type
self.headnode_machine_type: str = cluster_spec.headnode.vm_spec.machine_type
self.headnode_spec: virtual_machine.BaseVmSpec = (
cluster_spec.headnode.vm_spec
)
self.image: str = cluster_spec.workers.vm_spec.image
self.workers_spec: virtual_machine.BaseVmSpec = cluster_spec.workers.vm_spec
self.workers_static_disk_spec: disk.BaseDiskSpec = (
cluster_spec.workers.disk_spec
)
self.workers_static_disk: static_virtual_machine.StaticDisk | None = (
static_virtual_machine.StaticDisk(self.workers_static_disk_spec)
if self.workers_static_disk_spec
else None
)
self.os_type: str = cluster_spec.workers.os_type
self.num_workers: int = cluster_spec.workers.vm_count
self.vms: List[linux_virtual_machine.BaseLinuxVirtualMachine] = []
self.headnode_vm: linux_virtual_machine.BaseLinuxVirtualMachine | None = (
None
)
self.worker_vms: List[linux_virtual_machine.BaseLinuxVirtualMachine] = []
self.name: str = f'pkb{FLAGS.run_uri}'[:10]
self.nfs_path: str = None
def GetResourceMetadata(self):
return {
'zone': self.zone,
'machine_type': self.machine_type,
'worker_machine_type': self.worker_machine_type,
'headnode_machine_type': self.headnode_machine_type,
'image': self.image,
'os_type': self.os_type,
'num_workers': self.num_workers,
'template': self.template,
'unmanaged': self.unmanaged
}
def __repr__(self):
return f'<BaseCluster [name={self.name}]>'
# TODO(yuyanting) Move common logic here after having concrete implementation.
def _RenderClusterConfig(self):
"""Render the config file that will be used to create the cluster."""
pass
def RemoteCommand(
self,
command: str,
ignore_failure: bool = False,
timeout: float | None = None,
**kwargs,
) -> Tuple[str, str]:
"""Runs a command on the VM.
Derived classes may add additional kwargs if necessary, but they should not
be used outside of the class itself since they are non standard.
Args:
command: A valid bash command.
ignore_failure: Ignore any failure if set to true.
timeout: The time to wait in seconds for the command before exiting. None
means no timeout.
**kwargs: Additional command arguments.
Returns:
A tuple of stdout and stderr from running the command.
Raises:
RemoteCommandError: If there was a problem issuing the command.
"""
return self.headnode_vm.RemoteCommand(
f'srun -N {self.num_workers} {command}',
ignore_failure=ignore_failure,
timeout=timeout,
**kwargs,
)
def RobustRemoteCommand(
self,
command: str,
timeout: float | None = None,
ignore_failure: bool = False,
) -> Tuple[str, str]:
"""Runs a command on the VM in a more robust way than RemoteCommand.
The default should be to call RemoteCommand and log that it is not yet
implemented. This function should be overwritten it is decendents.
Args:
command: The command to run.
timeout: The timeout for the command in seconds.
ignore_failure: Ignore any failure if set to true.
Returns:
A tuple of stdout, stderr from running the command.
Raises:
RemoteCommandError: If there was a problem establishing the connection, or
the command fails.
"""
return self.headnode_vm.RobustRemoteCommand(
command, ignore_failure=ignore_failure, timeout=timeout
)
def TryRemoteCommand(self, command: str, **kwargs):
"""Runs a remote command and returns True iff it succeeded."""
try:
self.RemoteCommand(command, **kwargs)
return True
except errors.VirtualMachine.RemoteCommandError:
return False
def InstantiateVm(self, vm_spec):
"""Creates VM object."""
vm_class = virtual_machine.GetVmClass(vm_spec.CLOUD, self.os_type)
return vm_class(vm_spec)
def BackfillVm(
self,
vm_spec: virtual_machine.BaseVmSpec,
fn: Callable[[virtual_machine.BaseVirtualMachine], None],
):
"""Create and backfill a VM object created using cluster resource.
Args:
vm_spec: VM spec to be used to find corresponding VM class.
fn: The function to be called on the newly created VM.
Returns:
The newly created VM object.
"""
vm = self.InstantiateVm(vm_spec)
fn(vm)
vm.disks = []
vm._PostCreate() # pylint: disable=protected-access
vm.created = True
return vm
def AuthenticateVM(self):
"""Authenticate a remote machine to access all vms."""
for vm in self.vms:
vm.AuthenticateVm()
def ExportVmGroupsForUnmanagedProvision(self):
"""Export VmGroups for unmanaged provisioning.
Returns:
Dictionary of VmGroupSpec for provisioning in BenchmarkSpec object.
"""
if not self.unmanaged:
return {}
logging.info('Provisioning cluster resources with unmanaged VM creation.')
return {
'headnode': self.spec.headnode,
'workers': self.spec.workers,
}
def ImportVmGroups(self, headnode, workers):
"""Imports VMGroups from unmanaged provision.
After VMs being created from unmanaged codepath. Add corresponding vm_groups
back to cluster object. So the benchmark hopefully do not care about
how underlying resources being created.
Args:
headnode: VirtualMachine object representing a headnode.
workers: List of VirtualMachine objects representing workers.
"""
self.headnode_vm = headnode
self.worker_vms = workers
self.vms = [self.headnode_vm] + self.worker_vms
def _Create(self):
pass
def _Delete(self):
pass
def Create(self):
if self.unmanaged:
return
super().Create()
def Delete(self):
if self.unmanaged:
return
super().Delete()
@vm_util.Retry(
fuzz=0,
timeout=1800,
max_retries=5,
retryable_exceptions=(errors.Resource.RetryableCreationError,),
)
def _WaitUntilReady(self):
if self.unmanaged:
return
if not self.headnode_vm.TryRemoteCommand(
f'srun -N {self.num_workers} hostname'
):
raise errors.Resource.RetryableCreationError('Cluster not ready.')
Cluster = typing.TypeVar('Cluster', bound=BaseCluster)