in python/pyspark/sql/pandas/types.py [0:0]
def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]:
if isinstance(dt, ArrayType):
_element_conv = _converter(dt.elementType)
if ignore_unexpected_complex_type_values:
if _element_conv is None:
def convert_array(value: Any) -> Any:
if isinstance(value, Iterable):
return list(value)
else:
return value
else:
assert _element_conv is not None
def convert_array(value: Any) -> Any:
if isinstance(value, Iterable):
return [_element_conv(v) if v is not None else None for v in value]
else:
return value
else:
if _element_conv is None:
def convert_array(value: Any) -> Any:
return list(value)
else:
assert _element_conv is not None
def convert_array(value: Any) -> Any:
# Iterable
return [_element_conv(v) if v is not None else None for v in value]
return convert_array
elif isinstance(dt, MapType):
_key_conv = _converter(dt.keyType)
_value_conv = _converter(dt.valueType)
if ignore_unexpected_complex_type_values:
if _key_conv is None and _value_conv is None:
def convert_map(value: Any) -> Any:
if isinstance(value, dict):
return list(value.items())
else:
return value
else:
def convert_map(value: Any) -> Any:
if isinstance(value, dict):
return [
(
_key_conv(k) if _key_conv is not None and k is not None else k,
_value_conv(v)
if _value_conv is not None and v is not None
else v,
)
for k, v in value.items()
]
else:
return value
else:
if _key_conv is None and _value_conv is None:
def convert_map(value: Any) -> Any:
# dict
return list(value.items())
else:
def convert_map(value: Any) -> Any:
# dict
return [
(
_key_conv(k) if _key_conv is not None and k is not None else k,
_value_conv(v) if _value_conv is not None and v is not None else v,
)
for k, v in value.items()
]
return convert_map
elif isinstance(dt, StructType):
field_names = dt.names
if error_on_duplicated_field_names and len(set(field_names)) != len(field_names):
raise UnsupportedOperationException(
errorClass="DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT",
messageParameters={"field_names": str(field_names)},
)
dedup_field_names = _dedup_names(field_names)
field_convs = [_converter(f.dataType) for f in dt.fields]
if ignore_unexpected_complex_type_values:
if all(conv is None for conv in field_convs):
def convert_struct(value: Any) -> Any:
if isinstance(value, dict):
return {
name: value.get(key, None)
for name, key in zip(dedup_field_names, field_names)
}
elif isinstance(value, tuple):
return dict(zip(dedup_field_names, value))
else:
return value
else:
def convert_struct(value: Any) -> Any:
if isinstance(value, dict):
return {
name: conv(v) if conv is not None and v is not None else v
for name, conv, v in zip(
dedup_field_names,
field_convs,
(value.get(key, None) for key in field_names),
)
}
elif isinstance(value, tuple):
return {
name: conv(v) if conv is not None and v is not None else v
for name, conv, v in zip(dedup_field_names, field_convs, value)
}
else:
return value
else:
if all(conv is None for conv in field_convs):
def convert_struct(value: Any) -> Any:
if isinstance(value, dict):
return {
name: value.get(key, None)
for name, key in zip(dedup_field_names, field_names)
}
else:
# tuple
return dict(zip(dedup_field_names, value))
else:
def convert_struct(value: Any) -> Any:
if isinstance(value, dict):
return {
name: conv(v) if conv is not None and v is not None else v
for name, conv, v in zip(
dedup_field_names,
field_convs,
(value.get(key, None) for key in field_names),
)
}
else:
# tuple
return {
name: conv(v) if conv is not None and v is not None else v
for name, conv, v in zip(dedup_field_names, field_convs, value)
}
return convert_struct
elif isinstance(dt, TimestampType):
assert timezone is not None
def convert_timestamp(value: Any) -> Any:
if isinstance(value, datetime.datetime) and value.tzinfo is not None:
ts = pd.Timestamp(value)
else:
ts = pd.Timestamp(value).tz_localize(timezone)
return ts.to_pydatetime()
return convert_timestamp
elif isinstance(dt, UserDefinedType):
udt: UserDefinedType = dt
conv = _converter(udt.sqlType())
if conv is None:
def convert_udt(value: Any) -> Any:
return udt.serialize(value)
else:
def convert_udt(value: Any) -> Any:
return conv(udt.serialize(value))
return convert_udt
elif isinstance(dt, VariantType):
def convert_variant(variant: Any) -> Any:
assert isinstance(variant, VariantVal)
return {"value": variant.value, "metadata": variant.metadata}
return convert_variant
return None