# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from datetime import datetime
from typing import Any, Literal, Optional

import pymysql
from google.cloud.sql.connector import Connector, RefreshStrategy
from pydantic import BaseModel
from sqlalchemy import Engine, create_engine, text
from sqlalchemy.engine.base import Engine

import models

from .. import datastore

MYSQL_IDENTIFIER = "cloudsql-mysql"


class Config(BaseModel, datastore.AbstractConfig):
    kind: Literal["cloudsql-mysql"]
    project: str
    region: str
    instance: str
    user: str
    password: str
    database: str


class Client(datastore.Client[Config]):
    __pool: Engine
    __db_name: str
    __connector: Optional[Connector] = None

    @datastore.classproperty
    def kind(cls):
        return MYSQL_IDENTIFIER

    def __init__(self, pool: Engine, db_name: str):
        self.__pool = pool
        self.__db_name = db_name

    @classmethod
    def create_sync(cls, config: Config) -> "Client":
        def getconn() -> pymysql.Connection:
            if cls.__connector is None:
                cls.__connector = Connector(refresh_strategy=RefreshStrategy.LAZY)

            return cls.__connector.connect(
                # Cloud SQL instance connection name
                f"{config.project}:{config.region}:{config.instance}",
                "pymysql",
                user=f"{config.user}",
                password=f"{config.password}",
                db=f"{config.database}",
                autocommit=True,
            )

        pool = create_engine(
            "mysql+pymysql://",
            creator=getconn,
        )
        if pool is None:
            raise TypeError("pool not instantiated")
        return cls(pool, config.database)

    @classmethod
    async def create(cls, config: Config) -> "Client":
        loop = asyncio.get_running_loop()

        pool = await loop.run_in_executor(None, cls.create_sync, config)
        return pool

    def initialize_data_sync(
        self,
        airports: list[models.Airport],
        amenities: list[models.Amenity],
        flights: list[models.Flight],
        policies: list[models.Policy],
    ) -> None:
        with self.__pool.connect() as conn:
            # If the table already exists, drop it to avoid conflicts
            conn.execute(text("DROP TABLE IF EXISTS airports"))
            # Create a new table
            conn.execute(
                text(
                    """
                    CREATE TABLE airports(
                      id INT PRIMARY KEY,
                      iata TEXT,
                      name TEXT,
                      city TEXT,
                      country TEXT
                    )
                    """
                )
            )
            # Insert all the data
            conn.execute(
                text(
                    """INSERT INTO airports VALUES (:id, :iata, :name, :city, :country)"""
                ),
                parameters=[
                    {
                        "id": a.id,
                        "iata": a.iata,
                        "name": a.name,
                        "city": a.city,
                        "country": a.country,
                    }
                    for a in airports
                ],
            )

            # If the table already exists, drop it to avoid conflicts
            conn.execute(text("DROP TABLE IF EXISTS amenities CASCADE"))

            # Create a new table
            conn.execute(
                text(
                    """
                    CREATE TABLE amenities(
                      id INT PRIMARY KEY,
                      name TEXT,
                      description TEXT,
                      location TEXT,
                      terminal TEXT,
                      category TEXT,
                      hour TEXT,
                      sunday_start_hour TIME,
                      sunday_end_hour TIME,
                      monday_start_hour TIME,
                      monday_end_hour TIME,
                      tuesday_start_hour TIME,
                      tuesday_end_hour TIME,
                      wednesday_start_hour TIME,
                      wednesday_end_hour TIME,
                      thursday_start_hour TIME,
                      thursday_end_hour TIME,
                      friday_start_hour TIME,
                      friday_end_hour TIME,
                      saturday_start_hour TIME,
                      saturday_end_hour TIME,
                      content TEXT NOT NULL,
                      embedding vector(768) USING VARBINARY NOT NULL
                    )
                    """
                )
            )

            # Insert all the data
            conn.execute(
                text(
                    """
                    INSERT INTO amenities VALUES (:id, :name, :description, :location,
                    :terminal, :category, :hour, :sunday_start_hour, :sunday_end_hour,
                    :monday_start_hour, :monday_end_hour, :tuesday_start_hour,
                    :tuesday_end_hour, :wednesday_start_hour, :wednesday_end_hour,
                    :thursday_start_hour, :thursday_end_hour, :friday_start_hour,
                    :friday_end_hour, :saturday_start_hour, :saturday_end_hour, :content, string_to_vector(:embedding))
                    """
                ),
                parameters=[
                    {
                        "id": a.id,
                        "name": a.name,
                        "description": a.description,
                        "location": a.location,
                        "terminal": a.terminal,
                        "category": a.category,
                        "hour": a.hour,
                        "sunday_start_hour": a.sunday_start_hour,
                        "sunday_end_hour": a.sunday_end_hour,
                        "monday_start_hour": a.monday_start_hour,
                        "monday_end_hour": a.monday_end_hour,
                        "tuesday_start_hour": a.tuesday_start_hour,
                        "tuesday_end_hour": a.tuesday_end_hour,
                        "wednesday_start_hour": a.wednesday_start_hour,
                        "wednesday_end_hour": a.wednesday_end_hour,
                        "thursday_start_hour": a.thursday_start_hour,
                        "thursday_end_hour": a.thursday_end_hour,
                        "friday_start_hour": a.friday_start_hour,
                        "friday_end_hour": a.friday_end_hour,
                        "saturday_start_hour": a.saturday_start_hour,
                        "saturday_end_hour": a.saturday_end_hour,
                        "content": a.content,
                        "embedding": f"{a.embedding}",
                    }
                    for a in amenities
                ],
            )

            # If the table already exists, drop it to avoid conflicts
            conn.execute(text("DROP TABLE IF EXISTS flights"))
            # Create a new table
            conn.execute(
                text(
                    """
                    CREATE TABLE flights(
                      id INTEGER PRIMARY KEY,
                      airline TEXT,
                      flight_number TEXT,
                      departure_airport TEXT,
                      arrival_airport TEXT,
                      departure_time TIMESTAMP,
                      arrival_time TIMESTAMP,
                      departure_gate TEXT,
                      arrival_gate TEXT
                    )
                    """
                )
            )
            # Insert all the data
            conn.execute(
                text(
                    """
                    INSERT INTO flights VALUES (:id, :airline, :flight_number,
                    :departure_airport, :arrival_airport, :departure_time,
                    :arrival_time, :departure_gate, :arrival_gate)
                    """
                ),
                parameters=[
                    {
                        "id": f.id,
                        "airline": f.airline,
                        "flight_number": f.flight_number,
                        "departure_airport": f.departure_airport,
                        "arrival_airport": f.arrival_airport,
                        "departure_time": f.departure_time,
                        "arrival_time": f.arrival_time,
                        "departure_gate": f.departure_gate,
                        "arrival_gate": f.arrival_gate,
                    }
                    for f in flights
                ],
            )

            # If the table already exists, drop it to avoid conflicts
            conn.execute(text("DROP TABLE IF EXISTS tickets"))
            # Create a new table
            conn.execute(
                text(
                    """
                    CREATE TABLE tickets(
                        user_id TEXT,
                        user_name TEXT,
                        user_email TEXT,
                        airline TEXT,
                        flight_number TEXT,
                        departure_airport TEXT,
                        arrival_airport TEXT,
                        departure_time TIMESTAMP,
                        arrival_time TIMESTAMP
                    )
                    """
                )
            )

            # If the table already exists, drop it to avoid conflicts
            conn.execute(text("DROP TABLE IF EXISTS policies"))
            # Create a new table
            conn.execute(
                text(
                    """
                    CREATE TABLE policies(
                      id INT PRIMARY KEY,
                      content TEXT NOT NULL,
                      embedding vector(768) USING VARBINARY NOT NULL
                    )
                    """
                )
            )
            # Insert all the data
            conn.execute(
                text(
                    """
                    INSERT INTO policies VALUES (:id, :content, string_to_vector(:embedding))
                    """
                ),
                parameters=[
                    {
                        "id": p.id,
                        "content": p.content,
                        "embedding": f"{p.embedding}",
                    }
                    for p in policies
                ],
            )

    async def initialize_data(
        self,
        airports: list[models.Airport],
        amenities: list[models.Amenity],
        flights: list[models.Flight],
        policies: list[models.Policy],
    ) -> None:
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(
            None, self.initialize_data_sync, airports, amenities, flights, policies
        )

    def export_data_sync(
        self,
    ) -> tuple[
        list[models.Airport],
        list[models.Amenity],
        list[models.Flight],
        list[models.Policy],
    ]:
        with self.__pool.connect() as conn:
            airport_task = conn.execute(
                text("""SELECT * FROM airports ORDER BY id ASC""")
            )
            amenity_task = conn.execute(
                text(
                    """
                                             SELECT id,
                                                    name,
                                                    description,
                                                    location,
                                                    terminal,
                                                    category,
                                                    hour,
                                                    DATE_FORMAT(sunday_start_hour,  '%H:%i') AS sunday_start_hour,
                                                    DATE_FORMAT(sunday_end_hour,    '%H:%i') AS sunday_end_hour,
                                                    DATE_FORMAT(monday_start_hour,  '%H:%i') AS monday_start_hour,
                                                    DATE_FORMAT(monday_end_hour,    '%H:%i') AS monday_end_hour,
                                                    DATE_FORMAT(tuesday_start_hour, '%H:%i') AS tuesday_start_hour,
                                                    DATE_FORMAT(tuesday_end_hour,   '%H:%i') AS tuesday_end_hour,
                                                    DATE_FORMAT(wednesday_start_hour, '%H:%i') AS wednesday_start_hour,
                                                    DATE_FORMAT(wednesday_end_hour,   '%H:%i') AS wednesday_end_hour,
                                                    DATE_FORMAT(thursday_start_hour,  '%H:%i') AS thursday_start_hour,
                                                    DATE_FORMAT(thursday_end_hour,    '%H:%i') AS thursday_end_hour,
                                                    DATE_FORMAT(friday_start_hour,   '%H:%i') AS friday_start_hour,
                                                    DATE_FORMAT(friday_end_hour,     '%H:%i') AS friday_end_hour,
                                                    DATE_FORMAT(saturday_start_hour,  '%H:%i') AS saturday_start_hour,
                                                    DATE_FORMAT(saturday_end_hour,    '%H:%i') AS saturday_end_hour, 
                                                    content,
                                                    vector_to_string(embedding) as embedding
                                              FROM amenities ORDER BY id ASC
                                             """
                )
            )
            flights_task = conn.execute(
                text("""SELECT * FROM flights ORDER BY id ASC""")
            )
            policy_task = conn.execute(
                text(
                    """SELECT id, content, vector_to_string(embedding) as embedding FROM policies ORDER BY id ASC"""
                )
            )

            airport_results = (airport_task).mappings().fetchall()
            amenity_results = (amenity_task).mappings().fetchall()
            flights_results = (flights_task).mappings().fetchall()
            policy_results = (policy_task).mappings().fetchall()

            airports = [models.Airport.model_validate(a) for a in airport_results]
            amenities = [models.Amenity.model_validate(a) for a in amenity_results]
            flights = [models.Flight.model_validate(f) for f in flights_results]
            policies = [models.Policy.model_validate(p) for p in policy_results]

            return airports, amenities, flights, policies

    async def export_data(
        self,
    ) -> tuple[
        list[models.Airport],
        list[models.Amenity],
        list[models.Flight],
        list[models.Policy],
    ]:
        loop = asyncio.get_running_loop()
        res = await loop.run_in_executor(None, self.export_data_sync)
        return res

    def get_airport_by_id_sync(
        self, id: int
    ) -> tuple[Optional[models.Airport], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text("""SELECT * FROM airports WHERE id=:id""")
            params = {"id": id}
            result = (conn.execute(s, params)).mappings().fetchone()

        if result is None:
            return None, None

        res = models.Airport.model_validate(result)
        return res, None

    async def get_airport_by_id(
        self, id: int
    ) -> tuple[Optional[models.Airport], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(None, self.get_airport_by_id_sync, id)
        return res, sql

    def get_airport_by_iata_sync(
        self, iata: str
    ) -> tuple[Optional[models.Airport], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text("""SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER(:iata)""")
            params = {"iata": iata}
            result = (conn.execute(s, params)).mappings().fetchone()

        if result is None:
            return None, None

        res = models.Airport.model_validate(result)
        return res, None

    async def get_airport_by_iata(
        self, iata: str
    ) -> tuple[Optional[models.Airport], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(None, self.get_airport_by_iata_sync, iata)
        return res, sql

    def search_airports_sync(
        self,
        country: Optional[str] = None,
        city: Optional[str] = None,
        name: Optional[str] = None,
    ) -> tuple[list[models.Airport], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
               SELECT * FROM airports
                WHERE (:country IS NULL OR LOWER(country) LIKE CONCAT('%', LOWER(:country), '%'))
                AND (:city IS NULL OR LOWER(city) LIKE CONCAT('%', LOWER(:city), '%'))
                AND (:name IS NULL OR LOWER(name) LIKE CONCAT('%', LOWER(:name), '%'))
                LIMIT 10;
                """
            )
            params = {
                "country": country,
                "city": city,
                "name": name,
            }
            results = (conn.execute(s, parameters=params)).mappings().fetchall()

        res = [models.Airport.model_validate(r) for r in results]
        return res, None

    async def search_airports(
        self,
        country: Optional[str] = None,
        city: Optional[str] = None,
        name: Optional[str] = None,
    ) -> tuple[list[models.Airport], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(
            None, self.search_airports_sync, country, city, name
        )
        return res, sql

    def get_amenity_sync(
        self, id: int
    ) -> tuple[Optional[models.Amenity], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
                SELECT id, name, description, location, terminal, category, hour
                FROM amenities WHERE id=:id
                """
            )
            params = {"id": id}
            result = (conn.execute(s, parameters=params)).mappings().fetchone()

        if result is None:
            return None, None

        res = models.Amenity.model_validate(result)
        return res, None

    async def get_amenity(
        self, id: int
    ) -> tuple[Optional[models.Amenity], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(None, self.get_amenity_sync, id)
        return res, sql

    def amenities_search_sync(
        self, query_embedding: list[float], similarity_threshold: float, top_k: int
    ) -> tuple[list[Any], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
                SELECT name, description, location, terminal, category, hour
                  FROM amenities
                  ORDER BY APPROX_DISTANCE(embedding, string_to_vector(:query), 'distance_measure=cosine') LIMIT :search_options
                """
            )
            params = {
                "query": f"{query_embedding}",
                "search_options": top_k,
            }
            results = (conn.execute(s, parameters=params)).mappings().fetchall()

        res = [r for r in results]
        return res, None

    async def amenities_search(
        self, query_embedding: list[float], similarity_threshold: float, top_k: int
    ) -> tuple[list[Any], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(
            None,
            self.amenities_search_sync,
            query_embedding,
            similarity_threshold,
            top_k,
        )
        return res, sql

    def get_flight_sync(
        self, flight_id: int
    ) -> tuple[Optional[models.Flight], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
                SELECT * FROM flights
                  WHERE id = :flight_id
                """
            )
            params = {"flight_id": flight_id}
            result = (conn.execute(s, parameters=params)).mappings().fetchone()

        if result is None:
            return None, None

        res = models.Flight.model_validate(result)
        return res, None

    async def get_flight(
        self, flight_id: int
    ) -> tuple[Optional[models.Flight], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(None, self.get_flight_sync, flight_id)
        return res, sql

    def search_flights_by_number_sync(
        self,
        airline: str,
        number: str,
    ) -> tuple[list[models.Flight], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
                SELECT * FROM flights
                  WHERE airline = :airline
                  AND flight_number = :number
                  LIMIT 10
                """
            )
            params = {
                "airline": airline,
                "number": number,
            }
            results = (conn.execute(s, parameters=params)).mappings().fetchall()

        res = [models.Flight.model_validate(r) for r in results]
        return res, None

    async def search_flights_by_number(
        self,
        airline: str,
        number: str,
    ) -> tuple[list[models.Flight], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(
            None, self.search_flights_by_number_sync, airline, number
        )
        return res, sql

    def search_flights_by_airports_sync(
        self,
        date: str,
        departure_airport: Optional[str] = None,
        arrival_airport: Optional[str] = None,
    ) -> tuple[list[models.Flight], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
                SELECT * FROM flights
                  WHERE (CAST(:departure_airport AS CHAR(255)) IS NULL OR LOWER(departure_airport) LIKE LOWER(:departure_airport))
                  AND (CAST(:arrival_airport AS CHAR(255)) IS NULL OR LOWER(arrival_airport) LIKE LOWER(:arrival_airport))
                  AND departure_time >= CAST(:datetime AS DATETIME)
                  AND (departure_time < DATE_ADD(CAST(:datetime AS DATETIME), interval 1 day))
                  LIMIT 10
                """
            )
            params = {
                "departure_airport": departure_airport,
                "arrival_airport": arrival_airport,
                "datetime": datetime.strptime(date, "%Y-%m-%d"),
            }

            results = (conn.execute(s, parameters=params)).mappings().fetchall()

        res = [models.Flight.model_validate(r) for r in results]
        return res, None

    async def search_flights_by_airports(
        self,
        date: str,
        departure_airport: Optional[str] = None,
        arrival_airport: Optional[str] = None,
    ) -> tuple[list[models.Flight], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(
            None,
            self.search_flights_by_airports_sync,
            date,
            departure_airport,
            arrival_airport,
        )
        return res, sql

    def validate_ticket_sync(
        self,
        airline: str,
        flight_number: str,
        departure_airport: str,
        departure_time: str,
    ) -> tuple[Optional[models.Flight], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
                SELECT * FROM flights
                  WHERE LOWER(airline) LIKE LOWER(:airline)
                  AND LOWER(flight_number) LIKE LOWER(:flight_number)
                  AND LOWER(departure_airport) LIKE LOWER(:departure_airport)
                  AND departure_time =  CAST(:departure_time AS DATETIME)
                  LIMIT 10
                """
            )
            params = {
                "airline": airline,
                "flight_number": flight_number,
                "departure_airport": departure_airport,
                "departure_time": departure_time,
            }

            result = (conn.execute(s, parameters=params)).mappings().fetchone()
        if result is None:
            return None, None
        res = models.Flight.model_validate(result)
        return res, None

    async def validate_ticket(
        self,
        airline: str,
        flight_number: str,
        departure_airport: str,
        departure_time: str,
    ) -> tuple[Optional[models.Flight], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(
            None,
            self.validate_ticket_sync,
            airline,
            flight_number,
            departure_airport,
            departure_time,
        )
        return res, sql

    def insert_ticket_sync(
        self,
        user_id: str,
        user_name: str,
        user_email: str,
        airline: str,
        flight_number: str,
        departure_airport: str,
        arrival_airport: str,
        departure_time: str,
        arrival_time: str,
    ):
        with self.__pool.connect() as conn:
            s = text(
                """
                INSERT INTO tickets (
                    user_id,
                    user_name,
                    user_email,
                    airline,
                    flight_number,
                    departure_airport,
                    arrival_airport,
                    departure_time,
                    arrival_time
                ) VALUES (
                    :user_id,
                    :user_name,
                    :user_email,
                    :airline,
                    :flight_number,
                    :departure_airport,
                    :arrival_airport,
                    :departure_time,
                    :arrival_time
                );
            """
            )
            params = {
                "user_id": user_id,
                "user_name": user_name,
                "user_email": user_email,
                "airline": airline,
                "flight_number": flight_number,
                "departure_airport": departure_airport,
                "arrival_airport": arrival_airport,
                "departure_time": departure_time,
                "arrival_time": arrival_time,
            }
            conn.execute(s, params).mappings()

    async def insert_ticket(
        self,
        user_id: str,
        user_name: str,
        user_email: str,
        airline: str,
        flight_number: str,
        departure_airport: str,
        arrival_airport: str,
        departure_time: str,
        arrival_time: str,
    ):
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(
            None,
            self.insert_ticket_sync,
            user_id,
            user_name,
            user_email,
            airline,
            flight_number,
            departure_airport,
            arrival_airport,
            departure_time,
            arrival_time,
        )

    def list_tickets_sync(
        self,
        user_id: str,
    ) -> tuple[list[Any], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
                SELECT user_name, airline, flight_number, departure_airport, arrival_airport, departure_time, arrival_time FROM tickets
                WHERE user_id = :user_id
                """
            )
            params = {
                "user_id": user_id,
            }

            results = (conn.execute(s, parameters=params)).mappings().fetchall()

        res = [r for r in results]
        return res, None

    async def list_tickets(
        self,
        user_id: str,
    ) -> tuple[list[models.Ticket], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(None, self.list_tickets_sync, user_id)
        return res, sql

    def policies_search_sync(
        self, query_embedding: list[float], similarity_threshold: float, top_k: int
    ) -> tuple[list[str], Optional[str]]:
        with self.__pool.connect() as conn:
            s = text(
                """
                SELECT content
                  FROM policies 
                  ORDER BY APPROX_DISTANCE(embedding, string_to_vector(:query), 'distance_measure=cosine') LIMIT :search_options
                """
            )
            params = {
                "query": f"{query_embedding}",
                "search_options": top_k,
            }

            results = (conn.execute(s, parameters=params)).mappings().fetchall()

        res = [r["content"] for r in results]
        return res, None

    async def policies_search(
        self, query_embedding: list[float], similarity_threshold: float, top_k: int
    ) -> tuple[list[str], Optional[str]]:
        loop = asyncio.get_running_loop()
        res, sql = await loop.run_in_executor(
            None,
            self.policies_search_sync,
            query_embedding,
            similarity_threshold,
            top_k,
        )
        return res, sql

    async def close(self):
        self.__pool.dispose()
