def __init__()

in nl2sql_src/nl2sql_generic.py [0:0]


    def __init__(self,
                 project_id,
                 dataset_id,
                 metadata_json_path=None,
                 model_name="gemini-pro",
                 tuned_model=True):
        """
            Init function
        """
        self.dataset_id = f"{project_id}.{dataset_id}"
        self.metadata_json = None
        self.model_name = model_name

        if model_name == 'text-bison@002' and tuned_model:
            # self.llm = VertexAI(temperature=0,
            #                     model_name=self.model_name,
            #                     tuned_model_name='projects/862253555914/\
            #                     locations/us-central1/\
            #                     models/7566417909400993792',
            #                     max_output_tokens=1024)
            tuned_model_name = 'projects/174482663155/locations/' + \
                                'us-central1/models/6975883408262037504'
            self.llm = VertexAI(temperature=0,
                                model_name=self.model_name,
                                tuned_model_name=tuned_model_name,
                                max_output_tokens=1024)

        else:
            self.llm = VertexAI(temperature=0,
                                model_name=self.model_name,
                                max_output_tokens=1024)

        logger.info(f"Current LLM model : {self.model_name}")

        # self.llm = VertexAI(temperature=0,
        #                     model_name=self.model_name,
        #                     tuned_model_name='projects/862253555914/\
        #                     locations/us-central1/models/\
        #                     7566417909400993792',
        #                     max_output_tokens=1024)

        self.engine = sqlalchemy.engine.create_engine(
            f"bigquery://{self.dataset_id.replace('.', '/')}")
        if metadata_json_path:
            f = open(metadata_json_path, encoding="utf-8")
            self.metadata_json = json.loads(f.read())