awswrangler/data_api/_connector.py (99 lines of code) (raw):
"""Data API Connector base class."""
from __future__ import annotations
import datetime as dt
from abc import ABC, abstractmethod
from dataclasses import dataclass
from decimal import Decimal
from types import TracebackType
from typing import Any
import awswrangler.pandas as pd
class DataApiConnector(ABC):
"""Base class for Data API (RDS, Redshift, etc.) connectors."""
def execute(
self,
sql: str,
database: str | None = None,
transaction_id: str | None = None,
parameters: list[dict[str, Any]] | None = None,
) -> pd.DataFrame:
"""Execute SQL statement against a Data API Service.
Parameters
----------
sql
SQL statement to execute.
Returns
-------
A Pandas DataFrame containing the execution results.
"""
request_id: str = self._execute_statement(
sql, database=database, transaction_id=transaction_id, parameters=parameters
)
return self._get_statement_result(request_id)
def batch_execute(
self,
sql: str | list[str],
database: str | None = None,
transaction_id: str | None = None,
parameter_sets: list[list[dict[str, Any]]] | None = None,
) -> None:
"""Batch execute SQL statements against a Data API Service.
Parameters
----------
sql
SQL statement to execute.
"""
self._batch_execute_statement(
sql, database=database, transaction_id=transaction_id, parameter_sets=parameter_sets
)
def __enter__(self) -> "DataApiConnector":
return self
@abstractmethod
def close(self) -> None:
"""Close underlying endpoint connections."""
pass
def __exit__(
self,
exception_type: type[BaseException] | None,
exception_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
self.close()
return None
@abstractmethod
def begin_transaction(self, database: str | None = None, schema: str | None = None) -> str:
pass
@abstractmethod
def commit_transaction(self, transaction_id: str) -> str:
pass
@abstractmethod
def rollback_transaction(self, transaction_id: str) -> str:
pass
@abstractmethod
def _execute_statement(
self,
sql: str,
database: str | None = None,
transaction_id: str | None = None,
parameters: list[dict[str, Any]] | None = None,
) -> str:
pass
@abstractmethod
def _batch_execute_statement(
self,
sql: str | list[str],
database: str | None = None,
transaction_id: str | None = None,
parameter_sets: list[list[dict[str, Any]]] | None = None,
) -> str:
pass
@abstractmethod
def _get_statement_result(self, request_id: str) -> pd.DataFrame:
pass
@staticmethod
def _get_column_value( # noqa: PLR0911
column_value: dict[str, Any], col_type: str | None = None
) -> Any:
"""Return the first non-null key value for a given dictionary.
The key names for a given record depend on the column type: stringValue, longValue, etc.
Therefore, a record in the response does not have consistent key names. The ColumnMetadata
typeName information could be used to infer the key, but there is no direct mapping here
that could be easily parsed with creating a static dictionary:
varchar -> stringValue
int2 -> longValue
timestamp -> stringValue
What has been observed is that each record appears to have a single key, so this function
iterates over the keys and returns the first non-null value. If none are found, None is
returned.
Documentation:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-data.html#RedshiftDataAPIService.Client.get_statement_result
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds-data.html#RDSDataService.Client.execute_statement
"""
for key in column_value: # noqa: PLC0206
if column_value[key] is not None:
if (key == "isNull") and column_value[key]:
return None
if key == "arrayValue":
raise ValueError(f"arrayValue not supported yet - could not extract {column_value[key]}")
if key == "stringValue":
if col_type == "DATETIME":
return dt.datetime.strptime(column_value[key], "%Y-%m-%d %H:%M:%S")
if col_type == "DATE":
return dt.datetime.strptime(column_value[key], "%Y-%m-%d").date()
if col_type == "TIME":
return dt.datetime.strptime(column_value[key], "%H:%M:%S").time()
if col_type == "DECIMAL":
return Decimal(column_value[key])
return column_value[key]
return None
@dataclass
class WaitConfig:
"""Holds standard wait configuration values."""
sleep: float
backoff: float
retries: int