awswrangler/data_api/redshift.py (165 lines of code) (raw):
"""Redshift Data API Connector."""
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING, Any
import boto3
import awswrangler.pandas as pd
from awswrangler import _utils, exceptions
from awswrangler.data_api import _connector
if TYPE_CHECKING:
from mypy_boto3_redshift_data.client import RedshiftDataAPIServiceClient
from mypy_boto3_redshift_data.type_defs import ColumnMetadataTypeDef
_logger = logging.getLogger(__name__)
class RedshiftDataApi(_connector.DataApiConnector):
"""Provides access to a Redshift cluster via the Data API.
Note
----
When connecting to a standard Redshift cluster, `cluster_id` is used.
When connecting to Redshift Serverless, `workgroup_name` is used. These two arguments are mutually exclusive.
Parameters
----------
cluster_id
Id for the target Redshift cluster - only required if `workgroup_name` not provided.
database
Target database name.
workgroup_name
Name for the target serverless Redshift workgroup - only required if `cluster_id` not provided.
secret_arn
The ARN for the secret to be used for authentication - only required if `db_user` not provided.
db_user
The database user to generate temporary credentials for - only required if `secret_arn` not provided.
sleep: float
Number of seconds to sleep between result fetch attempts - defaults to 0.25.
backoff
Factor by which to increase the sleep between result fetch attempts - defaults to 1.5.
retries
Maximum number of result fetch attempts - defaults to 15.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
"""
def __init__(
self,
cluster_id: str = "",
database: str = "",
workgroup_name: str = "",
secret_arn: str = "",
db_user: str = "",
sleep: float = 0.25,
backoff: float = 1.5,
retries: int = 15,
boto3_session: boto3.Session | None = None,
) -> None:
super().__init__()
self.client = _utils.client(service_name="redshift-data", session=boto3_session)
self.cluster_id = cluster_id
self.database = database
self.workgroup_name = workgroup_name
self.secret_arn = secret_arn
self.db_user = db_user
self.waiter = RedshiftDataApiWaiter(self.client, sleep, backoff, retries)
def close(self) -> None:
"""Close underlying endpoint connections."""
self.client.close()
def begin_transaction(self, database: str | None = None, schema: str | None = None) -> str:
"""Start an SQL transaction."""
raise NotImplementedError("Redshift Data API does not support transactions.")
def commit_transaction(self, transaction_id: str) -> str:
"""Commit an SQL transaction."""
raise NotImplementedError("Redshift Data API does not support transactions.")
def rollback_transaction(self, transaction_id: str) -> str:
"""Roll back an SQL transaction."""
raise NotImplementedError("Redshift Data API does not support transactions.")
def _validate_redshift_target(self) -> None:
if not self.database:
raise ValueError("`database` must be set for connection")
if not self.cluster_id and not self.workgroup_name:
raise ValueError("Either `cluster_id` or `workgroup_name`(Redshift Serverless) must be set for connection")
def _validate_auth_method(self) -> None:
if not self.workgroup_name and not self.secret_arn and not self.db_user and not self.cluster_id:
raise exceptions.InvalidArgumentCombination(
"Either `secret_arn`, `workgroup_name`, `db_user`, or `cluster_id` must be set for authentication."
)
if self.db_user and self.secret_arn:
raise exceptions.InvalidArgumentCombination("Only one of `secret_arn` or `db_user` is allowed.")
def _execute_statement(
self,
sql: str,
database: str | None = None,
transaction_id: str | None = None,
parameters: list[dict[str, Any]] | None = None,
) -> str:
if transaction_id:
raise exceptions.InvalidArgument("`transaction_id` not supported for Redshift Data API")
self._validate_redshift_target()
self._validate_auth_method()
args = {}
if self.secret_arn:
args["SecretArn"] = self.secret_arn
if self.db_user:
args["DbUser"] = self.db_user
if database is None:
database = self.database
if self.cluster_id:
args["ClusterIdentifier"] = self.cluster_id
if self.workgroup_name:
args["WorkgroupName"] = self.workgroup_name
if parameters:
args["Parameters"] = parameters # type: ignore[assignment]
_logger.debug("Executing %s", sql)
response = self.client.execute_statement(
Database=database,
Sql=sql,
**args, # type: ignore[arg-type]
)
return response["Id"]
def _batch_execute_statement(
self,
sql: str | list[str],
database: str | None = None,
transaction_id: str | None = None,
parameter_sets: list[list[dict[str, Any]]] | None = None,
) -> str:
raise NotImplementedError("Batch execute statement not support for Redshift Data API.")
def _get_statement_result(self, request_id: str) -> pd.DataFrame:
self.waiter.wait(request_id)
describe_response = self.client.describe_statement(Id=request_id)
if not describe_response["HasResultSet"]:
return pd.DataFrame()
paginator = self.client.get_paginator("get_statement_result")
response_iterator = paginator.paginate(Id=request_id)
rows: list[list[Any]] = []
column_metadata: list["ColumnMetadataTypeDef"]
for response in response_iterator:
column_metadata = response["ColumnMetadata"]
for record in response["Records"]:
row: list[Any] = [
_connector.DataApiConnector._get_column_value(column) # type: ignore[arg-type]
for column in record
]
rows.append(row)
column_names: list[str] = [column["name"] for column in column_metadata]
dataframe = pd.DataFrame(rows, columns=column_names)
return dataframe
class RedshiftDataApiWaiter:
"""Waits for a DescribeStatement call to return a completed status.
Parameters
----------
client:
A Boto client with a `describe_statement` function, such as 'redshift-data'
sleep
Number of seconds to sleep between tries.
backoff
Factor by which to increase the sleep between tries.
retries
Maximum number of tries.
"""
def __init__(self, client: "RedshiftDataAPIServiceClient", sleep: float, backoff: float, retries: int) -> None:
self.client = client
self.wait_config = _connector.WaitConfig(sleep, backoff, retries)
def wait(self, request_id: str) -> bool:
"""Wait for the `describe_statement` function of self.client to return a completed status.
Parameters
----------
request_id
The execution id to check the status for.
Returns
-------
True if the execution finished without error.
Raises RedshiftDataApiExecutionFailedException if FAILED or ABORTED.
Raises RedshiftDataApiExecutionTimeoutException if retries exceeded before completion.
"""
sleep: float = self.wait_config.sleep
total_sleep: float = 0
total_tries: int = 0
while total_tries <= self.wait_config.retries:
response = self.client.describe_statement(Id=request_id)
status: str = response["Status"]
if status == "FINISHED":
return True
if status in ["ABORTED", "FAILED"]:
error = response["Error"]
raise RedshiftDataApiFailedException(
f"Request {request_id} failed with status {status} and error {error}"
)
_logger.debug("Statement execution status %s - sleeping for %s seconds", status, sleep)
time.sleep(sleep)
sleep = sleep * self.wait_config.backoff
total_tries += 1
total_sleep += sleep
raise RedshiftDataApiTimeoutException(
f"Request {request_id} timed out after {total_tries} tries and {total_sleep}s total sleep"
)
class RedshiftDataApiFailedException(Exception):
"""Indicates a statement execution was aborted or failed."""
class RedshiftDataApiTimeoutException(Exception):
"""Indicates a statement execution did not complete in the expected wait time."""
def connect(
cluster_id: str = "",
database: str = "",
workgroup_name: str = "",
secret_arn: str = "",
db_user: str = "",
boto3_session: boto3.Session | None = None,
**kwargs: Any,
) -> RedshiftDataApi:
"""Create a Redshift Data API connection.
Note
----
When connecting to a standard Redshift cluster, `cluster_id` is used.
When connecting to Redshift Serverless, `workgroup_name` is used. These two arguments are mutually exclusive.
Parameters
----------
cluster_id
Id for the target Redshift cluster - only required if `workgroup_name` not provided.
database
Target database name.
workgroup_name
Name for the target serverless Redshift workgroup - only required if `cluster_id` not provided.
secret_arn
The ARN for the secret to be used for authentication - only required if `db_user` not provided.
db_user
The database user to generate temporary credentials for - only required if `secret_arn` not provided.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
**kwargs
Any additional kwargs are passed to the underlying RedshiftDataApi class.
Returns
-------
A RedshiftDataApi connection instance that can be used with `wr.redshift.data_api.read_sql_query`.
"""
return RedshiftDataApi(
cluster_id=cluster_id,
database=database,
workgroup_name=workgroup_name,
secret_arn=secret_arn,
db_user=db_user,
boto3_session=boto3_session,
**kwargs,
)
def read_sql_query(
sql: str,
con: RedshiftDataApi,
database: str | None = None,
parameters: list[dict[str, Any]] | None = None,
) -> pd.DataFrame:
"""Run an SQL query on a RedshiftDataApi connection and return the result as a DataFrame.
Parameters
----------
sql
SQL query to run.
con
A RedshiftDataApi connection instance
database
Database to run query on - defaults to the database specified by `con`.
parameters
A list of named parameters e.g. [{"name": "id", "value": "42"}].
Returns
-------
A Pandas DataFrame containing the query results.
Examples
--------
>>> import awswrangler as wr
>>> df = wr.data_api.redshift.read_sql_query(
>>> sql="SELECT * FROM public.my_table",
>>> con=con,
>>> )
>>> import awswrangler as wr
>>> df = wr.data_api.redshift.read_sql_query(
>>> sql="SELECT * FROM public.my_table WHERE id >= :id",
>>> con=con,
>>> parameters=[
>>> {"name": "id", "value": "42"},
>>> ],
>>> )
"""
return con.execute(sql, database=database, parameters=parameters)