#
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.
#
"""Notion source module responsible to fetch documents from the Notion Platform."""

import asyncio
import json
import os
import re
from copy import copy
from functools import cached_property, partial
from typing import Any, Awaitable, Callable
from urllib.parse import unquote

import aiohttp
import fastjsonschema
from aiohttp.client_exceptions import ClientResponseError
from notion_client import APIResponseError, AsyncClient

from connectors.filtering.validation import (
    AdvancedRulesValidator,
    SyncRuleValidationResult,
)
from connectors.logger import logger
from connectors.source import BaseDataSource, ConfigurableFieldValueError
from connectors.utils import CancellableSleeps, RetryStrategy, retryable

RETRIES = 3
RETRY_INTERVAL = 2
DEFAULT_RETRY_SECONDS = 30
BASE_URL = "https://api.notion.com"
MAX_CONCURRENT_CLIENT_SUPPORT = 30


if "OVERRIDE_URL" in os.environ:
    BASE_URL = os.environ["OVERRIDE_URL"]


class NotFound(Exception):
    pass


class NotionClient:
    """Notion API client"""

    def __init__(self, configuration):
        self._sleeps = CancellableSleeps()
        self.configuration = configuration
        self._logger = logger
        self.notion_secret_key = self.configuration["notion_secret_key"]

    def set_logger(self, logger_):
        self._logger = logger_

    @cached_property
    def _get_client(self):
        return AsyncClient(
            auth=self.notion_secret_key,
            base_url=BASE_URL,
        )

    @cached_property
    def session(self):
        """Generate aiohttp client session.

        Returns:
            aiohttp.ClientSession: An instance of Client Session
        """
        connector = aiohttp.TCPConnector(limit=MAX_CONCURRENT_CLIENT_SUPPORT)

        return aiohttp.ClientSession(
            connector=connector,
            raise_for_status=True,
        )

    @retryable(
        retries=RETRIES,
        interval=RETRY_INTERVAL,
        strategy=RetryStrategy.EXPONENTIAL_BACKOFF,
        skipped_exceptions=NotFound,
    )
    async def get_via_session(self, url):
        self._logger.debug(f"Fetching data from url {url}")
        try:
            async with self.session.get(url=url) as response:
                yield response
        except ClientResponseError as e:
            if e.status == 429:
                retry_seconds = e.headers.get("Retry-After") or DEFAULT_RETRY_SECONDS
                self._logger.debug(
                    f"Rate Limit reached: retry in {retry_seconds} seconds"
                )
                await self._sleeps.sleep(retry_seconds)
                raise
            elif e.status == 404:
                raise NotFound from e
            else:
                raise

    @retryable(
        retries=RETRIES,
        interval=RETRY_INTERVAL,
        strategy=RetryStrategy.EXPONENTIAL_BACKOFF,
        skipped_exceptions=NotFound,
    )
    async def fetch_results(
        self, function: Callable[..., Awaitable[Any]], next_cursor=None, **kwargs: Any
    ):
        try:
            return await function(start_cursor=next_cursor, **kwargs)
        except APIResponseError as exception:
            if exception.code == "rate_limited" or exception.status == 429:
                retry_after = (
                    exception.headers.get("retry-after") or DEFAULT_RETRY_SECONDS
                )
                request_info = f"Request: {function.__name__} (next_cursor: {next_cursor}, kwargs: {kwargs})"
                self._logger.info(
                    f"Connector will attempt to retry after {int(retry_after)} seconds. {request_info}"
                )
                await self._sleeps.sleep(int(retry_after))
                msg = "Rate limit exceeded."
                raise Exception(msg) from exception
            else:
                raise

    async def async_iterate_paginated_api(
        self, function: Callable[..., Awaitable[Any]], **kwargs: Any
    ):
        """Return an async iterator over the results of any paginated Notion API."""
        next_cursor = kwargs.pop("start_cursor", None)
        while True:
            response = await self.fetch_results(function, next_cursor, **kwargs)
            if response:
                for result in response.get("results"):
                    yield result
                next_cursor = response.get("next_cursor")
                if not response["has_more"] or next_cursor is None:
                    return

    async def fetch_owner(self):
        """Fetch integration authorized owner"""
        await self._get_client.users.me()

    async def close(self):
        self._sleeps.cancel()
        await self._get_client.aclose()
        await self.session.close()
        del self._get_client
        del self.session

    async def fetch_users(self):
        """Iterate over user information retrieved from the API.
        Yields:
        dict: User document information excluding bots."""
        async for user_document in self.async_iterate_paginated_api(
            self._get_client.users.list
        ):
            if user_document.get("type") != "bot":
                yield user_document

    async def fetch_child_blocks(self, block_id):
        """Fetch child blocks recursively for a given block ID.
        Args:
            block_id (str): The ID of the parent block.
        Yields:
            dict: Child block information."""

        async def fetch_children_recursively(block):
            if block.get("has_children") is True:
                async for child_block in self.async_iterate_paginated_api(
                    self._get_client.blocks.children.list, block_id=block.get("id")
                ):
                    yield child_block

                    async for grandchild in fetch_children_recursively(child_block):  # pyright: ignore
                        yield grandchild

        try:
            async for block in self.async_iterate_paginated_api(
                self._get_client.blocks.children.list, block_id=block_id
            ):
                if block.get("type") not in [
                    "child_database",
                    "child_page",
                    "unsupported",
                ]:
                    yield block
                    if block.get("has_children") is True:
                        async for child in fetch_children_recursively(block):
                            yield child
                if block.get("type") == "child_database":
                    async for record in self.query_database(block.get("id")):
                        yield record
        except APIResponseError as error:
            if error.code == "validation_error" and "external_object" in json.loads(
                error.body
            ).get("message"):
                self._logger.warning(
                    f"Encountered external object with id: {block_id}. Skipping : {error}"
                )
            elif error.code == "object_not_found":
                self._logger.warning(f"Object not found: {error}")
            else:
                raise

    async def fetch_by_query(self, query):
        async for document in self.async_iterate_paginated_api(
            self._get_client.search, **query
        ):
            yield document
            if query and query.get("filter", {}).get("value") == "database":
                async for database in self.query_database(document.get("id")):
                    yield database

    async def fetch_comments(self, block_id):
        async for block_comment in self.async_iterate_paginated_api(
            self._get_client.comments.list, block_id=block_id
        ):
            yield block_comment

    async def query_database(self, database_id, body=None):
        if body is None:
            body = {}
        async for result in self.async_iterate_paginated_api(
            self._get_client.databases.query, database_id=database_id, **body
        ):
            yield result


class NotionAdvancedRulesValidator(AdvancedRulesValidator):
    DATABASE_QUERY_DEFINITION = {
        "type": "array",
        "items": {
            "type": "object",
            "properties": {
                "database_id": {"type": "string", "minLength": 1},
                "filter": {
                    "type": "object",
                    "properties": {"property": {"type": "string", "minLength": 1}},
                    "additionalProperties": {
                        "type": ["object", "array"],
                    },
                },
            },
            "required": ["database_id"],
        },
    }

    SEARCH_QUERY_DEFINITION = {
        "type": "array",
        "items": {
            "type": "object",
            "properties": {
                "query": {"type": "string"},
                "filter": {
                    "type": "object",
                    "properties": {
                        "value": {"type": "string", "enum": ["page", "database"]},
                    },
                    "required": ["value"],
                },
            },
            "required": ["query"],
        },
    }

    RULES_OBJECT_SCHEMA_DEFINITION = {
        "type": "object",
        "properties": {
            "database_query_filters": DATABASE_QUERY_DEFINITION,
            "searches": SEARCH_QUERY_DEFINITION,
        },
        "minProperties": 1,
        "additionalProperties": False,
    }

    SCHEMA = fastjsonschema.compile(definition=RULES_OBJECT_SCHEMA_DEFINITION)

    def __init__(self, source):
        self.source = source
        self._logger = logger

    async def validate(self, advanced_rules):
        if len(advanced_rules) == 0:
            return SyncRuleValidationResult.valid_result(
                SyncRuleValidationResult.ADVANCED_RULES
            )

        self._logger.info("Remote validation started")
        return await self._remote_validation(advanced_rules)

    async def _remote_validation(self, advanced_rules):
        try:
            NotionAdvancedRulesValidator.SCHEMA(advanced_rules)
        except fastjsonschema.JsonSchemaValueException as e:
            return SyncRuleValidationResult(
                rule_id=SyncRuleValidationResult.ADVANCED_RULES,
                is_valid=False,
                validation_message=e.message,
            )
        invalid_database = []
        databases = []
        page_title = []
        database_title = []
        query = {"filter": {"property": "object", "value": "database"}}
        async for document in self.source.notion_client.async_iterate_paginated_api(
            self.source.notion_client._get_client.search, **query
        ):
            databases.append(document.get("id").replace("-", ""))
        for rule in advanced_rules:
            if "database_query_filters" in rule:
                self._logger.info("Validating database query filters")
                for database in advanced_rules.get("database_query_filters"):
                    database_id = (
                        database.get("database_id").replace("-", "")
                        if "-" in database.get("database_id")
                        else database.get("database_id")
                    )
                    if database_id not in databases:
                        invalid_database.append(database.get("database_id"))
                if invalid_database:
                    return SyncRuleValidationResult(
                        SyncRuleValidationResult.ADVANCED_RULES,
                        is_valid=False,
                        validation_message=f"Invalid database id: {', '.join(invalid_database)}",
                    )
            if "searches" in rule:
                self._logger.info("Validating search filters")
                for database_page in advanced_rules.get("searches"):
                    if database_page.get("filter", {}).get("value") == "page":
                        page_title.append(database_page.get("query"))
                    elif database_page.get("filter", {}).get("value") == "database":
                        database_title.append(database_page.get("query"))
                try:
                    if page_title:
                        await self.source.get_entities("page", page_title)
                    if database_title:
                        await self.source.get_entities("database", database_title)
                except ConfigurableFieldValueError as error:
                    return SyncRuleValidationResult(
                        SyncRuleValidationResult.ADVANCED_RULES,
                        is_valid=False,
                        validation_message=str(error),
                    )
        self._logger.info("Remote validation successful")
        return SyncRuleValidationResult.valid_result(
            SyncRuleValidationResult.ADVANCED_RULES
        )


class NotionDataSource(BaseDataSource):
    """Notion"""

    name = "Notion"
    service_type = "notion"
    advanced_rules_enabled = True
    incremental_sync_enabled = True

    def __init__(self, configuration):
        """Setup the connection to the Notion instance.

        Args:
            configuration (DataSourceConfiguration): Instance of DataSourceConfiguration class.
        """
        super().__init__(configuration=configuration)
        self.notion_client = NotionClient(configuration=configuration)

        self.configuration = configuration
        self.index_comments = self.configuration["index_comments"]
        self.pages = self.configuration["pages"]
        self.databases = self.configuration["databases"]
        self._logger = logger
        self._sleeps = CancellableSleeps()
        self.concurrent_downloads = self.configuration["concurrent_downloads"]

    def _set_internal_logger(self):
        self.notion_client.set_logger(self._logger)

    async def ping(self):
        try:
            await self.notion_client.fetch_owner()
            self._logger.info("Successfully connected to Notion.")
        except Exception:
            self._logger.exception("Error while connecting to Notion.")
            raise

    async def close(self):
        await self.notion_client.close()

    async def get_entities(self, entity_type, entity_titles):
        """Search for a database or page with the given title."""
        invalid_titles = []
        found_titles = set()
        exact_match_results = []
        if entity_titles != ["*"]:
            try:
                data = {
                    "query": " ".join(entity_titles),
                    "filter": {"value": entity_type, "property": "object"},
                }
                search_results = []
                async for response in self.notion_client.fetch_by_query(data):
                    search_results.append(response)

                get_title = {
                    "database": lambda result: result.get("title", [{}])[0]
                    .get("plain_text", "")
                    .lower(),
                    "page": lambda result: result.get("properties", {})
                    .get("title", {})
                    .get("title", [{}])[0]
                    .get("text", {})
                    .get("content", "")
                    .lower(),
                }.get(entity_type)
                if get_title is not None:
                    found_titles = {get_title(result) for result in search_results}
                    exact_match_results = [
                        result
                        for result in search_results
                        if get_title(result).lower() in map(str.lower, entity_titles)
                    ]
                invalid_titles = [
                    title
                    for title in entity_titles
                    if title.lower() not in found_titles
                ]
            except Exception as e:
                self._logger.exception(f"Error searching for {entity_type}: {e}")
                raise
            if invalid_titles:
                msg = f"Invalid {entity_type} titles found: {', '.join(invalid_titles)}"
                raise ConfigurableFieldValueError(msg)
        return exact_match_results

    async def validate_config(self):
        """Validates if user configured databases and pages are available in notion."""
        await super().validate_config()
        await asyncio.gather(
            self.get_entities("page", self.configuration.get("pages", [])),
            self.get_entities("database", self.configuration.get("databases", [])),
        )

    @classmethod
    def get_default_configuration(cls):
        """Get the default configuration for Notion.
        Returns:
            dict: Default configuration.
        """
        return {
            "notion_secret_key": {
                "display": "text",
                "label": "Notion Secret Key",
                "order": 1,
                "required": True,
                "sensitive": True,
                "type": "str",
            },
            "databases": {
                "label": "List of Databases",
                "display": "text",
                "order": 2,
                "required": True,
                "type": "list",
            },
            "pages": {
                "label": "List of Pages",
                "display": "text",
                "order": 3,
                "required": True,
                "type": "list",
            },
            "index_comments": {
                "display": "toggle",
                "label": "Enable indexing comments",
                "order": 4,
                "tooltip": "Enabling this will increase the amount of network calls to the source, and may decrease performance",
                "type": "bool",
                "value": False,
            },
            "concurrent_downloads": {
                "default_value": 30,
                "display": "numeric",
                "label": "Maximum concurrent downloads",
                "order": 5,
                "required": False,
                "type": "int",
                "ui_restrictions": ["advanced"],
            },
        }

    def advanced_rules_validators(self):
        return [NotionAdvancedRulesValidator(self)]

    def tweak_bulk_options(self, options):
        """Tweak bulk options as per concurrent downloads support by Notion

        Args:
            options (dict): Config bulker options.
        """

        options["concurrent_downloads"] = self.concurrent_downloads

    async def get_file_metadata(self, attachment_metadata, file_url):
        response = await anext(self.notion_client.get_via_session(url=file_url))
        attachment_metadata["extension"] = "." + response.url.path.split(".")[-1]
        attachment_metadata["size"] = response.content_length
        attachment_metadata["name"] = unquote(response.url.path.split("/")[-1])
        return attachment_metadata

    async def get_content(self, attachment, file_url, timestamp=None, doit=False):
        """Extracts the content for Apache TIKA supported file types.

        Args:
            attachment (dictionary): Formatted attachment document.
            timestamp (timestamp, optional): Timestamp of attachment last modified. Defaults to None.
            doit (boolean, optional): Boolean value for whether to get content or not. Defaults to False.

        Returns:
            dictionary: Content document with _id, _timestamp and attachment content
        """
        if not file_url:
            self._logger.debug(
                f"skipping attachment with id {attachment['id']} as url is empty"
            )
            return
        attachment = await self.get_file_metadata(attachment, file_url)
        attachment_size = int(attachment["size"])

        attachment_name = attachment["name"]
        attachment_extension = attachment["extension"]

        if not self.can_file_be_downloaded(
            attachment_extension, attachment_name, attachment_size
        ):
            return
        document = {
            "_id": f"{attachment['_id']}",
            "_timestamp": attachment["_timestamp"],
        }

        return await self.download_and_extract_file(
            document,
            attachment_name,
            attachment_extension,
            partial(
                self.generic_chunked_download_func,
                partial(self.notion_client.get_via_session, url=file_url),
            ),
        )

    def _format_doc(self, data):
        """Format document for handling empty values & type casting.

        Args:
            data (dict): Fetched record from Notion.

        Returns:
            dict: Formatted document.
        """
        data = {key: value for key, value in data.items() if value}
        data["_id"] = data["id"]
        if "last_edited_time" in data:
            data["_timestamp"] = data["last_edited_time"]

        if "properties" in data:
            data["details"] = str(data["properties"])
            del data["properties"]
        return data

    def generate_query(self):
        if self.pages == ["*"] and self.databases == ["*"]:
            yield {}
        else:
            for page in self.pages:
                yield {
                    "query": "" if page == "*" else page,
                    "filter": {"value": "page", "property": "object"},
                }
            for database in self.databases:
                yield {
                    "query": "" if database == "*" else database,
                    "filter": {"value": "database", "property": "object"},
                }

    def is_connected_property_block(self, page_database):
        properties = page_database.get("properties")
        if properties is None:
            return False
        for field in properties.keys():
            if re.match(r"^Related to.*\(.*\)$", field):
                return True

        return False

    async def retrieve_and_process_blocks(self, query):
        block_ids_store = []
        async for page_database in self.notion_client.fetch_by_query(query=query):
            block_id = page_database.get("id")
            if self.index_comments is True:
                block_ids_store.append(block_id)

            yield self._format_doc(page_database), None
            self._logger.info(f"Fetching child blocks for block {block_id}")

            if self.is_connected_property_block(page_database):
                self._logger.debug(
                    f"Skipping children of block with id: {block_id} as not supported by API"
                )
                continue

            async for child_block in self.notion_client.fetch_child_blocks(
                block_id=block_id
            ):
                if self.index_comments is True:
                    block_ids_store.append(child_block.get("id"))

                if child_block.get("type") != "file":
                    yield self._format_doc(child_block), None
                else:
                    file_url = child_block.get("file", {}).get("file", {}).get("url")
                    child_block = self._format_doc(child_block)
                    yield (
                        child_block,
                        partial(self.get_content, copy(child_block), file_url),
                    )

        if self.index_comments is True:
            for block_id in block_ids_store:
                self._logger.info(f"Fetching comments for block {block_id}")
                async for comment in self.notion_client.fetch_comments(
                    block_id=block_id
                ):
                    yield self._format_doc(comment), None

    async def get_docs(self, filtering=None):
        """Executes the logic to fetch following Notion objects: Users, Pages, Databases, Files,
           Comments, Blocks, Child Blocks in async manner.
        Args:
            filtering (filtering, None): Filtering Rules. Defaults to None.
        Yields:
            dict: Documents from Notion.
        """
        if filtering and filtering.has_advanced_rules():
            advanced_rules = filtering.get_advanced_rules()
            for rule in advanced_rules:
                if "searches" in rule:
                    for database_page in advanced_rules.get("searches"):
                        if "filter" in database_page:
                            database_page["filter"]["property"] = "object"
                        self._logger.info(
                            f"Fetching databases and pages using search query: {database_page}"
                        )
                        async for data in self.retrieve_and_process_blocks(
                            database_page
                        ):
                            yield data
                if "database_query_filters" in rule:
                    for database_query_filter in advanced_rules.get(
                        "database_query_filters"
                    ):
                        filter_criteria = (
                            {"filter": database_query_filter.get("filter")}
                            if "filter" in database_query_filter
                            else {}
                        )
                        try:
                            database_id = database_query_filter.get("database_id")
                            self._logger.info(
                                f"Fetching records for database with id: {database_id}"
                            )
                            async for database in self.notion_client.query_database(
                                database_id, filter_criteria
                            ):
                                yield self._format_doc(database), None
                        except APIResponseError as e:
                            msg = (
                                f"Please make sure to include correct filter field, {e}"
                            )
                            raise ConfigurableFieldValueError(msg) from e

        else:
            self._logger.info("Fetching users")
            async for user_document in self.notion_client.fetch_users():
                yield self._format_doc(user_document), None

            for query in self.generate_query():
                self._logger.info(f"Fetching pages and databases using query {query}")
                async for data in self.retrieve_and_process_blocks(query):
                    yield data
