awswrangler/_sql_formatter.py (130 lines of code) (raw):
"""Formatting logic for SQL parameters."""
from __future__ import annotations
import datetime
import decimal
import re
from abc import ABC, abstractmethod
from typing import Any, Callable, Sequence
from typing_extensions import Literal
from awswrangler import exceptions
_EngineTypeLiteral = Literal["presto", "hive", "partiql"]
class _Engine(ABC):
def __init__(self, engine_name: _EngineTypeLiteral) -> None:
self.engine_name = engine_name
def format_null(self, value: None = None) -> str:
return "NULL"
@abstractmethod
def format_string(self, value: str) -> str:
pass
def format_bool(self, value: bool) -> str:
return str(value).upper()
def format_integer(self, value: int) -> str:
return str(value)
def format_float(self, value: float) -> str:
return f"{value:f}"
def format_decimal(self, value: decimal.Decimal) -> str:
return f"DECIMAL '{value:f}'"
def format_timestamp(self, value: datetime.datetime) -> str:
if value.tzinfo is not None:
raise TypeError(f"Supports only timezone naive datatype, got {value}.")
return f"TIMESTAMP '{value.isoformat(sep=' ', timespec='milliseconds')}'"
def format_date(self, value: datetime.date) -> str:
return f"DATE '{value.isoformat()}'"
def format_array(self, value: Sequence[Any]) -> str:
return f"ARRAY [{', '.join(map(self.format, value))}]"
def format_dict(self, value: dict[Any, Any]) -> str:
if not value:
return "MAP()"
map_keys = list(value.keys())
key_type = type(map_keys[0])
for key in map_keys:
if key is None:
raise TypeError("Map key cannot be null.")
if not isinstance(key, key_type):
raise TypeError("All Map key elements must be the same type.")
map_values = list(value.values())
return (
f"MAP(ARRAY [{', '.join(map(self.format, map_keys))}], ARRAY [{', '.join(map(self.format, map_values))}])"
)
def format(self, data: Any) -> str:
formats_dict: dict[type[Any], Callable[[Any], str]] = {
bool: self.format_bool,
str: self.format_string,
int: self.format_integer,
datetime.datetime: self.format_timestamp,
datetime.date: self.format_date,
decimal.Decimal: self.format_decimal,
float: self.format_float,
list: self.format_array,
tuple: self.format_array,
set: self.format_array,
dict: self.format_dict,
}
if data is None:
return self.format_null()
for python_type, format_func in formats_dict.items():
if isinstance(data, python_type):
return format_func(data)
raise TypeError(f"Unsupported type {type(data)} in parameter.")
class _PrestoEngine(_Engine):
def __init__(self) -> None:
super().__init__("presto")
def format_string(self, value: str) -> str:
return f"""'{value.replace("'", "''")}'"""
class _HiveEngine(_Engine):
def __init__(self) -> None:
super().__init__("hive")
def format_string(self, value: str) -> str:
return "'{}'".format(
value.replace("\\", "\\\\")
.replace("'", "\\'")
.replace("\r", "\\r")
.replace("\n", "\\n")
.replace("\t", "\\t")
)
class _PartiQLEngine(_Engine):
def __init__(self) -> None:
super().__init__("partiql")
def format_null(self, value: None = None) -> str:
return "null"
def format_string(self, value: str) -> str:
return f"""'{value.replace("'", "''")}'"""
def format_bool(self, value: bool) -> str:
return "1" if value else "0"
def format_decimal(self, value: decimal.Decimal) -> str:
return f"'{value}'"
def format_timestamp(self, value: datetime.datetime) -> str:
if value.tzinfo is not None:
raise TypeError(f"Supports only timezone naive datatype, got {value}.")
return f"'{value.isoformat()}'"
def format_date(self, value: datetime.date) -> str:
return f"'{value.isoformat()}'"
def format_array(self, value: Sequence[Any]) -> str:
raise NotImplementedError(f"format_array not implemented for engine={self.engine_name}.")
def format_dict(self, value: dict[Any, Any]) -> str:
raise NotImplementedError(f"format_dict not implemented for engine={self.engine_name}.")
def _format_parameters(params: dict[str, Any], engine: _Engine) -> dict[str, Any]:
processed_params = {}
for k, v in params.items():
processed_params[k] = engine.format(data=v)
return processed_params
_PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])")
def _create_engine(engine_type: _EngineTypeLiteral) -> _Engine:
if engine_type == "hive":
return _HiveEngine()
if engine_type == "presto":
return _PrestoEngine()
if engine_type == "partiql":
return _PartiQLEngine()
raise exceptions.InvalidArgumentValue(f"Unknown engine type: {engine_type}")
def _process_sql_params(sql: str, params: dict[str, Any] | None, engine_type: _EngineTypeLiteral = "presto") -> str:
if params is None:
params = {}
engine = _create_engine(engine_type)
processed_params = _format_parameters(params, engine=engine)
def replace(match: re.Match) -> str: # type: ignore[type-arg]
key = match.group(1)
if key not in processed_params:
# do not replace anything if the parameter is not provided
return str(match.group(0))
return str(processed_params[key])
sql = _PATTERN.sub(replace, sql)
return sql