pyre_extensions/safe_json.py (84 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import json
import sys
from typing import Any, Dict, List, Type, TypeVar, cast, IO, Union
if sys.version_info[:2] >= (3, 9):
# pyre-fixme[21]: Could not find name `_TypedDictMeta` in `typing`.
from typing import _TypedDictMeta
else:
# pyre-fixme[21]: Could not find name `_TypedDictMeta` in `typing_extensions`.
from typing_extensions import _TypedDictMeta
from typing_inspect import get_origin, is_optional_type
if sys.version_info[:2] < (3, 7):
from typing_inspect import get_last_args as get_args
else:
from typing_inspect import get_args as get_args
class InvalidJson(json.JSONDecodeError):
def __init__(self, message: str) -> None:
super().__init__(message, "", 0)
def _is_primitive(target_type: Type[object]) -> bool:
return target_type in (int, float, str, bool)
def _is_list(target_type: Type[object]) -> bool:
return get_origin(target_type) in (List, list)
def _is_dictionary(target_type: Type[object]) -> bool:
return get_origin(target_type) in (Dict, dict)
def _is_typed_dictionary(target_type: Type[object]) -> bool:
return isinstance(target_type, _TypedDictMeta) # pyre-ignore: private API.
def _validate_list(value: object, target_type: Type[List[object]]) -> None:
if not isinstance(value, list):
raise InvalidJson(f"`{value}` is not a list")
(element_type,) = get_args(target_type)
for element in value:
_validate_value(element, element_type)
def _validate_dictionary(
value: object, target_type: Type[Dict[object, object]]
) -> None:
if not isinstance(value, dict):
raise InvalidJson(f"`{value}` is not a dictionary")
key_type, value_type = get_args(target_type)
for key, value in value.items():
_validate_value(key, key_type)
_validate_value(value, value_type)
def _validate_typed_dictionary(value: object, target_type: Type[object]) -> None:
if not isinstance(value, dict):
raise InvalidJson(f"`{value}` is not a dictionary")
for key, value_type in target_type.__annotations__.items():
if key not in value:
raise InvalidJson(
f"{value} of TypedDict {target_type} must contain key {key}"
)
_validate_value(value[key], value_type)
def _validate_value(value: object, target_type: Type[object]) -> None:
if target_type is Any:
return
elif _is_list(target_type):
_validate_list(value, cast(Type[List[object]], target_type))
elif _is_dictionary(target_type):
_validate_dictionary(value, cast(Type[Dict[object, object]], target_type))
elif _is_typed_dictionary(target_type):
_validate_typed_dictionary(value, target_type)
elif is_optional_type(target_type):
if value is None:
return
_validate_value(value, get_args(target_type)[0])
else:
if target_type not in [int, float, str, bool]:
raise InvalidJson(f"Invalid value type {target_type}")
if not isinstance(value, target_type):
raise InvalidJson(f"`{value}` is not a {target_type}")
T = TypeVar("T")
def load(
input_: Union[IO[str], IO[bytes]],
target: Type[T],
*,
validate: bool = True,
) -> T:
return loads(input_.read(), target, validate=validate)
def loads(input_: Union[str, bytes], target: Type[T], *, validate: bool = True) -> T:
try:
parsed = json.loads(input_)
if validate:
_validate_value(parsed, target)
return parsed
except Exception as exception:
raise InvalidJson(str(exception))
def validate(input: object, target: Type[T]) -> T:
_validate_value(input, target)
return cast(T, input)