in tensorflow_text/python/ops/pointer_ops.py [0:0]
def span_overlaps(source_start,
source_limit,
target_start,
target_limit,
contains=False,
contained_by=False,
partial_overlap=False,
name=None):
"""Returns a boolean tensor indicating which source and target spans overlap.
The source and target spans are specified using B+1 dimensional tensors,
with `B>=0` batch dimensions followed by a final dimension that lists the
span offsets for each span in the batch:
* The `i`th source span in batch `b1...bB` starts at
`source_start[b1...bB, i]` (inclusive), and extends to just before
`source_limit[b1...bB, i]` (exclusive).
* The `j`th target span in batch `b1...bB` starts at
`target_start[b1...bB, j]` (inclusive), and extends to just before
`target_limit[b1...bB, j]` (exclusive).
`result[b1...bB, i, j]` is true if the `i`th source span overlaps with the
`j`th target span in batch `b1...bB`, where a source span overlaps a target
span if any of the following are true:
* The spans are identical.
* `contains` is true, and the source span contains the target span.
* `contained_by` is true, and the source span is contained by the target
span.
* `partial_overlap` is true, and there is a non-zero overlap between the
source span and the target span.
#### Example:
Given the following source and target spans (with no batch dimensions):
>>> # 0 5 10 15 20 25 30 35 40
>>> # |====|====|====|====|====|====|====|====|
>>> # Source: [-0-] [-1-] [2] [-3-][-4-][-5-]
>>> # Target: [-0-][-1-] [-2-] [3] [-4-][-5-]
>>> # |====|====|====|====|====|====|====|====|
>>> source_start = [0, 10, 16, 20, 25, 30]
>>> source_limit = [5, 15, 19, 25, 30, 35]
>>> target_start = [0, 5, 15, 21, 27, 31]
>>> target_limit = [5, 10, 20, 24, 32, 37]
`result[i, j]` will be true at the following locations:
* `[0, 0]` (always)
* `[2, 2]` (if contained_by=True or partial_overlaps=True)
* `[3, 3]` (if contains=True or partial_overlaps=True)
* `[4, 4]` (if partial_overlaps=True)
* `[5, 4]` (if partial_overlaps=True)
* `[5, 5]` (if partial_overlaps=True)
Args:
source_start: A B+1 dimensional potentially ragged tensor with shape
`[D1...DB, source_size]`: the start offset of each source span.
source_limit: A B+1 dimensional potentially ragged tensor with shape
`[D1...DB, source_size]`: the limit offset of each source span.
target_start: A B+1 dimensional potentially ragged tensor with shape
`[D1...DB, target_size]`: the start offset of each target span.
target_limit: A B+1 dimensional potentially ragged tensor with shape
`[D1...DB, target_size]`: the limit offset of each target span.
contains: If true, then a source span is considered to overlap a target span
when the source span contains the target span.
contained_by: If true, then a source span is considered to overlap a target
span when the source span is contained by the target span.
partial_overlap: If true, then a source span is considered to overlap a
target span when the source span partially overlaps the target span.
name: A name for the operation (optional).
Returns:
A B+2 dimensional potentially ragged boolean tensor with shape
`[D1...DB, source_size, target_size]`.
Raises:
ValueError: If the span tensors are incompatible.
"""
_check_type(contains, 'contains', bool)
_check_type(contained_by, 'contained_by', bool)
_check_type(partial_overlap, 'partial_overlap', bool)
scope_tensors = [source_start, source_limit, target_start, target_limit]
with ops.name_scope(name, 'SpanOverlaps', scope_tensors):
# Convert input tensors.
source_start = ragged_tensor.convert_to_tensor_or_ragged_tensor(
source_start, name='source_start')
source_limit = ragged_tensor.convert_to_tensor_or_ragged_tensor(
source_limit, name='source_limit')
target_start = ragged_tensor.convert_to_tensor_or_ragged_tensor(
target_start, name='target_start')
target_limit = ragged_tensor.convert_to_tensor_or_ragged_tensor(
target_limit, name='target_limit')
span_tensors = [source_start, source_limit, target_start, target_limit]
# Verify input tensor shapes and types.
source_start.shape.assert_is_compatible_with(source_limit.shape)
target_start.shape.assert_is_compatible_with(target_limit.shape)
source_start.shape.assert_same_rank(target_start.shape)
source_start.shape.assert_same_rank(target_limit.shape)
source_limit.shape.assert_same_rank(target_start.shape)
source_limit.shape.assert_same_rank(target_limit.shape)
if not (source_start.dtype == target_start.dtype == source_limit.dtype ==
target_limit.dtype):
raise TypeError('source_start, source_limit, target_start, and '
'target_limit must all have the same dtype')
ndims = set(
[t.shape.ndims for t in span_tensors if t.shape.ndims is not None])
assert len(ndims) <= 1 # because of assert_same_rank statements above.
if all(not isinstance(t, ragged_tensor.RaggedTensor) for t in span_tensors):
return _span_overlaps(source_start, source_limit, target_start,
target_limit, contains, contained_by,
partial_overlap)
elif all(isinstance(t, ragged_tensor.RaggedTensor) for t in span_tensors):
if not ndims:
raise ValueError('For ragged inputs, the shape.ndims of at least one '
'span tensor must be statically known.')
if list(ndims)[0] == 2:
return _span_overlaps(source_start, source_limit, target_start,
target_limit, contains, contained_by,
partial_overlap)
else:
# Handle ragged batch dimension by recursion on values.
row_splits = span_tensors[0].row_splits
shape_checks = [
check_ops.assert_equal(
t.row_splits,
row_splits,
message='Mismatched ragged shapes for batch dimensions')
for t in span_tensors[1:]
]
with ops.control_dependencies(shape_checks):
return ragged_tensor.RaggedTensor.from_row_splits(
span_overlaps(source_start.values, source_limit.values,
target_start.values, target_limit.values, contains,
contained_by, partial_overlap), row_splits)
else:
# Mix of dense and ragged tensors.
raise ValueError('Span tensors must all have the same ragged_rank')