def add_dataset_to_query()

in nl2sql_src/nl2sql_generic.py [0:0]


    def add_dataset_to_query(self, sql_query):
        """
        This function adds the specified dataset prefix to the tables
        in the FROM clause of a SQL query.

        Parameters:
        - dataset (str): The dataset name to be added as a prefix.
        - sql_query (str): The original SQL query.

        Returns:
        str: Modified SQL query with the specified dataset prefix
        added to the tables in the FROM clause.
        """
        logger.info(f"Original query : {sql_query}")
        dataset = self.dataset_id
        if sql_query:
            sql_query = sql_query.replace('`', '')
            # Define a regular expression pattern to match the FROM clause
            pattern = re.compile(r'\bFROM\b\s+(\w+)', re.IGNORECASE)

            # Find all matches of the pattern in the SQL query
            matches = pattern.findall(sql_query)

            # Iterate through matches and replace the table name
            for match in matches:
                # check text following the match if it is a complete table name
                next_text = sql_query.split(match)[1].split('\n')[0]
                next_text = next_text.split(' ')[0]

                # Check if the previous word is not DAY, YEAR, or MONTH
                if re.search(r'\b(?:DAY|YEAR|MONTH)\b',
                             sql_query[:sql_query.find(match)],
                             re.IGNORECASE) is None:

                    # Replace the next word after FROM with dataset.table
                    if match == dataset.split('.')[0]:
                        # checking if in generated SQL, table
                        # includes the project-id and dataset or not
                        replacement = f'`{match}'
                    else:
                        sql_query = sql_query.replace(next_text, '')
                        replacement = f'{dataset}.`{match}{next_text}`'

                    # replacement = f'{dataset}.{match}'
                    sql_query = re.sub(r'\bFROM\b\s+' + re.escape(match),
                                       f'FROM {replacement}',
                                       sql_query,
                                       flags=re.IGNORECASE
                                       )
                    if match == dataset.split('.')[0]:
                        sql_query = sql_query.replace(f'{match}{next_text}',
                                                      f'{match}{next_text}`'
                                                      )

            sql_query = sql_query.replace('CAST', 'SAFE_CAST')
            sql_query = sql_query.replace('SAFE_SAFE_CAST', 'SAFE_CAST')
            return sql_query
        else:
            return ""