import asyncio
import asyncpg
import pandas as pd
import numpy as np
from pgvector.asyncpg import register_vector
from google.cloud.sql.connector import Connector
from langchain_community.embeddings import VertexAIEmbeddings
from google.cloud import bigquery
from dbconnectors import pgconnector
from agents import EmbedderAgent
from sqlalchemy.sql import text
from utilities import VECTOR_STORE, PROJECT_ID, PG_INSTANCE, PG_DATABASE, PG_USER, PG_PASSWORD, PG_REGION, BQ_OPENDATAQNA_DATASET_NAME, BQ_REGION, EMBEDDING_MODEL

embedder = EmbedderAgent(EMBEDDING_MODEL)

async def store_schema_embeddings(table_details_embeddings, 
                            tablecolumn_details_embeddings, 
                            project_id,
                            instance_name,
                            database_name,
                            schema,
                            database_user,
                            database_password,
                            region,
                            VECTOR_STORE):
    """ 
    Store the vectorised table and column details in the DB table.
    This code may run for a few minutes.  
    """

    if VECTOR_STORE == "cloudsql-pgvector":
    
        loop = asyncio.get_running_loop()
        async with Connector(loop=loop) as connector:
            # Create connection to Cloud SQL database.
            conn: asyncpg.Connection = await connector.connect_async(
                f"{project_id}:{region}:{instance_name}",  # Cloud SQL instance connection name
                "asyncpg",
                user=f"{database_user}",
                password=f"{database_password}",
                db=f"{database_name}",
            )

            await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
            await register_vector(conn)

            # await conn.execute(f"DROP SCHEMA IF EXISTS {pg_schema} CASCADE")        

            # await conn.execute(f"CREATE SCHEMA {pg_schema}")        

            # await conn.execute("DROP TABLE IF EXISTS table_details_embeddings")
            # Create the `table_details_embeddings` table to store vector embeddings.
            await conn.execute(
                """CREATE TABLE IF NOT EXISTS table_details_embeddings(
                                    source_type VARCHAR(100) NOT NULL,
                                    user_grouping VARCHAR(100) NOT NULL,
                                    table_schema VARCHAR(1024) NOT NULL,
                                    table_name VARCHAR(1024) NOT NULL,
                                    content TEXT,
                                    embedding vector(768))"""
            )

            # Store all the generated embeddings back into the database.
            for index, row in table_details_embeddings.iterrows():
                await conn.execute(
                    f"""
                    MERGE INTO table_details_embeddings AS target
                    USING (SELECT $1::text AS source_type, $2::text AS user_grouping, $3::text AS table_schema, $4::text AS table_name, $5::text AS content, $6::vector AS embedding) AS source
                    ON target.user_grouping = source.user_grouping AND target.table_name = source.table_name
                    WHEN MATCHED THEN 
                        UPDATE SET source_type = source.source_type, table_schema = source.table_schema, content = source.content, embedding = source.embedding
                    WHEN NOT MATCHED THEN
                        INSERT (source_type, user_grouping, table_schema, table_name, content, embedding) 
                        VALUES (source.source_type, source.user_grouping, source.table_schema, source.table_name, source.content, source.embedding);
                    """,
                    row["source_type"],
                    row["user_grouping"],
                    row["table_schema"],
                    row["table_name"],
                    row["content"],
                    np.array(row["embedding"]),
                )

            # await conn.execute("DROP TABLE IF EXISTS tablecolumn_details_embeddings")
            # Create the `table_details_embeddings` table to store vector embeddings.
            await conn.execute(
                """CREATE TABLE IF NOT EXISTS tablecolumn_details_embeddings(
                                    source_type VARCHAR(100) NOT NULL,
                                    user_grouping VARCHAR(100) NOT NULL,
                                    table_schema VARCHAR(1024) NOT NULL,
                                    table_name VARCHAR(1024) NOT NULL,
                                    column_name VARCHAR(1024) NOT NULL,
                                    content TEXT,
                                    embedding vector(768))"""
            )

            # Store all the generated embeddings back into the database.
            for index, row in tablecolumn_details_embeddings.iterrows():
                await conn.execute(
                    f"""
                    MERGE INTO tablecolumn_details_embeddings AS target
                    USING (SELECT $1::text AS source_type, $2::text AS user_grouping, $3::text AS table_schema, 
                                $4::text AS table_name, $5::text AS column_name, $6::text AS content, $7::vector AS embedding) AS source
                    ON target.user_grouping = source.user_grouping 
                    AND target.table_name = source.table_name 
                    AND target.column_name = source.column_name
                    WHEN MATCHED THEN 
                        UPDATE SET source_type = source.source_type, table_schema = source.table_schema, content = source.content, embedding = source.embedding
                    WHEN NOT MATCHED THEN
                        INSERT (source_type, user_grouping, table_schema, table_name, column_name, content, embedding) 
                        VALUES (source.source_type, source.user_grouping, source.table_schema, source.table_name, source.column_name, source.content, source.embedding);
                    """,
                    row["source_type"],
                    row["user_grouping"],
                    row["table_schema"],
                    row["table_name"],
                    row["column_name"],
                    row["content"],
                    np.array(row["embedding"]),
                )
            await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
            await register_vector(conn)

            # await conn.execute("DROP TABLE IF EXISTS example_prompt_sql_embeddings")
            await conn.execute(
                        """CREATE TABLE IF NOT EXISTS example_prompt_sql_embeddings(
                                            user_grouping VARCHAR(1024) NOT NULL,
                                            example_user_question text NOT NULL,
                                            example_generated_sql text NOT NULL,
                                            embedding vector(768))"""
                        )

            await conn.close()


    elif VECTOR_STORE == "bigquery-vector": 
         
        client=bigquery.Client(project=project_id)

        #Store table embeddings
        client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.table_details_embeddings` (
            source_type string NOT NULL, user_grouping string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, content string, embedding ARRAY<FLOAT64>)''')
        #job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")

        delete_conditions = table_details_embeddings[['user_grouping', 'table_name']].apply(tuple, axis=1).tolist()
        where_clause = " OR ".join([f"(user_grouping = '{cond[0]}' AND table_name = '{cond[1]}')" for cond in delete_conditions])

        delete_query = f"""
        DELETE FROM `{project_id}.{schema}.table_details_embeddings`
        WHERE {where_clause}
        """
        client.query_and_wait(delete_query)
        
        client.load_table_from_dataframe(table_details_embeddings,f'{project_id}.{schema}.table_details_embeddings')


        #Store column embeddings
        client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.tablecolumn_details_embeddings` (
            source_type string NOT NULL,user_grouping string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, column_name string NOT NULL,
            content string, embedding ARRAY<FLOAT64>)''')
        #job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")

        delete_conditions = tablecolumn_details_embeddings[['user_grouping', 'table_name', 'column_name']].apply(tuple, axis=1).tolist()
        where_clause = " OR ".join([f"(user_grouping = '{cond[0]}' AND table_name = '{cond[1]}' AND column_name = '{cond[2]}')" for cond in delete_conditions])

        delete_query = f"""
            DELETE FROM `{project_id}.{schema}.tablecolumn_details_embeddings`
            WHERE {where_clause}
            """
        client.query_and_wait(delete_query)

        client.load_table_from_dataframe(tablecolumn_details_embeddings,f'{project_id}.{schema}.tablecolumn_details_embeddings')

        client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.example_prompt_sql_embeddings` (
                              user_grouping string NOT NULL, example_user_question string NOT NULL, example_generated_sql string NOT NULL,
                              embedding ARRAY<FLOAT64>)''')
                              
    else: raise ValueError("Please provide a valid Vector Store.")
    return "Embeddings are stored successfully"

async def add_sql_embedding(user_question, generated_sql, database):
        
        
        emb=embedder.create(user_question)

        if VECTOR_STORE == "cloudsql-pgvector":
        #    sql=  f'''MERGE INTO example_prompt_sql_embeddings as tgt
        #    using (SELECT '{user_question}' as example_user_question) as src 
        #    on tgt.example_user_question=src.example_user_question 
        #    when not matched then
        #    insert (table_schema, example_user_question,example_generated_sql,embedding) 
        #    values('{database}','{user_question}','{generated_sql}','{(emb)}')
        #    when matched then update set
        #    table_schema = '{database}',
        #    example_generated_sql = '{generated_sql}',
        #    embedding = '{(emb)}' '''

        # #    print(sql)
        #    conn=pgconnector.pool.connect()
        #    await conn.execute(text(sql))
        #    pgconnector.retrieve_df(sql)
            loop = asyncio.get_running_loop()
            async with Connector(loop=loop) as connector:
                    # Create connection to Cloud SQL database.
                conn: asyncpg.Connection = await connector.connect_async(
                        f"{PROJECT_ID}:{PG_REGION}:{PG_INSTANCE}",  # Cloud SQL instance connection name
                        "asyncpg",
                        user=f"{PG_USER}",
                        password=f"{PG_PASSWORD}",
                        db=f"{PG_DATABASE}",
                    )

                await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
                await register_vector(conn)

                await conn.execute("DELETE FROM example_prompt_sql_embeddings WHERE user_grouping= $1 and example_user_question=$2",
                                    database,
                                    user_question)
                cleaned_sql =generated_sql.replace("\r", " ").replace("\n", " ")
                await conn.execute(
                                "INSERT INTO example_prompt_sql_embeddings (user_grouping, example_user_question, example_generated_sql, embedding) VALUES ($1, $2, $3, $4)",
                                database,
                                user_question,
                                cleaned_sql,
                                np.array(emb),
                            )

        elif VECTOR_STORE == "bigquery-vector":

            client=bigquery.Client(project=PROJECT_ID)
        
            client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings` (
                user_grouping string NOT NULL, example_user_question string NOT NULL, example_generated_sql string NOT NULL,
                embedding ARRAY<FLOAT64>)''')
            client.query_and_wait(f'''DELETE FROM `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings`
                                WHERE user_grouping= '{database}' and example_user_question= "{user_question}" '''
                                    )
                        # embedding=np.array(row["embedding"])
            cleaned_sql = generated_sql.replace("\r", " ").replace("\n", " ")
            client.query_and_wait(f'''INSERT INTO `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings` 
                        VALUES ("{database}","{user_question}" , 
                        "{cleaned_sql}",{emb})''')
        return 1



if __name__ == '__main__': 
    from retrieve_embeddings import retrieve_embeddings
    from utilities import PG_SCHEMA, PROJECT_ID, PG_INSTANCE, PG_DATABASE, PG_USER, PG_PASSWORD, PG_REGION
    VECTOR_STORE = "cloudsql-pgvector"
    t, c = retrieve_embeddings(VECTOR_STORE, PG_SCHEMA) 
    asyncio.run(store_schema_embeddings(t, 
                            c, 
                            PROJECT_ID,
                            PG_INSTANCE,
                            PG_DATABASE,
                            PG_SCHEMA,
                            PG_USER,
                            PG_PASSWORD,
                            PG_REGION,
                            VECTOR_STORE = VECTOR_STORE))