in nl2sql_src/nl2sql_generic.py [0:0]
def __init__(self,
project_id,
dataset_id,
metadata_json_path=None,
model_name="gemini-pro",
tuned_model=True):
"""
Init function
"""
self.dataset_id = f"{project_id}.{dataset_id}"
self.metadata_json = None
self.model_name = model_name
if model_name == 'text-bison@002' and tuned_model:
# self.llm = VertexAI(temperature=0,
# model_name=self.model_name,
# tuned_model_name='projects/862253555914/\
# locations/us-central1/\
# models/7566417909400993792',
# max_output_tokens=1024)
tuned_model_name = 'projects/174482663155/locations/' + \
'us-central1/models/6975883408262037504'
self.llm = VertexAI(temperature=0,
model_name=self.model_name,
tuned_model_name=tuned_model_name,
max_output_tokens=1024)
else:
self.llm = VertexAI(temperature=0,
model_name=self.model_name,
max_output_tokens=1024)
logger.info(f"Current LLM model : {self.model_name}")
# self.llm = VertexAI(temperature=0,
# model_name=self.model_name,
# tuned_model_name='projects/862253555914/\
# locations/us-central1/models/\
# 7566417909400993792',
# max_output_tokens=1024)
self.engine = sqlalchemy.engine.create_engine(
f"bigquery://{self.dataset_id.replace('.', '/')}")
if metadata_json_path:
f = open(metadata_json_path, encoding="utf-8")
self.metadata_json = json.loads(f.read())