def _create_type_verifier()

in flink-python/pyflink/table/types.py [0:0]


def _create_type_verifier(data_type: DataType, name: str = None):
    """
    Creates a verifier that checks the type of obj against data_type and raises a TypeError if they
    do not match.

    This verifier also checks the value of obj against data_type and raises a ValueError if it's
    not within the allowed range, e.g. using 128 as TinyIntType will overflow. Note that, Python
    float is not checked, so it will become infinity when cast to Java float if it overflows.

    >>> _create_type_verifier(RowType([]))(None)
    >>> _create_type_verifier(VarCharType(100))("")
    >>> _create_type_verifier(BigIntType())(0)
    >>> _create_type_verifier(ArrayType(SmallIntType()))(list(range(3)))
    >>> _create_type_verifier(ArrayType(VarCharType(10)))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
        ...
    TypeError:...
    >>> _create_type_verifier(MapType(VarCharType(100), IntType()))({})
    >>> _create_type_verifier(RowType([]))(())
    >>> _create_type_verifier(RowType([]))([])
    >>> _create_type_verifier(RowType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
        ...
    ValueError:...
    >>> # Check if numeric values are within the allowed range.
    >>> _create_type_verifier(TinyIntType())(12)
    >>> _create_type_verifier(TinyIntType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
        ...
    ValueError:...
    >>> _create_type_verifier(TinyIntType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
        ...
    ValueError:...
    >>> _create_type_verifier(
    ...     ArrayType(SmallIntType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
        ...
    ValueError:...
    >>> _create_type_verifier(MapType(VarCharType(100), IntType()))({None: 1})
    Traceback (most recent call last):
        ...
    ValueError:...
    >>> schema = RowType().add("a", IntType()).add("b", VarCharType(100), False)
    >>> _create_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
        ...
    ValueError:...
    """

    if name is None:
        new_msg = lambda msg: msg
        new_name = lambda n: "field %s" % n
    else:
        new_msg = lambda msg: "%s: %s" % (name, msg)
        new_name = lambda n: "field %s in %s" % (n, name)

    def verify_nullability(obj):
        if obj is None:
            if data_type._nullable:
                return True
            else:
                raise ValueError(new_msg("This field is not nullable, but got None"))
        else:
            return False

    _type = type(data_type)

    assert _type in _acceptable_types or isinstance(data_type, UserDefinedType),\
        new_msg("unknown datatype: %s" % data_type)

    def verify_acceptable_types(obj):
        # subclass of them can not be from_sql_type in JVM
        if type(obj) not in _acceptable_types[_type]:
            raise TypeError(new_msg("%s can not accept object %r in type %s"
                                    % (data_type, obj, type(obj))))

    if isinstance(data_type, CharType):
        def verify_char(obj):
            verify_acceptable_types(obj)
            if len(obj) != data_type.length:
                raise ValueError(new_msg(
                    "length of object (%s) of CharType is not: %d" % (obj, data_type.length)))

        verify_value = verify_char

    elif isinstance(data_type, VarCharType):
        def verify_varchar(obj):
            verify_acceptable_types(obj)
            if len(obj) > data_type.length:
                raise ValueError(new_msg(
                    "length of object (%s) of VarCharType exceeds: %d" % (obj, data_type.length)))

        verify_value = verify_varchar

    elif isinstance(data_type, BinaryType):
        def verify_binary(obj):
            verify_acceptable_types(obj)
            if len(obj) != data_type.length:
                raise ValueError(new_msg(
                    "length of object (%s) of BinaryType is not: %d" % (obj, data_type.length)))

        verify_value = verify_binary

    elif isinstance(data_type, VarBinaryType):
        def verify_varbinary(obj):
            verify_acceptable_types(obj)
            if len(obj) > data_type.length:
                raise ValueError(new_msg(
                    "length of object (%s) of VarBinaryType exceeds: %d"
                    % (obj, data_type.length)))

        verify_value = verify_varbinary

    elif isinstance(data_type, UserDefinedType):
        sql_type = data_type.sql_type()
        verifier = _create_type_verifier(sql_type, name=name)

        def verify_udf(obj):
            if not (hasattr(obj, '__UDT__') and obj.__UDT__ == data_type):
                raise ValueError(new_msg("%r is not an instance of type %r" % (obj, data_type)))
            data = data_type.to_sql_type(obj)
            if isinstance(sql_type, RowType):
                # remove the RowKind value in the first position.
                data = data[1:]
            verifier(data)

        verify_value = verify_udf

    elif isinstance(data_type, TinyIntType):
        def verify_tiny_int(obj):
            verify_acceptable_types(obj)
            if obj < -128 or obj > 127:
                raise ValueError(new_msg("object of TinyIntType out of range, got: %s" % obj))

        verify_value = verify_tiny_int

    elif isinstance(data_type, SmallIntType):
        def verify_small_int(obj):
            verify_acceptable_types(obj)
            if obj < -32768 or obj > 32767:
                raise ValueError(new_msg("object of SmallIntType out of range, got: %s" % obj))

        verify_value = verify_small_int

    elif isinstance(data_type, IntType):
        def verify_integer(obj):
            verify_acceptable_types(obj)
            if obj < -2147483648 or obj > 2147483647:
                raise ValueError(
                    new_msg("object of IntType out of range, got: %s" % obj))

        verify_value = verify_integer

    elif isinstance(data_type, ArrayType):
        element_verifier = _create_type_verifier(
            data_type.element_type, name="element in array %s" % name)

        def verify_array(obj):
            verify_acceptable_types(obj)
            for i in obj:
                element_verifier(i)

        verify_value = verify_array

    elif isinstance(data_type, MapType):
        key_verifier = _create_type_verifier(data_type.key_type, name="key of map %s" % name)
        value_verifier = _create_type_verifier(data_type.value_type, name="value of map %s" % name)

        def verify_map(obj):
            verify_acceptable_types(obj)
            for k, v in obj.items():
                key_verifier(k)
                value_verifier(v)

        verify_value = verify_map

    elif isinstance(data_type, RowType):
        verifiers = []
        for f in data_type.fields:
            verifier = _create_type_verifier(f.data_type, name=new_name(f.name))
            verifiers.append((f.name, verifier))

        def verify_row_field(obj):
            if isinstance(obj, dict):
                for f, verifier in verifiers:
                    verifier(obj.get(f))
            elif isinstance(obj, Row):
                if obj._from_dict:
                    # Since the row was created with field names, use the verifier
                    # associated with the field name
                    for f, verifier in verifiers:
                        verifier(obj[f])
                else:
                    # If the row was created with positional arguments, use the verifier
                    # in the same position.
                    for idx, (_, verifier) in enumerate(verifiers):
                        verifier(obj[idx])
            elif isinstance(obj, (tuple, list)):
                if len(obj) != len(verifiers):
                    raise ValueError(
                        new_msg("Length of object (%d) does not match with "
                                "length of fields (%d)" % (len(obj), len(verifiers))))
                for v, (_, verifier) in zip(obj, verifiers):
                    verifier(v)
            elif hasattr(obj, "__dict__"):
                d = obj.__dict__
                for f, verifier in verifiers:
                    verifier(d.get(f))
            else:
                raise TypeError(new_msg("RowType can not accept object %r in type %s"
                                        % (obj, type(obj))))

        verify_value = verify_row_field

    else:
        def verify_default(obj):
            verify_acceptable_types(obj)

        verify_value = verify_default

    def verify(obj):
        if not verify_nullability(obj):
            verify_value(obj)

    return verify