awswrangler/data_api/rds.py (319 lines of code) (raw):

"""RDS Data API Connector.""" from __future__ import annotations import datetime as dt import logging import time import uuid from decimal import Decimal from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast import boto3 from typing_extensions import Literal import awswrangler.pandas as pd from awswrangler import _data_types, _databases, _utils, exceptions from awswrangler._sql_utils import identifier from awswrangler.data_api import _connector if TYPE_CHECKING: from mypy_boto3_rds_data.client import BotocoreClientError # type: ignore[attr-defined] from mypy_boto3_rds_data.type_defs import BatchExecuteStatementResponseTypeDef, ExecuteStatementResponseTypeDef _logger = logging.getLogger(__name__) _ExecuteStatementResponseType = TypeVar( "_ExecuteStatementResponseType", "ExecuteStatementResponseTypeDef", "BatchExecuteStatementResponseTypeDef" ) class RdsDataApi(_connector.DataApiConnector): """Provides access to the RDS Data API. Parameters ---------- resource_arn ARN for the RDS resource. database Target database name. secret_arn The ARN for the secret to be used for authentication. sleep Number of seconds to sleep between connection attempts to paused clusters - defaults to 0.5. backoff Factor by which to increase the sleep between connection attempts to paused clusters - defaults to 1.0. retries Maximum number of connection attempts to paused clusters - defaults to 10. boto3_session The default boto3 session will be used if **boto3_session** is ``None``. """ def __init__( self, resource_arn: str, database: str, secret_arn: str = "", sleep: float = 0.5, backoff: float = 1.0, retries: int = 30, boto3_session: boto3.Session | None = None, ) -> None: super().__init__() self.client = _utils.client(service_name="rds-data", session=boto3_session) self.resource_arn = resource_arn self.database = database self.secret_arn = secret_arn self.wait_config = _connector.WaitConfig(sleep, backoff, retries) self.results: dict[str, "ExecuteStatementResponseTypeDef" | "BatchExecuteStatementResponseTypeDef"] = {} def close(self) -> None: """Close underlying endpoint connections.""" self.client.close() def begin_transaction(self, database: str | None = None, schema: str | None = None) -> str: """Start an SQL transaction.""" if database is None: database = self.database kwargs = {} if database: kwargs["database"] = database if schema: kwargs["schema"] = schema response = self.client.begin_transaction( resourceArn=self.resource_arn, secretArn=self.secret_arn, **kwargs, ) return response["transactionId"] def commit_transaction(self, transaction_id: str) -> str: """Commit an SQL transaction.""" response = self.client.commit_transaction( resourceArn=self.resource_arn, secretArn=self.secret_arn, transactionId=transaction_id, ) return response["transactionStatus"] def rollback_transaction(self, transaction_id: str) -> str: """Roll back an SQL transaction.""" response = self.client.rollback_transaction( resourceArn=self.resource_arn, secretArn=self.secret_arn, transactionId=transaction_id, ) return response["transactionStatus"] def _execute_with_retry( self, sql: str, function: Callable[[str], _ExecuteStatementResponseType], ) -> str: sleep: float = self.wait_config.sleep total_tries: int = 0 total_sleep: float = 0 response: _ExecuteStatementResponseType | None = None last_exception: "BotocoreClientError" | None = None while total_tries < self.wait_config.retries: try: response = function(sql) _logger.debug( "Response received after %s tries and sleeping for a total of %s seconds", total_tries, total_sleep ) break except self.client.exceptions.BadRequestException as exception: last_exception = exception total_sleep += sleep _logger.debug("BadRequestException occurred: %s", exception) _logger.debug( "Cluster may be paused - sleeping for %s seconds for a total of %s before retrying", sleep, total_sleep, ) time.sleep(sleep) total_tries += 1 sleep *= self.wait_config.backoff if response is None: _logger.exception("Maximum BadRequestException retries reached for query %s", sql) raise last_exception # type: ignore[misc] request_id: str = uuid.uuid4().hex self.results[request_id] = response return request_id def _execute_statement( self, sql: str, database: str | None = None, transaction_id: str | None = None, parameters: list[dict[str, Any]] | None = None, ) -> str: if database is None: database = self.database additional_kwargs: dict[str, Any] = {} if transaction_id: additional_kwargs["transactionId"] = transaction_id if parameters: additional_kwargs["parameters"] = parameters def function(sql: str) -> "ExecuteStatementResponseTypeDef": return self.client.execute_statement( resourceArn=self.resource_arn, database=database, sql=sql, secretArn=self.secret_arn, includeResultMetadata=True, **additional_kwargs, ) return self._execute_with_retry(sql=sql, function=function) def _batch_execute_statement( self, sql: str | list[str], database: str | None = None, transaction_id: str | None = None, parameter_sets: list[list[dict[str, Any]]] | None = None, ) -> str: if isinstance(sql, list): raise exceptions.InvalidArgumentType("`sql` parameter cannot be list.") if database is None: database = self.database additional_kwargs: dict[str, Any] = {} if transaction_id: additional_kwargs["transactionId"] = transaction_id if parameter_sets: additional_kwargs["parameterSets"] = parameter_sets def function(sql: str) -> "BatchExecuteStatementResponseTypeDef": return self.client.batch_execute_statement( resourceArn=self.resource_arn, database=database, sql=sql, secretArn=self.secret_arn, **additional_kwargs, ) return self._execute_with_retry(sql=sql, function=function) def _get_statement_result(self, request_id: str) -> pd.DataFrame: try: result = cast("ExecuteStatementResponseTypeDef", self.results.pop(request_id)) except KeyError as exception: raise KeyError(f"Request {request_id} not found in results {self.results}") from exception if "records" not in result: return pd.DataFrame() rows: list[list[Any]] = [] column_types = [col.get("typeName") for col in result["columnMetadata"]] for record in result["records"]: row: list[Any] = [ _connector.DataApiConnector._get_column_value(column, col_type) # type: ignore[arg-type] for column, col_type in zip(record, column_types) ] rows.append(row) column_names: list[str] = [column["name"] for column in result["columnMetadata"]] dataframe = pd.DataFrame(rows, columns=column_names) return dataframe def connect( resource_arn: str, database: str, secret_arn: str = "", boto3_session: boto3.Session | None = None, **kwargs: Any ) -> RdsDataApi: """Create a RDS Data API connection. Parameters ---------- resource_arn ARN for the RDS resource. database Target database name. secret_arn The ARN for the secret to be used for authentication. boto3_session The default boto3 session will be used if **boto3_session** is ``None``. **kwargs Any additional kwargs are passed to the underlying RdsDataApi class. Returns ------- A RdsDataApi connection instance that can be used with `wr.rds.data_api.read_sql_query`. """ return RdsDataApi(resource_arn, database, secret_arn=secret_arn, boto3_session=boto3_session, **kwargs) def read_sql_query( sql: str, con: RdsDataApi, database: str | None = None, parameters: list[dict[str, Any]] | None = None ) -> pd.DataFrame: """Run an SQL query on an RdsDataApi connection and return the result as a DataFrame. Parameters ---------- sql SQL query to run. con A RdsDataApi connection instance database Database to run query on - defaults to the database specified by `con`. parameters A list of named parameters e.g. [{"name": "col", "value": {"stringValue": "val1"}}]. Returns ------- A Pandas DataFrame containing the query results. Examples -------- >>> import awswrangler as wr >>> df = wr.data_api.rds.read_sql_query( >>> sql="SELECT * FROM public.my_table", >>> con=con, >>> ) >>> import awswrangler as wr >>> df = wr.data_api.rds.read_sql_query( >>> sql="SELECT * FROM public.my_table WHERE col = :name", >>> con=con, >>> parameters=[ >>> {"name": "col1", "value": {"stringValue": "val1"}} >>> ], >>> ) """ return con.execute(sql, database=database, parameters=parameters) def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, sql_mode: str) -> None: sql = f"DROP TABLE IF EXISTS {identifier(table, sql_mode=sql_mode)}" _logger.debug("Drop table query:\n%s", sql) con.execute(sql, database=database, transaction_id=transaction_id) def _does_table_exist(con: RdsDataApi, table: str, database: str, transaction_id: str) -> bool: res = con.execute( "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = :table", parameters=[ { "name": "table", "value": {"stringValue": table}, }, ], ) return not res.empty def _create_table( df: pd.DataFrame, con: RdsDataApi, table: str, database: str, transaction_id: str, mode: str, index: bool, dtype: dict[str, str] | None, varchar_lengths: dict[str, int] | None, sql_mode: str, ) -> None: if mode == "overwrite": _drop_table(con=con, table=table, database=database, transaction_id=transaction_id, sql_mode=sql_mode) elif _does_table_exist(con=con, table=table, database=database, transaction_id=transaction_id): return mysql_types: dict[str, str] = _data_types.database_types_from_pandas( df=df, index=index, dtype=dtype, varchar_lengths_default="TEXT", varchar_lengths=varchar_lengths, converter_func=_data_types.pyarrow2mysql, ) cols_str: str = "".join([f"{identifier(k, sql_mode=sql_mode)} {v},\n" for k, v in mysql_types.items()])[:-2] sql = f"CREATE TABLE IF NOT EXISTS {identifier(table, sql_mode=sql_mode)} (\n{cols_str})" _logger.debug("Create table query:\n%s", sql) con.execute(sql, database=database, transaction_id=transaction_id) def _create_value_dict( # noqa: PLR0911 value: Any, ) -> tuple[dict[str, Any], str | None]: if value is None or pd.isnull(value): return {"isNull": True}, None if isinstance(value, bool): return {"booleanValue": value}, None if isinstance(value, int): return {"longValue": value}, None if isinstance(value, float): return {"doubleValue": value}, None if isinstance(value, str): return {"stringValue": value}, None if isinstance(value, bytes): return {"blobValue": value}, None if isinstance(value, dt.datetime): return {"stringValue": str(value)}, "TIMESTAMP" if isinstance(value, dt.date): return {"stringValue": str(value)}, "DATE" if isinstance(value, dt.time): return {"stringValue": str(value)}, "TIME" if isinstance(value, Decimal): return {"stringValue": str(value)}, "DECIMAL" if isinstance(value, uuid.UUID): return {"stringValue": str(value)}, "UUID" raise exceptions.InvalidArgumentType(f"Value {value} not supported.") def _generate_parameters(columns: list[str], values: list[Any]) -> list[dict[str, Any]]: parameter_list = [] for col, value in zip(columns, values): value, type_hint = _create_value_dict(value) # noqa: PLW2901 parameter = { "name": col, "value": value, } if type_hint: parameter["typeHint"] = type_hint parameter_list.append(parameter) return parameter_list def _generate_parameter_sets(df: pd.DataFrame) -> list[list[dict[str, Any]]]: parameter_sets = [] columns = df.columns.tolist() for values in df.values.tolist(): parameter_sets.append(_generate_parameters(columns, values)) return parameter_sets def to_sql( df: pd.DataFrame, con: RdsDataApi, table: str, database: str, mode: Literal["append", "overwrite"] = "append", index: bool = False, dtype: dict[str, str] | None = None, varchar_lengths: dict[str, int] | None = None, use_column_names: bool = False, chunksize: int = 200, sql_mode: str = "mysql", ) -> None: """ Insert data using an SQL query on a Data API connection. Parameters ---------- df `Pandas DataFrame <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html>`_ con A RdsDataApi connection instance database Database to run query on - defaults to the database specified by `con`. table Table name mode `append` (inserts new records into table), `overwrite` (drops table and recreates) index True to store the DataFrame index as a column in the table, otherwise False to ignore it. dtype Dictionary of columns names and MySQL types to be casted. Useful when you have columns with undetermined or mixed data types. (e.g. ```{'col name': 'TEXT', 'col2 name': 'FLOAT'}```) varchar_lengths Dict of VARCHAR length by columns. (e.g. ```{"col1": 10, "col5": 200}```). use_column_names If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query. E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be inserted into the database columns `col1` and `col3`. chunksize Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query. sql_mode "mysql" for default MySQL identifiers (backticks) or "ansi" for ANSI-compatible identifiers (double quotes). """ if df.empty is True: raise exceptions.EmptyDataFrame("DataFrame cannot be empty.") _databases.validate_mode(mode=mode, allowed_modes=["append", "overwrite"]) transaction_id: str | None = None try: transaction_id = con.begin_transaction(database=database) _create_table( df=df, con=con, table=table, database=database, transaction_id=transaction_id, mode=mode, index=index, dtype=dtype, varchar_lengths=varchar_lengths, sql_mode=sql_mode, ) if index: df = df.reset_index(level=df.index.names) if use_column_names: insertion_columns = "(" + ", ".join([f"{identifier(col, sql_mode=sql_mode)}" for col in df.columns]) + ")" else: insertion_columns = "" placeholders = ", ".join([f":{col}" for col in df.columns]) sql = f"INSERT INTO {identifier(table, sql_mode=sql_mode)} {insertion_columns} VALUES ({placeholders})" parameter_sets = _generate_parameter_sets(df) for parameter_sets_chunk in _utils.chunkify(parameter_sets, max_length=chunksize): con.batch_execute( sql=sql, database=database, transaction_id=transaction_id, parameter_sets=parameter_sets_chunk, ) except Exception as ex: if transaction_id: con.rollback_transaction(transaction_id=transaction_id) _logger.error(ex) raise else: con.commit_transaction(transaction_id=transaction_id)