awswrangler/timestream/_write.py (381 lines of code) (raw):
"""Amazon Timestream Module."""
from __future__ import annotations
import itertools
import logging
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Literal, cast
import boto3
from botocore.config import Config
import awswrangler.pandas as pd
from awswrangler import _data_types, _utils, exceptions, s3
from awswrangler._config import apply_configs
from awswrangler._distributed import engine
from awswrangler._executor import _BaseExecutor, _get_executor
from awswrangler.distributed.ray import ray_get
from awswrangler.typing import TimestreamBatchLoadReportS3Configuration
if TYPE_CHECKING:
from mypy_boto3_timestream_write.client import TimestreamWriteClient
_BATCH_LOAD_FINAL_STATES: list[str] = ["SUCCEEDED", "FAILED", "PROGRESS_STOPPED", "PENDING_RESUME"]
_BATCH_LOAD_WAIT_POLLING_DELAY: float = 2 # SECONDS
_TIME_UNITS_MAPPING = {
"SECONDS": (9, 0),
"MILLISECONDS": (6, 3),
"MICROSECONDS": (3, 6),
"NANOSECONDS": (0, 9),
}
_logger: logging.Logger = logging.getLogger(__name__)
_TimeUnitLiteral = Literal["MILLISECONDS", "SECONDS", "MICROSECONDS", "NANOSECONDS"]
def _df2list(df: pd.DataFrame) -> list[list[Any]]:
"""Extract Parameters."""
parameters: list[list[Any]] = df.values.tolist()
for i, row in enumerate(parameters):
for j, value in enumerate(row):
if pd.isna(value):
parameters[i][j] = None
elif hasattr(value, "to_pydatetime"):
parameters[i][j] = value.to_pydatetime()
return parameters
def _check_time_unit(time_unit: _TimeUnitLiteral) -> str:
time_unit = time_unit if time_unit else "MILLISECONDS"
if time_unit not in _TIME_UNITS_MAPPING.keys():
raise exceptions.InvalidArgumentValue(
f"Invalid time unit: {time_unit}. Must be one of {_TIME_UNITS_MAPPING.keys()}."
)
return time_unit
def _format_timestamp(timestamp: int | datetime, time_unit: _TimeUnitLiteral) -> str:
if isinstance(timestamp, int):
return str(round(timestamp / pow(10, _TIME_UNITS_MAPPING[time_unit][0])))
if isinstance(timestamp, datetime):
return str(round(timestamp.timestamp() * pow(10, _TIME_UNITS_MAPPING[time_unit][1])))
raise exceptions.InvalidArgumentType("`time_col` must be of type timestamp.")
def _format_measure(
measure_name: str, measure_value: Any, measure_type: str, time_unit: _TimeUnitLiteral
) -> dict[str, str]:
return {
"Name": measure_name,
"Value": _format_timestamp(measure_value, time_unit) if measure_type == "TIMESTAMP" else str(measure_value),
"Type": measure_type,
}
def _sanitize_common_attributes(
common_attributes: dict[str, Any] | None,
version: int,
time_unit: _TimeUnitLiteral,
measure_name: str | None,
) -> dict[str, Any]:
common_attributes = {} if not common_attributes else common_attributes
# Values in common_attributes take precedence
common_attributes.setdefault("Version", version)
common_attributes.setdefault("TimeUnit", _check_time_unit(common_attributes.get("TimeUnit", time_unit)))
if "Time" not in common_attributes and common_attributes["TimeUnit"] == "NANOSECONDS":
raise exceptions.InvalidArgumentValue("Python datetime objects do not support nanoseconds precision.")
if "MeasureValue" in common_attributes and "MeasureValueType" not in common_attributes:
raise exceptions.InvalidArgumentCombination(
"MeasureValueType must be supplied alongside MeasureValue in common_attributes."
)
if measure_name:
common_attributes.setdefault("MeasureName", measure_name)
elif "MeasureName" not in common_attributes:
raise exceptions.InvalidArgumentCombination(
"MeasureName must be supplied with the `measure_name` argument or in common_attributes."
)
return common_attributes
@engine.dispatch_on_engine
def _write_batch(
timestream_client: "TimestreamWriteClient" | None,
database: str,
table: str,
common_attributes: dict[str, Any],
cols_names: list[str | None],
measure_cols: list[str | None],
measure_types: list[str],
dimensions_cols: list[str | None],
batch: list[Any],
) -> list[dict[str, str]]:
client_timestream = timestream_client if timestream_client else _utils.client(service_name="timestream-write")
records: list[dict[str, Any]] = []
scalar = bool(len(measure_cols) == 1 and "MeasureValues" not in common_attributes)
time_loc = 0
measure_cols_loc = 1 if cols_names[0] else 0
dimensions_cols_loc = 1 if len(measure_cols) == 1 else 1 + len(measure_cols)
if all(cols_names):
# Time and Measures are supplied in the data frame
dimensions_cols_loc = 1 + len(measure_cols)
elif all(v is None for v in cols_names[:2]):
# Time and Measures are supplied in common_attributes
dimensions_cols_loc = 0
time_unit = common_attributes["TimeUnit"]
for row in batch:
record: dict[str, Any] = {}
if "Time" not in common_attributes:
record["Time"] = _format_timestamp(row[time_loc], time_unit)
if scalar and "MeasureValue" not in common_attributes:
measure_value = row[measure_cols_loc]
if pd.isnull(measure_value):
continue
record["MeasureValue"] = str(measure_value)
elif not scalar and "MeasureValues" not in common_attributes:
record["MeasureValues"] = [
_format_measure(measure_name, measure_value, measure_value_type, time_unit) # type: ignore[arg-type]
for measure_name, measure_value, measure_value_type in zip(
measure_cols, row[measure_cols_loc:dimensions_cols_loc], measure_types
)
if not pd.isnull(measure_value)
]
if len(record["MeasureValues"]) == 0:
continue
if "MeasureValueType" not in common_attributes:
record["MeasureValueType"] = measure_types[0] if scalar else "MULTI"
# Dimensions can be specified in both common_attributes and the data frame
dimensions = (
[
{"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)}
for name, value in zip(dimensions_cols, row[dimensions_cols_loc:])
]
if all(dimensions_cols)
else []
)
if dimensions:
record["Dimensions"] = dimensions
if record:
records.append(record)
try:
if records:
_utils.try_it(
f=client_timestream.write_records,
ex=(
client_timestream.exceptions.ThrottlingException,
client_timestream.exceptions.InternalServerException,
),
max_num_tries=5,
DatabaseName=database,
TableName=table,
CommonAttributes=common_attributes,
Records=records,
)
except client_timestream.exceptions.RejectedRecordsException as ex:
return cast(List[Dict[str, str]], ex.response["RejectedRecords"]) # type: ignore[typeddict-item]
return []
@engine.dispatch_on_engine
def _write_df(
df: pd.DataFrame,
executor: _BaseExecutor,
database: str,
table: str,
common_attributes: dict[str, Any],
cols_names: list[str | None],
measure_cols: list[str | None],
measure_types: list[str],
dimensions_cols: list[str | None],
boto3_session: boto3.Session | None,
) -> list[dict[str, str]]:
timestream_client = _utils.client(
service_name="timestream-write",
session=boto3_session,
botocore_config=Config(read_timeout=20, max_pool_connections=5000, retries={"max_attempts": 10}),
)
batches: list[list[Any]] = _utils.chunkify(lst=_df2list(df=df), max_length=100)
_logger.debug("Writing %d batches of data", len(batches))
return executor.map(
_write_batch, # type: ignore[arg-type]
timestream_client,
itertools.repeat(database),
itertools.repeat(table),
itertools.repeat(common_attributes),
itertools.repeat(cols_names),
itertools.repeat(measure_cols),
itertools.repeat(measure_types),
itertools.repeat(dimensions_cols),
batches,
)
@_utils.validate_distributed_kwargs(
unsupported_kwargs=["boto3_session"],
)
def write(
df: pd.DataFrame,
database: str,
table: str,
time_col: str | None = None,
measure_col: str | list[str | None] | None = None,
dimensions_cols: list[str | None] | None = None,
version: int = 1,
time_unit: _TimeUnitLiteral = "MILLISECONDS",
use_threads: bool | int = True,
measure_name: str | None = None,
common_attributes: dict[str, Any] | None = None,
boto3_session: boto3.Session | None = None,
) -> list[dict[str, str]]:
"""Store a Pandas DataFrame into an Amazon Timestream table.
Note
----
In case `use_threads=True`, the number of threads from os.cpu_count() is used.
If the Timestream service rejects a record(s),
this function will not throw a Python exception.
Instead it will return the rejection information.
Note
----
If ``time_col`` column is supplied, it must be of type timestamp. ``time_unit`` is set to MILLISECONDS by default.
NANOSECONDS is not supported as python datetime objects are limited to microseconds precision.
Parameters
----------
df
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
database
Amazon Timestream database name.
table
Amazon Timestream table name.
time_col
DataFrame column name to be used as time. MUST be a timestamp column.
measure_col
DataFrame column name(s) to be used as measure.
dimensions_cols
List of DataFrame column names to be used as dimensions.
version
Version number used for upserts.
Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html.
time_unit
Time unit for the time column. MILLISECONDS by default.
use_threads
True to enable concurrent writing, False to disable multiple threads.
If enabled, os.cpu_count() is used as the number of threads.
If integer is provided, specified number is used.
measure_name
Name that represents the data attribute of the time series.
Overrides ``measure_col`` if specified.
common_attributes
Dictionary of attributes shared across all records in the request.
Using common attributes can optimize the cost of writes by reducing the size of request payloads.
Values in ``common_attributes`` take precedence over all other arguments and data frame values.
Dimension attributes are merged with attributes in record objects.
Example: ``{"Dimensions": [{"Name": "device_id", "Value": "12345"}], "MeasureValueType": "DOUBLE"}``.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
Returns
-------
Rejected records.
Possible reasons for rejection are described here:
https://docs.aws.amazon.com/timestream/latest/developerguide/API_RejectedRecord.html
Examples
--------
Store a Pandas DataFrame into a Amazon Timestream table.
>>> import awswrangler as wr
>>> import pandas as pd
>>> df = pd.DataFrame(
>>> {
>>> "time": [datetime.now(), datetime.now(), datetime.now()],
>>> "dim0": ["foo", "boo", "bar"],
>>> "dim1": [1, 2, 3],
>>> "measure": [1.0, 1.1, 1.2],
>>> }
>>> )
>>> rejected_records = wr.timestream.write(
>>> df=df,
>>> database="sampleDB",
>>> table="sampleTable",
>>> time_col="time",
>>> measure_col="measure",
>>> dimensions_cols=["dim0", "dim1"],
>>> )
>>> assert len(rejected_records) == 0
Return value if some records are rejected.
>>> [
>>> {
>>> 'ExistingVersion': 2,
>>> 'Reason': 'The record version 1 is lower than the existing version 2. A '
>>> 'higher version is required to update the measure value.',
>>> 'RecordIndex': 0
>>> }
>>> ]
"""
measure_cols = measure_col if isinstance(measure_col, list) else [measure_col]
measure_types: list[str] = (
_data_types.timestream_type_from_pandas(df.loc[:, measure_cols]) if all(measure_cols) else []
)
dimensions_cols = dimensions_cols if dimensions_cols else [dimensions_cols] # type: ignore[list-item]
cols_names: list[str | None] = [time_col] + measure_cols + dimensions_cols
measure_name = measure_name if measure_name else measure_cols[0]
common_attributes = _sanitize_common_attributes(common_attributes, version, time_unit, measure_name)
_logger.debug(
"Writing to Timestream table %s in database %s\ncommon_attributes: %s\n, cols_names: %s\n, measure_types: %s",
table,
database,
common_attributes,
cols_names,
measure_types,
)
# User can supply arguments in one of two ways:
# 1. With the `common_attributes` dictionary which takes precedence
# 2. With data frame columns
# However, the data frame cannot be completely empty.
# So if all values in `cols_names` are None, an exception is raised.
if any(cols_names):
dfs = _utils.split_pandas_frame(
df.loc[:, [c for c in cols_names if c]], _utils.ensure_cpu_count(use_threads=use_threads)
)
else:
raise exceptions.InvalidArgumentCombination(
"At least one of `time_col`, `measure_col` or `dimensions_cols` must be specified."
)
_logger.debug("Writing %d dataframes to Timestream table", len(dfs))
executor: _BaseExecutor = _get_executor(use_threads=use_threads)
errors = list(
itertools.chain(
*ray_get(
[
_write_df(
df=df,
executor=executor,
database=database,
table=table,
common_attributes=common_attributes,
cols_names=cols_names,
measure_cols=measure_cols,
measure_types=measure_types,
dimensions_cols=dimensions_cols,
boto3_session=boto3_session,
)
for df in dfs
]
)
)
)
return list(itertools.chain(*ray_get(errors)))
@apply_configs
def wait_batch_load_task(
task_id: str,
timestream_batch_load_wait_polling_delay: float = _BATCH_LOAD_WAIT_POLLING_DELAY,
boto3_session: boto3.Session | None = None,
) -> dict[str, Any]:
"""
Wait for the Timestream batch load task to complete.
Parameters
----------
task_id
The ID of the batch load task.
timestream_batch_load_wait_polling_delay
Time to wait between two polling attempts.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
Returns
-------
Dictionary with the describe_batch_load_task response.
Examples
--------
>>> import awswrangler as wr
>>> res = wr.timestream.wait_batch_load_task(task_id='task-id')
Raises
------
exceptions.TimestreamLoadError
Error message raised by failed task.
"""
timestream_client = _utils.client(service_name="timestream-write", session=boto3_session)
response = timestream_client.describe_batch_load_task(TaskId=task_id)
status = response["BatchLoadTaskDescription"]["TaskStatus"]
while status not in _BATCH_LOAD_FINAL_STATES:
time.sleep(timestream_batch_load_wait_polling_delay)
response = timestream_client.describe_batch_load_task(TaskId=task_id)
status = response["BatchLoadTaskDescription"]["TaskStatus"]
_logger.debug("Task status: %s", status)
if status != "SUCCEEDED":
_logger.debug("Task response: %s", response)
raise exceptions.TimestreamLoadError(response.get("ErrorMessage"))
return response # type: ignore[return-value]
@apply_configs
@_utils.validate_distributed_kwargs(
unsupported_kwargs=["boto3_session", "s3_additional_kwargs"],
)
def batch_load(
df: pd.DataFrame,
path: str,
database: str,
table: str,
time_col: str,
dimensions_cols: list[str],
measure_cols: list[str],
measure_name_col: str,
report_s3_configuration: TimestreamBatchLoadReportS3Configuration,
time_unit: _TimeUnitLiteral = "MILLISECONDS",
record_version: int = 1,
timestream_batch_load_wait_polling_delay: float = _BATCH_LOAD_WAIT_POLLING_DELAY,
keep_files: bool = False,
use_threads: bool | int = True,
boto3_session: boto3.Session | None = None,
s3_additional_kwargs: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Batch load a Pandas DataFrame into a Amazon Timestream table.
Note
----
The supplied column names (time, dimension, measure) MUST match those in the Timestream table.
Note
----
Only ``MultiMeasureMappings`` is supported.
See https://docs.aws.amazon.com/timestream/latest/developerguide/batch-load-data-model-mappings.html
Parameters
----------
df
Pandas DataFrame.
path
S3 prefix to write the data.
database
Amazon Timestream database name.
table
Amazon Timestream table name.
time_col
Column name with the time data. It must be a long data type that represents the time since the Unix epoch.
dimensions_cols
List of column names with the dimensions data.
measure_cols
List of column names with the measure data.
measure_name_col
Column name with the measure name.
report_s3_configuration
Dictionary of the configuration for the S3 bucket where the error report is stored.
https://docs.aws.amazon.com/timestream/latest/developerguide/API_ReportS3Configuration.html
Example: {"BucketName": 'error-report-bucket-name'}
time_unit
Time unit for the time column. MILLISECONDS by default.
record_version
Record version.
timestream_batch_load_wait_polling_delay
Time to wait between two polling attempts.
keep_files
Whether to keep the files after the operation.
use_threads
True to enable concurrent requests, False to disable multiple threads.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
s3_additional_kwargs
Forwarded to S3 botocore requests.
Returns
-------
A dictionary of the batch load task response.
Examples
--------
>>> import awswrangler as wr
>>> response = wr.timestream.batch_load(
>>> df=df,
>>> path='s3://bucket/path/',
>>> database='sample_db',
>>> table='sample_table',
>>> time_col='time',
>>> dimensions_cols=['region', 'location'],
>>> measure_cols=['memory_utilization', 'cpu_utilization'],
>>> report_s3_configuration={'BucketName': 'error-report-bucket-name'},
>>> )
"""
path = path if path.endswith("/") else f"{path}/"
if s3.list_objects(path=path, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs):
raise exceptions.InvalidArgument(
f"The received S3 path ({path}) is not empty. "
"Please, provide a different path or use wr.s3.delete_objects() to clean up the current one."
)
columns = [time_col, *dimensions_cols, *measure_cols, measure_name_col]
try:
s3.to_csv(
df=df.loc[:, columns],
path=path,
index=False,
dataset=True,
mode="append",
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
measure_types: list[str] = _data_types.timestream_type_from_pandas(df.loc[:, measure_cols])
return batch_load_from_files(
path=path,
database=database,
table=table,
time_col=time_col,
dimensions_cols=dimensions_cols,
measure_cols=measure_cols,
measure_types=measure_types,
report_s3_configuration=report_s3_configuration,
time_unit=time_unit,
measure_name_col=measure_name_col,
record_version=record_version,
timestream_batch_load_wait_polling_delay=timestream_batch_load_wait_polling_delay,
boto3_session=boto3_session,
)
finally:
if not keep_files:
_logger.debug("Deleting objects in S3 path: %s", path)
s3.delete_objects(
path=path,
use_threads=use_threads,
boto3_session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs,
)
@apply_configs
def batch_load_from_files(
path: str,
database: str,
table: str,
time_col: str,
dimensions_cols: list[str],
measure_cols: list[str],
measure_types: list[str],
measure_name_col: str,
report_s3_configuration: TimestreamBatchLoadReportS3Configuration,
time_unit: _TimeUnitLiteral = "MILLISECONDS",
record_version: int = 1,
data_source_csv_configuration: dict[str, str | bool] | None = None,
timestream_batch_load_wait_polling_delay: float = _BATCH_LOAD_WAIT_POLLING_DELAY,
boto3_session: boto3.Session | None = None,
) -> dict[str, Any]:
"""Batch load files from S3 into a Amazon Timestream table.
Note
----
The supplied column names (time, dimension, measure) MUST match those in the Timestream table.
Note
----
Only ``MultiMeasureMappings`` is supported.
See https://docs.aws.amazon.com/timestream/latest/developerguide/batch-load-data-model-mappings.html
Parameters
----------
path
S3 prefix to write the data.
database
Amazon Timestream database name.
table
Amazon Timestream table name.
time_col
Column name with the time data. It must be a long data type that represents the time since the Unix epoch.
dimensions_cols
List of column names with the dimensions data.
measure_cols
List of column names with the measure data.
measure_name_col
Column name with the measure name.
report_s3_configuration
Dictionary of the configuration for the S3 bucket where the error report is stored.
https://docs.aws.amazon.com/timestream/latest/developerguide/API_ReportS3Configuration.html
Example: {"BucketName": 'error-report-bucket-name'}
time_unit
Time unit for the time column. MILLISECONDS by default.
record_version
Record version.
data_source_csv_configuration
Dictionary of the data source CSV configuration.
https://docs.aws.amazon.com/timestream/latest/developerguide/API_CsvConfiguration.html
timestream_batch_load_wait_polling_delay
Time to wait between two polling attempts.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
Returns
-------
A dictionary of the batch load task response.
Examples
--------
>>> import awswrangler as wr
>>> response = wr.timestream.batch_load_from_files(
>>> path='s3://bucket/path/',
>>> database='sample_db',
>>> table='sample_table',
>>> time_col='time',
>>> dimensions_cols=['region', 'location'],
>>> measure_cols=['memory_utilization', 'cpu_utilization'],
>>> report_s3_configuration={'BucketName': 'error-report-bucket-name'},
>>> )
"""
timestream_client = _utils.client(service_name="timestream-write", session=boto3_session)
bucket, prefix = _utils.parse_path(path=path)
kwargs: dict[str, Any] = {
"TargetDatabaseName": database,
"TargetTableName": table,
"DataModelConfiguration": {
"DataModel": {
"TimeColumn": time_col,
"TimeUnit": _check_time_unit(time_unit),
"DimensionMappings": [{"SourceColumn": c} for c in dimensions_cols],
"MeasureNameColumn": measure_name_col,
"MultiMeasureMappings": {
"MultiMeasureAttributeMappings": [
{"SourceColumn": c, "MeasureValueType": t} for c, t in zip(measure_cols, measure_types)
],
},
}
},
"DataSourceConfiguration": {
"DataSourceS3Configuration": {"BucketName": bucket, "ObjectKeyPrefix": prefix},
"DataFormat": "CSV",
"CsvConfiguration": data_source_csv_configuration if data_source_csv_configuration else {},
},
"ReportConfiguration": {"ReportS3Configuration": report_s3_configuration},
"RecordVersion": record_version,
}
task_id = timestream_client.create_batch_load_task(**kwargs)["TaskId"]
return wait_batch_load_task(
task_id=task_id,
timestream_batch_load_wait_polling_delay=timestream_batch_load_wait_polling_delay,
boto3_session=boto3_session,
)