backend-apis/app/utils/utils_palm.py (138 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utils for PaLM
"""
import asyncio
import base64
import functools
import typing
import numpy as np
from google.api_core.exceptions import GoogleAPICallError
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from google.protobuf.json_format import MessageToDict
from vertexai.preview.language_models import (
TextEmbeddingInput,
TextEmbeddingModel,
TextGenerationModel,
)
model = TextGenerationModel.from_pretrained(model_name="text-bison@002")
text_embedding_model = TextEmbeddingModel.from_pretrained(
"textembedding-gecko@003"
)
def text_generation(
prompt: str,
max_output_tokens: int = 1024,
temperature: float = 0.2,
top_k: int = 40,
top_p: float = 0.8,
) -> str:
"""
Args:
model:
prompt:
max_output_tokens:
temperature:
top_k:
top_p:
Returns:
"""
return model.predict(
prompt=prompt,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
).text
class EmbeddingResponse(typing.NamedTuple):
"""Embedding Response"""
text_embedding: list[float]
image_embedding: list[float]
class EmbeddingPredictionClient:
"""Wrapper around Prediction Service Client."""
def __init__(
self,
project: str,
location: str = "us-central1",
api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
self.location = location
self.project = project
self.set_client_options(api_regional_endpoint)
def set_client_options(self, api_endpoint: str):
"""
Args:
api_endpoint:
"""
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
self.client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)
def get_embedding(self, text: str = "", image_bytes: bytes = b""):
"""
Args:
text:
image_bytes:
Raises:
ValueError:
Returns:
"""
if not text and not image_bytes:
raise ValueError(
"At least one of text or image_bytes must be specified."
)
instance = struct_pb2.Value()
if text:
instance.struct_value.update({"text": text})
if image_bytes:
encoded_content = base64.b64encode(image_bytes).decode("utf-8")
instance.struct_value.update(
{"image": {"bytesBase64Encoded": encoded_content}}
)
instances = [instance]
endpoint = (
f"projects/{self.project}/locations/{self.location}"
"/publishers/google/models/multimodalembedding@001"
)
response = self.client.predict(endpoint=endpoint, instances=instances)
response_dict = MessageToDict(
response._pb # pylint: disable=protected-access
)
text_embedding = []
if text:
text_emb_value = response_dict["predictions"][0]["textEmbedding"]
text_embedding = [float(v) for v in text_emb_value]
image_embedding = []
if image_bytes:
image_emb_value = response_dict["predictions"][0]["imageEmbedding"]
image_embedding = [float(v) for v in image_emb_value]
return EmbeddingResponse(
text_embedding=text_embedding, image_embedding=image_embedding
)
def reduce_embedding_dimension(
vector_text: list[float] | None = None,
vector_image: list[float] | None = None,
) -> list:
"""
Args:
vector_text:
vector_image:
Returns:
"""
if vector_image and vector_text:
matrix = np.array([vector_text, vector_image])
max_pooled_rows = np.sum(matrix, axis=0)
else:
max_pooled_rows = np.array(vector_text or vector_image)
return list(max_pooled_rows)
async def async_predict_text_llm(
prompt: str,
max_output_tokens: int = 1024,
temperature: float = 0.2,
top_k: int = 40,
top_p: float = 0.8,
) -> str:
"""
Args:
model:
prompt:
max_output_tokens:
temperature:
top_k:
top_p:
Returns:
"""
loop = asyncio.get_running_loop()
generated_response = None
try:
generated_response = await loop.run_in_executor(
None,
functools.partial(
model.predict,
prompt=prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_k=top_k,
top_p=top_p,
),
)
except GoogleAPICallError as e:
print(e)
return ""
if generated_response and generated_response.text:
generated_response = generated_response.text.replace("```json", "")
generated_response = generated_response.replace("```JSON", "")
generated_response = generated_response.replace("```", "")
return generated_response
return ""
async def run_predict_text_llm(
prompts: list, temperature: float = 0.2
) -> list:
"""
Args:
prompts:
model:
temperature:
Returns:
"""
tasks = [
async_predict_text_llm(prompt=prompt, temperature=temperature)
for prompt in prompts
]
results = await asyncio.gather(*tasks)
return results
def get_text_embeddings(input_text: str) -> list:
"""
Args:
input_text:
Returns:
"""
text_input = TextEmbeddingInput(text=input_text, task_type="CLUSTERING")
embeddings = text_embedding_model.get_embeddings(texts=[text_input])[
0
].values
return embeddings