awswrangler/redshift/_utils.py (428 lines of code) (raw):
# mypy: disable-error-code=name-defined
"""Amazon Redshift Utils Module (PRIVATE)."""
from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Literal
import boto3
import botocore
import pandas as pd
from awswrangler import _data_types, _sql_utils, _utils, exceptions, s3
if TYPE_CHECKING:
try:
import redshift_connector
except ImportError:
pass
else:
redshift_connector = _utils.import_optional_dependency("redshift_connector")
_logger: logging.Logger = logging.getLogger(__name__)
_RS_DISTSTYLES: list[str] = ["AUTO", "EVEN", "ALL", "KEY"]
_RS_SORTSTYLES: list[str] = ["COMPOUND", "INTERLEAVED"]
def _identifier(sql: str) -> str:
return _sql_utils.identifier(sql, sql_mode="ansi")
def _make_s3_auth_string(
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_session_token: str | None = None,
iam_role: str | None = None,
boto3_session: boto3.Session | None = None,
) -> str:
if aws_access_key_id is not None and aws_secret_access_key is not None:
auth_str: str = f"ACCESS_KEY_ID '{aws_access_key_id}'\nSECRET_ACCESS_KEY '{aws_secret_access_key}'\n"
if aws_session_token is not None:
auth_str += f"SESSION_TOKEN '{aws_session_token}'\n"
elif iam_role is not None:
auth_str = f"IAM_ROLE '{iam_role}'\n"
else:
_logger.debug("Attempting to get S3 authorization credentials from boto3 session.")
credentials: botocore.credentials.ReadOnlyCredentials
credentials = _utils.get_credentials_from_session(boto3_session=boto3_session)
if credentials.access_key is None or credentials.secret_key is None:
raise exceptions.InvalidArgument(
"One of IAM Role or AWS ACCESS_KEY_ID and SECRET_ACCESS_KEY must be "
"given. Unable to find ACCESS_KEY_ID and SECRET_ACCESS_KEY in boto3 "
"session."
)
auth_str = f"ACCESS_KEY_ID '{credentials.access_key}'\nSECRET_ACCESS_KEY '{credentials.secret_key}'\n"
if credentials.token is not None:
auth_str += f"SESSION_TOKEN '{credentials.token}'\n"
return auth_str
def _begin_transaction(cursor: "redshift_connector.Cursor") -> None:
sql = "BEGIN TRANSACTION"
_logger.debug("Executing begin transaction query:\n%s", sql)
cursor.execute(sql)
def _drop_table(cursor: "redshift_connector.Cursor", schema: str | None, table: str, cascade: bool = False) -> None:
schema_str = f'"{schema}".' if schema else ""
cascade_str = " CASCADE" if cascade else ""
sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"{cascade_str}'
_logger.debug("Executing drop table query:\n%s", sql)
cursor.execute(sql)
def _truncate_table(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> None:
if schema:
sql = f"TRUNCATE TABLE {_identifier(schema)}.{_identifier(table)}"
else:
sql = f"TRUNCATE TABLE {_identifier(table)}"
_logger.debug("Executing truncate table query:\n%s", sql)
cursor.execute(sql)
def _delete_all(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> None:
if schema:
sql = f"DELETE FROM {_identifier(schema)}.{_identifier(table)}"
else:
sql = f"DELETE FROM {_identifier(table)}"
_logger.debug("Executing delete query:\n%s", sql)
cursor.execute(sql)
def _get_primary_keys(cursor: "redshift_connector.Cursor", schema: str, table: str) -> list[str]:
sql = f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{schema}' AND tablename = '{table}'"
_logger.debug("Executing select query:\n%s", sql)
cursor.execute(sql)
result: str = cursor.fetchall()[0][0]
rfields: list[str] = result.split("(")[1].strip(")").split(",")
fields: list[str] = [field.strip().strip('"') for field in rfields]
return fields
def _get_table_columns(cursor: "redshift_connector.Cursor", schema: str, table: str) -> list[str]:
sql = f"SELECT column_name FROM svv_columns\n WHERE table_schema = '{schema}' AND table_name = '{table}'"
_logger.debug("Executing select query:\n%s", sql)
cursor.execute(sql)
result: tuple[list[str]] = cursor.fetchall()
columns = ["".join(lst) for lst in result]
return columns
def _add_table_columns(
cursor: "redshift_connector.Cursor", schema: str, table: str, new_columns: dict[str, str]
) -> None:
for column_name, column_type in new_columns.items():
sql = (
f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}"
f"\nADD COLUMN {_identifier(column_name)} {column_type};"
)
_logger.debug("Executing alter query:\n%s", sql)
cursor.execute(sql)
def _does_table_exist(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> bool:
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
sql = (
f"SELECT true WHERE EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = '{table}');"
)
_logger.debug("Executing select query:\n%s", sql)
cursor.execute(sql)
return len(cursor.fetchall()) > 0
def _get_paths_from_manifest(path: str, boto3_session: boto3.Session | None = None) -> list[str]:
client_s3 = _utils.client(service_name="s3", session=boto3_session)
bucket, key = _utils.parse_path(path)
manifest_content = json.loads(client_s3.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8"))
paths = [path["url"] for path in manifest_content["entries"]]
_logger.debug("Read %d paths from manifest file in: %s", len(paths), path)
return paths
def _get_parameter_setting(cursor: "redshift_connector.Cursor", parameter_name: str) -> str:
sql = f"SHOW {parameter_name}"
_logger.debug("Executing select query:\n%s", sql)
cursor.execute(sql)
result = cursor.fetchall()
status = str(result[0][0])
_logger.debug(f"{parameter_name}='{status}'")
return status
def _lock(
cursor: "redshift_connector.Cursor",
table_names: list[str],
schema: str | None = None,
) -> None:
tables = ", ".join(
[(f"{_identifier(schema)}.{_identifier(table)}" if schema else _identifier(table)) for table in table_names]
)
sql: str = f"LOCK {tables};\n"
_logger.debug("Executing lock query:\n%s", sql)
cursor.execute(sql)
def _upsert(
cursor: "redshift_connector.Cursor",
table: str,
temp_table: str,
schema: str,
primary_keys: list[str] | None = None,
precombine_key: str | None = None,
column_names: list[str] | None = None,
) -> None:
if not primary_keys:
primary_keys = _get_primary_keys(cursor=cursor, schema=schema, table=table)
_logger.debug("primary_keys: %s", primary_keys)
if not primary_keys:
raise exceptions.InvalidRedshiftPrimaryKeys()
equals_clause: str = f"{_identifier(table)}.%s = {_identifier(temp_table)}.%s"
join_clause: str = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys])
if precombine_key:
delete_from_target_filter: str = (
f"AND {_identifier(table)}.{precombine_key} <= {_identifier(temp_table)}.{precombine_key}"
)
delete_from_temp_filter: str = (
f"AND {_identifier(table)}.{precombine_key} > {_identifier(temp_table)}.{precombine_key}"
)
target_del_sql: str = f"DELETE FROM {_identifier(schema)}.{_identifier(table)} USING {_identifier(temp_table)} WHERE {join_clause} {delete_from_target_filter}"
_logger.debug("Executing delete query:\n%s", target_del_sql)
cursor.execute(target_del_sql)
source_del_sql: str = f"DELETE FROM {_identifier(temp_table)} USING {_identifier(schema)}.{_identifier(table)} WHERE {join_clause} {delete_from_temp_filter}"
_logger.debug("Executing delete query:\n%s", source_del_sql)
cursor.execute(source_del_sql)
else:
sql: str = f"DELETE FROM {_identifier(schema)}.{_identifier(table)} USING {_identifier(temp_table)} WHERE {join_clause}"
_logger.debug("Executing delete query:\n%s", sql)
cursor.execute(sql)
if column_names:
column_names_str = ",".join(column_names)
insert_sql = f"INSERT INTO {_identifier(schema)}.{_identifier(table)}({column_names_str}) SELECT {column_names_str} FROM {_identifier(temp_table)}"
else:
insert_sql = f"INSERT INTO {_identifier(schema)}.{_identifier(table)} SELECT * FROM {_identifier(temp_table)}"
_logger.debug("Executing insert query:\n%s", insert_sql)
cursor.execute(insert_sql)
_drop_table(cursor=cursor, schema=schema, table=temp_table)
def _validate_parameters(
redshift_types: dict[str, str],
diststyle: str,
distkey: str | None,
sortstyle: str,
sortkey: list[str] | None,
primary_keys: list[str] | None,
) -> None:
if diststyle not in _RS_DISTSTYLES:
raise exceptions.InvalidRedshiftDiststyle(f"diststyle must be in {_RS_DISTSTYLES}")
cols = list(redshift_types.keys())
_logger.debug("Redshift columns: %s", cols)
if (diststyle == "KEY") and (not distkey):
raise exceptions.InvalidRedshiftDistkey("You must pass a distkey if you intend to use KEY diststyle")
if distkey and distkey not in cols:
raise exceptions.InvalidRedshiftDistkey(f"distkey ({distkey}) must be in the columns list: {cols})")
if sortstyle and sortstyle not in _RS_SORTSTYLES:
raise exceptions.InvalidRedshiftSortstyle(f"sortstyle must be in {_RS_SORTSTYLES}")
if sortkey:
if not isinstance(sortkey, list):
raise exceptions.InvalidRedshiftSortkey(
f"sortkey must be a List of items in the columns list: {cols}. Currently value: {sortkey}"
)
for key in sortkey:
if key not in cols:
raise exceptions.InvalidRedshiftSortkey(
f"sortkey must be a List of items in the columns list: {cols}. Currently value: {key}"
)
if primary_keys:
if not isinstance(primary_keys, list):
raise exceptions.InvalidArgumentType(
f"""
primary keys should be of type list[str].
Current value: {primary_keys} is of type {type(primary_keys)}
"""
)
def _redshift_types_from_path(
path: str | list[str],
data_format: Literal["parquet", "orc"],
varchar_lengths_default: int,
varchar_lengths: dict[str, int] | None,
parquet_infer_sampling: float,
path_suffix: str | None,
path_ignore_suffix: str | list[str] | None,
use_threads: bool | int,
boto3_session: boto3.Session | None,
s3_additional_kwargs: dict[str, str] | None,
) -> dict[str, str]:
"""Extract Redshift data types from a Pandas DataFrame."""
_varchar_lengths: dict[str, int] = {} if varchar_lengths is None else varchar_lengths
_logger.debug("Scanning parquet schemas in S3 path: %s", path)
if data_format == "orc":
athena_types, _ = s3.read_orc_metadata(
path=path,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
dataset=False,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
else:
athena_types, _ = s3.read_parquet_metadata(
path=path,
sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
dataset=False,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
_logger.debug("Parquet metadata types: %s", athena_types)
redshift_types: dict[str, str] = {}
for col_name, col_type in athena_types.items():
length: int = _varchar_lengths[col_name] if col_name in _varchar_lengths else varchar_lengths_default
redshift_types[col_name] = _data_types.athena2redshift(dtype=col_type, varchar_length=length)
_logger.debug("Converted redshift types: %s", redshift_types)
return redshift_types
def _get_rsh_columns_types(
df: pd.DataFrame | None,
path: str | list[str] | None,
index: bool,
dtype: dict[str, str] | None,
varchar_lengths_default: int,
varchar_lengths: dict[str, int] | None,
data_format: Literal["parquet", "orc", "csv"] = "parquet",
redshift_column_types: dict[str, str] | None = None,
parquet_infer_sampling: float = 1.0,
path_suffix: str | None = None,
path_ignore_suffix: str | list[str] | None = None,
manifest: bool | None = False,
use_threads: bool | int = True,
boto3_session: boto3.Session | None = None,
s3_additional_kwargs: dict[str, str] | None = None,
) -> dict[str, str]:
if df is not None:
redshift_types: dict[str, str] = _data_types.database_types_from_pandas(
df=df,
index=index,
dtype=dtype,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
converter_func=_data_types.pyarrow2redshift,
)
_logger.debug("Converted redshift types from pandas: %s", redshift_types)
elif path is not None:
if manifest:
if not isinstance(path, str):
raise TypeError(
f"""type: {type(path)} is not a valid type for 'path' when 'manifest' is set to True;
must be a string"""
)
path = _get_paths_from_manifest(
path=path,
boto3_session=boto3_session,
)
if data_format in ["parquet", "orc"]:
redshift_types = _redshift_types_from_path(
path=path,
data_format=data_format, # type: ignore[arg-type]
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
else:
if redshift_column_types is None:
raise ValueError(
"redshift_column_types is None. It must be specified for files formats other than Parquet or ORC."
)
redshift_types = redshift_column_types
else:
raise ValueError("df and path are None. You MUST pass at least one.")
return redshift_types
def _add_new_table_columns(
cursor: "redshift_connector.Cursor", schema: str, table: str, redshift_columns_types: dict[str, str]
) -> None:
# Check if Redshift is configured as case sensitive or not
is_case_sensitive = False
if _get_parameter_setting(cursor=cursor, parameter_name="enable_case_sensitive_identifier").lower() in [
"on",
"true",
]:
is_case_sensitive = True
# If it is case-insensitive, convert all the DataFrame columns names to lowercase before performing the comparison
if is_case_sensitive is False:
redshift_columns_types = {key.lower(): value for key, value in redshift_columns_types.items()}
actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table))
new_df_columns = {key: value for key, value in redshift_columns_types.items() if key not in actual_table_columns}
_add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns)
def _create_table( # noqa: PLR0913
df: pd.DataFrame | None,
path: str | list[str] | None,
con: "redshift_connector.Connection",
cursor: "redshift_connector.Cursor",
table: str,
schema: str,
mode: str,
overwrite_method: str,
index: bool,
dtype: dict[str, str] | None,
diststyle: str,
sortstyle: str,
distkey: str | None,
sortkey: list[str] | None,
primary_keys: list[str] | None,
varchar_lengths_default: int,
varchar_lengths: dict[str, int] | None,
data_format: Literal["parquet", "orc", "csv"] = "parquet",
redshift_column_types: dict[str, str] | None = None,
parquet_infer_sampling: float = 1.0,
path_suffix: str | None = None,
path_ignore_suffix: str | list[str] | None = None,
manifest: bool | None = False,
use_threads: bool | int = True,
boto3_session: boto3.Session | None = None,
s3_additional_kwargs: dict[str, str] | None = None,
lock: bool = False,
) -> tuple[str, str | None]:
_logger.debug("Creating table %s with mode %s, and overwrite method %s", table, mode, overwrite_method)
if mode == "overwrite":
if overwrite_method == "truncate":
try:
# Truncate commits current transaction, if successful.
# Fast, but not atomic.
_truncate_table(cursor=cursor, schema=schema, table=table)
except redshift_connector.error.ProgrammingError as e:
# Caught "relation does not exist".
if e.args[0]["C"] != "42P01":
raise e
_logger.debug(str(e))
con.rollback()
_begin_transaction(cursor=cursor)
if lock:
_lock(cursor, [table], schema=schema)
elif overwrite_method == "delete":
if _does_table_exist(cursor=cursor, schema=schema, table=table):
if lock:
_lock(cursor, [table], schema=schema)
# Atomic, but slow.
_delete_all(cursor=cursor, schema=schema, table=table)
else:
# Fast, atomic, but either fails if there are any dependent views or, in cascade mode, deletes them.
_drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade"))
# No point in locking here, the oid will change.
elif _does_table_exist(cursor=cursor, schema=schema, table=table) is True:
_logger.debug("Table %s exists", table)
if lock:
_lock(cursor, [table], schema=schema)
if mode == "upsert":
guid: str = uuid.uuid4().hex
temp_table: str = f"temp_redshift_{guid}"
sql: str = f"CREATE TEMPORARY TABLE {temp_table} (LIKE {_identifier(schema)}.{_identifier(table)})"
_logger.debug("Executing create temporary table query:\n%s", sql)
cursor.execute(sql)
return temp_table, None
return table, schema
diststyle = diststyle.upper() if diststyle else "AUTO"
sortstyle = sortstyle.upper() if sortstyle else "COMPOUND"
redshift_types = _get_rsh_columns_types(
df=df,
path=path,
index=index,
dtype=dtype,
varchar_lengths_default=varchar_lengths_default,
varchar_lengths=varchar_lengths,
parquet_infer_sampling=parquet_infer_sampling,
path_suffix=path_suffix,
path_ignore_suffix=path_ignore_suffix,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
data_format=data_format,
redshift_column_types=redshift_column_types,
manifest=manifest,
)
_validate_parameters(
redshift_types=redshift_types,
diststyle=diststyle,
distkey=distkey,
sortstyle=sortstyle,
sortkey=sortkey,
primary_keys=primary_keys,
)
cols_str: str = "".join([f'"{k}" {v},\n' for k, v in redshift_types.items()])[:-2]
primary_keys_str: str = (
",\nPRIMARY KEY ({})".format(", ".join('"' + pk + '"' for pk in primary_keys)) if primary_keys else ""
)
distkey_str: str = f"\nDISTKEY({distkey})" if distkey and diststyle == "KEY" else ""
sortkey_str: str = f"\n{sortstyle} SORTKEY({','.join(sortkey)})" if sortkey else ""
sql = (
f"CREATE TABLE IF NOT EXISTS {_identifier(schema)}.{_identifier(table)} (\n"
f"{cols_str}"
f"{primary_keys_str}"
f")\nDISTSTYLE {diststyle}"
f"{distkey_str}"
f"{sortkey_str}"
)
_logger.debug("Executing create table query:\n%s", sql)
cursor.execute(sql)
_logger.info("Created table %s", table)
if lock:
_lock(cursor, [table], schema=schema)
return table, schema