perfkitbenchmarker/scripts/aws_jump_start_runner.py (97 lines of code) (raw):

r"""Command line utility for manipulating AWS Sagemaker/Jumpstart Models. Example call with required values for all operations on first line: ``` python aws_jump_start_runner.py --model_id=meta-textgeneration-llama-2-7b-f --model_version=2.* --region=us-west-2 \ --operation=create --role=my-role:sagemaker-full-access ``` """ from absl import app from absl import flags from sagemaker import predictor as predictor_lib from sagemaker.jumpstart.model import JumpStartModel _ENDPOINT_NAME = flags.DEFINE_string( 'endpoint_name', '', help='Name of an existing endpoint. Not needed if an endpoint is created.', ) _ROLE = flags.DEFINE_string( 'role', '', help='AWS role the model/endpoint will run under.' ) _MODEL_ID = flags.DEFINE_string( 'model_id', '', help='Id of the model. Required.' ) _MODEL_VERSION = flags.DEFINE_string( 'model_version', '', help='Version of the model. Required.' ) _REGION = flags.DEFINE_string( 'region', '', help='Name of the region model/endpoint are in. Required.' ) _OPERATION = flags.DEFINE_enum( 'operation', default='prompt', enum_values=['create', 'delete', 'prompt'], help='Required. Operation that will be done against the endpoint.', ) _PROMPT = flags.DEFINE_string( 'prompt', '', help='Prompt sent to model, used in prompt mode' ) _MAX_TOKENS = flags.DEFINE_integer( 'max_tokens', 512, help='Max tokens returned in response, used in prompt mode', ) _TEMPERATURE = flags.DEFINE_float( 'temperature', 1.0, help='Temperature / randomness of responses, used in prompt mode', ) def Create(): assert _ROLE.value model = JumpStartModel( model_id=_MODEL_ID.value, model_version=_MODEL_VERSION.value, region=_REGION.value, role=_ROLE.value, ) print('Model name: <' + model.name + '>') predictor = model.deploy(accept_eula=True) print('Endpoint name: <' + predictor.endpoint_name + '>') return predictor def GetModel(): assert _ENDPOINT_NAME.value return predictor_lib.retrieve_default( _ENDPOINT_NAME.value, region=_REGION.value, model_id=_MODEL_ID.value, model_version=_MODEL_VERSION.value, ) def SendPrompt(predictor: predictor_lib.Predictor): """Sends prompt given flags to the predictor.""" assert _PROMPT.value def PrintDialog(response): assumed_role = response[0]['generation']['role'].capitalize() content = response[0]['generation']['content'] print(f'Response>>>> {assumed_role}: {content}') print('\n====\n') payload = { 'inputs': [[ {'role': 'user', 'content': _PROMPT.value}, ]], 'parameters': { 'max_new_tokens': _MAX_TOKENS.value, 'temperature': _TEMPERATURE.value, }, } response = predictor.predict(payload, custom_attributes='accept_eula=true') PrintDialog(response) def Delete(predictor: predictor_lib.Predictor): predictor.delete_model() predictor.delete_endpoint() def main(argv): del argv assert _MODEL_VERSION.value assert _MODEL_ID.value assert _REGION.value if _OPERATION.value == 'create': Create() elif _OPERATION.value == 'delete': predictor = GetModel() Delete(predictor) elif _OPERATION.value == 'prompt': predictor = GetModel() SendPrompt(predictor) if __name__ == '__main__': app.run(main)