packages/blueprints/gen-ai-chatbot/static-assets/chatbot-genai-components/backend/python/app/bedrock.py (247 lines of code) (raw):
import json
import logging
import os
from anthropic import AnthropicBedrock
from app.config import BEDROCK_PRICING, DEFAULT_EMBEDDING_CONFIG
from app.config import DEFAULT_GENERATION_CONFIG as DEFAULT_CLAUDE_GENERATION_CONFIG
from app.config import DEFAULT_MISTRAL_GENERATION_CONFIG
from app.repositories.models.conversation import MessageModel
from app.repositories.models.custom_bot import GenerationParamsModel
from app.utils import get_bedrock_client, is_anthropic_model
from pydantic import BaseModel
logger = logging.getLogger(__name__)
BEDROCK_REGION = os.environ.get("BEDROCK_REGION", "us-east-1")
ENABLE_MISTRAL = os.environ.get("ENABLE_MISTRAL", "") == "true"
DEFAULT_GENERATION_CONFIG = (
DEFAULT_MISTRAL_GENERATION_CONFIG
if ENABLE_MISTRAL
else DEFAULT_CLAUDE_GENERATION_CONFIG
)
client = get_bedrock_client()
anthropic_client = AnthropicBedrock()
class InvocationMetrics(BaseModel):
input_tokens: int
output_tokens: int
def compose_args(
messages: list[MessageModel],
model: str,
instruction: str | None = None,
stream: bool = False,
generation_params: GenerationParamsModel | None = None,
) -> dict:
# if model is from Anthropic, use AnthropicBedrock
# otherwise, use bedrock client
model_id = get_model_id(model)
if is_anthropic_model(model_id):
return compose_args_for_anthropic_client(
messages, model, instruction, stream, generation_params
)
else:
return compose_args_for_other_client(
messages, model, instruction, stream, generation_params
)
def compose_args_for_other_client(
messages: list[MessageModel],
model: str,
instruction: str | None = None,
stream: bool = False,
generation_params: GenerationParamsModel | None = None,
) -> dict:
arg_messages = []
for message in messages:
if message.role not in ["system", "instruction"]:
content: list[dict] = []
for c in message.content:
if c.content_type == "text":
content.append(
{
"type": "text",
"text": c.body,
}
)
m = {"role": message.role, "content": content}
arg_messages.append(m)
args = {
**DEFAULT_MISTRAL_GENERATION_CONFIG,
**(
{
"max_tokens": generation_params.max_tokens,
"top_k": generation_params.top_k,
"top_p": generation_params.top_p,
"temperature": generation_params.temperature,
"stop_sequences": generation_params.stop_sequences,
}
if generation_params
else {}
),
"model": get_model_id(model),
"messages": arg_messages,
"stream": stream,
}
if instruction:
args["system"] = instruction
return args
def compose_args_for_anthropic_client(
messages: list[MessageModel],
model: str,
instruction: str | None = None,
stream: bool = False,
generation_params: GenerationParamsModel | None = None,
) -> dict:
"""Compose arguments for Anthropic client.
Ref: https://docs.anthropic.com/claude/reference/messages_post
"""
arg_messages = []
for message in messages:
if message.role not in ["system", "instruction"]:
content: list[dict] = []
for c in message.content:
if c.content_type == "text":
content.append(
{
"type": "text",
"text": c.body,
}
)
elif c.content_type == "image":
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": c.media_type,
"data": c.body,
},
}
)
m = {"role": message.role, "content": content}
arg_messages.append(m)
args = {
**DEFAULT_GENERATION_CONFIG,
**(
{
"max_tokens": generation_params.max_tokens,
"top_k": generation_params.top_k,
"top_p": generation_params.top_p,
"temperature": generation_params.temperature,
"stop_sequences": generation_params.stop_sequences,
}
if generation_params
else {}
),
"model": get_model_id(model),
"messages": arg_messages,
"stream": stream,
}
if instruction:
args["system"] = instruction
return args
def calculate_price(
model: str, input_tokens: int, output_tokens: int, region: str = BEDROCK_REGION
) -> float:
input_price = (
BEDROCK_PRICING.get(region, {})
.get(model, {})
.get("input", BEDROCK_PRICING["default"][model]["input"])
)
output_price = (
BEDROCK_PRICING.get(region, {})
.get(model, {})
.get("output", BEDROCK_PRICING["default"][model]["output"])
)
return input_price * input_tokens / 1000.0 + output_price * output_tokens / 1000.0
def get_model_id(model: str) -> str:
# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
if model == "claude-v2":
return "anthropic.claude-v2:1"
elif model == "claude-instant-v1":
return "anthropic.claude-instant-v1"
elif model == "claude-v3-sonnet":
return "anthropic.claude-3-sonnet-20240229-v1:0"
elif model == "claude-v3-haiku":
return "anthropic.claude-3-haiku-20240307-v1:0"
elif model == "claude-v3-opus":
return "anthropic.claude-3-opus-20240229-v1:0"
elif model == "mistral-7b-instruct":
return "mistral.mistral-7b-instruct-v0:2"
elif model == "mixtral-8x7b-instruct":
return "mistral.mixtral-8x7b-instruct-v0:1"
elif model == "mistral-large":
return "mistral.mistral-large-2402-v1:0"
else:
raise NotImplementedError()
def calculate_query_embedding(question: str) -> list[float]:
model_id = DEFAULT_EMBEDDING_CONFIG["model_id"]
# Currently only supports "cohere.embed-multilingual-v3"
assert model_id == "cohere.embed-multilingual-v3"
payload = json.dumps({"texts": [question], "input_type": "search_query"})
accept = "application/json"
content_type = "application/json"
response = client.invoke_model(
accept=accept, contentType=content_type, body=payload, modelId=model_id
)
output = json.loads(response.get("body").read())
embedding = output.get("embeddings")[0]
return embedding
def calculate_document_embeddings(documents: list[str]) -> list[list[float]]:
def _calculate_document_embeddings(documents: list[str]) -> list[list[float]]:
payload = json.dumps({"texts": documents, "input_type": "search_document"})
accept = "application/json"
content_type = "application/json"
response = client.invoke_model(
accept=accept, contentType=content_type, body=payload, modelId=model_id
)
output = json.loads(response.get("body").read())
embeddings = output.get("embeddings")
return embeddings
BATCH_SIZE = 10
model_id = DEFAULT_EMBEDDING_CONFIG["model_id"]
# Currently only supports "cohere.embed-multilingual-v3"
assert model_id == "cohere.embed-multilingual-v3"
embeddings = []
for i in range(0, len(documents), BATCH_SIZE):
# Split documents into batches to avoid exceeding the payload size limit
batch = documents[i : i + BATCH_SIZE]
embeddings += _calculate_document_embeddings(batch)
return embeddings
def get_bedrock_response(args: dict) -> dict:
client = get_bedrock_client()
messages = args["messages"]
prompt = "\n".join(
[
message["content"][0]["text"]
for message in messages
if message["content"][0]["type"] == "text"
]
)
model_id = args["model"]
is_mistral_model = model_id.startswith("mistral")
if is_mistral_model:
prompt = f"<s>[INST] {prompt} [/INST]"
logger.info(f"Final Prompt: {prompt}")
body = json.dumps(
{
"prompt": prompt,
"max_tokens": args["max_tokens"],
"temperature": args["temperature"],
"top_p": args["top_p"],
"top_k": args["top_k"],
}
)
logger.info(f"The args before invoke bedrock: {args}")
if args["stream"]:
try:
response = client.invoke_model_with_response_stream(
modelId=model_id,
body=body,
)
# Ref: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/invoke_model_with_response_stream.html
response_body = response
except Exception as e:
logger.error(e)
else:
response = client.invoke_model(
modelId=model_id,
body=body,
)
# Ref: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/invoke_model.html
response_body = json.loads(response.get("body").read())
invocation_metrics = InvocationMetrics(
input_tokens=response["ResponseMetadata"]["HTTPHeaders"][
"x-amzn-bedrock-input-token-count"
],
output_tokens=response["ResponseMetadata"]["HTTPHeaders"][
"x-amzn-bedrock-output-token-count"
],
)
response_body["amazon-bedrock-invocationMetrics"] = invocation_metrics
return response_body