perfkitbenchmarker/resources/managed_ai_model_spec.py (71 lines of code) (raw):
# Copyright 2024 PerfKitBenchmarker Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spec for a managed AI model resource."""
from perfkitbenchmarker import provider_info
from perfkitbenchmarker.configs import option_decoders
from perfkitbenchmarker.configs import spec
class BaseManagedAiModelSpec(spec.BaseSpec):
"""Spec for a managed AI model resource.
Attributes:
model_name: The name of the model to use.
"""
SPEC_TYPE = 'BaseManagedAiModelSpec'
SPEC_ATTRS = ['CLOUD', 'MODEL_NAME', 'MODEL_SIZE']
CLOUD = None
MODEL_NAME = None
MODEL_SIZE = None
def __init__(self, component_full_name, flag_values=None, **kwargs):
self.cloud: str
self.interface: str
self.model_name: str
self.model_size: str
self.max_scale: int
super().__init__(component_full_name, flag_values=flag_values, **kwargs)
@classmethod
def _ApplyFlags(cls, config_values, flag_values):
"""Modifies config options based on runtime flag values.
Args:
config_values: dict mapping config option names to provided values. May be
modified by this function.
flag_values: flags.FlagValues. Runtime flags that may override the
provided config values.
"""
super()._ApplyFlags(config_values, flag_values)
if flag_values['cloud'].present or 'cloud' not in config_values:
config_values['cloud'] = flag_values.cloud
@classmethod
def _GetOptionDecoderConstructions(cls):
"""Gets decoder classes and constructor args for each configurable option.
Returns:
dict. Maps option name string to a (ConfigOptionDecoder class, dict) pair.
The pair specifies a decoder class and its __init__() keyword arguments
to construct in order to decode the named option.
"""
result = super()._GetOptionDecoderConstructions()
result.update({
'cloud': (
option_decoders.EnumDecoder,
{
'valid_values': provider_info.VALID_CLOUDS,
'default': provider_info.GCP,
},
),
'interface': (
option_decoders.EnumDecoder,
{
'valid_values': ['CLI', 'SDK', 'MODEL-GARDEN-CLI'],
'default': 'CLI',
},
),
'model_name': (
option_decoders.StringDecoder,
{
'none_ok': True,
'default': '',
},
),
'model_size': (
option_decoders.StringDecoder,
{
'none_ok': True,
'default': '',
},
),
'max_scale': (
option_decoders.IntDecoder,
{
'none_ok': True,
'default': 1, # Default to 1 means no scaling.
},
),
})
return result
def GetManagedAiModelSpecClass(
cloud: str, model_name: str, model_size: str
) -> spec.BaseSpecMetaClass | None:
"""Gets the example spec class corresponding to the given attributes."""
return spec.GetSpecClass(
BaseManagedAiModelSpec,
CLOUD=cloud,
MODEL_NAME=model_name,
MODEL_SIZE=model_size,
)