# Copyright 2024 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from collections import defaultdict
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any

from aiosql.utils import VAR_REF

if TYPE_CHECKING:
    from aiosql.types import SQLOperationType


class MaybeAcquire:
    def __init__(self, client, driver=None) -> None:
        self.client = client
        self._driver = driver

    async def __aenter__(self):
        if "acquire" in dir(self.client):
            self._managed_conn = await self.client.acquire()
            return self._managed_conn
        self._managed_conn = None
        return self.client

    async def __aexit__(self, exc_type, exc, tb):
        if self._managed_conn is not None:
            await self.client.release(self._managed_conn)


class AsyncOracleDBAdapter:
    is_aio_driver = True

    def __init__(self) -> None:
        self.var_sorted: dict[str, Any] = defaultdict(list)  # type: ignore[assignment]

    def process_sql(self, query_name: str, op_type: SQLOperationType, sql: str) -> str:
        adj = 0

        for match in VAR_REF.finditer(sql):
            gd = match.groupdict()
            # Do nothing if the match is found within quotes.
            if gd["dquote"] is not None or gd["squote"] is not None:
                continue

            var_name = gd["var_name"]
            if var_name in self.var_sorted[query_name]:
                replacement = f"${self.var_sorted[query_name].index(var_name) + 1}"
            else:
                replacement = f"${len(self.var_sorted[query_name]) + 1}"
                self.var_sorted[query_name].append(var_name)

            # Determine the offset of the start and end of the original
            # variable that we are replacing, taking into account an adjustment
            # factor based on previous replacements (see the note below).
            start = match.start() + len(gd["lead"]) + adj
            end = match.end() + adj

            sql = sql[:start] + replacement + sql[end:]

            # If the replacement and original variable were different lengths,
            # then the offsets of subsequent matches will be wrong by the
            # difference.  Calculate an adjustment to apply to reconcile those
            # offsets with the modified string.
            #
            # The "- 1" is to account for the leading ":" character in the
            # original string.
            adj += len(replacement) - len(var_name) - 1

        return sql

    def maybe_order_params(self, query_name: str, parameters: dict | tuple | Any) -> list | tuple:
        if isinstance(parameters, dict):
            return [parameters[rk] for rk in self.var_sorted[query_name]]
        if isinstance(parameters, tuple):
            return parameters
        msg = f"Parameters expected to be dict or tuple, received {parameters}"
        raise ValueError(msg)

    async def select(self, conn, query_name, sql, parameters, record_class=None):
        parameters = self.maybe_order_params(query_name, parameters)
        async with MaybeAcquire(conn) as connection:
            results = await connection.fetch(sql, *parameters)
            if record_class is not None:
                results = [record_class(**dict(rec)) for rec in results]
        return results

    async def select_one(self, conn, query_name, sql, parameters, record_class=None):
        parameters = self.maybe_order_params(query_name, parameters)
        async with MaybeAcquire(conn) as connection:
            result = await connection.fetchrow(sql, *parameters)
            if result is not None and record_class is not None:
                result = record_class(**dict(result))
        return result

    async def select_value(self, conn, query_name, sql, parameters):
        parameters = self.maybe_order_params(query_name, parameters)
        async with MaybeAcquire(conn) as connection:
            return await connection.fetchval(sql, *parameters)

    @asynccontextmanager
    async def select_cursor(self, conn, query_name, sql, parameters):
        parameters = self.maybe_order_params(query_name, parameters)
        async with MaybeAcquire(conn) as connection:
            stmt = await connection.prepare(sql)
            async with connection.transaction():
                yield stmt.cursor(*parameters)

    async def insert_returning(self, conn, query_name, sql, parameters):
        parameters = self.maybe_order_params(query_name, parameters)
        async with MaybeAcquire(conn) as connection:
            res = await connection.fetchrow(sql, *parameters)
            if res:
                return res[0] if len(res) == 1 else res
            return None

    async def insert_update_delete(self, conn, query_name, sql, parameters):
        parameters = self.maybe_order_params(query_name, parameters)
        async with MaybeAcquire(conn) as connection:
            # TODO extract integer last result
            return await connection.execute(sql, *parameters)

    async def insert_update_delete_many(self, conn, query_name, sql, parameters):
        parameters = [self.maybe_order_params(query_name, params) for params in parameters]
        async with MaybeAcquire(conn) as connection:
            return await connection.executemany(sql, parameters)

    async def execute_script(self, conn, sql):
        async with MaybeAcquire(conn) as connection:
            return await connection.execute(sql)
