connectors/sources/notion.py (557 lines of code) (raw):

# # Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one # or more contributor license agreements. Licensed under the Elastic License 2.0; # you may not use this file except in compliance with the Elastic License 2.0. # """Notion source module responsible to fetch documents from the Notion Platform.""" import asyncio import json import os import re from copy import copy from functools import cached_property, partial from typing import Any, Awaitable, Callable from urllib.parse import unquote import aiohttp import fastjsonschema from aiohttp.client_exceptions import ClientResponseError from notion_client import APIResponseError, AsyncClient from connectors.filtering.validation import ( AdvancedRulesValidator, SyncRuleValidationResult, ) from connectors.logger import logger from connectors.source import BaseDataSource, ConfigurableFieldValueError from connectors.utils import CancellableSleeps, RetryStrategy, retryable RETRIES = 3 RETRY_INTERVAL = 2 DEFAULT_RETRY_SECONDS = 30 BASE_URL = "https://api.notion.com" MAX_CONCURRENT_CLIENT_SUPPORT = 30 if "OVERRIDE_URL" in os.environ: BASE_URL = os.environ["OVERRIDE_URL"] class NotFound(Exception): pass class NotionClient: """Notion API client""" def __init__(self, configuration): self._sleeps = CancellableSleeps() self.configuration = configuration self._logger = logger self.notion_secret_key = self.configuration["notion_secret_key"] def set_logger(self, logger_): self._logger = logger_ @cached_property def _get_client(self): return AsyncClient( auth=self.notion_secret_key, base_url=BASE_URL, ) @cached_property def session(self): """Generate aiohttp client session. Returns: aiohttp.ClientSession: An instance of Client Session """ connector = aiohttp.TCPConnector(limit=MAX_CONCURRENT_CLIENT_SUPPORT) return aiohttp.ClientSession( connector=connector, raise_for_status=True, ) @retryable( retries=RETRIES, interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=NotFound, ) async def get_via_session(self, url): self._logger.debug(f"Fetching data from url {url}") try: async with self.session.get(url=url) as response: yield response except ClientResponseError as e: if e.status == 429: retry_seconds = e.headers.get("Retry-After") or DEFAULT_RETRY_SECONDS self._logger.debug( f"Rate Limit reached: retry in {retry_seconds} seconds" ) await self._sleeps.sleep(retry_seconds) raise elif e.status == 404: raise NotFound from e else: raise @retryable( retries=RETRIES, interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=NotFound, ) async def fetch_results( self, function: Callable[..., Awaitable[Any]], next_cursor=None, **kwargs: Any ): try: return await function(start_cursor=next_cursor, **kwargs) except APIResponseError as exception: if exception.code == "rate_limited" or exception.status == 429: retry_after = ( exception.headers.get("retry-after") or DEFAULT_RETRY_SECONDS ) request_info = f"Request: {function.__name__} (next_cursor: {next_cursor}, kwargs: {kwargs})" self._logger.info( f"Connector will attempt to retry after {int(retry_after)} seconds. {request_info}" ) await self._sleeps.sleep(int(retry_after)) msg = "Rate limit exceeded." raise Exception(msg) from exception else: raise async def async_iterate_paginated_api( self, function: Callable[..., Awaitable[Any]], **kwargs: Any ): """Return an async iterator over the results of any paginated Notion API.""" next_cursor = kwargs.pop("start_cursor", None) while True: response = await self.fetch_results(function, next_cursor, **kwargs) if response: for result in response.get("results"): yield result next_cursor = response.get("next_cursor") if not response["has_more"] or next_cursor is None: return async def fetch_owner(self): """Fetch integration authorized owner""" await self._get_client.users.me() async def close(self): self._sleeps.cancel() await self._get_client.aclose() await self.session.close() del self._get_client del self.session async def fetch_users(self): """Iterate over user information retrieved from the API. Yields: dict: User document information excluding bots.""" async for user_document in self.async_iterate_paginated_api( self._get_client.users.list ): if user_document.get("type") != "bot": yield user_document async def fetch_child_blocks(self, block_id): """Fetch child blocks recursively for a given block ID. Args: block_id (str): The ID of the parent block. Yields: dict: Child block information.""" async def fetch_children_recursively(block): if block.get("has_children") is True: async for child_block in self.async_iterate_paginated_api( self._get_client.blocks.children.list, block_id=block.get("id") ): yield child_block async for grandchild in fetch_children_recursively(child_block): # pyright: ignore yield grandchild try: async for block in self.async_iterate_paginated_api( self._get_client.blocks.children.list, block_id=block_id ): if block.get("type") not in [ "child_database", "child_page", "unsupported", ]: yield block if block.get("has_children") is True: async for child in fetch_children_recursively(block): yield child if block.get("type") == "child_database": async for record in self.query_database(block.get("id")): yield record except APIResponseError as error: if error.code == "validation_error" and "external_object" in json.loads( error.body ).get("message"): self._logger.warning( f"Encountered external object with id: {block_id}. Skipping : {error}" ) elif error.code == "object_not_found": self._logger.warning(f"Object not found: {error}") else: raise async def fetch_by_query(self, query): async for document in self.async_iterate_paginated_api( self._get_client.search, **query ): yield document if query and query.get("filter", {}).get("value") == "database": async for database in self.query_database(document.get("id")): yield database async def fetch_comments(self, block_id): async for block_comment in self.async_iterate_paginated_api( self._get_client.comments.list, block_id=block_id ): yield block_comment async def query_database(self, database_id, body=None): if body is None: body = {} async for result in self.async_iterate_paginated_api( self._get_client.databases.query, database_id=database_id, **body ): yield result class NotionAdvancedRulesValidator(AdvancedRulesValidator): DATABASE_QUERY_DEFINITION = { "type": "array", "items": { "type": "object", "properties": { "database_id": {"type": "string", "minLength": 1}, "filter": { "type": "object", "properties": {"property": {"type": "string", "minLength": 1}}, "additionalProperties": { "type": ["object", "array"], }, }, }, "required": ["database_id"], }, } SEARCH_QUERY_DEFINITION = { "type": "array", "items": { "type": "object", "properties": { "query": {"type": "string"}, "filter": { "type": "object", "properties": { "value": {"type": "string", "enum": ["page", "database"]}, }, "required": ["value"], }, }, "required": ["query"], }, } RULES_OBJECT_SCHEMA_DEFINITION = { "type": "object", "properties": { "database_query_filters": DATABASE_QUERY_DEFINITION, "searches": SEARCH_QUERY_DEFINITION, }, "minProperties": 1, "additionalProperties": False, } SCHEMA = fastjsonschema.compile(definition=RULES_OBJECT_SCHEMA_DEFINITION) def __init__(self, source): self.source = source self._logger = logger async def validate(self, advanced_rules): if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES ) self._logger.info("Remote validation started") return await self._remote_validation(advanced_rules) async def _remote_validation(self, advanced_rules): try: NotionAdvancedRulesValidator.SCHEMA(advanced_rules) except fastjsonschema.JsonSchemaValueException as e: return SyncRuleValidationResult( rule_id=SyncRuleValidationResult.ADVANCED_RULES, is_valid=False, validation_message=e.message, ) invalid_database = [] databases = [] page_title = [] database_title = [] query = {"filter": {"property": "object", "value": "database"}} async for document in self.source.notion_client.async_iterate_paginated_api( self.source.notion_client._get_client.search, **query ): databases.append(document.get("id").replace("-", "")) for rule in advanced_rules: if "database_query_filters" in rule: self._logger.info("Validating database query filters") for database in advanced_rules.get("database_query_filters"): database_id = ( database.get("database_id").replace("-", "") if "-" in database.get("database_id") else database.get("database_id") ) if database_id not in databases: invalid_database.append(database.get("database_id")) if invalid_database: return SyncRuleValidationResult( SyncRuleValidationResult.ADVANCED_RULES, is_valid=False, validation_message=f"Invalid database id: {', '.join(invalid_database)}", ) if "searches" in rule: self._logger.info("Validating search filters") for database_page in advanced_rules.get("searches"): if database_page.get("filter", {}).get("value") == "page": page_title.append(database_page.get("query")) elif database_page.get("filter", {}).get("value") == "database": database_title.append(database_page.get("query")) try: if page_title: await self.source.get_entities("page", page_title) if database_title: await self.source.get_entities("database", database_title) except ConfigurableFieldValueError as error: return SyncRuleValidationResult( SyncRuleValidationResult.ADVANCED_RULES, is_valid=False, validation_message=str(error), ) self._logger.info("Remote validation successful") return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES ) class NotionDataSource(BaseDataSource): """Notion""" name = "Notion" service_type = "notion" advanced_rules_enabled = True incremental_sync_enabled = True def __init__(self, configuration): """Setup the connection to the Notion instance. Args: configuration (DataSourceConfiguration): Instance of DataSourceConfiguration class. """ super().__init__(configuration=configuration) self.notion_client = NotionClient(configuration=configuration) self.configuration = configuration self.index_comments = self.configuration["index_comments"] self.pages = self.configuration["pages"] self.databases = self.configuration["databases"] self._logger = logger self._sleeps = CancellableSleeps() self.concurrent_downloads = self.configuration["concurrent_downloads"] def _set_internal_logger(self): self.notion_client.set_logger(self._logger) async def ping(self): try: await self.notion_client.fetch_owner() self._logger.info("Successfully connected to Notion.") except Exception: self._logger.exception("Error while connecting to Notion.") raise async def close(self): await self.notion_client.close() async def get_entities(self, entity_type, entity_titles): """Search for a database or page with the given title.""" invalid_titles = [] found_titles = set() exact_match_results = [] if entity_titles != ["*"]: try: data = { "query": " ".join(entity_titles), "filter": {"value": entity_type, "property": "object"}, } search_results = [] async for response in self.notion_client.fetch_by_query(data): search_results.append(response) get_title = { "database": lambda result: result.get("title", [{}])[0] .get("plain_text", "") .lower(), "page": lambda result: result.get("properties", {}) .get("title", {}) .get("title", [{}])[0] .get("text", {}) .get("content", "") .lower(), }.get(entity_type) if get_title is not None: found_titles = {get_title(result) for result in search_results} exact_match_results = [ result for result in search_results if get_title(result).lower() in map(str.lower, entity_titles) ] invalid_titles = [ title for title in entity_titles if title.lower() not in found_titles ] except Exception as e: self._logger.exception(f"Error searching for {entity_type}: {e}") raise if invalid_titles: msg = f"Invalid {entity_type} titles found: {', '.join(invalid_titles)}" raise ConfigurableFieldValueError(msg) return exact_match_results async def validate_config(self): """Validates if user configured databases and pages are available in notion.""" await super().validate_config() await asyncio.gather( self.get_entities("page", self.configuration.get("pages", [])), self.get_entities("database", self.configuration.get("databases", [])), ) @classmethod def get_default_configuration(cls): """Get the default configuration for Notion. Returns: dict: Default configuration. """ return { "notion_secret_key": { "display": "text", "label": "Notion Secret Key", "order": 1, "required": True, "sensitive": True, "type": "str", }, "databases": { "label": "List of Databases", "display": "text", "order": 2, "required": True, "type": "list", }, "pages": { "label": "List of Pages", "display": "text", "order": 3, "required": True, "type": "list", }, "index_comments": { "display": "toggle", "label": "Enable indexing comments", "order": 4, "tooltip": "Enabling this will increase the amount of network calls to the source, and may decrease performance", "type": "bool", "value": False, }, "concurrent_downloads": { "default_value": 30, "display": "numeric", "label": "Maximum concurrent downloads", "order": 5, "required": False, "type": "int", "ui_restrictions": ["advanced"], }, } def advanced_rules_validators(self): return [NotionAdvancedRulesValidator(self)] def tweak_bulk_options(self, options): """Tweak bulk options as per concurrent downloads support by Notion Args: options (dict): Config bulker options. """ options["concurrent_downloads"] = self.concurrent_downloads async def get_file_metadata(self, attachment_metadata, file_url): response = await anext(self.notion_client.get_via_session(url=file_url)) attachment_metadata["extension"] = "." + response.url.path.split(".")[-1] attachment_metadata["size"] = response.content_length attachment_metadata["name"] = unquote(response.url.path.split("/")[-1]) return attachment_metadata async def get_content(self, attachment, file_url, timestamp=None, doit=False): """Extracts the content for Apache TIKA supported file types. Args: attachment (dictionary): Formatted attachment document. timestamp (timestamp, optional): Timestamp of attachment last modified. Defaults to None. doit (boolean, optional): Boolean value for whether to get content or not. Defaults to False. Returns: dictionary: Content document with _id, _timestamp and attachment content """ if not file_url: self._logger.debug( f"skipping attachment with id {attachment['id']} as url is empty" ) return attachment = await self.get_file_metadata(attachment, file_url) attachment_size = int(attachment["size"]) attachment_name = attachment["name"] attachment_extension = attachment["extension"] if not self.can_file_be_downloaded( attachment_extension, attachment_name, attachment_size ): return document = { "_id": f"{attachment['_id']}", "_timestamp": attachment["_timestamp"], } return await self.download_and_extract_file( document, attachment_name, attachment_extension, partial( self.generic_chunked_download_func, partial(self.notion_client.get_via_session, url=file_url), ), ) def _format_doc(self, data): """Format document for handling empty values & type casting. Args: data (dict): Fetched record from Notion. Returns: dict: Formatted document. """ data = {key: value for key, value in data.items() if value} data["_id"] = data["id"] if "last_edited_time" in data: data["_timestamp"] = data["last_edited_time"] if "properties" in data: data["details"] = str(data["properties"]) del data["properties"] return data def generate_query(self): if self.pages == ["*"] and self.databases == ["*"]: yield {} else: for page in self.pages: yield { "query": "" if page == "*" else page, "filter": {"value": "page", "property": "object"}, } for database in self.databases: yield { "query": "" if database == "*" else database, "filter": {"value": "database", "property": "object"}, } def is_connected_property_block(self, page_database): properties = page_database.get("properties") if properties is None: return False for field in properties.keys(): if re.match(r"^Related to.*\(.*\)$", field): return True return False async def retrieve_and_process_blocks(self, query): block_ids_store = [] async for page_database in self.notion_client.fetch_by_query(query=query): block_id = page_database.get("id") if self.index_comments is True: block_ids_store.append(block_id) yield self._format_doc(page_database), None self._logger.info(f"Fetching child blocks for block {block_id}") if self.is_connected_property_block(page_database): self._logger.debug( f"Skipping children of block with id: {block_id} as not supported by API" ) continue async for child_block in self.notion_client.fetch_child_blocks( block_id=block_id ): if self.index_comments is True: block_ids_store.append(child_block.get("id")) if child_block.get("type") != "file": yield self._format_doc(child_block), None else: file_url = child_block.get("file", {}).get("file", {}).get("url") child_block = self._format_doc(child_block) yield ( child_block, partial(self.get_content, copy(child_block), file_url), ) if self.index_comments is True: for block_id in block_ids_store: self._logger.info(f"Fetching comments for block {block_id}") async for comment in self.notion_client.fetch_comments( block_id=block_id ): yield self._format_doc(comment), None async def get_docs(self, filtering=None): """Executes the logic to fetch following Notion objects: Users, Pages, Databases, Files, Comments, Blocks, Child Blocks in async manner. Args: filtering (filtering, None): Filtering Rules. Defaults to None. Yields: dict: Documents from Notion. """ if filtering and filtering.has_advanced_rules(): advanced_rules = filtering.get_advanced_rules() for rule in advanced_rules: if "searches" in rule: for database_page in advanced_rules.get("searches"): if "filter" in database_page: database_page["filter"]["property"] = "object" self._logger.info( f"Fetching databases and pages using search query: {database_page}" ) async for data in self.retrieve_and_process_blocks( database_page ): yield data if "database_query_filters" in rule: for database_query_filter in advanced_rules.get( "database_query_filters" ): filter_criteria = ( {"filter": database_query_filter.get("filter")} if "filter" in database_query_filter else {} ) try: database_id = database_query_filter.get("database_id") self._logger.info( f"Fetching records for database with id: {database_id}" ) async for database in self.notion_client.query_database( database_id, filter_criteria ): yield self._format_doc(database), None except APIResponseError as e: msg = ( f"Please make sure to include correct filter field, {e}" ) raise ConfigurableFieldValueError(msg) from e else: self._logger.info("Fetching users") async for user_document in self.notion_client.fetch_users(): yield self._format_doc(user_document), None for query in self.generate_query(): self._logger.info(f"Fetching pages and databases using query {query}") async for data in self.retrieve_and_process_blocks(query): yield data