def CanHandle()

in tfx_bsl/tfxio/tensor_adapter.py [0:0]


  def CanHandle(arrow_schema: pa.Schema,
                tensor_representation: schema_pb2.TensorRepresentation) -> bool:
    """Returns whether `tensor_representation` can be handled.

    The case where the tensor_representation cannot be handled is when:
    1. Wrong column name / field name requested.
    2. Non-leaf field is requested (for StructTypes).
    3. There does not exist a ListType along the path.
    4. Requested partitions paths are not an integer values or doesn't exist.

    Args:
      arrow_schema: The pyarrow schema.
      tensor_representation: The TensorRepresentation proto.
    """
    ragged_tensor = tensor_representation.ragged_tensor
    if len(ragged_tensor.feature_path.step) < 1:
      return False

    value_path = path.ColumnPath.from_proto(ragged_tensor.feature_path)

    # Checking the outer dimensions represented by the value feature path.
    contains_list = False
    try:
      arrow_type = None
      for arrow_type in _EnumerateTypesAlongPath(arrow_schema, value_path):
        if _IsListLike(arrow_type):
          contains_list = True
      if pa.types.is_struct(arrow_type):
        # The path is depleted, but the last arrow_type is a struct. This means
        # the path is a Non-leaf field.
        return False
    except ValueError:
      # ValueError signifies wrong column name / field name requested.
      return False
    if not contains_list:
      return False

    # Check the auxiliar features that need to be accessed to form the inner
    # dimensions partitions.
    parent_path = value_path.parent()

    # Check the columns exists and have correct depth and type.
    for partition in ragged_tensor.partition:
      if partition.HasField("row_length"):
        try:
          field_path = parent_path.child(partition.row_length)
          # To avoid loop undefined variable lint error.
          partition_type = arrow_schema.field(field_path.initial_step()).type
          for partition_type in _EnumerateTypesAlongPath(
              arrow_schema, field_path, stop_at_path_end=True):
            # Iterate through them all. Only interested on the last type.
            pass
          if not _IsListLike(partition_type) or not pa.types.is_integer(
              partition_type.value_type):
            return False
        except ValueError:
          # ValueError signifies wrong column name / field name requested.
          return False

      elif partition.HasField("uniform_row_length"):
        if partition.uniform_row_length <= 0:
          return False
      else:
        return False

    # All checks passed successfully.
    return True