# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from datetime import datetime
from ipaddress import IPv4Address, IPv6Address
from typing import Any, Literal, Optional

import asyncpg
from pgvector.asyncpg import register_vector
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

import models

from .. import datastore
from ..helpers import format_sql

POSTGRES_IDENTIFIER = "postgres"


class Config(BaseModel, datastore.AbstractConfig):
    kind: Literal["postgres"]
    host: IPv4Address | IPv6Address = IPv4Address("127.0.0.1")
    port: int = 5432
    user: str
    password: str
    database: str


class Client(datastore.Client[Config]):
    __async_engine: AsyncEngine

    @datastore.classproperty
    def kind(cls):
        return POSTGRES_IDENTIFIER

    def __init__(self, async_engine: AsyncEngine):
        self.__async_engine = async_engine

    @classmethod
    async def create(cls, config: Config) -> "Client":
        async def getconn() -> asyncpg.Connection:
            conn: asyncpg.Connection = await asyncpg.connection.connect(
                host=str(config.host),
                user=config.user,
                password=config.password,
                database=config.database,
                port=config.port,
            )
            await register_vector(conn)
            return conn

        async_engine = create_async_engine(
            "postgresql+asyncpg://",
            async_creator=getconn,
        )
        if async_engine is None:
            raise TypeError("async_engine not instantiated")
        return cls(async_engine)

    async def initialize_data(
        self,
        airports: list[models.Airport],
        amenities: list[models.Amenity],
        flights: list[models.Flight],
        policies: list[models.Policy],
    ) -> None:
        async with self.__async_engine.connect() as conn:
            # If the table already exists, drop it to avoid conflicts
            await conn.execute(text("DROP TABLE IF EXISTS airports CASCADE"))
            # Create a new table
            await conn.execute(
                text(
                    """
                    CREATE TABLE airports(
                      id INT PRIMARY KEY,
                      iata TEXT,
                      name TEXT,
                      city TEXT,
                      country TEXT
                    )
                    """
                )
            )
            # Insert all the data
            await conn.execute(
                text(
                    """INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)"""
                ),
                [
                    {
                        "id": a.id,
                        "iata": a.iata,
                        "name": a.name,
                        "city": a.city,
                        "country": a.country,
                    }
                    for a in airports
                ],
            )

            await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
            # If the table already exists, drop it to avoid conflicts
            await conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE"))
            # Create a new table
            await conn.execute(
                text(
                    """
                    CREATE TABLE amenities(
                      id INT PRIMARY KEY,
                      name TEXT,
                      description TEXT,
                      location TEXT,
                      terminal TEXT,
                      category TEXT,
                      hour TEXT,
                      sunday_start_hour TIME,
                      sunday_end_hour TIME,
                      monday_start_hour TIME,
                      monday_end_hour TIME,
                      tuesday_start_hour TIME,
                      tuesday_end_hour TIME,
                      wednesday_start_hour TIME,
                      wednesday_end_hour TIME,
                      thursday_start_hour TIME,
                      thursday_end_hour TIME,
                      friday_start_hour TIME,
                      friday_end_hour TIME,
                      saturday_start_hour TIME,
                      saturday_end_hour TIME,
                      content TEXT NOT NULL,
                      embedding vector(768) NOT NULL
                    )
                    """
                )
            )
            # Insert all the data
            await conn.execute(
                text(
                    """
                    INSERT INTO amenities VALUES (:id, :name, :description, :location,
                      :terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour,
                      :monday_start_hour, :monday_end_hour, :tuesday_start_hour,
                      :tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour,
                      :thursday_start_hour, :thursday_end_hour, :friday_start_hour,
                      :friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, :embedding)
                    """
                ),
                [
                    {
                        "id": a.id,
                        "name": a.name,
                        "description": a.description,
                        "location": a.location,
                        "terminal": a.terminal,
                        "category": a.category,
                        "hour": a.hour,
                        "sunday_start_hour": a.sunday_start_hour,
                        "sunday_end_hour": a.sunday_end_hour,
                        "monday_start_hour": a.monday_start_hour,
                        "monday_end_hour": a.monday_end_hour,
                        "tuesday_start_hour": a.tuesday_start_hour,
                        "tuesday_end_hour": a.tuesday_end_hour,
                        "wednesday_start_hour": a.wednesday_start_hour,
                        "wednesday_end_hour": a.wednesday_end_hour,
                        "thursday_start_hour": a.thursday_start_hour,
                        "thursday_end_hour": a.thursday_end_hour,
                        "friday_start_hour": a.friday_start_hour,
                        "friday_end_hour": a.friday_end_hour,
                        "saturday_start_hour": a.saturday_start_hour,
                        "saturday_end_hour": a.saturday_end_hour,
                        "content": a.content,
                        "embedding": a.embedding,
                    }
                    for a in amenities
                ],
            )

            # If the table already exists, drop it to avoid conflicts
            await conn.execute(text("DROP TABLE IF EXISTS flights CASCADE"))
            # Create a new table
            await conn.execute(
                text(
                    """
                    CREATE TABLE flights(
                      id INTEGER PRIMARY KEY,
                      airline TEXT,
                      flight_number TEXT,
                      departure_airport TEXT,
                      arrival_airport TEXT,
                      departure_time TIMESTAMP,
                      arrival_time TIMESTAMP,
                      departure_gate TEXT,
                      arrival_gate TEXT
                    )
                    """
                )
            )
            # Insert all the data
            await conn.execute(
                text(
                    """
                    INSERT INTO flights VALUES (:id, :airline, :flight_number,
                      :departure_airport, :arrival_airport, :departure_time,
                      :arrival_time, :departure_gate, :arrival_gate)
                    """
                ),
                [
                    {
                        "id": f.id,
                        "airline": f.airline,
                        "flight_number": f.flight_number,
                        "departure_airport": f.departure_airport,
                        "arrival_airport": f.arrival_airport,
                        "departure_time": f.departure_time,
                        "arrival_time": f.arrival_time,
                        "departure_gate": f.departure_gate,
                        "arrival_gate": f.arrival_gate,
                    }
                    for f in flights
                ],
            )

            # If the table already exists, drop it to avoid conflicts
            await conn.execute(text("DROP TABLE IF EXISTS tickets CASCADE"))
            # Create a new table
            await conn.execute(
                text(
                    """
                    CREATE TABLE tickets(
                        user_id TEXT,
                        user_name TEXT,
                        user_email TEXT,
                        airline TEXT,
                        flight_number TEXT,
                        departure_airport TEXT,
                        arrival_airport TEXT,
                        departure_time TIMESTAMP,
                        arrival_time TIMESTAMP
                    )
                    """
                )
            )

            # If the table already exists, drop it to avoid conflicts
            await conn.execute(text("DROP TABLE IF EXISTS policies CASCADE"))
            # Create a new table
            await conn.execute(
                text(
                    """
                    CREATE TABLE policies(
                      id INT PRIMARY KEY,
                      content TEXT NOT NULL,
                      embedding vector(768) NOT NULL
                    )
                    """
                )
            )
            # Insert all the data
            await conn.execute(
                text(
                    """
                    INSERT INTO policies VALUES (:id, :content, :embedding)
                    """
                ),
                [
                    {
                        "id": p.id,
                        "content": p.content,
                        "embedding": p.embedding,
                    }
                    for p in policies
                ],
            )
            await conn.commit()

    async def export_data(
        self,
    ) -> tuple[
        list[models.Airport],
        list[models.Amenity],
        list[models.Flight],
        list[models.Policy],
    ]:
        async with self.__async_engine.connect() as conn:
            airport_task = asyncio.create_task(
                conn.execute(text("""SELECT * FROM airports ORDER BY id ASC"""))
            )
            amenity_task = asyncio.create_task(
                conn.execute(text("""SELECT * FROM amenities ORDER BY id ASC"""))
            )
            flights_task = asyncio.create_task(
                conn.execute(text("""SELECT * FROM flights ORDER BY id ASC"""))
            )
            policy_task = asyncio.create_task(
                conn.execute(text("""SELECT * FROM policies ORDER BY id ASC"""))
            )

            airport_results = (await airport_task).mappings().fetchall()
            amenity_results = (await amenity_task).mappings().fetchall()
            flights_results = (await flights_task).mappings().fetchall()
            policy_results = (await policy_task).mappings().fetchall()

            airports = [models.Airport.model_validate(a) for a in airport_results]
            amenities = [models.Amenity.model_validate(a) for a in amenity_results]
            flights = [models.Flight.model_validate(f) for f in flights_results]
            policies = [models.Policy.model_validate(p) for p in policy_results]

            return airports, amenities, flights, policies

    async def get_airport_by_id(
        self, id: int
    ) -> tuple[Optional[models.Airport], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """SELECT * FROM airports WHERE id=:id"""
            s = text(sql)
            params = {"id": id}
            result = (await conn.execute(s, params)).mappings().fetchone()

        if result is None:
            return None, None

        res = models.Airport.model_validate(result)
        return res, format_sql(sql, params)

    async def get_airport_by_iata(
        self, iata: str
    ) -> tuple[Optional[models.Airport], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """SELECT * FROM airports WHERE iata ILIKE :iata"""
            s = text(sql)
            params = {"iata": iata}
            result = (await conn.execute(s, params)).mappings().fetchone()

        if result is None:
            return None, None

        res = models.Airport.model_validate(result)
        return res, format_sql(sql, params)

    async def search_airports(
        self,
        country: Optional[str] = None,
        city: Optional[str] = None,
        name: Optional[str] = None,
    ) -> tuple[list[models.Airport], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """
                SELECT * FROM airports
                  WHERE (CAST(:country AS TEXT) IS NULL OR country ILIKE :country)
                  AND (CAST(:city AS TEXT) IS NULL OR city ILIKE :city)
                  AND (CAST(:name AS TEXT) IS NULL OR name ILIKE '%' || :name || '%')
                  LIMIT 10
                """
            s = text(sql)
            params = {
                "country": country,
                "city": city,
                "name": name,
            }
            results = (await conn.execute(s, params)).mappings().fetchall()

        res = [models.Airport.model_validate(r) for r in results]
        return res, format_sql(sql, params)

    async def get_amenity(
        self, id: int
    ) -> tuple[Optional[models.Amenity], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """
                SELECT id, name, description, location, terminal, category, hour
                FROM amenities WHERE id=:id
                """
            s = text(sql)
            params = {"id": id}
            result = (await conn.execute(s, params)).mappings().fetchone()

        if result is None:
            return None, None

        res = models.Amenity.model_validate(result)
        return res, format_sql(sql, params)

    async def amenities_search(
        self, query_embedding: list[float], similarity_threshold: float, top_k: int
    ) -> tuple[list[Any], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """
                SELECT name, description, location, terminal, category, hour
                FROM amenities
                WHERE (embedding <=> :query_embedding) < :similarity_threshold
                ORDER BY (embedding <=> :query_embedding)
                LIMIT :top_k
                """
            s = text(sql)
            params = {
                "query_embedding": query_embedding,
                "similarity_threshold": similarity_threshold,
                "top_k": top_k,
            }
            results = (await conn.execute(s, params)).mappings().fetchall()

        res = [r for r in results]
        return res, format_sql(sql, params)

    async def get_flight(
        self, flight_id: int
    ) -> tuple[Optional[models.Flight], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """
                SELECT * FROM flights
                  WHERE id = :flight_id
                """
            s = text(sql)
            params = {"flight_id": flight_id}
            result = (await conn.execute(s, params)).mappings().fetchone()

        if result is None:
            return None, None

        res = models.Flight.model_validate(result)
        return res, format_sql(sql, params)

    async def search_flights_by_number(
        self,
        airline: str,
        number: str,
    ) -> tuple[list[models.Flight], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """
                SELECT * FROM flights
                  WHERE airline = :airline
                  AND flight_number = :number
                  LIMIT 10
                """
            s = text(sql)
            params = {
                "airline": airline,
                "number": number,
            }
            results = (await conn.execute(s, params)).mappings().fetchall()

        res = [models.Flight.model_validate(r) for r in results]
        return res, format_sql(sql, params)

    async def search_flights_by_airports(
        self,
        date: str,
        departure_airport: Optional[str] = None,
        arrival_airport: Optional[str] = None,
    ) -> tuple[list[models.Flight], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """
                SELECT * FROM flights
                  WHERE (CAST(:departure_airport AS TEXT) IS NULL OR departure_airport ILIKE :departure_airport)
                  AND (CAST(:arrival_airport AS TEXT) IS NULL OR arrival_airport ILIKE :arrival_airport)
                  AND departure_time >= CAST(:datetime AS timestamp)
                  AND departure_time < CAST(:datetime AS timestamp) + interval '1 day'
                  LIMIT 10
                """
            s = text(sql)
            params = {
                "departure_airport": departure_airport,
                "arrival_airport": arrival_airport,
                "datetime": datetime.strptime(date, "%Y-%m-%d"),
            }

            results = (await conn.execute(s, params)).mappings().fetchall()

        res = [models.Flight.model_validate(r) for r in results]
        return res, format_sql(sql, params)

    async def validate_ticket(
        self,
        airline: str,
        flight_number: str,
        departure_airport: str,
        departure_time: str,
    ) -> tuple[Optional[models.Flight], Optional[str]]:
        departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S")
        async with self.__async_engine.connect() as conn:
            sql = """
                    SELECT * FROM flights
                    WHERE airline ILIKE :airline
                    AND flight_number ILIKE :flight_number
                    AND departure_airport ILIKE :departure_airport
                    AND departure_time = :departure_time
                """
            s = text(sql)
            params = {
                "airline": airline,
                "flight_number": flight_number,
                "departure_airport": departure_airport,
                "departure_time": departure_time_datetime,
            }
            result = (await conn.execute(s, params)).mappings().fetchone()

        if result is None:
            return None, None
        res = models.Flight.model_validate(result)
        return res, format_sql(sql, params)

    async def insert_ticket(
        self,
        user_id: str,
        user_name: str,
        user_email: str,
        airline: str,
        flight_number: str,
        departure_airport: str,
        arrival_airport: str,
        departure_time: str,
        arrival_time: str,
    ):
        departure_time_datetime = datetime.strptime(departure_time, "%Y-%m-%d %H:%M:%S")
        arrival_time_datetime = datetime.strptime(arrival_time, "%Y-%m-%d %H:%M:%S")

        async with self.__async_engine.connect() as conn:
            s = text(
                """
                INSERT INTO tickets (
                    user_id,
                    user_name,
                    user_email,
                    airline,
                    flight_number,
                    departure_airport,
                    arrival_airport,
                    departure_time,
                    arrival_time
                ) VALUES (
                    :user_id,
                    :user_name,
                    :user_email,
                    :airline,
                    :flight_number,
                    :departure_airport,
                    :arrival_airport,
                    :departure_time,
                    :arrival_time
                );
            """
            )
            params = {
                "user_id": user_id,
                "user_name": user_name,
                "user_email": user_email,
                "airline": airline,
                "flight_number": flight_number,
                "departure_airport": departure_airport,
                "arrival_airport": arrival_airport,
                "departure_time": departure_time_datetime,
                "arrival_time": arrival_time_datetime,
            }
            result = (await conn.execute(s, params)).mappings()
            await conn.commit()
            if not result:
                raise Exception("Ticket Insertion failure")

    async def list_tickets(
        self,
        user_id: str,
    ) -> tuple[list[Any], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """
                SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets
                WHERE user_id = :user_id
                """
            s = text(sql)
            params = {
                "user_id": user_id,
            }
            results = (await conn.execute(s, params)).mappings().fetchall()

        res = [r for r in results]
        return res, format_sql(sql, params)

    async def policies_search(
        self, query_embedding: list[float], similarity_threshold: float, top_k: int
    ) -> tuple[list[str], Optional[str]]:
        async with self.__async_engine.connect() as conn:
            sql = """
                SELECT content
                FROM policies
                WHERE (embedding <=> :query_embedding) < :similarity_threshold
                ORDER BY (embedding <=> :query_embedding)
                LIMIT :top_k
                """
            s = text(sql)
            params = {
                "query_embedding": query_embedding,
                "similarity_threshold": similarity_threshold,
                "top_k": top_k,
            }
            results = (await conn.execute(s, params)).mappings().fetchall()

        res = [r["content"] for r in results]
        return res, format_sql(sql, params)

    async def close(self):
        await self.__async_engine.dispose()
