pyre_extensions/safe_json.py (84 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import json import sys from typing import Any, Dict, List, Type, TypeVar, cast, IO, Union if sys.version_info[:2] >= (3, 9): # pyre-fixme[21]: Could not find name `_TypedDictMeta` in `typing`. from typing import _TypedDictMeta else: # pyre-fixme[21]: Could not find name `_TypedDictMeta` in `typing_extensions`. from typing_extensions import _TypedDictMeta from typing_inspect import get_origin, is_optional_type if sys.version_info[:2] < (3, 7): from typing_inspect import get_last_args as get_args else: from typing_inspect import get_args as get_args class InvalidJson(json.JSONDecodeError): def __init__(self, message: str) -> None: super().__init__(message, "", 0) def _is_primitive(target_type: Type[object]) -> bool: return target_type in (int, float, str, bool) def _is_list(target_type: Type[object]) -> bool: return get_origin(target_type) in (List, list) def _is_dictionary(target_type: Type[object]) -> bool: return get_origin(target_type) in (Dict, dict) def _is_typed_dictionary(target_type: Type[object]) -> bool: return isinstance(target_type, _TypedDictMeta) # pyre-ignore: private API. def _validate_list(value: object, target_type: Type[List[object]]) -> None: if not isinstance(value, list): raise InvalidJson(f"`{value}` is not a list") (element_type,) = get_args(target_type) for element in value: _validate_value(element, element_type) def _validate_dictionary( value: object, target_type: Type[Dict[object, object]] ) -> None: if not isinstance(value, dict): raise InvalidJson(f"`{value}` is not a dictionary") key_type, value_type = get_args(target_type) for key, value in value.items(): _validate_value(key, key_type) _validate_value(value, value_type) def _validate_typed_dictionary(value: object, target_type: Type[object]) -> None: if not isinstance(value, dict): raise InvalidJson(f"`{value}` is not a dictionary") for key, value_type in target_type.__annotations__.items(): if key not in value: raise InvalidJson( f"{value} of TypedDict {target_type} must contain key {key}" ) _validate_value(value[key], value_type) def _validate_value(value: object, target_type: Type[object]) -> None: if target_type is Any: return elif _is_list(target_type): _validate_list(value, cast(Type[List[object]], target_type)) elif _is_dictionary(target_type): _validate_dictionary(value, cast(Type[Dict[object, object]], target_type)) elif _is_typed_dictionary(target_type): _validate_typed_dictionary(value, target_type) elif is_optional_type(target_type): if value is None: return _validate_value(value, get_args(target_type)[0]) else: if target_type not in [int, float, str, bool]: raise InvalidJson(f"Invalid value type {target_type}") if not isinstance(value, target_type): raise InvalidJson(f"`{value}` is not a {target_type}") T = TypeVar("T") def load( input_: Union[IO[str], IO[bytes]], target: Type[T], *, validate: bool = True, ) -> T: return loads(input_.read(), target, validate=validate) def loads(input_: Union[str, bytes], target: Type[T], *, validate: bool = True) -> T: try: parsed = json.loads(input_) if validate: _validate_value(parsed, target) return parsed except Exception as exception: raise InvalidJson(str(exception)) def validate(input: object, target: Type[T]) -> T: _validate_value(input, target) return cast(T, input)