# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from contextlib import contextmanager
import copy
import logging
from typing import TYPE_CHECKING
import warnings

import google.oauth2.service_account
from google.cloud import bigquery
from google.api_core import client_options
import ibis
import pandas

from data_validation import client_info, consts, exceptions
from data_validation.secret_manager import SecretManagerBuilder
from third_party.ibis.ibis_cloud_spanner.api import spanner_connect
from third_party.ibis.ibis_impala.api import impala_connect
from third_party.ibis.ibis_mssql.api import mssql_connect
from third_party.ibis.ibis_redshift.api import redshift_connect

if TYPE_CHECKING:
    import ibis.expr.schema as sch
    import ibis.expr.types as ir


ibis.options.sql.default_limit = None

# Filter Ibis MySQL error when loading client.table()
warnings.filterwarnings(
    "ignore",
    "`BaseBackend.database` is deprecated; use equivalent methods in the backend",
)


IBIS_ALCHEMY_BACKENDS = [
    "mysql",
    "oracle",
    "postgres",
    "db2",
    "mssql",
    "redshift",
    "snowflake",
]


def _raise_missing_client_error(msg):
    def get_client_call(*args, **kwargs):
        raise Exception(msg)

    return get_client_call


# Teradata requires teradatasql and licensing
try:
    from third_party.ibis.ibis_teradata.api import teradata_connect
except Exception:
    msg = "pip install teradatasql (requires Teradata licensing)"
    teradata_connect = _raise_missing_client_error(msg)

# Oracle requires cx_Oracle driver
try:
    from third_party.ibis.ibis_oracle.api import oracle_connect
except Exception:
    oracle_connect = _raise_missing_client_error("pip install cx_Oracle")

# Snowflake requires snowflake-connector-python and snowflake-sqlalchemy
try:
    from third_party.ibis.ibis_snowflake.api import snowflake_connect
except Exception:
    snowflake_connect = _raise_missing_client_error(
        "pip install snowflake-connector-python && pip install snowflake-sqlalchemy"
    )

# DB2 requires ibm_db_sa
try:
    from third_party.ibis.ibis_db2.api import db2_connect
except Exception:
    db2_connect = _raise_missing_client_error("pip install ibm_db_sa")


def get_google_bigquery_client(
    project_id: str, credentials=None, api_endpoint: str = None
):
    info = client_info.get_http_client_info()
    job_config = bigquery.QueryJobConfig(
        connection_properties=[bigquery.ConnectionProperty("time_zone", "UTC")]
    )
    options = None
    if api_endpoint:
        options = client_options.ClientOptions(api_endpoint=api_endpoint)
    return bigquery.Client(
        project=project_id,
        client_info=info,
        credentials=credentials,
        default_query_job_config=job_config,
        client_options=options,
    )


def get_bigquery_client(
    project_id: str, dataset_id: str = "", credentials=None, api_endpoint: str = None
):
    google_client = get_google_bigquery_client(
        project_id, credentials=credentials, api_endpoint=api_endpoint
    )

    ibis_client = ibis.bigquery.connect(
        project_id=project_id,
        dataset_id=dataset_id,
        credentials=credentials,
    )

    # Override the BigQuery client object to ensure the correct user agent is
    # included and any api_endpoint is used.
    ibis_client.client = google_client
    return ibis_client


def get_pandas_client(table_name, file_path, file_type):
    """Return pandas client and env with file loaded into DataFrame

    table_name (str): Table name to use as reference for file data
    file_path (str): The local, s3, or GCS file path to the data
    file_type (str): The file type of the file (csv, json, orc or parquet)
    """
    if file_type == "csv":
        df = pandas.read_csv(file_path)
    elif file_type == "json":
        df = pandas.read_json(file_path)
    elif file_type == "orc":
        df = pandas.read_orc(file_path)
    elif file_type == "parquet":
        df = pandas.read_parquet(file_path)
    else:
        raise ValueError(f"Unknown Pandas File Type: {file_type}")

    pandas_client = ibis.pandas.connect({table_name: df})

    return pandas_client


def is_sqlalchemy_backend(client):
    try:
        return bool(client.name in IBIS_ALCHEMY_BACKENDS)
    except Exception:
        return False


def is_oracle_client(client):
    try:
        return client.name == "oracle"
    except TypeError:
        # When no Oracle backend has been installed OracleBackend is not a class
        return False


def get_ibis_table(client, schema_name, table_name, database_name=None):
    """Return Ibis Table for Supplied Client.

    client (IbisClient): Client to use for table
    schema_name (str): Schema name of table object
    table_name (str): Table name of table object
    database_name (str): Database name (generally default is used)
    """
    if client.name in [
        "oracle",
        "postgres",
        "db2",
        "mssql",
        "redshift",
    ]:
        return client.table(table_name, database=database_name, schema=schema_name)
    elif client.name == "pandas":
        return client.table(table_name, schema=schema_name)
    else:
        return client.table(table_name, database=schema_name)


def get_ibis_query(client, query) -> "ir.Table":
    """Return Ibis Table from query expression for Supplied Client."""
    iq = client.sql(query)
    # Normalise all columns in the query to lower case.
    # https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/992
    iq = iq.relabel(dict(zip(iq.columns, [_.lower() for _ in iq.columns])))
    return iq


def get_ibis_table_schema(client, schema_name: str, table_name: str) -> "sch.Schema":
    """Return Ibis Table Schema for Supplied Client.

    client (IbisClient): Client to use for table
    schema_name (str): Schema name of table object, may not need this since Backend uses database
    table_name (str): Table name of table object
    database_name (str): Database name (generally default is used)
    """
    if is_sqlalchemy_backend(client):
        return client.table(table_name, schema=schema_name).schema()
    else:
        return client.get_schema(table_name, schema_name)


def get_ibis_query_schema(client, query_str) -> "sch.Schema":
    if is_sqlalchemy_backend(client):
        ibis_query = get_ibis_query(client, query_str)
        return ibis_query.schema()
    else:
        # NJ: I'm not happy about calling a private method but don't see how I can avoid it.
        #     Ibis does not expose a public method like it does for get_schema().
        return client._get_schema_using_query(query_str)


def list_schemas(client):
    """Return a list of schemas in the DB."""
    if hasattr(client, "list_databases"):
        try:
            return client.list_databases()
        except NotImplementedError:
            return [None]
    else:
        return [None]


def list_tables(client, schema_name, tables_only=True):
    """Return a list of tables in the DB schema."""
    fn = (
        client.dvt_list_tables
        if tables_only and client.name != "pandas"
        else client.list_tables
    )
    if client.name in ["db2", "mssql", "redshift", "snowflake", "pandas"]:
        return fn()
    return fn(database=schema_name)


def get_all_tables(client, allowed_schemas=None, tables_only=True):
    """Return a list of tuples with database and table names.

    client (IbisClient): Client to use for tables
    allowed_schemas (List[str]): List of schemas to pull.
    """
    table_objs = []
    schemas = list_schemas(client)
    for schema_name in schemas:
        if allowed_schemas and schema_name not in allowed_schemas:
            continue
        try:
            tables = list_tables(client, schema_name, tables_only=tables_only)
        except Exception as e:
            logging.warning(f"List Tables Error: {schema_name} -> {e}")
            continue

        for table_name in tables:
            table_objs.append((schema_name, table_name))

    return table_objs


def get_data_client(connection_config):
    """Return DataClient client from given configuration"""
    connection_config = copy.deepcopy(connection_config)
    source_type = connection_config.pop(consts.SOURCE_TYPE)
    secret_manager_type = connection_config.pop(consts.SECRET_MANAGER_TYPE, None)
    secret_manager_project_id = connection_config.pop(
        consts.SECRET_MANAGER_PROJECT_ID, None
    )

    decrypted_connection_config = {}
    if secret_manager_type is not None:
        sm = SecretManagerBuilder().build(secret_manager_type.lower())
        for config_item in connection_config:
            decrypted_connection_config[config_item] = sm.maybe_secret(
                secret_manager_project_id, connection_config[config_item]
            )
    else:
        decrypted_connection_config = connection_config

    # The ibis_bigquery.connect expects a credentials object, not a string.
    if consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH in decrypted_connection_config:
        key_path = decrypted_connection_config.pop(
            consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH
        )
        if key_path:
            decrypted_connection_config[
                "credentials"
            ] = google.oauth2.service_account.Credentials.from_service_account_file(
                key_path
            )

    if source_type not in CLIENT_LOOKUP:
        msg = 'ConfigurationError: Source type "{source_type}" is not supported'.format(
            source_type=source_type
        )
        raise Exception(msg)

    try:
        data_client = CLIENT_LOOKUP[source_type](**decrypted_connection_config)
        data_client._source_type = source_type
    except Exception as e:
        msg = 'Connection Type "{source_type}" could not connect: {error}'.format(
            source_type=source_type, error=str(e)
        )
        raise exceptions.DataClientConnectionFailure(msg)

    return data_client


@contextmanager
def get_data_client_ctx(*args, **kwargs):
    """Provide get_data_client() via a context manager."""
    client = None
    try:
        client = get_data_client(*args, **kwargs)
        yield client
    finally:
        # TODO When we upgrade Ibis beyond 5.x this try/except may become redundant.
        # https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/1376
        if hasattr(client, "close"):
            try:
                client.close()
            except Exception as exc:
                # No need to reraise, we can silently fail if exiting throws up an issue.
                logging.warning("Exception closing connection: %s", str(exc))


def get_max_column_length(client):
    """Return the max column length supported by client.

    client (IbisClient): Client to use for tables
    """
    if is_oracle_client(client):
        # We can't reliably know which Version class client.version is stored in
        # because it is out of our control. Therefore using string identification
        # of Oracle <= 12.1 to avoid exceptions of this nature:
        #  TypeError: '<' not supported between instances of 'Version' and 'Version'
        if str(client.version)[:2] in ["10", "11"] or str(client.version)[:4] == "12.1":
            return 30
    return 128


def get_max_in_list_size(client, in_list_over_expressions=False):
    if client.name == "snowflake":
        if in_list_over_expressions:
            # This is a workaround for Snowflake limitation:
            #   SQL compilation error: In-list contains more than 50 non-constant values
            # getattr(..., "cast") expression above is looking for lists where the contents are casts and not simple literals.
            return 50
        else:
            return 16000
    elif is_oracle_client(client):
        # This is a workaround for Oracle limitation:
        #   ORA-01795: maximum number of expressions in a list is 1000
        return 1000
    else:
        return None


CLIENT_LOOKUP = {
    consts.SOURCE_TYPE_BIGQUERY: get_bigquery_client,
    consts.SOURCE_TYPE_IMPALA: impala_connect,
    consts.SOURCE_TYPE_MYSQL: ibis.mysql.connect,
    consts.SOURCE_TYPE_ORACLE: oracle_connect,
    consts.SOURCE_TYPE_FILESYSTEM: get_pandas_client,
    consts.SOURCE_TYPE_POSTGRES: ibis.postgres.connect,
    consts.SOURCE_TYPE_REDSHIFT: redshift_connect,
    consts.SOURCE_TYPE_TERADATA: teradata_connect,
    consts.SOURCE_TYPE_MSSQL: mssql_connect,
    consts.SOURCE_TYPE_SNOWFLAKE: snowflake_connect,
    consts.SOURCE_TYPE_SPANNER: spanner_connect,
    consts.SOURCE_TYPE_DB2: db2_connect,
}
