perfkitbenchmarker/providers/gcp/vertex_ai.py (718 lines of code) (raw):

"""Implementation of a model & endpoint in Vertex AI. Uses gcloud python libraries to manage those resources. One time setup of service account: - We assume the existence of a "{PROJECT_NUMBER}-compute@developer.gserviceaccount.com" service account with the required permissions. - Follow instructions from https://cloud.google.com/vertex-ai/docs/general/custom-service-account to create it & give permissions if one doesn't exist. """ import json import logging import os import re import time from typing import Any from absl import flags # pylint: disable=g-import-not-at-top, g-statement-before-imports # External needs from google.cloud. # pytype: disable=module-attr try: from google.cloud.aiplatform import aiplatform except ImportError: from google.cloud import aiplatform from google.api_core import exceptions as google_exceptions from perfkitbenchmarker import errors from perfkitbenchmarker import resource from perfkitbenchmarker import sample from perfkitbenchmarker import virtual_machine from perfkitbenchmarker import vm_util from perfkitbenchmarker.providers.gcp import flags as gcp_flags from perfkitbenchmarker.providers.gcp import gcs from perfkitbenchmarker.providers.gcp import util from perfkitbenchmarker.resources import managed_ai_model from perfkitbenchmarker.resources import managed_ai_model_spec FLAGS = flags.FLAGS CLI = 'CLI' MODEL_GARDEN_CLI = 'MODEL-GARDEN-CLI' SDK = 'SDK' SERVICE_ACCOUNT_BASE = '{}-compute@developer.gserviceaccount.com' _MODEL_DEPLOY_TIMEOUT = 60 * 60 # 1 hour class BaseVertexAiModel(managed_ai_model.BaseManagedAiModel): """Represents a Vertex AI model in the model registry. Attributes: model_name: The official name of the model in Model Garden, e.g. Llama2. name: The name of the created model in private model registry. model_resource_name: The full resource name of the created model, e.g. projects/123/locations/us-east1/models/1234. region: The region, derived from the zone. project: The project. endpoint: The PKB resource endpoint the model is deployed to. service_account: Name of the service account used by the model. model_deploy_time: Time it took to deploy the model. model_upload_time: Time it took to upload the model. vm: A way to run commands on the machine. json_write_times: List of times it took to write the json request to disk. json_cache: Cache from request JSON -> JSON request file. gcs_bucket_copy_time: Time it took to copy the model to the GCS bucket. gcs_client: The GCS client used to copy the model to the GCS bucket. Only instantiated if ai_create_bucket flag is True. bucket_uri: The GCS bucket where the model is stored. model_bucket_path: Where the model bucket is located. staging_bucket: The staging bucket used by the model. """ CLOUD: str = 'GCP' INTERFACE: list[str] | str = [CLI, MODEL_GARDEN_CLI] endpoint: 'BaseVertexAiEndpoint' model_spec: 'VertexAiModelSpec' model_name: str name: str region: str project: str service_account: str model_resource_name: str | None model_deploy_time: float | None model_upload_time: float | None json_write_times: list[float] json_cache: dict[str, str] gcs_bucket_copy_time: float | None gcs_client: gcs.GoogleCloudStorageService | None bucket_uri: str model_bucket_path: str staging_bucket: str def __init__( self, vm: virtual_machine.BaseVirtualMachine, model_spec: managed_ai_model_spec.BaseManagedAiModelSpec, name: str | None = None, bucket_uri: str | None = None, **kwargs, ): super().__init__(model_spec, vm, **kwargs) if not isinstance(model_spec, VertexAiModelSpec): raise errors.Config.InvalidValue( f'Invalid model spec class: "{model_spec.__class__.__name__}". ' 'Must be a VertexAiModelSpec. It had config values of ' f'{model_spec.model_name} & {model_spec.cloud}' ) self.model_spec = model_spec self.model_name = model_spec.model_name self.model_resource_name = None if name: self.name = name else: self.name = 'pkb' + FLAGS.run_uri self.project = FLAGS.project self.endpoint = self._CreateEndpoint() if not self.project: raise errors.Setup.InvalidConfigurationError( 'Project is required for Vertex AI but was not set.' ) self.metadata.update({ 'name': self.name, 'model_name': self.model_name, 'model_size': self.model_spec.model_size, 'machine_type': self.model_spec.machine_type, 'accelerator_type': self.model_spec.accelerator_type, 'accelerator_count': self.model_spec.accelerator_count, }) project_number = util.GetProjectNumber(self.project) self.service_account = SERVICE_ACCOUNT_BASE.format(project_number) self.model_upload_time = None self.model_deploy_time = None self.json_write_times = [] self.json_cache = {} self.gcs_client = None if bucket_uri is not None: logging.warning('bucket_uri %s ', bucket_uri) self.bucket_uri = bucket_uri elif gcp_flags.AI_BUCKET_URI.value is not None: self.bucket_uri = gcp_flags.AI_BUCKET_URI.value else: self.gcs_client = gcs.GoogleCloudStorageService() self.gcs_client.PrepareService(self.region) self.bucket_uri = f'{self.project}-{self.region}-tmp-{self.name}' self.model_bucket_path = 'gs://' + os.path.join( self.bucket_uri, self.model_spec.model_bucket_suffix ) self.staging_bucket = 'gs://' + os.path.join(self.bucket_uri, 'temporal') self.gcs_bucket_copy_time = None def _InitializeNewModel(self) -> 'BaseVertexAiModel': """Returns a new instance of the same class.""" return self.__class__( vm=self.vm, model_spec=self.model_spec, name=self.name + '2', # Reuse the same bucket for the next model. bucket_uri=self.bucket_uri, ) # pytype: disable=not-instantiable def GetRegionFromZone(self, zone: str) -> str: return util.GetRegionFromZone(zone) def ListExistingEndpoints(self, region: str | None = None) -> list[str]: """Returns a list of existing model endpoint ids in the same region.""" if region is None: region = self.region # Expected output example: # ENDPOINT_ID DISPLAY_NAME # 12345 some_endpoint_name out, _, _ = self.vm.RunCommand( f'gcloud ai endpoints list --region={region} --project={self.project}' ) lines = out.splitlines() if not lines: return [] ids = [line.split()[0] for line in lines] ids.pop(0) # Remove the first line which just has titles return ids def GetSamples(self) -> list[sample.Sample]: """Gets samples relating to the provisioning of the resource.""" samples = super().GetSamples() metadata = self.GetResourceMetadata() if self.model_upload_time: samples.append( sample.Sample( 'Model Upload Time', self.model_upload_time, 'seconds', metadata, ) ) if self.model_deploy_time: samples.append( sample.Sample( 'Model Deploy Time', self.model_deploy_time, 'seconds', metadata, ) ) if self.json_write_times: samples.append( sample.Sample( 'Max JSON Write Time', max(self.json_write_times), 'seconds', metadata, ) ) if self.gcs_bucket_copy_time: samples.append( sample.Sample( 'GCS Bucket Copy Time', self.gcs_bucket_copy_time, 'seconds', metadata, ) ) return samples def _CreateEndpoint(self) -> 'BaseVertexAiEndpoint': """Creates the correct endpoint.""" raise NotImplementedError( '_CreateEndpoint is not implemented for this model type.' ) def _Create(self) -> None: """Creates the underlying resource.""" start_model_upload = time.time() self._UploadModel() end_model_upload = time.time() self.model_upload_time = end_model_upload - start_model_upload logging.info( 'Model resource uploaded with name: %s in %s seconds', self.model_resource_name, self.model_upload_time, ) start_model_deploy = time.time() self._DeployModel() end_model_deploy = time.time() self.model_deploy_time = end_model_deploy - start_model_deploy logging.info( 'Successfully deployed model in %s seconds', self.model_deploy_time ) def _UploadModel(self) -> None: """Uploads the model to the model registry.""" raise NotImplementedError( '_UploadModel is not implemented for this model type.' ) def _DeployModel(self) -> None: """Deploys the model to the endpoint.""" raise NotImplementedError( '_DeployModel is not implemented for this model type.' ) def _CreateDependencies(self): """Creates the endpoint & copies the model to a bucket.""" super()._CreateDependencies() if self.gcs_client: gcs_bucket_copy_start_time = time.time() self.gcs_client.MakeBucket( self.bucket_uri ) # pytype: disable=attribute-error self.gcs_client.Copy( self.model_spec.model_garden_bucket, self.model_bucket_path, recursive=True, timeout=60 * 40, ) # pytype: disable=attribute-error self.gcs_bucket_copy_time = time.time() - gcs_bucket_copy_start_time self.endpoint.Create() def Delete(self, freeze: bool = False) -> None: """Deletes the underlying resource & its dependencies.""" # Normally _DeleteDependencies is called by parent after _Delete, but we # need to call it before. self._DeleteDependencies() super().Delete(freeze) def _DeleteDependencies(self): super()._DeleteDependencies() self.endpoint.Delete() if self.gcs_client: self.gcs_client.DeleteBucket( self.bucket_uri ) # pytype: disable=attribute-error class BaseCliVertexAiModel(BaseVertexAiModel): """Vertex AI model with shared code between CLI & Model Garden CLI.""" def _CreateEndpoint(self) -> 'BaseVertexAiEndpoint': return VertexAiCliEndpoint( name=self.name, region=self.region, project=self.project, vm=self.vm, ) def _SendPrompt( self, prompt: str, max_tokens: int, temperature: float, **kwargs: Any ) -> list[str]: """Sends a prompt to the model and returns the response.""" out, _, _ = self.vm.RunCommand( self.GetPromptCommand(prompt, max_tokens, temperature, **kwargs), ) responses = out.strip('[]').split(',') return responses def GetPromptCommand( self, prompt: str, max_tokens: int, temperature: float, **kwargs: Any ) -> str: """Returns the command to send a prompt to the model.""" instances = self.model_spec.ConvertToInstances( prompt, max_tokens, temperature, **kwargs ) instances_dict = {'instances': instances, 'parameters': {}} start_write_time = time.time() json_dump = json.dumps(instances_dict) if json_dump in self.json_cache: name = self.json_cache[json_dump] else: name = self.vm.WriteTemporaryFile(json_dump) self.json_cache[json_dump] = name end_write_time = time.time() write_time = end_write_time - start_write_time self.json_write_times.append(write_time) return ( 'gcloud ai endpoints predict' f' {self.endpoint.endpoint_name} --json-request={name}' ) def _Delete(self) -> None: """Deletes the underlying resource.""" logging.info('Deleting the resource: %s.', self.model_name) self.vm.RunCommand( 'gcloud ai models delete' f' {self.model_resource_name} --region={self.region} --project={self.project}' ) class CliVertexAiModel(BaseCliVertexAiModel): """Vertex AI model using CLI ai & endpoint interface.""" INTERFACE: str = CLI def _UploadModel(self) -> None: """Uploads the model via gcloud command.""" upload_cmd = ( f'gcloud ai models upload --display-name={self.name}' f' --project={self.project} --region={self.region}' f' --artifact-uri={self.model_bucket_path}' ) if util.GetDefaultTags(): upload_cmd += f' --labels={util.MakeFormattedDefaultTags()}' upload_cmd += self.model_spec.GetModelUploadCliArgs( model_bucket_path=self.model_bucket_path ) self.vm.RunCommand(upload_cmd) out, _, _ = self.vm.RunCommand( f'gcloud ai models list --project={self.project} --region={self.region}' ) lines = out.splitlines() for line in lines: pieces = line.split() if len(pieces) != 2: continue if pieces[1] == self.name: self.model_resource_name = pieces[0] logging.info( 'Model resource with name %s uploaded & found with model id %s', self.name, self.model_resource_name, ) return if not self.model_resource_name: raise errors.Resource.CreationError( 'Could not find model resource with name %s' % self.name ) def _DeployModel(self): """Deploys the model to the endpoint via gcloud command.""" accelerator_type = self.model_spec.accelerator_type.lower() accelerator_type = accelerator_type.replace('_', '-') _, err, code = self.vm.RunCommand( f'gcloud ai endpoints deploy-model {self.endpoint.endpoint_name}' f' --model={self.model_resource_name} --region={self.region}' f' --project={self.project} --display-name={self.name}' f' --machine-type={self.model_spec.machine_type}' f' --accelerator=type={accelerator_type},count={self.model_spec.accelerator_count}' f' --service-account={self.service_account}' f' --max-replica-count={self.max_scaling}', ignore_failure=True, timeout=_MODEL_DEPLOY_TIMEOUT, ) if code: if ( 'The operations may still be underway remotely and may still succeed' in err ): @vm_util.Retry( poll_interval=self.POLL_INTERVAL, fuzz=0, timeout=self.READY_TIMEOUT, retryable_exceptions=(errors.Resource.RetryableCreationError,), ) def WaitUntilReady(): if not self._IsReady(): raise errors.Resource.RetryableCreationError('Not yet ready') WaitUntilReady() elif 'Machine type temporarily unavailable' in err: raise errors.Benchmarks.QuotaFailure(err) else: raise errors.VmUtil.IssueCommandError(err) class ModelGardenCliVertexAiModel(BaseCliVertexAiModel): """Vertex AI model created via Model Garden CLI.""" INTERFACE: str = MODEL_GARDEN_CLI def __init__( self, vm: virtual_machine.BaseVirtualMachine, model_spec: managed_ai_model_spec.BaseManagedAiModelSpec, name: str | None = None, bucket_uri: str | None = None, **kwargs, ): super().__init__(vm, model_spec, name, bucket_uri, **kwargs) # GCS client is not needed by Model Garden CLI. self.gcs_client = None def _Create(self) -> None: """Creates the underlying resource.""" deploy_start_time = time.time() deploy_cmd = ( 'gcloud beta ai model-garden models deploy' f' --model={self.model_spec.GetModelGardenName()}' f' --endpoint-display-name={self.name}' f' --project={self.project} --region={self.region}' f' --machine-type={self.model_spec.machine_type}' ) _, err_out, _ = self.vm.RunCommand( deploy_cmd, timeout=_MODEL_DEPLOY_TIMEOUT ) deploy_end_time = time.time() self.model_deploy_time = deploy_end_time - deploy_start_time operation_id = _FindRegexInOutput( err_out, r'gcloud ai operations describe (.*) --region', errors.Resource.CreationError, ) out, _, _ = self.vm.RunCommand( 'gcloud ai operations describe' f' {operation_id} --project={self.project} --region={self.region}' ) # Only get the model id, not the full resource name. self.model_resource_name = _FindRegexInOutput( out, r'model:' rf' projects/(.*)/locations/{self.region}/models/([^(@|\n)]*)(@|\n)', exception_type=errors.Resource.CreationError, group_index=2, ) self.endpoint.endpoint_name = _FindRegexInOutput( out, r'endpoint: (.*)\n', exception_type=errors.Resource.CreationError, ) logging.info( 'Model resource with name %s deployed & found with model id %s &' ' endpoint id %s', self.name, self.model_resource_name, self.endpoint.endpoint_name, ) def _PostCreate(self): super()._PostCreate() self.endpoint.UpdateLabels() def _CreateDependencies(self): """Does not create any dependencies. Skips the direct parent call (and its creating of the endpoint), but still calls the grandparent managed_ai_model._CreateDependencies to add metadata. """ super(BaseVertexAiModel, self)._CreateDependencies() class VertexAiPythonSdkModel(BaseVertexAiModel): """Vertex AI model managed via python SDK. Attributes: gcloud_model: Representation of the model in gcloud python library. """ INTERFACE: str = SDK endpoint: 'VertexAiSdkEndpoint' gcloud_model: aiplatform.Model def _CreateEndpoint(self) -> 'VertexAiSdkEndpoint': return VertexAiSdkEndpoint( name=self.name, region=self.region, project=self.project, vm=self.vm, ) def _SendPrompt( self, prompt: str, max_tokens: int, temperature: float, **kwargs: Any ) -> list[str]: """Sends a prompt to the model and returns the response.""" instances = self.model_spec.ConvertToInstances( prompt, max_tokens, temperature, **kwargs ) assert self.endpoint.ai_endpoint response = self.endpoint.ai_endpoint.predict(instances=instances) # pytype: disable=attribute-error str_responses = [str(response) for response in response.predictions] return str_responses def _UploadModel(self): env_vars = self.model_spec.GetEnvironmentVariables( model_bucket_path=self.model_bucket_path ) logging.info( 'Uploading ai model %s with env vars %s', self.model_name, env_vars ) self.gcloud_model = aiplatform.Model.upload( display_name=self.name, serving_container_image_uri=self.model_spec.container_image_uri, serving_container_command=self.model_spec.serving_container_command, serving_container_args=self.model_spec.serving_container_args, serving_container_ports=self.model_spec.serving_container_ports, serving_container_predict_route=self.model_spec.serving_container_predict_route, serving_container_health_route=self.model_spec.serving_container_health_route, serving_container_environment_variables=env_vars, artifact_uri=self.model_bucket_path, labels=util.GetDefaultTags(), ) self.model_resource_name = self.gcloud_model.resource_name def _DeployModel(self): try: assert self.gcloud_model self.gcloud_model.deploy( endpoint=self.endpoint.ai_endpoint, machine_type=self.model_spec.machine_type, accelerator_type=self.model_spec.accelerator_type, accelerator_count=self.model_spec.accelerator_count, deploy_request_timeout=1800, max_replica_count=self.max_scaling, ) except google_exceptions.ServiceUnavailable as ex: logging.info('Tried to deploy model but got unavailable error %s', ex) raise errors.Benchmarks.QuotaFailure(ex) def _CreateDependencies(self): """Creates the endpoint & copies the model to a bucket.""" aiplatform.init( project=self.project, location=self.region, staging_bucket=self.staging_bucket, service_account=self.service_account, ) super()._CreateDependencies() def Delete(self, freeze: bool = False) -> None: """Deletes the underlying resource & its dependencies.""" # Normally _DeleteDependencies is called by parent after _Delete, but we # need to call it before. self._DeleteDependencies() super().Delete(freeze) def _Delete(self) -> None: """Deletes the underlying resource.""" logging.info('Deleting the resource: %s.', self.model_name) assert self.gcloud_model self.gcloud_model.delete() def __getstate__(self): """Override pickling as the AI platform objects are not picklable.""" to_pickle_dict = { 'name': self.name, 'model_name': self.model_name, 'model_bucket_path': self.model_bucket_path, 'region': self.region, 'project': self.project, 'service_account': self.service_account, 'model_upload_time': self.model_upload_time, 'model_deploy_time': self.model_deploy_time, 'model_spec': self.model_spec, } return to_pickle_dict def __setstate__(self, pickled_dict): """Override pickling as the AI platform objects are not picklable.""" self.name = pickled_dict['name'] self.model_name = pickled_dict['model_name'] self.model_bucket_path = pickled_dict['model_bucket_path'] self.region = pickled_dict['region'] self.project = pickled_dict['project'] self.service_account = pickled_dict['service_account'] self.model_upload_time = pickled_dict['model_upload_time'] self.model_deploy_time = pickled_dict['model_deploy_time'] self.model_spec = pickled_dict['model_spec'] class BaseVertexAiEndpoint(resource.BaseResource): """Represents a Vertex AI endpoint independent of interface. Attributes: name: The name of the endpoint. project: The project. region: The region, derived from the zone. endpoint_name: The full resource name of the created endpoint, e.g. projects/123/locations/us-east1/endpoints/1234. """ def __init__( self, name: str, project: str, region: str, vm: virtual_machine.BaseVirtualMachine, **kwargs, ): super().__init__(**kwargs) self.name = name self.project = project self.region = region self.vm = vm self.endpoint_name = None def UpdateLabels(self) -> None: """Updates the labels of the endpoint.""" pass class VertexAiCliEndpoint(BaseVertexAiEndpoint): """Vertex AI endpoint managed via gcloud CLI.""" def _Create(self) -> None: """Creates the underlying resource.""" logging.info('Creating the endpoint: %s.', self.name) _, err, _ = self.vm.RunCommand( f'gcloud ai endpoints create --display-name={self.name}-endpoint' f' --project={self.project} --region={self.region}' f' --labels={util.MakeFormattedDefaultTags()}', ignore_failure=True, ) self.endpoint_name = _FindRegexInOutput( err, r'Created Vertex AI endpoint: (.+)\.' ) if not self.endpoint_name: raise errors.VmUtil.IssueCommandError( f'Could not find endpoint name in output {err}.' ) logging.info('Successfully created endpoint %s', self.endpoint_name) def _Delete(self) -> None: """Deletes the underlying resource.""" logging.info('Deleting the endpoint: %s.', self.name) out, _, _ = self.vm.RunCommand( f'gcloud ai endpoints describe {self.endpoint_name}', ) model_id = _FindRegexInOutput(out, r' id: \'(.+)\'') if model_id: self.vm.RunCommand( 'gcloud ai endpoints undeploy-model' f' {self.endpoint_name} --deployed-model-id={model_id} --quiet', ) else: if 'deployedModels:' not in out: logging.info( 'No deployed models found; perhaps they failed to deploy or were' ' already deleted?' ) else: raise errors.VmUtil.IssueCommandError( 'Found deployed models but Could not find model id in' f' output.\n{out}' ) self.vm.RunCommand( f'gcloud ai endpoints delete {self.endpoint_name} --quiet' ) def UpdateLabels(self) -> None: """Updates the labels of the endpoint.""" self.vm.RunCommand( f'gcloud ai endpoints update {self.endpoint_name} ' f' --project={self.project} --region={self.region}' f' --update-labels={util.MakeFormattedDefaultTags()}', ) class VertexAiSdkEndpoint(BaseVertexAiEndpoint): """Represents a Vertex AI endpoint managed by the SDK. Attributes: ai_endpoint: The AIPlatform object representing the endpoint. """ INTERFACE = SDK ai_endpoint: aiplatform.Endpoint | None def _Create(self) -> None: """Creates the underlying resource.""" logging.info('Creating the endpoint: %s.', self.name) self.ai_endpoint = aiplatform.Endpoint.create( display_name=f'{self.name}-endpoint' ) def _Delete(self) -> None: """Deletes the underlying resource.""" logging.info('Deleting the endpoint: %s.', self.name) assert self.ai_endpoint self.ai_endpoint.delete(force=True) self.ai_endpoint = None # Object is not picklable - none it out def _FindRegexInOutput( output: str, regex: str, exception_type: type[errors.Error] | None = None, group_index: int = 1, ) -> str | None: """Returns the 1st match of the regex in the output. Args: output: The output to search. regex: The regex to search for. exception_type: The exception type to raise if no match is found. group_index: If there are multiple groups in the regex, which one to return. """ matches = re.search(regex, output) if not matches: if exception_type: raise exception_type( f'Could not find match for regex {regex} in output {output}.' ) return None return matches.group(group_index) class VertexAiModelSpec(managed_ai_model_spec.BaseManagedAiModelSpec): """Spec for a Vertex AI model. Attributes: env_vars: Environment variables set on the node. serving_container_command: Command run on container to start the model. serving_container_args: The arguments passed to container create. serving_container_ports: The ports to expose for the model. serving_container_predict_route: The route to use for prediction requests. serving_container_health_route: The route to use for health checks. machine_type: The machine type for model's cluster. accelerator_type: The type of the GPU/TPU. model_bucket_suffix: Suffix with the particular version of the model (eg 7b) model_garden_bucket: The bucket in Model Garden to copy from. """ CLOUD = 'GCP' def __init__(self, component_full_name, flag_values=None, **kwargs): super().__init__(component_full_name, flag_values=flag_values, **kwargs) # The pre-built serving docker images. self.container_image_uri: str self.model_bucket_suffix: str self.model_garden_bucket: str self.serving_container_command: list[str] self.serving_container_args: list[str] self.serving_container_ports: list[int] self.serving_container_predict_route: str self.serving_container_health_route: str self.machine_type: str self.accelerator_count: int self.accelerator_type: str def GetModelUploadCliArgs(self, **input_args) -> str: """Returns the kwargs needed to upload the model.""" env_vars = self.GetEnvironmentVariables(**input_args) env_vars_str = ','.join(f'{key}={value}' for key, value in env_vars.items()) ports_str = ','.join(str(port) for port in self.serving_container_ports) return ( f' --container-image-uri={self.container_image_uri}' f' --container-command={",".join(self.serving_container_command)}' f' --container-args={",".join(self.serving_container_args)}' f' --container-ports={ports_str}' f' --container-predict-route={self.serving_container_predict_route}' f' --container-health-route={self.serving_container_health_route}' f' --container-env-vars={env_vars_str}' ) def GetModelDeployKwargs(self) -> dict[str, Any]: """Returns the kwargs needed to deploy the model.""" return { 'machine_type': self.machine_type, 'accelerator_type': self.accelerator_type, 'accelerator_count': self.accelerator_count, } def GetEnvironmentVariables(self, **kwargs) -> dict[str, str]: """Returns container's environment variables needed by Llama2.""" return { 'MODEL_ID': kwargs['model_bucket_path'], 'DEPLOY_SOURCE': 'pkb', } def ConvertToInstances( self, prompt: str, max_tokens: int, temperature: float, **kwargs: Any ) -> list[dict[str, Any]]: """Converts input to the form expected by the model.""" instances = { 'prompt': prompt, 'max_tokens': max_tokens, 'temperature': temperature, } for params in ['top_p', 'top_k', 'raw_response']: if params in kwargs: instances[params] = kwargs[params] return [instances] def GetModelGardenName(self) -> str: """Returns the name of the model in Model Garden.""" return '' class VertexAiLlama2Spec(VertexAiModelSpec): """Spec for running the Llama2 7b & 70b models.""" MODEL_NAME: str = 'llama2' MODEL_SIZE: list[str] = ['7b', '70b'] VLLM_ARGS = [ '--host=0.0.0.0', '--port=7080', '--swap-space=16', '--gpu-memory-utilization=0.9', '--max-model-len=1024', '--max-num-batched-tokens=4096', ] def __init__(self, component_full_name, flag_values=None, **kwargs): super().__init__(component_full_name, flag_values=flag_values, **kwargs) # The pre-built serving docker images. self.container_image_uri = 'us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20240222_0916_RC00' self.serving_container_command = [ 'python', '-m', 'vllm.entrypoints.api_server', ] size_suffix = os.path.join('llama2', f'llama2-{self.model_size}-hf') self.model_garden_bucket = os.path.join( 'gs://vertex-model-garden-public-us-central1', size_suffix ) self.model_bucket_suffix = size_suffix self.serving_container_ports = [7080] self.serving_container_predict_route = '/generate' self.serving_container_health_route = '/ping' # Machine type from deployment notebook: # https://pantheon.corp.google.com/vertex-ai/colab/notebooks?e=13802955 if self.model_size == '7b': self.machine_type = 'g2-standard-12' self.accelerator_count = 1 else: self.machine_type = 'g2-standard-96' self.accelerator_count = 8 self.accelerator_type = 'NVIDIA_L4' self.serving_container_args = self.VLLM_ARGS.copy() self.serving_container_args.append( f'--tensor-parallel-size={self.accelerator_count}' ) def GetModelGardenName(self) -> str: """Returns the name of the model in Model Garden.""" return f'meta/llama2@llama-2-{self.model_size}' class VertexAiLlama3Spec(VertexAiModelSpec): """Spec for running the Llama3 8b & 70b model.""" MODEL_NAME: str = 'llama3' MODEL_SIZE: list[str] = ['8b', '70b'] VLLM_ARGS = [ '--host=0.0.0.0', '--port=8080', '--swap-space=16', '--gpu-memory-utilization=0.9', '--max-model-len=1024', '--dtype=auto', '--max-loras=1', '--max-cpu-loras=8', '--max-num-seqs=256', '--disable-log-stats', ] def __init__(self, component_full_name, flag_values=None, **kwargs): super().__init__(component_full_name, flag_values=flag_values, **kwargs) # The pre-built serving docker images. self.container_image_uri = 'us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241001_0916_RC00' self.serving_container_command = [ 'python', '-m', 'vllm.entrypoints.api_server', ] size_suffix = os.path.join('llama3', 'llama3-8b-hf') self.model_garden_bucket = os.path.join( 'gs://vertex-model-garden-public-us', size_suffix ) self.model_bucket_suffix = size_suffix self.serving_container_ports = [7080] self.serving_container_predict_route = '/generate' self.serving_container_health_route = '/ping' # Machine type from deployment notebook: # https://pantheon.corp.google.com/vertex-ai/publishers/meta/model-garden/llama3 if self.model_size == '8b': self.machine_type = 'g2-standard-12' self.accelerator_count = 1 else: self.machine_type = 'g2-standard-96' self.accelerator_count = 8 self.accelerator_type = 'NVIDIA_L4' self.serving_container_args = self.VLLM_ARGS.copy() self.serving_container_args.append( f'--tensor-parallel-size={self.accelerator_count}' ) def GetModelUploadCliArgs(self, **input_args) -> str: """Returns the kwargs needed to upload the model.""" upload_args = super().GetModelUploadCliArgs(**input_args) upload_args += ( f' --container-shared-memory-size-mb={16 * 1024}' # 16GB f' --container-deployment-timeout-seconds={60 * 40}' ) return upload_args def GetModelGardenName(self) -> str: """Returns the name of the model in Model Garden.""" return f'meta/llama3@meta-llama-3-{self.model_size}'