nl2sql_library/nl2sql/llms/google_palm.py (29 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Used to get an instance of the PaLM LLM
"""
import os
from google.cloud import secretmanager
from langchain.llms.google_palm import GooglePalm
class ExtendedPalm(GooglePalm):
"""
Adds utility functions to GooglePalm
"""
def get_num_tokens(self, text: str):
"""
Returns the token count for some text
"""
return self.client.count_message_tokens(prompt=text)["token_count"]
def get_max_input_tokens(self):
"""
Returns the maximum number of input tokens allowed
"""
return self.client.get_model("models/text-bison-001").input_token_limit
def get_secretmanager_authed_palm(
project_id: str | None = None,
secret_id: str = "palm-api-key",
secret_version_id: str = "latest",
**kwargs,
) -> ExtendedPalm:
"""
Returns an Instance of ExtendedPalm already authed using the GSD API Key
"""
if project_id is None:
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
return ExtendedPalm(
n=3,
verbose=True,
temperature=0.3,
max_output_tokens=1024,
**kwargs,
google_api_key=secretmanager.SecretManagerServiceClient()
.access_secret_version(
name=f"projects/{project_id}/secrets/{secret_id}\
/versions/{secret_version_id}"
)
.payload.data.decode("UTF-8"),
)