awswrangler/dynamodb/_utils.py (166 lines of code) (raw):
"""Amazon DynamoDB Utils Module (PRIVATE)."""
from __future__ import annotations
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Any, Iterator, Mapping, TypedDict
import boto3
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from botocore.exceptions import ClientError
from typing_extensions import NotRequired, Required
from awswrangler import _utils, exceptions
from awswrangler._config import apply_configs
from awswrangler.annotations import Deprecated
if TYPE_CHECKING:
from mypy_boto3_dynamodb.client import DynamoDBClient
from mypy_boto3_dynamodb.service_resource import Table
from mypy_boto3_dynamodb.type_defs import (
AttributeValueTypeDef,
ExecuteStatementOutputTypeDef,
KeySchemaElementTypeDef,
TableAttributeValueTypeDef,
WriteRequestOutputTypeDef,
)
_logger: logging.Logger = logging.getLogger(__name__)
@apply_configs
@Deprecated
def get_table(
table_name: str,
boto3_session: boto3.Session | None = None,
) -> "Table":
"""Get DynamoDB table object for specified table name.
Parameters
----------
table_name
Name of the Amazon DynamoDB table.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
Returns
-------
Boto3 DynamoDB.Table object.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Table
"""
dynamodb_resource = _utils.resource(service_name="dynamodb", session=boto3_session)
dynamodb_table = dynamodb_resource.Table(table_name)
return dynamodb_table
def _serialize_item(
item: Mapping[str, "TableAttributeValueTypeDef"], serializer: TypeSerializer | None = None
) -> dict[str, "AttributeValueTypeDef"]:
serializer = serializer if serializer else TypeSerializer()
return {k: serializer.serialize(v) for k, v in item.items()}
def _deserialize_item(
item: Mapping[str, "AttributeValueTypeDef"], deserializer: TypeDeserializer | None = None
) -> dict[str, "TableAttributeValueTypeDef"]:
deserializer = deserializer if deserializer else TypeDeserializer()
return {k: deserializer.deserialize(v) for k, v in item.items()}
class _ReadExecuteStatementKwargs(TypedDict):
Statement: Required[str]
ConsistentRead: Required[bool]
Parameters: NotRequired[list["AttributeValueTypeDef"]]
NextToken: NotRequired[str]
def _execute_statement(
kwargs: _ReadExecuteStatementKwargs,
dynamodb_client: "DynamoDBClient",
) -> "ExecuteStatementOutputTypeDef":
try:
response = dynamodb_client.execute_statement(**kwargs)
except ClientError as err:
if err.response["Error"]["Code"] == "ResourceNotFoundException":
_logger.error("Couldn't execute PartiQL: '%s' because the table does not exist.", kwargs["Statement"])
else:
_logger.error(
"Couldn't execute PartiQL: '%s'. %s: %s",
kwargs["Statement"],
err.response["Error"]["Code"],
err.response["Error"]["Message"],
)
raise
return response
def _read_execute_statement(
kwargs: _ReadExecuteStatementKwargs,
dynamodb_client: "DynamoDBClient",
) -> Iterator[list[dict[str, Any]]]:
next_token: str | None = "init_token" # Dummy token
deserializer = TypeDeserializer()
while next_token:
response = _execute_statement(kwargs=kwargs, dynamodb_client=dynamodb_client)
yield [_deserialize_item(item, deserializer) for item in response["Items"]]
next_token = response.get("NextToken", None)
if next_token:
kwargs["NextToken"] = next_token
def execute_statement(
statement: str,
parameters: list[Any] | None = None,
consistent_read: bool = False,
boto3_session: boto3.Session | None = None,
) -> Iterator[list[dict[str, Any]]] | None:
"""Run a PartiQL statement against a DynamoDB table.
Parameters
----------
statement
The PartiQL statement.
parameters
The list of PartiQL parameters. These are applied to the statement in the order they are listed.
consistent_read
The consistency of a read operation. If `True`, then a strongly consistent read is used. False by default.
boto3_session
Boto3 Session. If None, the default boto3 Session is used.
Returns
-------
An iterator of the items from the statement response, if any.
Examples
--------
Insert an item
>>> import awswrangler as wr
>>> wr.dynamodb.execute_statement(
... statement="INSERT INTO movies VALUE {'title': ?, 'year': ?, 'info': ?}",
... parameters=[title, year, {"plot": plot, "rating": rating}],
... )
Select items
>>> wr.dynamodb.execute_statement(
... statement="SELECT * FROM movies WHERE title=? AND year=?",
... parameters=[title, year],
... )
Update items
>>> wr.dynamodb.execute_statement(
... statement="UPDATE movies SET info.rating=? WHERE title=? AND year=?",
... parameters=[rating, title, year],
... )
Delete items
>>> wr.dynamodb.execute_statement(
... statement="DELETE FROM movies WHERE title=? AND year=?",
... parameters=[title, year],
... )
"""
kwargs: _ReadExecuteStatementKwargs = {"Statement": statement, "ConsistentRead": consistent_read}
if parameters:
serializer = TypeSerializer()
kwargs["Parameters"] = [serializer.serialize(p) for p in parameters]
dynamodb_client = _utils.client(service_name="dynamodb", session=boto3_session)
if not statement.strip().upper().startswith("SELECT"):
_execute_statement(kwargs=kwargs, dynamodb_client=dynamodb_client)
return None
return _read_execute_statement(kwargs=kwargs, dynamodb_client=dynamodb_client)
def _validate_items(
items: list[dict[str, Any]] | list[Mapping[str, Any]], key_schema: list["KeySchemaElementTypeDef"]
) -> None:
"""
Validate if all items have the required keys for the Amazon DynamoDB table.
Parameters
----------
items: Union[List[Dict[str, Any]], List[Mapping[str, Any]]]
List which contains the items that will be validated.
key_schema: List[KeySchemaElementTableTypeDef]
The primary key structure for the table.
Each element consists of the attribute name and it's type (HASH or RANGE).
"""
table_keys = [schema["AttributeName"] for schema in key_schema]
if not all(key in item for item in items for key in table_keys):
raise exceptions.InvalidArgumentValue("All items need to contain the required keys for the table.")
# Based on https://github.com/boto/boto3/blob/fcc24f39cc0a923fa578587fcd1f781e820488a1/boto3/dynamodb/table.py#L63
class _TableBatchWriter:
"""Automatically handle batch writes to DynamoDB for a single table."""
def __init__(
self,
table_name: str,
client: "DynamoDBClient",
flush_amount: int = 25,
overwrite_by_pkeys: list[str] | None = None,
):
self._table_name = table_name
self._client = client
self._items_buffer: list["WriteRequestOutputTypeDef"] = []
self._flush_amount = flush_amount
self._overwrite_by_pkeys = overwrite_by_pkeys
def put_item(self, item: dict[str, "AttributeValueTypeDef"]) -> None:
"""
Add a new put item request to the batch.
Parameters
----------
item: Dict[str, AttributeValueTypeDef]
The item to add.
"""
self._add_request_and_process({"PutRequest": {"Item": item}})
def delete_item(self, key: dict[str, "AttributeValueTypeDef"]) -> None:
"""
Add a new delete request to the batch.
Parameters
----------
key: Dict[str, AttributeValueTypeDef]
The key of the item to delete.
"""
self._add_request_and_process({"DeleteRequest": {"Key": key}})
def _add_request_and_process(self, request: "WriteRequestOutputTypeDef") -> None:
if self._overwrite_by_pkeys:
self._remove_dup_pkeys_request_if_any(request, self._overwrite_by_pkeys)
self._items_buffer.append(request)
self._flush_if_needed()
def _remove_dup_pkeys_request_if_any(
self, request: "WriteRequestOutputTypeDef", overwrite_by_pkeys: list[str]
) -> None:
pkey_values_new = self._extract_pkey_values(request, overwrite_by_pkeys)
for item in self._items_buffer:
if self._extract_pkey_values(item, overwrite_by_pkeys) == pkey_values_new:
self._items_buffer.remove(item)
_logger.debug(
"With overwrite_by_pkeys enabled, skipping request:%s",
item,
)
def _extract_pkey_values(
self, request: "WriteRequestOutputTypeDef", overwrite_by_pkeys: list[str]
) -> list[Any] | None:
if request.get("PutRequest"):
return [request["PutRequest"]["Item"][key] for key in overwrite_by_pkeys]
if request.get("DeleteRequest"):
return [request["DeleteRequest"]["Key"][key] for key in overwrite_by_pkeys]
return None
def _flush_if_needed(self) -> None:
if len(self._items_buffer) >= self._flush_amount:
self._flush()
def _flush(self) -> None:
items_to_send = self._items_buffer[: self._flush_amount]
self._items_buffer = self._items_buffer[self._flush_amount :]
response = self._client.batch_write_item(RequestItems={self._table_name: items_to_send})
unprocessed_items = response["UnprocessedItems"]
if not unprocessed_items:
unprocessed_items = {}
item_list = unprocessed_items.get(self._table_name, [])
# Any unprocessed_items are immediately added to the
# next batch we send.
self._items_buffer.extend(item_list)
_logger.debug(
"Batch write sent %s, unprocessed: %s",
len(items_to_send),
len(self._items_buffer),
)
def __enter__(self) -> "_TableBatchWriter":
return self
def __exit__(
self,
exception_type: type[BaseException] | None,
exception_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
# When we exit, we need to keep flushing whatever's left
# until there's nothing left in our items buffer.
while self._items_buffer:
self._flush()
return None