perfkitbenchmarker/linux_benchmarks/ai_model_create_benchmark.py (78 lines of code) (raw):

# Copyright 2024 PerfKitBenchmarker Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Benchmark to measure the creation time of a managed AI Model.""" import logging from typing import Any from absl import flags from perfkitbenchmarker import benchmark_spec as bm_spec from perfkitbenchmarker import configs from perfkitbenchmarker import errors from perfkitbenchmarker import sample from perfkitbenchmarker.resources import managed_ai_model BENCHMARK_NAME = 'ai_model_create' BENCHMARK_CONFIG = """ ai_model_create: description: > Times creation of a managed AI model. ai_model: model_name: 'llama2' model_size: '7b' cloud: 'GCP' vm_groups: clients: vm_spec: *default_dual_core vm_count: 1 flags: gcloud_scopes: cloud-platform """ _CREATE_SECOND_MODEL = flags.DEFINE_boolean( 'create_second_model', False, 'Whether to create & benchmark a second model in addition to the first.', ) _VALIDATE_EXISTING_MODELS = flags.DEFINE_boolean( 'validate_existing_models', False, 'Whether to fail the benchmark if there are other models in the region.', ) def GetConfig(user_config: dict[Any, Any]) -> dict[Any, Any]: """Load and return benchmark config. Args: user_config: user supplied configuration (flags and config file) Returns: loaded benchmark configuration """ return configs.LoadConfig(BENCHMARK_CONFIG, user_config, BENCHMARK_NAME) def Prepare(benchmark_spec: bm_spec.BenchmarkSpec): del benchmark_spec def _ValidateExistingModels( ai_model: managed_ai_model.BaseManagedAiModel, expected_count: int ) -> int: """Validates that no other models are running in the region.""" endpoints = ai_model.ListExistingEndpoints() # The presence of other models in a region changes startup performance. if len(endpoints) != expected_count: message = ( f'Expected {expected_count} model(s) but found all these models:' f' {endpoints}.' ) if _VALIDATE_EXISTING_MODELS.value: raise errors.Benchmarks.KnownIntermittentError(message) else: message += ' Continuing benchmark as validate_existing_models is False.' logging.warning(message) return len(endpoints) def Run(benchmark_spec: bm_spec.BenchmarkSpec) -> list[sample.Sample]: """Run the example benchmark. Args: benchmark_spec: The benchmark specification. Contains all data that is required to run the benchmark. Returns: A list of sample.Sample instances. """ logging.info('Running Run phase & gathering response times for model 1') model1 = benchmark_spec.ai_model assert model1 num_endpoints = _ValidateExistingModels(model1, 1) SendPromptsForModel(model1) if not _CREATE_SECOND_MODEL.value: logging.info('Only benchmarking one model by flag; returning') return [] if num_endpoints != 1: logging.warning( 'Not creating a second model as there were already other models in the' ' region before the first one this benchmark created. Ending benchmark' ' with only one set of results.' ) return [] logging.info('Creating model 2 & gathering response times') model2 = model1.InitializeNewModel() model2.Create() benchmark_spec.resources.append(model2) SendPromptsForModel(model2) # All resource samples gathered by benchmark_spec automatically. return [] def SendPromptsForModel( ai_model: managed_ai_model.BaseManagedAiModel, ): _SendPrompt(ai_model, 'Why do crabs walk sideways?') _SendPrompt(ai_model, 'How can I save more money each month?') def _SendPrompt( ai_model: managed_ai_model.BaseManagedAiModel, prompt: str, ): """Sends a prompt to the model and prints the response.""" responses = ai_model.SendPrompt( prompt=prompt, max_tokens=512, temperature=0.8 ) for response in responses: logging.info('Sent request & got response: %s', response) def Cleanup(benchmark_spec: bm_spec.BenchmarkSpec): """Cleanup resources to their original state. Args: benchmark_spec: The benchmark specification. Contains all data that is required to run the benchmark. """ logging.info('Running Cleanup phase of the benchmark') del benchmark_spec