hasher-matcher-actioner/hmalib/common/aws_dataclass.py (114 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Helpers for converting python dataclasses to/from aws-friendly formats. There's likely already an existing library that exists somewhere that does this much better, but after spending 5 minutes looking, just try and write something on our own. How to use: import typing as t from hmalib.common.aws_dataclass import HasAWSSerialization from dataclasses import dataclass @dataclass class Item(HasAWSSerialization): a: int b: str c: t.List[str] item = Item(5, "5", ["five"]) more_portable_dict = item.to_aws() # Can write to dynamodb, or same_item = Item.from_aws(more_portable_dict) """ from decimal import Decimal from dataclasses import dataclass, field, fields, is_dataclass import json import typing as t T = t.TypeVar("T") class AWSSerializationFailure(ValueError): pass def py_to_aws(py_field: t.Any, in_type: t.Optional[t.Type[T]] = None) -> T: """ Convert a py item into its AWS equivalent. Should exactly inverse aws_to_py """ if in_type is None: in_type = type(py_field) origin = t.get_origin(in_type) args = t.get_args(in_type) check_type = origin or in_type if isinstance(check_type, t.ForwardRef): raise AWSSerializationFailure( "Serialization error: " f"Expected no forward refs, but detected {check_type}. " "Rework your dataclasses to avoid forward references." ) if not isinstance(py_field, check_type): raise AWSSerializationFailure( "Serialization error: " f"Expected {check_type} got {type(py_field)} ({py_field!r})" ) if in_type == int: # N # Technically, this also needs to be converted to decimal, # but the boto3 translater seems to handle it fine return py_field # type: ignore # mypy/issues/10003 if in_type == float: # N # WARNING WARNING # floating point is not truly supported in dynamodb # We can fake it for numbers without too much precision # but Decimal("3.4") != float(3.4) return Decimal(str(py_field)) # type: ignore # mypy/issues/10003 if in_type == Decimal: # N return py_field # type: ignore # mypy/issues/10003 if in_type == str: # S return py_field # type: ignore # mypy/issues/10003 if in_type == bool: # BOOL return py_field # type: ignore # mypy/issues/10003 if in_type == t.Set[str]: # SS return py_field # type: ignore # mypy/issues/10003 if in_type == t.Set[int]: # SN return {i for i in py_field} # type: ignore # mypy/issues/10003 if in_type == t.Set[float]: # SN # WARNING WARNING # floating point is not truly supported in dynamodb # See note above return {Decimal(str(s)) for s in py_field} # type: ignore # mypy/issues/10003 if origin is list: # L return [py_to_aws(v, args[0]) for v in py_field] # type: ignore # mypy/issues/10003 # various simple collections that don't fit into a # special cases above can likely be coerced into list. if origin is set: # L - Special case return [py_to_aws(v, args[0]) for v in py_field] # type: ignore # mypy/issues/10003 if origin is dict and args[0] is str: # M return {k: py_to_aws(v, args[1]) for k, v in py_field.items()} # type: ignore # mypy/issues/10003 if is_dataclass(in_type): return { field.name: py_to_aws(getattr(py_field, field.name), field.type) for field in fields(in_type) } # type: ignore # mypy/issues/10003 raise AWSSerializationFailure(f"Missing Serialization logic for {in_type!r}") def aws_to_py(in_type: t.Type[T], aws_field: t.Any) -> T: """ Convert an AWS item back into its py equivalent This might not even be strictly required, but we check that all the types are roughly what we expect, and convert Decimals back into ints/floats """ origin = t.get_origin(in_type) args = t.get_args(in_type) check_type = origin if in_type is float: check_type = Decimal elif in_type is int: check_type = (int, Decimal) elif is_dataclass(in_type): check_type = dict elif check_type is set and args: if args[0] not in (str, float, int, Decimal): check_type = list if not isinstance(aws_field, check_type or in_type): # If you are getting random deserialization errors in tests that you did # not touch, have a look at # https://github.com/facebook/ThreatExchange/issues/697 raise AWSSerializationFailure( "Deserialization error: " f"Expected {in_type} got {type(aws_field)} ({aws_field!r})" ) if in_type is int: # N return int(aws_field) # type: ignore # mypy/issues/10003 if in_type is float: # N return float(aws_field) # type: ignore # mypy/issues/10003 if in_type is Decimal: # N return aws_field # type: ignore # mypy/issues/10003 if in_type is str: # S return aws_field # type: ignore # mypy/issues/10003 if in_type is bool: # BOOL return aws_field # type: ignore # mypy/issues/10003 if in_type is t.Set[str]: # SS return aws_field # type: ignore # mypy/issues/10003 if in_type is t.Set[int]: # SN return {int(s) for s in aws_field} # type: ignore # mypy/issues/10003 if in_type is t.Set[float]: # SN return {float(s) for s in aws_field} # type: ignore # mypy/issues/10003 if origin is set: # L - special case return {aws_to_py(args[0], v) for v in aws_field} # type: ignore # mypy/issues/10003 if origin is list: # L return [aws_to_py(args[0], v) for v in aws_field] # type: ignore # mypy/issues/10003 # It would be possible to add support for nested dataclasses here, which # just become maps with the keys as their attributes # Another option would be adding a new class that adds methods to convert # to an AWS-friendly struct and back if origin is dict and args[0] is str: # M # check if value type of map origin is explicitly set return {k: aws_to_py(args[1], v) for k, v in aws_field.items()} # type: ignore # mypy/issues/10003 if is_dataclass(in_type): kwargs = {} for field in fields(in_type): if not field.init: continue val = aws_field.get(field.name) if val is None: continue # Hopefully missing b/c default or version difference kwargs[field.name] = aws_to_py(field.type, val) return in_type(**kwargs) # type: ignore # No idea how to correctly type this raise AWSSerializationFailure(f"Missing deserialization logic for {in_type!r}") class HasAWSSerialization: """Convenience mixin to add serialization to a class""" def to_aws(self): return py_to_aws(self) def to_aws_json(self): return json.dumps(self.to_aws()) @classmethod def from_aws(cls: t.Type[T], val: t.Dict[str, t.Any]) -> T: return aws_to_py(cls, val) @classmethod def from_aws_json(cls: t.Type[T], val: str) -> T: return aws_to_py(cls, json.loads(val))