perfkitbenchmarker/static_virtual_machine.py (256 lines of code) (raw):
# Copyright 2014 PerfKitBenchmarker Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class to represent a Static Virtual Machine object.
All static VMs provided in a given group will be used before any non-static
VMs are provisioned. For example, in a test that uses 4 VMs, if 3 static VMs
are provided, all of them will be used and one additional non-static VM
will be provisioned. The VM's should be set up with passwordless ssh and
passwordless sudo (neither sshing nor running a sudo command should prompt
the user for a password).
All VM specifics are self-contained and the class provides methods to
operate on the VM: boot, shutdown, etc.
"""
import collections
import json
import logging
import threading
from absl import flags
from perfkitbenchmarker import disk
from perfkitbenchmarker import linux_virtual_machine
from perfkitbenchmarker import os_types
from perfkitbenchmarker import virtual_machine
from perfkitbenchmarker import windows_virtual_machine
FLAGS = flags.FLAGS
flags.DEFINE_list(
'static_vm_tags',
None,
'The tags of static VMs for PKB to run with. Even if other '
"VMs are specified in a config, if they aren't in this list "
'they will be skipped during VM creation.',
)
flags.DEFINE_bool(
'copy_ssh_private_keys_into_static_vms',
False,
'A flag to allow the VM to copy ssh private key to '
'authenticate static VMs.',
)
class StaticVmSpec(virtual_machine.BaseVmSpec):
"""Object containing all info needed to create a Static VM."""
CLOUD = 'Static'
def __init__(
self,
component_full_name,
ip_address=None,
user_name=None,
ssh_private_key=None,
internal_ip=None,
ssh_port=22,
password=None,
disk_specs=None,
os_type=None,
tag=None,
zone=None,
**kwargs
):
"""Initialize the StaticVmSpec object.
Args:
component_full_name: string. Fully qualified name of the configurable
component containing the config options.
ip_address: The public ip address of the VM.
user_name: The username of the VM that the keyfile corresponds to.
ssh_private_key: The absolute path to the private keyfile to use to ssh to
the VM.
internal_ip: The internal ip address of the VM.
ssh_port: The port number to use for SSH and SCP commands.
password: The password used to log into the VM (Windows Only).
disk_specs: None or a list of dictionaries containing kwargs used to
create disk.BaseDiskSpecs.
os_type: The OS type of the VM. See the flag of the same name for more
information.
tag: A string that allows the VM to be included or excluded from a run by
using the 'static_vm_tags' flag.
zone: The VM's zone.
**kwargs: Other args for the superclass.
"""
super().__init__(component_full_name, **kwargs)
self.ip_address = ip_address
self.user_name = user_name
self.ssh_private_key = ssh_private_key
self.internal_ip = internal_ip
self.ssh_port = ssh_port
self.password = password
self.os_type = os_type
self.tag = tag
self.zone = zone
self.disk_specs = [
disk.BaseDiskSpec(
'{}.disk_specs[{}]'.format(component_full_name, i),
flag_values=kwargs.get('flag_values'),
**disk_spec
)
for i, disk_spec in enumerate(disk_specs or ())
]
class StaticDisk(disk.BaseDisk):
"""Object representing a static Disk."""
def _Create(self):
"""StaticDisks don't implement _Create()."""
pass
def _Delete(self):
"""StaticDisks don't implement _Delete()."""
pass
def Attach(self, vm):
"""StaticDisks don't implement Attach()."""
pass
def Detach(self):
"""StaticDisks don't implement Detach()."""
pass
class StaticVirtualMachine(virtual_machine.BaseVirtualMachine):
"""Object representing a Static Virtual Machine."""
CLOUD = 'Static'
is_static = True
vm_pool = collections.deque()
vm_pool_lock = threading.Lock()
def __init__(self, vm_spec):
"""Initialize a static virtual machine.
Args:
vm_spec: A StaticVmSpec object containing arguments.
"""
super().__init__(vm_spec)
self.ip_address = vm_spec.ip_address
self.user_name = vm_spec.user_name
self.ssh_private_key = vm_spec.ssh_private_key
self.internal_ip = vm_spec.internal_ip
self.zone = self.zone or (
'Static - %s@%s' % (self.user_name, self.ip_address)
)
self.ssh_port = vm_spec.ssh_port
self.password = vm_spec.password
self.disk_specs = vm_spec.disk_specs
self.from_pool = False
self.preemptible = False
def _Suspend(self):
"""Suspends VM."""
raise NotImplementedError()
def _Resume(self):
"""Resumes VM."""
raise NotImplementedError()
def _Create(self):
"""StaticVirtualMachines do not implement _Create()."""
pass
# StaticVirtualMachines do not implement _Start or _Stop
def _Start(self):
"""Starts the VM."""
raise NotImplementedError()
def _Stop(self):
"""Stops the VM."""
raise NotImplementedError()
def _Delete(self):
"""Returns the virtual machine to the pool."""
if self.from_pool:
with self.vm_pool_lock:
self.vm_pool.appendleft(self)
def CreateScratchDisk(self, _, disk_spec):
"""Create a VM's scratch disk.
Args:
disk_spec: virtual_machine.BaseDiskSpec object of the disk.
"""
spec = self.disk_specs[len(self.scratch_disks)]
self.scratch_disks.append(StaticDisk(spec))
def DeleteScratchDisks(self):
"""StaticVirtualMachines do not delete scratch disks."""
pass
@classmethod
def ReadStaticVirtualMachineFile(cls, file_obj):
"""Read a file describing the static VMs to use.
This function will read the static VM information from the provided file,
instantiate VMs corresponding to the info, and add the VMs to the static
VM pool. The provided file should contain a single array in JSON-format.
Each element in the array must be an object with required format:
ip_address: string.
user_name: string.
keyfile_path: string.
ssh_port: integer, optional. Default 22
internal_ip: string, optional.
zone: string, optional.
local_disks: array of strings, optional.
scratch_disk_mountpoints: array of strings, optional
os_type: string, optional (see package_managers)
install_packages: bool, optional
Args:
file_obj: An open handle to a file containing the static VM info.
Raises:
ValueError: On missing required keys, or invalid keys.
"""
vm_arr = json.load(file_obj)
if not isinstance(vm_arr, list):
raise ValueError(
'Invalid static VM file. Expected array, got: %s.' % type(vm_arr)
)
required_keys = frozenset(['ip_address', 'user_name'])
linux_required_keys = required_keys | frozenset(['keyfile_path'])
required_keys_by_os = {
os_types.WINDOWS: required_keys | frozenset(['password']),
os_types.DEBIAN: linux_required_keys,
os_types.RHEL: linux_required_keys,
os_types.CLEAR: linux_required_keys,
}
# assume linux_required_keys for unknown os_type
required_keys = required_keys_by_os.get(FLAGS.os_type, linux_required_keys)
optional_keys = frozenset([
'internal_ip',
'zone',
'local_disks',
'scratch_disk_mountpoints',
'os_type',
'ssh_port',
'install_packages',
])
allowed_keys = required_keys | optional_keys
def VerifyItemFormat(item):
"""Verify that the decoded JSON object matches the required schema."""
item_keys = frozenset(item)
extra_keys = sorted(item_keys - allowed_keys)
missing_keys = required_keys - item_keys
if extra_keys:
raise ValueError('Unexpected keys: {}'.format(', '.join(extra_keys)))
elif missing_keys:
raise ValueError(
'Missing required keys: {}'.format(', '.join(missing_keys))
)
for item in vm_arr:
VerifyItemFormat(item)
ip_address = item['ip_address']
user_name = item['user_name']
keyfile_path = item.get('keyfile_path')
internal_ip = item.get('internal_ip')
zone = item.get('zone')
local_disks = item.get('local_disks', [])
password = item.get('password')
if not isinstance(local_disks, list):
raise ValueError(
'Expected a list of local disks, got: {}'.format(local_disks)
)
scratch_disk_mountpoints = item.get('scratch_disk_mountpoints', [])
if not isinstance(scratch_disk_mountpoints, list):
raise ValueError(
'Expected a list of disk mount points, got: {}'.format(
scratch_disk_mountpoints
)
)
ssh_port = item.get('ssh_port', 22)
os_type = item.get('os_type')
install_packages = item.get('install_packages', True)
if (
os_type == os_types.WINDOWS and FLAGS.os_type != os_types.WINDOWS
) or (os_type != os_types.WINDOWS and FLAGS.os_type == os_types.WINDOWS):
raise ValueError(
'Please only use Windows VMs when using '
'--os_type=windows and vice versa.'
)
disk_kwargs_list = []
for path in scratch_disk_mountpoints:
disk_kwargs_list.append({'mount_point': path})
for local_disk in local_disks:
disk_kwargs_list.append({'device_path': local_disk})
vm_spec = StaticVmSpec(
'static_vm_file',
ip_address=ip_address,
user_name=user_name,
ssh_port=ssh_port,
install_packages=install_packages,
ssh_private_key=keyfile_path,
internal_ip=internal_ip,
zone=zone,
disk_specs=disk_kwargs_list,
password=password,
flag_values=flags.FLAGS,
)
vm_class = GetStaticVmClass(os_type)
vm = vm_class(vm_spec) # pytype: disable=not-instantiable
cls.vm_pool.append(vm)
@classmethod
def GetStaticVirtualMachine(cls):
"""Pull a Static VM from the pool of static VMs.
If there are no VMs left in the pool, the method will return None.
Returns:
A static VM from the pool, or None if there are no static VMs left.
"""
with cls.vm_pool_lock:
if cls.vm_pool:
vm = cls.vm_pool.popleft()
vm.Create()
vm.from_pool = True
return vm
else:
return None
def GetStaticVmClass(os_type) -> type[virtual_machine.BaseVirtualMachine]:
"""Returns the static VM class that corresponds to the os_type."""
if not os_type:
os_type = os_types.DEFAULT
logging.warning('Could not find os type for VM. Defaulting to %s.', os_type)
return virtual_machine.GetVmClass(
StaticVirtualMachine.CLOUD,
os_type,
)
class Ubuntu2004BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Ubuntu2004Mixin
):
pass
class Ubuntu2204BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Ubuntu2204Mixin
):
pass
class Ubuntu2404BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Ubuntu2404Mixin
):
pass
class ClearBasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.ClearMixin
):
pass
class Rhel8BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Rhel8Mixin
):
pass
class Rhel9BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Rhel9Mixin
):
pass
class Fedora36BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Fedora36Mixin
):
pass
class Fedora37BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Fedora37Mixin
):
pass
class Debian11BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Debian11Mixin
):
pass
class Debian12BasedStaticVirtualMachine(
StaticVirtualMachine, linux_virtual_machine.Debian12Mixin
):
pass
class Windows2019SQLServer2019StandardStaticVirtualMachine(
StaticVirtualMachine,
windows_virtual_machine.Windows2019SQLServer2019Standard,
):
pass