connectors/sources/oracle.py (411 lines of code) (raw):

# # Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one # or more contributor license agreements. Licensed under the Elastic License 2.0; # you may not use this file except in compliance with the Elastic License 2.0. # """Oracle source module is responsible to fetch documents from Oracle.""" import asyncio import os from functools import cached_property, partial from urllib.parse import quote from asyncpg.exceptions._base import InternalClientError from sqlalchemy import create_engine, text from sqlalchemy.exc import ProgrammingError from connectors.source import BaseDataSource from connectors.sources.generic_database import ( DEFAULT_FETCH_SIZE, DEFAULT_RETRY_COUNT, Queries, configured_tables, fetch, is_wildcard, map_column_names, ) from connectors.utils import iso_utc DEFAULT_PROTOCOL = "TCP" DEFAULT_ORACLE_HOME = "" SID = "sid" SERVICE_NAME = "service_name" class OracleQueries(Queries): """Class contains methods which return query""" def ping(self): """Query to ping source""" return "SELECT 1+1 FROM DUAL" def all_tables(self, **kwargs): """Query to get all tables""" return ( f"SELECT TABLE_NAME FROM all_tables where OWNER = UPPER('{kwargs['user']}')" ) def table_primary_key(self, **kwargs): """Query to get the primary key""" return f"SELECT cols.column_name FROM all_constraints cons, all_cons_columns cols WHERE cols.table_name = '{kwargs['table']}' AND cons.constraint_type = 'P' AND cons.constraint_name = cols.constraint_name AND cons.owner = UPPER('{kwargs['user']}') AND cons.owner = cols.owner ORDER BY cols.table_name, cols.position" def table_data(self, **kwargs): """Query to get the table data""" return f"SELECT * FROM {kwargs['table']}" def table_last_update_time(self, **kwargs): """Query to get the last update time of the table""" return f"SELECT SCN_TO_TIMESTAMP(MAX(ora_rowscn)) from {kwargs['table']}" def table_data_count(self, **kwargs): """Query to get the number of rows in the table""" return f"SELECT COUNT(*) FROM {kwargs['table']}" def all_schemas(self): """Query to get all schemas of database""" pass # Multiple schemas not supported in Oracle class OracleClient: def __init__( self, host, port, user, password, connection_source, sid, service_name, tables, protocol, oracle_home, wallet_config, logger_, retry_count=DEFAULT_RETRY_COUNT, fetch_size=DEFAULT_FETCH_SIZE, ): self.host = host self.port = port self.user = user self.password = password self.connection_source = connection_source self.sid = sid self.service_name = service_name self.tables = tables self.protocol = protocol self.oracle_home = oracle_home self.wallet_config = wallet_config self.retry_count = retry_count self.fetch_size = fetch_size self.connection = None self.queries = OracleQueries() self._logger = logger_ def set_logger(self, logger_): self._logger = logger_ def close(self): if self.connection is not None: self.connection.close() @cached_property def engine(self): """Create sync engine for oracle""" if self.connection_source == SID: dsn = f"(DESCRIPTION=(ADDRESS=(PROTOCOL={self.protocol})(HOST={self.host})(PORT={self.port}))(CONNECT_DATA=(SID={self.sid})))" else: dsn = f"(DESCRIPTION=(ADDRESS=(PROTOCOL={self.protocol})(HOST={self.host})(PORT={self.port}))(CONNECT_DATA=(service_name={self.service_name})))" connection_string = ( f"oracle+oracledb://{self.user}:{quote(self.password)}@{dsn}" ) if self.oracle_home != "": os.environ["ORACLE_HOME"] = self.oracle_home return create_engine( connection_string, thick_mode={ "lib_dir": f"{self.oracle_home}/lib", "config_dir": self.wallet_config, }, ) else: return create_engine(connection_string) async def get_cursor(self, query): """Executes the passed query on the Non-Async supported Database server and return cursor. Args: query (str): Database query to be executed. Returns: cursor: Synchronous cursor """ self._logger.debug(f"Retrieving the cursor for query '{query}'") try: loop = asyncio.get_running_loop() if self.connection is None: self.connection = await loop.run_in_executor( executor=None, func=self.engine.connect, # pyright: ignore ) cursor = await loop.run_in_executor( executor=None, func=partial(self.connection.execute, statement=text(query)), ) return cursor except Exception as exception: self._logger.warning( f"Something went wrong while getting cursor; error: {exception}" ) raise async def ping(self): return await anext( fetch( cursor_func=partial(self.get_cursor, self.queries.ping()), fetch_size=1, retry_count=self.retry_count, ) ) async def get_tables_to_fetch(self): tables = configured_tables(self.tables) if is_wildcard(tables): self._logger.info( "Fetching all tables as the configuration field 'tables' is set to '*'" ) async for row in fetch( cursor_func=partial( self.get_cursor, self.queries.all_tables( user=self.user, ), ), fetch_size=self.fetch_size, retry_count=self.retry_count, ): yield row[0] else: self._logger.info(f"Fetching user-configured tables '{tables}'") for table in tables: yield table async def get_table_row_count(self, table): [row_count] = await anext( fetch( cursor_func=partial( self.get_cursor, self.queries.table_data_count( table=table, ), ), fetch_size=1, retry_count=self.retry_count, ) ) return row_count async def get_table_primary_key(self, table): self._logger.debug(f"Extracting primary keys for table '{table}'") primary_keys = [ key async for [key] in fetch( cursor_func=partial( self.get_cursor, self.queries.table_primary_key( user=self.user, table=table, ), ), fetch_size=self.fetch_size, retry_count=self.retry_count, ) ] self._logger.debug(f"Found primary keys for table '{table}'") return primary_keys async def get_table_last_update_time(self, table): self._logger.debug(f"Fetching last updated time for table '{table}'") [last_update_time] = await anext( fetch( cursor_func=partial( self.get_cursor, self.queries.table_last_update_time( table=table, ), ), fetch_size=1, retry_count=self.retry_count, ) ) return last_update_time async def data_streamer(self, table): """Streaming data from a table Args: table (str): Table. Raises: exception: Raise an exception after retrieving Yields: list: It will first yield the column names, then data in each row """ self._logger.debug(f"Streaming records from database for table '{table}'") record_count = 0 async for data in fetch( cursor_func=partial( self.get_cursor, self.queries.table_data( table=table, ), ), fetch_columns=True, fetch_size=self.fetch_size, retry_count=self.retry_count, ): record_count += 1 yield data self._logger.info(f"Found {record_count} records for table '{table}'") class OracleDataSource(BaseDataSource): """Oracle Database""" name = "Oracle Database" service_type = "oracle" def __init__(self, configuration): """Setup connection to the Oracle database-server configured by user Args: configuration (DataSourceConfiguration): Instance of DataSourceConfiguration class. """ super().__init__(configuration=configuration) self.database = ( self.configuration["sid"] if self.configuration["connection_source"] == SID else self.configuration["service_name"] ) self.oracle_client = OracleClient( host=self.configuration["host"], port=self.configuration["port"], user=self.configuration["username"], password=self.configuration["password"], connection_source=self.configuration["connection_source"], sid=self.configuration["sid"], service_name=self.configuration["service_name"], tables=self.configuration["tables"], protocol=self.configuration["oracle_protocol"], oracle_home=self.configuration["oracle_home"], wallet_config=self.configuration["wallet_configuration_path"], retry_count=self.configuration["retry_count"], fetch_size=self.configuration["fetch_size"], logger_=self._logger, ) def _set_internal_logger(self): self.oracle_client.set_logger(self._logger) @classmethod def get_default_configuration(cls): return { "host": { "label": "Host", "order": 1, "type": "str", }, "port": { "display": "numeric", "label": "Port", "order": 2, "type": "int", }, "username": { "label": "Username", "order": 3, "type": "str", }, "password": { "label": "Password", "order": 4, "sensitive": True, "type": "str", }, "connection_source": { "display": "dropdown", "label": "Connection Source", "options": [ {"label": "SID", "value": SID}, {"label": "Service Name", "value": SERVICE_NAME}, ], "order": 5, "type": "str", "value": SID, "tooltip": "Select 'Service Name' option if connecting to a pluggable database", }, "sid": { "depends_on": [{"field": "connection_source", "value": SID}], "label": "SID", "order": 6, "type": "str", }, "service_name": { "depends_on": [{"field": "connection_source", "value": SERVICE_NAME}], "label": "Service Name", "order": 7, "type": "str", }, "tables": { "display": "textarea", "label": "Comma-separated list of tables", "options": [], "order": 8, "type": "list", "value": "*", }, "fetch_size": { "default_value": DEFAULT_FETCH_SIZE, "display": "numeric", "label": "Rows fetched per request", "order": 9, "required": False, "type": "int", "ui_restrictions": ["advanced"], }, "retry_count": { "default_value": DEFAULT_RETRY_COUNT, "display": "numeric", "label": "Retries per request", "order": 10, "required": False, "type": "int", "ui_restrictions": ["advanced"], }, "oracle_protocol": { "default_value": DEFAULT_PROTOCOL, "display": "dropdown", "label": "Oracle connection protocol", "options": [ {"label": "TCP", "value": "TCP"}, {"label": "TCPS", "value": "TCPS"}, ], "order": 11, "type": "str", "value": DEFAULT_PROTOCOL, "ui_restrictions": ["advanced"], }, "oracle_home": { "default_value": DEFAULT_ORACLE_HOME, "label": "Path to Oracle Home", "order": 12, "required": False, "type": "str", "value": DEFAULT_ORACLE_HOME, "ui_restrictions": ["advanced"], }, "wallet_configuration_path": { "default_value": "", "label": "Path to SSL Wallet configuration files", "order": 13, "required": False, "type": "str", "ui_restrictions": ["advanced"], }, } async def close(self): self.oracle_client.close() async def ping(self): """Verify the connection with the database-server configured by user""" self._logger.debug("Validating that the Connector can connect to Oracle...") try: await self.oracle_client.ping() self._logger.debug("Successfully connected to Oracle") except Exception as e: msg = f"Can't connect to Oracle on {self.oracle_client.host}" raise Exception(msg) from e async def fetch_documents(self, table): """Fetches all the table entries and format them in Elasticsearch documents Args: table (str): Name of table Yields: Dict: Document to be indexed """ try: self._logger.info(f"Fetching records for table '{table}'") row_count = await self.oracle_client.get_table_row_count(table=table) if row_count > 0: # Query to get the table's primary key self._logger.debug(f"Total {row_count} rows found in table '{table}'") keys = await self.oracle_client.get_table_primary_key(table=table) keys = map_column_names(column_names=keys, tables=[table]) if keys: try: last_update_time = ( await self.oracle_client.get_table_last_update_time( table=table ) ) except Exception as e: self._logger.warning( f"Unable to fetch last updated time for table '{table}'; error: {e}" ) last_update_time = None streamer = self.oracle_client.data_streamer(table=table) column_names = await anext(streamer) column_names = map_column_names( column_names=column_names, tables=[table] ) async for row in streamer: row = dict(zip(column_names, row, strict=True)) keys_value = "" for key in keys: keys_value += f"{row.get(key)}_" if row.get(key) else "" row.update( { "_id": f"{self.database}_{table}_{keys_value}", "_timestamp": last_update_time or iso_utc(), "Database": self.database, "Table": table, } ) yield self.serialize(doc=row) else: self._logger.warning( f"Skipping '{table}' table from database '{self.database}' since no primary key is associated with it. Assign a primary key to the table to index it in the next sync interval." ) else: self._logger.warning(f"No records found for table '{table}'") except (InternalClientError, ProgrammingError) as exception: self._logger.warning( f"Something went wrong while fetching records from table '{table}'; error: {exception}" ) async def get_docs(self, filtering=None): """Executes the logic to fetch databases, tables and rows in async manner. Yields: dictionary: Row dictionary containing meta-data of the row. """ table_count = 0 async for table in self.oracle_client.get_tables_to_fetch(): table_count += 1 async for row in self.fetch_documents(table=table): yield row, None if table_count < 1: self._logger.warning(f"Fetched 0 tables for the database '{self.database}'")