def true_segments_1d()

in moonlight/util/segments.py [0:0]


def true_segments_1d(segments,
                     mode=SegmentsMode.CENTERS,
                     max_gap=0,
                     min_length=0,
                     name=None):
  """Labels contiguous True runs in segments.

  Args:
    segments: 1D boolean tensor.
    mode: The SegmentsMode. Returns the start of each segment (STARTS), or the
      rounded center of each segment (CENTERS).
    max_gap: Fill gaps of length at most `max_gap` between true segments. int.
    min_length: Minimum length of a returned segment. int.
    name: Optional name for the op.

  Returns:
    run_centers: int32 tensor. Depending on `mode`, either the start of each
        True run, or the (rounded) center of each True run.
    run_lengths: int32; the lengths of each True run.
  """
  with tf.name_scope(name, "true_segments", [segments]):
    segments = tf.convert_to_tensor(segments, tf.bool)
    run_starts, run_lengths = _segments_1d(segments, mode=SegmentsMode.STARTS)
    # Take only the True runs. After whichever run is True first, the True runs
    # are every other run.
    first_run = tf.cond(
        # First value is False, or all values are False. Handles empty segments
        # correctly.
        tf.logical_or(tf.reduce_any(segments[0:1]), ~tf.reduce_any(segments)),
        lambda: tf.constant(0),
        lambda: tf.constant(1))

    num_runs = tf.shape(run_starts)[0]
    run_nums = tf.range(num_runs)
    is_true_run = tf.equal(run_nums % 2, first_run % 2)
    # Find gaps between True runs that can be merged.
    is_gap = tf.logical_and(
        tf.not_equal(run_nums % 2, first_run % 2),
        tf.logical_and(
            tf.greater(run_nums, first_run), tf.less(run_nums, num_runs - 1)))
    fill_gap = tf.logical_and(is_gap, tf.less_equal(run_lengths, max_gap))

    # Segment the consecutive runs of True or False values based on whether they
    # are True, or are a gap of False values that can be bridged. Then, flatten
    # the runs of runs.
    runs_to_merge = tf.logical_or(is_true_run, fill_gap)
    run_of_run_starts, _ = _segments_1d(runs_to_merge, mode=SegmentsMode.STARTS)

    # Get the start of every new run from the original run starts.
    merged_run_starts = tf.gather(run_starts, run_of_run_starts)
    # Make an array mapping the original runs to their run of runs. Increment
    # the number for every run of run start except for the first one, so that
    # the array has values from 0 to num_run_of_runs.
    merged_run_inds = tf.cumsum(
        tf.sparse_to_dense(
            sparse_indices=tf.cast(run_of_run_starts[1:, None], tf.int64),
            output_shape=tf.cast(num_runs[None], tf.int64),
            sparse_values=tf.ones_like(run_of_run_starts[1:])))
    # Sum the lengths of the original runs that were merged.
    merged_run_lengths = tf.segment_sum(run_lengths, merged_run_inds)

    if mode is SegmentsMode.CENTERS:
      merged_starts_or_centers = (
          merged_run_starts + tf.floordiv(merged_run_lengths - 1, 2))
    else:
      merged_starts_or_centers = merged_run_starts

    # If there are no true values, increment first_run to 1, so we will skip
    # the single (false) run.
    first_run += tf.to_int32(tf.logical_not(tf.reduce_any(segments)))

    merged_starts_or_centers = merged_starts_or_centers[first_run::2]
    merged_run_lengths = merged_run_lengths[first_run::2]

    # Only take segments at least min_length long.
    is_long_enough = tf.greater_equal(merged_run_lengths, min_length)
    is_long_enough.set_shape([None])
    merged_starts_or_centers = tf.boolean_mask(merged_starts_or_centers,
                                               is_long_enough)
    merged_run_lengths = tf.boolean_mask(merged_run_lengths, is_long_enough)

    return merged_starts_or_centers, merged_run_lengths