in tfx_bsl/tfxio/tensor_adapter.py [0:0]
def CanHandle(arrow_schema: pa.Schema,
tensor_representation: schema_pb2.TensorRepresentation) -> bool:
"""Returns whether `tensor_representation` can be handled.
The case where the tensor_representation cannot be handled is when:
1. Wrong column name / field name requested.
2. Non-leaf field is requested (for StructTypes).
3. There does not exist a ListType along the path.
4. Requested partitions paths are not an integer values or doesn't exist.
Args:
arrow_schema: The pyarrow schema.
tensor_representation: The TensorRepresentation proto.
"""
ragged_tensor = tensor_representation.ragged_tensor
if len(ragged_tensor.feature_path.step) < 1:
return False
value_path = path.ColumnPath.from_proto(ragged_tensor.feature_path)
# Checking the outer dimensions represented by the value feature path.
contains_list = False
try:
arrow_type = None
for arrow_type in _EnumerateTypesAlongPath(arrow_schema, value_path):
if _IsListLike(arrow_type):
contains_list = True
if pa.types.is_struct(arrow_type):
# The path is depleted, but the last arrow_type is a struct. This means
# the path is a Non-leaf field.
return False
except ValueError:
# ValueError signifies wrong column name / field name requested.
return False
if not contains_list:
return False
# Check the auxiliar features that need to be accessed to form the inner
# dimensions partitions.
parent_path = value_path.parent()
# Check the columns exists and have correct depth and type.
for partition in ragged_tensor.partition:
if partition.HasField("row_length"):
try:
field_path = parent_path.child(partition.row_length)
# To avoid loop undefined variable lint error.
partition_type = arrow_schema.field(field_path.initial_step()).type
for partition_type in _EnumerateTypesAlongPath(
arrow_schema, field_path, stop_at_path_end=True):
# Iterate through them all. Only interested on the last type.
pass
if not _IsListLike(partition_type) or not pa.types.is_integer(
partition_type.value_type):
return False
except ValueError:
# ValueError signifies wrong column name / field name requested.
return False
elif partition.HasField("uniform_row_length"):
if partition.uniform_row_length <= 0:
return False
else:
return False
# All checks passed successfully.
return True