perfkitbenchmarker/providers/aws/aws_jump_start.py (188 lines of code) (raw):
"""Implementation of a model using Jumpstart on AWS Sagemaker.
Uses amazon python library to deploy & manage that model.
One time step:
- Create an Sagemaker execution ARN for your account. See here for details:
https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-geospatial-roles-create-execution-role.html
"""
import json
import logging
import re
from typing import Any
from absl import flags
from perfkitbenchmarker import errors
from perfkitbenchmarker import virtual_machine
from perfkitbenchmarker.providers.aws import util
from perfkitbenchmarker.resources import managed_ai_model
from perfkitbenchmarker.resources import managed_ai_model_spec
FLAGS = flags.FLAGS
EXECUTION_ARN_BASE = 'arn:aws:iam::{account_number}:role/sagemaker-full-access'
# File located at google3/third_party/py/perfkitbenchmarker/scripts/
AWS_RUNNER_SCRIPT = 'aws_jump_start_runner.py'
class JumpStartModelInRegistry(managed_ai_model.BaseManagedAiModel):
"""Represents a Vertex AI model in the model registry.
Attributes:
model_name: The official name of the model in Model Garden, e.g. Llama2.
name: The name of the created model in private model registry.
account_id: The AWS account id.
endpoint_name: The name of the deployed endpoint, if initialized.
execution_arn: The role the model uses to run.
vm: A vm to run commands on.
python_script: The path to the helper python script.
"""
CLOUD = 'AWS'
INTERFACE = ['SDK']
model_spec: 'JumpStartModelSpec'
model_name: str
model_id: str
model_version: str
name: str
account_id: str
endpoint_name: str | None
execution_arn: str
python_script: str
def __init__(
self,
vm: virtual_machine.BaseVirtualMachine,
model_spec: managed_ai_model_spec.BaseManagedAiModelSpec,
**kwargs,
):
super().__init__(model_spec, vm, **kwargs)
if not isinstance(model_spec, JumpStartModelSpec):
raise errors.Config.InvalidValue(
f'Invalid model spec class: "{model_spec.__class__.__name__}". '
'Must be a JumpStartModelSpec. It had config values of '
f'{model_spec.model_name} & {model_spec.cloud}'
)
self.model_spec = model_spec
self.model_name = model_spec.model_name
self.model_id = model_spec.model_id
self.model_version = model_spec.model_version
self.account_id = util.GetAccount()
self.execution_arn = EXECUTION_ARN_BASE.format(
account_number=self.account_id
)
self.endpoint_name = None
self.metadata.update({
'model_name': self.model_name,
'model_size': self.model_spec.model_size,
})
self.python_script = ''
def _InitializeNewModel(self) -> 'JumpStartModelInRegistry':
"""Returns a new instance of the same class."""
return self.__class__(vm=self.vm, model_spec=self.model_spec)
def GetRegionFromZone(self, zone: str) -> str:
return util.GetRegionFromZone(zone)
def ListExistingEndpoints(self, region: str | None = None) -> list[str]:
"""Returns list of endpoint names."""
if region is None:
region = self.region
out, _, _ = self.vm.RunCommand(
['aws', 'sagemaker', 'list-endpoints', f'--region={region}']
)
out_json = json.loads(out)
json_endpoints = out_json['Endpoints']
endpoints = []
for json_endpoint in json_endpoints:
endpoints.append(json_endpoint['EndpointName'])
return endpoints
def _RunPythonScript(self, args: list[str]) -> tuple[str, str]:
"""Calls the on-client-vm python script with appropriate arguments.
We do this rather than just run the python code in this file to avoid
importing the AWS libraries.
Args:
args: Additional arguments for the python script.
Returns:
Tuple of [stdout, stderr].
"""
out, err, _ = self.vm.RunCommand(
self._GetPythonScriptCommand(args),
raise_on_failure=False,
timeout=60 * 30,
stack_level=2,
)
return out, err
def _GetPythonScriptCommand(self, args: list[str]) -> str:
"""Returns the command to run the python script with the given args."""
# When run without the region variable, get the error:
# "ARN should be scoped to correct region: us-west-2"
return (
f'export AWS_DEFAULT_REGION={self.region} && '
# These arguments are needed for all operations.
'python3'
f' {self.python_script} --region={self.region} --model_id={self.model_id} '
f'--model_version={self.model_version} ' + ' '.join(args)
)
def _SendPrompt(
self, prompt: str, max_tokens: int, temperature: float, **kwargs: Any
) -> list[str]:
"""Sends a prompt to the model and returns the response."""
out, err = self._RunPythonScript([
'--operation=prompt',
f'--endpoint_name={self.endpoint_name}',
f'--prompt={prompt}',
f'--max_tokens={max_tokens}',
f'--temperature={temperature}',
])
matches = re.search('Response>>>>(.*)====', out, flags=re.DOTALL)
if not matches:
raise errors.Resource.GetError(
'Could not find response in endpoint call stdout.\nStdout:'
f' {out}\nStderr:{err}',
)
return [matches.group(1)]
def GetPromptCommand(
self, prompt: str, max_tokens: int, temperature: float, **kwargs: Any
) -> str:
return self._GetPythonScriptCommand([
'--operation=prompt',
f'--endpoint_name={self.endpoint_name}',
f'--prompt={prompt}',
f'--max_tokens={max_tokens}',
f'--temperature={temperature}',
])
def _Create(self) -> None:
"""Creates the underlying resource."""
logging.info('Creating Jump Start Model: %s', self.model_id)
out, err = self._RunPythonScript(
['--operation=create', f'--role={self.execution_arn}']
)
# TODO(user): Handle errors rather than swallowing them.
# Unfortunately even a correct run gives some errors.
def _FindNameMatch(out: str, resource_type: str) -> str:
"""Finds the name of the resource in the output of the python script."""
matches = re.search(f'{resource_type}: <(.+?)>', out)
if not matches:
raise errors.Resource.CreationError(
f'Could not find {resource_type} in python create output.\nStdout:'
f' {out}\nStderr:{err}',
)
return matches.group(1)
self.endpoint_name = _FindNameMatch(out, 'Endpoint name')
self.model_name = _FindNameMatch(out, 'Model name')
def _PostCreate(self) -> None:
"""Adds tags & metadata after creation timing."""
self._AddTags('endpoint', self.endpoint_name)
self._AddTags('model', self.model_name)
describe_cmd = (
f'aws sagemaker describe-endpoint --region={self.region} '
f'--endpoint-name={self.endpoint_name}'
)
describe_out, _, _ = self.vm.RunCommand(describe_cmd)
describe_json = json.loads(describe_out)
endpoint_config_name = describe_json['EndpointConfigName']
describe_endpoint_config_cmd = (
f'aws sagemaker describe-endpoint-config --region={self.region} '
f'--endpoint-config-name={endpoint_config_name}'
)
describe_endpoint_config_out, _, _ = self.vm.RunCommand(
describe_endpoint_config_cmd
)
describe_endpoint_config_json = json.loads(describe_endpoint_config_out)
self.metadata.update({
'machine_type': describe_endpoint_config_json['ProductionVariants'][0][
'InstanceType'
]
})
def _AddTags(self, resource_type: str, resource_name: str) -> None:
"""Adds tags to the resource with the given type & name."""
arn = f'arn:aws:sagemaker:{self.region}:{self.account_id}:{resource_type}/{resource_name}'
cmd = (
f'aws sagemaker add-tags --region={self.region} --resource-arn={arn} '
+ '--tags '
+ ' '.join(util.MakeFormattedDefaultTags())
)
self.vm.RunCommand(cmd)
def _CreateDependencies(self) -> None:
self.vm.Install('pip')
self.vm.Install('awscli')
self.vm.RunCommand('pip install sagemaker')
self.vm.RunCommand('pip install absl-py')
self.python_script = self.vm.PrepareResourcePath(AWS_RUNNER_SCRIPT)
super()._CreateDependencies()
def _Delete(self) -> None:
"""Deletes the underlying resource."""
assert self.endpoint_name
self._RunPythonScript(
['--operation=delete', f'--endpoint_name={self.endpoint_name}']
)
class JumpStartModelSpec(managed_ai_model_spec.BaseManagedAiModelSpec):
"""Spec for a Sagemaker JumpStart model.
Attributes:
model_id: Id of the model .
model_version: Version of the model.
"""
CLOUD = 'AWS'
def __init__(self, component_full_name, flag_values=None, **kwargs):
super().__init__(component_full_name, flag_values=flag_values, **kwargs)
self.model_id: str
self.model_version: str
class JumpStartLlama2Spec(JumpStartModelSpec):
"""Spec for running the Llama2 model.
Source is this python notebook:
https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-2-text-completion.ipynb
"""
MODEL_NAME = 'llama2'
MODEL_SIZE = ['7b', '70b']
def __init__(self, component_full_name, flag_values=None, **kwargs):
super().__init__(component_full_name, flag_values=flag_values, **kwargs)
self.model_id = f'meta-textgeneration-llama-2-{self.model_size}-f'
self.model_version = '2.*'
class JumpStartLlama3Spec(JumpStartModelSpec):
"""Spec for running the Llama3 model.
Source is this python notebook:
https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/llama-3-text-completion.ipynb
"""
MODEL_NAME = 'llama3'
MODEL_SIZE = ['8b', '70b']
def __init__(self, component_full_name, flag_values=None, **kwargs):
super().__init__(component_full_name, flag_values=flag_values, **kwargs)
self.model_id = f'meta-textgeneration-llama-3-{self.model_size}'
self.model_version = '2.*'