def span_overlaps()

in tensorflow_text/python/ops/pointer_ops.py [0:0]


def span_overlaps(source_start,
                  source_limit,
                  target_start,
                  target_limit,
                  contains=False,
                  contained_by=False,
                  partial_overlap=False,
                  name=None):
  """Returns a boolean tensor indicating which source and target spans overlap.

  The source and target spans are specified using B+1 dimensional tensors,
  with `B>=0` batch dimensions followed by a final dimension that lists the
  span offsets for each span in the batch:

  * The `i`th source span in batch `b1...bB` starts at
    `source_start[b1...bB, i]` (inclusive), and extends to just before
    `source_limit[b1...bB, i]` (exclusive).
  * The `j`th target span in batch `b1...bB` starts at
    `target_start[b1...bB, j]` (inclusive), and extends to just before
    `target_limit[b1...bB, j]` (exclusive).

  `result[b1...bB, i, j]` is true if the `i`th source span overlaps with the
  `j`th target span in batch `b1...bB`, where a source span overlaps a target
  span if any of the following are true:

    * The spans are identical.
    * `contains` is true, and the source span contains the target span.
    * `contained_by` is true, and the source span is contained by the target
      span.
    * `partial_overlap` is true, and there is a non-zero overlap between the
      source span and the target span.

  #### Example:
    Given the following source and target spans (with no batch dimensions):

    >>>  #         0    5    10   15   20   25   30   35   40
    >>>  #         |====|====|====|====|====|====|====|====|
    >>>  # Source: [-0-]     [-1-] [2] [-3-][-4-][-5-]
    >>>  # Target: [-0-][-1-]     [-2-] [3]   [-4-][-5-]
    >>>  #         |====|====|====|====|====|====|====|====|
    >>> source_start = [0, 10, 16, 20, 25, 30]
    >>> source_limit = [5, 15, 19, 25, 30, 35]
    >>> target_start = [0,  5, 15, 21, 27, 31]
    >>> target_limit = [5, 10, 20, 24, 32, 37]

    `result[i, j]` will be true at the following locations:

      * `[0, 0]` (always)
      * `[2, 2]` (if contained_by=True or partial_overlaps=True)
      * `[3, 3]` (if contains=True or partial_overlaps=True)
      * `[4, 4]` (if partial_overlaps=True)
      * `[5, 4]` (if partial_overlaps=True)
      * `[5, 5]` (if partial_overlaps=True)

  Args:
    source_start: A B+1 dimensional potentially ragged tensor with shape
      `[D1...DB, source_size]`: the start offset of each source span.
    source_limit: A B+1 dimensional potentially ragged tensor with shape
      `[D1...DB, source_size]`: the limit offset of each source span.
    target_start: A B+1 dimensional potentially ragged tensor with shape
      `[D1...DB, target_size]`: the start offset of each target span.
    target_limit: A B+1 dimensional potentially ragged tensor with shape
      `[D1...DB, target_size]`: the limit offset of each target span.
    contains: If true, then a source span is considered to overlap a target span
      when the source span contains the target span.
    contained_by: If true, then a source span is considered to overlap a target
      span when the source span is contained by the target span.
    partial_overlap: If true, then a source span is considered to overlap a
      target span when the source span partially overlaps the target span.
    name: A name for the operation (optional).

  Returns:
    A B+2 dimensional potentially ragged boolean tensor with shape
    `[D1...DB, source_size, target_size]`.

  Raises:
    ValueError: If the span tensors are incompatible.
  """
  _check_type(contains, 'contains', bool)
  _check_type(contained_by, 'contained_by', bool)
  _check_type(partial_overlap, 'partial_overlap', bool)

  scope_tensors = [source_start, source_limit, target_start, target_limit]
  with ops.name_scope(name, 'SpanOverlaps', scope_tensors):
    # Convert input tensors.
    source_start = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        source_start, name='source_start')
    source_limit = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        source_limit, name='source_limit')
    target_start = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        target_start, name='target_start')
    target_limit = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        target_limit, name='target_limit')
    span_tensors = [source_start, source_limit, target_start, target_limit]

    # Verify input tensor shapes and types.
    source_start.shape.assert_is_compatible_with(source_limit.shape)
    target_start.shape.assert_is_compatible_with(target_limit.shape)
    source_start.shape.assert_same_rank(target_start.shape)
    source_start.shape.assert_same_rank(target_limit.shape)
    source_limit.shape.assert_same_rank(target_start.shape)
    source_limit.shape.assert_same_rank(target_limit.shape)
    if not (source_start.dtype == target_start.dtype == source_limit.dtype ==
            target_limit.dtype):
      raise TypeError('source_start, source_limit, target_start, and '
                      'target_limit must all have the same dtype')
    ndims = set(
        [t.shape.ndims for t in span_tensors if t.shape.ndims is not None])
    assert len(ndims) <= 1  # because of assert_same_rank statements above.

    if all(not isinstance(t, ragged_tensor.RaggedTensor) for t in span_tensors):
      return _span_overlaps(source_start, source_limit, target_start,
                            target_limit, contains, contained_by,
                            partial_overlap)

    elif all(isinstance(t, ragged_tensor.RaggedTensor) for t in span_tensors):
      if not ndims:
        raise ValueError('For ragged inputs, the shape.ndims of at least one '
                         'span tensor must be statically known.')
      if list(ndims)[0] == 2:
        return _span_overlaps(source_start, source_limit, target_start,
                              target_limit, contains, contained_by,
                              partial_overlap)
      else:
        # Handle ragged batch dimension by recursion on values.
        row_splits = span_tensors[0].row_splits
        shape_checks = [
            check_ops.assert_equal(
                t.row_splits,
                row_splits,
                message='Mismatched ragged shapes for batch dimensions')
            for t in span_tensors[1:]
        ]
        with ops.control_dependencies(shape_checks):
          return ragged_tensor.RaggedTensor.from_row_splits(
              span_overlaps(source_start.values, source_limit.values,
                            target_start.values, target_limit.values, contains,
                            contained_by, partial_overlap), row_splits)

    else:
      # Mix of dense and ragged tensors.
      raise ValueError('Span tensors must all have the same ragged_rank')