awswrangler/postgresql.py (363 lines of code) (raw):
# mypy: disable-error-code=name-defined
"""Amazon PostgreSQL Module."""
from __future__ import annotations
import logging
import uuid
from ssl import SSLContext
from typing import TYPE_CHECKING, Any, Iterator, Literal, cast, overload
import boto3
import pyarrow as pa
import awswrangler.pandas as pd
from awswrangler import _data_types, _sql_utils, _utils, exceptions
from awswrangler import _databases as _db_utils
from awswrangler._config import apply_configs
if TYPE_CHECKING:
try:
import pg8000
from pg8000 import native as pg8000_native
except ImportError:
pass
else:
pg8000 = _utils.import_optional_dependency("pg8000")
pg8000_native = _utils.import_optional_dependency("pg8000.native")
_logger: logging.Logger = logging.getLogger(__name__)
def _identifier(sql: str) -> str:
return _sql_utils.identifier(sql, sql_mode="ansi")
def _validate_connection(con: "pg8000.Connection") -> None:
if not isinstance(con, pg8000.Connection):
raise exceptions.InvalidConnection(
"Invalid 'conn' argument, please pass a "
"pg8000.Connection object. Use pg8000.connect() to use "
"credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog."
)
def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None:
schema_str = f"{_identifier(schema)}." if schema else ""
cascade_str = "CASCADE" if cascade else "RESTRICT"
sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)} {cascade_str}"
_logger.debug("Drop table query:\n%s", sql)
cursor.execute(sql)
def _truncate_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None:
schema_str = f"{_identifier(schema)}." if schema else ""
cascade_str = "CASCADE" if cascade else "RESTRICT"
sql = f"TRUNCATE TABLE {schema_str}{_identifier(table)} {cascade_str}"
_logger.debug("Truncate table query:\n%s", sql)
cursor.execute(sql)
def _does_table_exist(cursor: "pg8000.Cursor", schema: str | None, table: str) -> bool:
schema_str = f"TABLE_SCHEMA = {pg8000_native.literal(schema)} AND" if schema else ""
cursor.execute(
f"SELECT true WHERE EXISTS ("
f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE "
f"{schema_str} TABLE_NAME = {pg8000_native.literal(table)}"
f");"
)
return len(cursor.fetchall()) > 0
def _create_table(
df: pd.DataFrame,
cursor: "pg8000.Cursor",
table: str,
schema: str,
mode: str,
overwrite_method: _ToSqlOverwriteModeLiteral,
index: bool,
dtype: dict[str, str] | None,
varchar_lengths: dict[str, int] | None,
unique_keys: list[str] | None = None,
) -> None:
if mode == "overwrite":
if overwrite_method in ["drop", "cascade"]:
_drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade"))
elif overwrite_method in ["truncate", "truncate cascade"]:
if _does_table_exist(cursor=cursor, schema=schema, table=table):
_truncate_table(
cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "truncate cascade")
)
else:
raise exceptions.InvalidArgumentValue(f"Invalid overwrite_method: {overwrite_method}")
elif _does_table_exist(cursor=cursor, schema=schema, table=table):
return
postgresql_types: dict[str, str] = _data_types.database_types_from_pandas(
df=df,
index=index,
dtype=dtype,
varchar_lengths_default="TEXT",
varchar_lengths=varchar_lengths,
converter_func=_data_types.pyarrow2postgresql,
)
cols_str: str = "".join([f"{_identifier(k)} {v},\n" for k, v in postgresql_types.items()])[:-2]
if unique_keys:
cols_str += f",\nUNIQUE ({', '.join([_identifier(k) for k in unique_keys])})"
sql = f"CREATE TABLE IF NOT EXISTS {_identifier(schema)}.{_identifier(table)} (\n{cols_str})"
_logger.debug("Create table query:\n%s", sql)
cursor.execute(sql)
def _iterate_server_side_cursor(
sql: str,
con: "pg8000.Connection",
chunksize: int,
index_col: str | list[str] | None,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None,
safe: bool,
dtype: dict[str, pa.DataType] | None,
timestamp_as_object: bool,
dtype_backend: Literal["numpy_nullable", "pyarrow"],
) -> Iterator[pd.DataFrame]:
"""
Iterate through the results using server-side cursor.
Note: Pg8000 is not fully DB API 2.0 - compliant with fetchmany() fetching all result set. Using server-side cursor
allows fetching only specific amount of results reducing memory impact. Ultimately we'd like pg8000 to add full
support for fetchmany() or add SSCursor implementation similar to MySQL and revise this implementation in the future.
"""
with con.cursor() as cursor:
sscursor_name: str = f"c_{uuid.uuid4().hex}"
cursor_args = _db_utils._convert_params(f"DECLARE {_identifier(sscursor_name)} CURSOR FOR {sql}", params)
cursor.execute(*cursor_args)
try:
while True:
cursor.execute(f"FETCH FORWARD {pg8000_native.literal(chunksize)} FROM {_identifier(sscursor_name)}")
records = cursor.fetchall()
if not records:
break
yield _db_utils._records2df(
records=records,
cols_names=_db_utils._get_cols_names(cursor.description),
index=index_col,
safe=safe,
dtype=dtype,
timestamp_as_object=timestamp_as_object,
dtype_backend=dtype_backend,
)
finally:
cursor.execute(f"CLOSE {_identifier(sscursor_name)}")
@_utils.check_optional_dependency(pg8000, "pg8000")
def connect(
connection: str | None = None,
secret_id: str | None = None,
catalog_id: str | None = None,
dbname: str | None = None,
boto3_session: boto3.Session | None = None,
ssl_context: bool | SSLContext | None = None,
timeout: int | None = None,
tcp_keepalive: bool = True,
) -> "pg8000.Connection":
"""Return a pg8000 connection from a Glue Catalog Connection.
https://github.com/tlocke/pg8000
Note
----
You MUST pass a `connection` OR `secret_id`.
Here is an example of the secret structure in Secrets Manager:
{
"host":"postgresql-instance-wrangler.dr8vkeyrb9m1.us-east-1.rds.amazonaws.com",
"username":"test",
"password":"test",
"engine":"postgresql",
"port":"3306",
"dbname": "mydb" # Optional
}
Parameters
----------
connection
Glue Catalog Connection name.
secret_id
Specifies the secret containing the connection details that you want to retrieve.
You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
catalog_id
The ID of the Data Catalog.
If none is provided, the AWS account ID is used by default.
dbname
Optional database name to overwrite the stored one.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
ssl_context
This governs SSL encryption for TCP/IP sockets.
This parameter is forward to pg8000.
https://github.com/tlocke/pg8000#functions
timeout
This is the time in seconds before the connection to the server will time out.
The default is None which means no timeout.
This parameter is forward to pg8000.
https://github.com/tlocke/pg8000#functions
tcp_keepalive
If ``True`` then use TCP keepalive. The default is ``True``.
This parameter is forwarded to pg8000.
https://github.com/tlocke/pg8000#functions
Returns
-------
pg8000 connection.
Examples
--------
>>> import awswrangler as wr
>>> with wr.postgresql.connect("MY_GLUE_CONNECTION") as con:
... with con.cursor() as cursor:
... cursor.execute("SELECT 1")
... print(cursor.fetchall())
"""
attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
)
if attrs.kind not in ("postgresql", "postgres"):
raise exceptions.InvalidDatabaseType(
f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)"
)
return pg8000.connect(
user=attrs.user,
database=attrs.database,
password=attrs.password,
port=attrs.port,
host=attrs.host,
ssl_context=ssl_context,
timeout=timeout,
tcp_keepalive=tcp_keepalive,
)
@overload
def read_sql_query(
sql: str,
con: "pg8000.Connection",
index_col: str | list[str] | None = ...,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = ...,
chunksize: None = ...,
dtype: dict[str, pa.DataType] | None = ...,
safe: bool = ...,
timestamp_as_object: bool = ...,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
) -> pd.DataFrame: ...
@overload
def read_sql_query(
sql: str,
con: "pg8000.Connection",
*,
index_col: str | list[str] | None = ...,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = ...,
chunksize: int,
dtype: dict[str, pa.DataType] | None = ...,
safe: bool = ...,
timestamp_as_object: bool = ...,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
) -> Iterator[pd.DataFrame]: ...
@overload
def read_sql_query(
sql: str,
con: "pg8000.Connection",
*,
index_col: str | list[str] | None = ...,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = ...,
chunksize: int | None,
dtype: dict[str, pa.DataType] | None = ...,
safe: bool = ...,
timestamp_as_object: bool = ...,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
@_utils.check_optional_dependency(pg8000, "pg8000")
def read_sql_query(
sql: str,
con: "pg8000.Connection",
index_col: str | list[str] | None = None,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = None,
chunksize: int | None = None,
dtype: dict[str, pa.DataType] | None = None,
safe: bool = True,
timestamp_as_object: bool = False,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
) -> pd.DataFrame | Iterator[pd.DataFrame]:
"""Return a DataFrame corresponding to the result set of the query string.
Parameters
----------
sql
SQL query.
con
Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
index_col
Column(s) to set as index(MultiIndex).
params
List of parameters to pass to execute method.
The syntax used to pass parameters is database driver dependent.
Check your database driver documentation for which of the five syntax styles,
described in PEP 249’s paramstyle, is supported.
chunksize
If specified, return an iterator where chunksize is the number of rows to include in each chunk.
dtype
Specifying the datatype for columns.
The keys should be the column names and the values should be the PyArrow types.
safe
Check for overflows or other unsafe data type conversions.
timestamp_as_object
Cast non-nanosecond timestamps (np.datetime64) to objects.
dtype_backend
Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays,
nullable dtypes are used for all dtypes that have a nullable implementation when
“numpy_nullable” is set, pyarrow is used for all dtypes if “pyarrow” is set.
The dtype_backends are still experimential. The "pyarrow" backend is only supported with Pandas 2.0 or above.
Returns
-------
Result as Pandas DataFrame(s).
Examples
--------
Reading from PostgreSQL using a Glue Catalog Connections
>>> import awswrangler as wr
>>> with wr.postgresql.connect("MY_GLUE_CONNECTION") as con:
... df = wr.postgresql.read_sql_query(
... sql="SELECT * FROM public.my_table",
... con=con,
... )
"""
_validate_connection(con=con)
if chunksize is not None:
return _iterate_server_side_cursor(
sql=sql,
con=con,
chunksize=chunksize,
index_col=index_col,
params=params,
safe=safe,
dtype=dtype,
timestamp_as_object=timestamp_as_object,
dtype_backend=dtype_backend,
)
return _db_utils.read_sql_query(
sql=sql,
con=con,
index_col=index_col,
params=params,
chunksize=None,
dtype=dtype,
safe=safe,
timestamp_as_object=timestamp_as_object,
dtype_backend=dtype_backend,
)
@overload
def read_sql_table(
table: str,
con: "pg8000.Connection",
schema: str | None = ...,
index_col: str | list[str] | None = ...,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = ...,
chunksize: None = ...,
dtype: dict[str, pa.DataType] | None = ...,
safe: bool = ...,
timestamp_as_object: bool = ...,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
) -> pd.DataFrame: ...
@overload
def read_sql_table(
table: str,
con: "pg8000.Connection",
*,
schema: str | None = ...,
index_col: str | list[str] | None = ...,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = ...,
chunksize: int,
dtype: dict[str, pa.DataType] | None = ...,
safe: bool = ...,
timestamp_as_object: bool = ...,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
) -> Iterator[pd.DataFrame]: ...
@overload
def read_sql_table(
table: str,
con: "pg8000.Connection",
*,
schema: str | None = ...,
index_col: str | list[str] | None = ...,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = ...,
chunksize: int | None,
dtype: dict[str, pa.DataType] | None = ...,
safe: bool = ...,
timestamp_as_object: bool = ...,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
@_utils.check_optional_dependency(pg8000, "pg8000")
def read_sql_table(
table: str,
con: "pg8000.Connection",
schema: str | None = None,
index_col: str | list[str] | None = None,
params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = None,
chunksize: int | None = None,
dtype: dict[str, pa.DataType] | None = None,
safe: bool = True,
timestamp_as_object: bool = False,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
) -> pd.DataFrame | Iterator[pd.DataFrame]:
"""Return a DataFrame corresponding the table.
Parameters
----------
table
Table name.
con
Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
schema
Name of SQL schema in database to query (if database flavor supports this).
Uses default schema if None (default).
index_col
Column(s) to set as index(MultiIndex).
params
List of parameters to pass to execute method.
The syntax used to pass parameters is database driver dependent.
Check your database driver documentation for which of the five syntax styles,
described in PEP 249’s paramstyle, is supported.
chunksize
If specified, return an iterator where chunksize is the number of rows to include in each chunk.
dtype
Specifying the datatype for columns.
The keys should be the column names and the values should be the PyArrow types.
safe
Check for overflows or other unsafe data type conversions.
timestamp_as_object
Cast non-nanosecond timestamps (np.datetime64) to objects.
dtype_backend
Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays,
nullable dtypes are used for all dtypes that have a nullable implementation when
“numpy_nullable” is set, pyarrow is used for all dtypes if “pyarrow” is set.
The dtype_backends are still experimential. The "pyarrow" backend is only supported with Pandas 2.0 or above.
Returns
-------
Result as Pandas DataFrame(s).
Examples
--------
Reading from PostgreSQL using a Glue Catalog Connections
>>> import awswrangler as wr
>>> with wr.postgresql.connect("MY_GLUE_CONNECTION") as con:
>>> df = wr.postgresql.read_sql_table(
... table="my_table",
... schema="public",
... con=con,
... )
"""
sql: str = (
f"SELECT * FROM {_identifier(table)}"
if schema is None
else f"SELECT * FROM {_identifier(schema)}.{_identifier(table)}"
)
return read_sql_query(
sql=sql,
con=con,
index_col=index_col,
params=params,
chunksize=chunksize,
dtype=dtype,
safe=safe,
timestamp_as_object=timestamp_as_object,
dtype_backend=dtype_backend,
)
_ToSqlModeLiteral = Literal["append", "overwrite", "upsert"]
_ToSqlOverwriteModeLiteral = Literal["drop", "cascade", "truncate", "truncate cascade"]
@_utils.check_optional_dependency(pg8000, "pg8000")
@apply_configs
def to_sql(
df: pd.DataFrame,
con: "pg8000.Connection",
table: str,
schema: str,
mode: _ToSqlModeLiteral = "append",
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
index: bool = False,
dtype: dict[str, str] | None = None,
varchar_lengths: dict[str, int] | None = None,
use_column_names: bool = False,
chunksize: int = 200,
upsert_conflict_columns: list[str] | None = None,
insert_conflict_columns: list[str] | None = None,
commit_transaction: bool = True,
) -> None:
"""Write records stored in a DataFrame into PostgreSQL.
Parameters
----------
df
`Pandas DataFrame <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html>`_
con
Use ``pg8000.connect()`` to use credentials directly or ``wr.postgresql.connect()`` to fetch it from the Glue Catalog.
table
Table name
schema
Schema name
mode
Append, overwrite or upsert.
- append: Inserts new records into table.
- overwrite: Drops table and recreates.
- upsert: Perform an upsert which checks for conflicts on columns given by ``upsert_conflict_columns`` and
sets the new values on conflicts. Note that ``upsert_conflict_columns`` is required for this mode.
overwrite_method
Drop, cascade, truncate, or truncate cascade. Only applicable in overwrite mode.
- "drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
- "cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
- "truncate" - ``TRUNCATE ... RESTRICT`` - truncates the table.
Fails if any of the tables have foreign-key references from tables that are not listed in the command.
- "truncate cascade" - ``TRUNCATE ... CASCADE`` - truncates the table, and all tables that have
foreign-key references to any of the named tables.
index
True to store the DataFrame index as a column in the table,
otherwise False to ignore it.
dtype
Dictionary of columns names and PostgreSQL types to be casted.
Useful when you have columns with undetermined or mixed data types.
(e.g. ``{'col name': 'TEXT', 'col2 name': 'FLOAT'}``)
varchar_lengths
Dict of VARCHAR length by columns. (e.g. ``{"col1": 10, "col5": 200}``).
use_column_names
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
inserted into the database columns `col1` and `col3`.
chunksize
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
upsert_conflict_columns
This parameter is only supported if `mode` is set top `upsert`. In this case conflicts for the given columns are
checked for evaluating the upsert.
insert_conflict_columns
This parameter is only supported if `mode` is set top `append`. In this case conflicts for the given columns are
checked for evaluating the insert 'ON CONFLICT DO NOTHING'.
commit_transaction
Whether to commit the transaction. True by default.
Examples
--------
Writing to PostgreSQL using a Glue Catalog Connections
>>> import awswrangler as wr
>>> with wr.postgresql.connect("MY_GLUE_CONNECTION") as con:
... wr.postgresql.to_sql(
... df=df,
... table="my_table",
... schema="public",
... con=con
... )
"""
if df.empty is True:
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
mode = cast(_ToSqlModeLiteral, mode.strip().lower())
allowed_modes = ["append", "overwrite", "upsert"]
_db_utils.validate_mode(mode=mode, allowed_modes=allowed_modes)
if mode == "upsert" and not upsert_conflict_columns:
raise exceptions.InvalidArgumentValue("<upsert_conflict_columns> needs to be set when using upsert mode.")
_validate_connection(con=con)
try:
with con.cursor() as cursor:
_create_table(
df=df,
cursor=cursor,
table=table,
schema=schema,
mode=mode,
overwrite_method=overwrite_method,
index=index,
dtype=dtype,
varchar_lengths=varchar_lengths,
unique_keys=upsert_conflict_columns or insert_conflict_columns,
)
if index:
df.reset_index(level=df.index.names, inplace=True)
column_placeholders: str = ", ".join(["%s"] * len(df.columns))
column_names = [_identifier(column) for column in df.columns]
insertion_columns = ""
upsert_str = ""
if use_column_names:
insertion_columns = f"({', '.join(column_names)})"
if mode == "upsert":
upsert_columns = ", ".join(f"{column}=EXCLUDED.{column}" for column in column_names)
conflict_columns = ", ".join(upsert_conflict_columns) # type: ignore[arg-type]
upsert_str = f" ON CONFLICT ({conflict_columns}) DO UPDATE SET {upsert_columns}"
if mode == "append" and insert_conflict_columns:
conflict_columns = ", ".join(insert_conflict_columns)
upsert_str = f" ON CONFLICT ({conflict_columns}) DO NOTHING"
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
df=df, column_placeholders=column_placeholders, chunksize=chunksize
)
for placeholders, parameters in placeholder_parameter_pair_generator:
sql: str = f"INSERT INTO {_identifier(schema)}.{_identifier(table)} {insertion_columns} VALUES {placeholders}{upsert_str}"
_logger.debug("sql: %s", sql)
cursor.executemany(sql, (parameters,))
if commit_transaction:
con.commit()
except Exception as ex:
con.rollback()
_logger.error(ex)
raise