def flat_beam_search()

in lingvo/core/flat_beam_search_helper.py [0:0]


def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
  """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
  assert beam_size > 0
  assert batch_size > 0
  assert max_steps > 0

  buf_size = beam_size * max_steps
  output_len = max_steps

  if prefix is None:
    assert prefix_len is None
    prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
    prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id
    prefix_len = tf.ones([batch_size], dtype=tf.int32)
  else:
    assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
    assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                    prefix_len.shape)
    output_len += int(prefix.shape[1])

  if debug:
    tpu_summary.tensor('prefix', prefix)
    tpu_summary.tensor('prefix_len', prefix_len)

  with tf.name_scope('init_state'):
    t = tf.constant(0)
    tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
    tgt_id += bos_id
    tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
    tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype)
    tgt_mask += tf.one_hot(tf.range(beam_size), buf_size, dtype=fprop_dtype)
    hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
    # penalize all hyps except the first
    hyp_score -= tf.cast(
        tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)
    nbest_size = nbest_size or beam_size
    nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
    nbest_score -= 1e9
    nbest_score_norm = nbest_score
    nbest_mask = tf.zeros([batch_size, nbest_size, buf_size], dtype=fprop_dtype)

  with tf.name_scope('init_ext'):
    # Initialize the extension buffer.
    #
    # Extension buffer stores a (potentially large) set of 'extensions',
    # which consist of a hypothesis (represented by ext_mask) and next token
    # (represented by ext_id). At each decoder iteration, top_k extensions
    # from each hypothesis are added to the buffer and sorted by score.
    #
    # Then top beam_size extensions are removed from the buffer and used
    # in the next decoder iteration. And top 'ext_size' remaining extensions
    # are carried over to be possibly evaluated at a later step.
    #
    # As a result of this manipulation, the decoder is no longer restricted
    # to always compare hyps of the same token length at each iteration.
    # In particular, for a fixed length N it can generate more than beam_size
    # terminated hyps.
    #
    # Setting ext_size = 0 disables this feautre.
    if ext_size:
      ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
      ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
      ext_score -= 1e9
      ext_mask = tf.zeros([batch_size, ext_size, buf_size], dtype=fprop_dtype)
    else:
      ext_size = ext_id = ext_score = ext_mask = 0

  with tf.name_scope('init_prefix'):
    # rename prefix->pfx for shorter variables
    pfx = tf.cast(prefix, tf.int32)
    pfx_len = tf.cast(prefix_len, tf.int32)
    del prefix, prefix_len
    # Before the first call to dec_callback() the prefix shall be packed into
    # the tgt_id buffer as follows:
    #
    # [ P P P P P P - - - - - - P* - - - ]   ^
    # [ P P P P P P P P P P - - P* - - - ]   | batch
    # [ P - - - - - - - - - - - P* - - - ]   V
    # |<---- prefix len ---->  |<-- beam -->
    #
    # The last meaningful token in the prefix (P*)
    # must be located at the same position in all batch rows.
    #
    # We then make one dec_callback() with full prefix (minus P*)
    # which will populate the initial dec_state
    # (for transformer -- self-attention key/value cache)
    #
    # The last block [batch, beam] then becomes the first tgt_id for the loop.
    pfx_max = int(pfx.shape[1])
    pfx_mul = pfx_max // beam_size
    assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
    pfx_time = tf.range(pfx_max)
    pfx_pad = tf.cast(
        tf.less(tf.expand_dims(pfx_time, 0), tf.expand_dims(pfx_len - 1, 1)),
        tf.int32)
    pfx_id = pfx * pfx_pad
    pfx_last = einsum_i32('BT,BT->B', pfx,
                          tf.one_hot(pfx_len - 1, pfx_max, dtype=fprop_dtype))

    buf_time = tf.range(buf_size)
    pfx_time_mask = tf.cast(
        tf.less_equal(tf.expand_dims(buf_time, 0), tf.expand_dims(pfx_time, 1)),
        fprop_dtype)
    pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                         pfx_time_mask)
    pfx_segment_id = pfx_pad
    pfx_pos = pfx_time * pfx_pad

    if debug:
      tpu_summary.tensor('pfx_id', pfx_id)
      tpu_summary.tensor('pfx_len', pfx_len)
      tpu_summary.tensor('pfx_pos', pfx_pos)
      tpu_summary.tensor('pfx_last', pfx_last)

    # Now call decoder with prefix minus P*:
    # 'dec_state' now shall contain the key/value cache for prefix tokens
    # (for transformer models), and 'logits' we can either discard or
    # roll into the initial hyp_score. Discard is simpler.
    with tf.name_scope('prefix_fprop'):
      # TODO(krikun): remove extra type checks
      assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
      assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
      assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
      assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
      assert (t.dtype == tf.int32), (t.dtype)
      logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                       pfx_mask, dec_state, t)
      del logits

    # Now construct the initial state for the rest of the beam search loop.
    # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
    # 'tgt_pos' is different for each batch row and is equal to prefix_len
    # 'tgt_segment_id' always 1 (no packing)
    # 'hyp_score' is 0 for beam=0 and negative for beam>=1
    tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
        pfx_last, 1)
    tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
        (pfx_len - 1), 1)
    hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
        tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

    # TODO(krikun) Here we make initial 't' constant and determined by the
    # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
    # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
    # however 'max' results in a very slow all-to-all for 'max' on 16x16
    # and variable number of decoder steps may result in bad latency.
    t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

    # Initial tgt_mask is such that each token P* has attention on itself
    # (as usual) and on all prefix tokens before it, which are not padding.
    tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype)
    tgt_mask += tf.cast(
        tf.expand_dims(tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
        fprop_dtype)
    tgt_mask += tf.one_hot(
        tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype)

    if debug:
      tpu_summary.tensor('tgt_id', tgt_id)
      tpu_summary.tensor('tgt_pos', tgt_pos)
      tpu_summary.tensor('tgt_mask', tgt_mask)
      tpu_summary.tensor('t', t)

  with tf.name_scope('init_hist'):
    # h_tgt_id is used to recover topk_ids from nbest_mask
    h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
    h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

    # When non-trivial prefix is present we also write prefix ids to
    # h_tgt_id so that the full sequence including prefix can be recovered
    # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
    # and the loop below becomes a no-op.
    # TODO(krikun): maybe a tf.while_loop is more appropriate here.
    for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
      h_tgt_id = h_tgt_id.write(i, x_i)
    for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
      h_tgt_pos = h_tgt_pos.write(i, x_i)

    hist = (h_tgt_id, h_tgt_pos)
    tf.logging.info('hist=%r', hist)

  nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
  tf.logging.info('nbest_hyps=%r', nbest_hyps)

  ext = (ext_id, ext_score, ext_mask)
  tf.logging.info('ext=%r', ext)

  loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist)
  tf.logging.info('loop_vars=%r', loop_vars)

  def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
    tf.logging.info('loop_vars=%r', loop_vars)
    tf.logging.info('dec_state=%r', dec_state)
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars
    (ext_id, ext_score, ext_mask) = ext
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
    h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
    # not using tf.ones() here because of XLA compilation error
    tgt_segment_id = tgt_id * 0 + 1
    logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos, tgt_mask,
                                     dec_state, t)
    # take predicted EOS score for each hyp and compute normalized score
    eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

    def length_norm(t):
      t = tf.cast(t, fprop_dtype)
      alpha = length_norm_alpha
      tf.logging.info('length_norm.alpha=%r', alpha)
      return tf.math.pow((t + 5.) / 5., alpha)

    hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
    eos_score_norm = eos_score / length_norm(hyp_len)
    # update the n-best list
    nbest_hyps = update_nbest(nbest_hyps, (tgt_mask, hyp_score, eos_score_norm))

    if debug:
      tpu_summary.tensor('eos_score', eos_score)
      tpu_summary.tensor('hyp_len', hyp_len)

    # take top k tokens for each hyp
    k = beam_size
    with tf.name_scope('topk1'):
      top_score, top_id = top_k_fn(logits, k)
      top_score = tf.cast(top_score, fprop_dtype)

    top_score += tf.expand_dims(hyp_score, -1)
    top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

    top_score = tf.reshape(top_score, [batch_size, beam_size * k])
    top_id = tf.reshape(top_id, [batch_size, beam_size * k])
    top_mask = tf.repeat(tgt_mask, beam_size, 1)

    if debug:
      tpu_summary.tensor('top_id', top_id)
      tpu_summary.tensor('top_score', top_score)
      # tpu_summary.tensor('top_mask', top_mask)

    with tf.name_scope('update_ext'):
      # combine top k tokens with extension buffer (if any)
      if ext_size:
        ext_id = tf.concat([ext_id, top_id], 1)
        ext_score = tf.concat([ext_score, top_score], 1)
        ext_mask = tf.concat([ext_mask, top_mask], 1)
      else:
        ext_id, ext_score, ext_mask = top_id, top_score, top_mask

      # sort by score
      ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
      i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
      ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
      ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

      # pick top beam_size extensions to evaluate at next iteration
      if ext_size:
        hyp_score = ext_score[:, :beam_size]
        ext_score = ext_score[:, beam_size:]
        tgt_id = ext_id[:, :beam_size]
        ext_id = ext_id[:, beam_size:]
        tgt_mask = ext_mask[:, :beam_size]
        ext_mask = ext_mask[:, beam_size:]
      else:
        hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
        ext_score = ext_id = ext_mask = 0

    tgt_pos = tf.reduce_sum(tgt_mask, -1)
    tgt_pos = tf.cast(tgt_pos, tf.int32)

    t += 1
    with tf.name_scope('tgt_mask_extend'):
      tgt_mask += tf.one_hot(
          tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype)

    ext = (ext_id, ext_score, ext_mask)
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist)
    tf.logging.info('loop_vars=%r', loop_vars)
    tf.logging.info('dec_state=%r', dec_state)
    return loop_vars, dec_state

  def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
    tf.logging.info('loop_vars=%r', loop_vars)
    tf.logging.info('dec_state=%r', dec_state)
    if beam_gap is None:
      (t, _, _, _, _, _, _, _) = loop_vars
      return t < max_steps
    else:
      (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
      (_, nbest_score, _) = nbest_hyps
      # stop early if all current hyps are significantly worse than nbest
      diff = tf.reduce_min(
          tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
      return tf.math.logical_and(t < max_steps, diff < beam_gap)

  with tf.name_scope('flat_beam_search_loop'):
    (loop_vars, dec_state) = tf.while_loop(
        loop_cond,
        loop_step,
        loop_vars=(loop_vars, dec_state),
        back_prop=False,
        swap_memory=False,
        maximum_iterations=max_steps)

  # flatten all tensorarrays into tensors
  (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars
  (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
  (h_tgt_id, h_tgt_pos) = hist
  h_tgt_id = h_tgt_id.stack()
  h_tgt_pos = h_tgt_pos.stack()
  hist = (h_tgt_id, h_tgt_pos)
  loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist)

  # recover topk_ids from nbest_mask and tgt_id history
  h = tf.transpose(h_tgt_id, [1, 0, 2])
  h = tf.reshape(h, [batch_size, buf_size])

  def unmask(h, m):
    with tf.name_scope('unmask'):
      tpu_summary.tensor('unmask_h', h)
      tpu_summary.tensor('unmask_m', m)
      t = tf.cumsum(m, -1) * m - 1
      mh = einsum_i32('bkt,bt->bkt', m, h)
      t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype)
      x = einsum_i32('bkt,bktT->bkT', mh, t2)
      return tf.cast(x, h.dtype)

  topk_ids = unmask(h, nbest_mask)
  topk_len = tf.reduce_sum(nbest_mask, -1)
  topk_len = tf.cast(topk_len, tf.int32)
  # add eos, because nbest_mask does not encode eos
  topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
  topk_len += 1
  topk_len = tf.minimum(topk_len, output_len)
  topk_score = nbest_score_norm

  nbest = (topk_ids, topk_len, topk_score)

  return loop_vars, dec_state, nbest