nl2sql_library/nl2sql/llms/google_palm.py (29 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Used to get an instance of the PaLM LLM """ import os from google.cloud import secretmanager from langchain.llms.google_palm import GooglePalm class ExtendedPalm(GooglePalm): """ Adds utility functions to GooglePalm """ def get_num_tokens(self, text: str): """ Returns the token count for some text """ return self.client.count_message_tokens(prompt=text)["token_count"] def get_max_input_tokens(self): """ Returns the maximum number of input tokens allowed """ return self.client.get_model("models/text-bison-001").input_token_limit def get_secretmanager_authed_palm( project_id: str | None = None, secret_id: str = "palm-api-key", secret_version_id: str = "latest", **kwargs, ) -> ExtendedPalm: """ Returns an Instance of ExtendedPalm already authed using the GSD API Key """ if project_id is None: project_id = os.getenv("GOOGLE_CLOUD_PROJECT") return ExtendedPalm( n=3, verbose=True, temperature=0.3, max_output_tokens=1024, **kwargs, google_api_key=secretmanager.SecretManagerServiceClient() .access_secret_version( name=f"projects/{project_id}/secrets/{secret_id}\ /versions/{secret_version_id}" ) .payload.data.decode("UTF-8"), )