def __init__()

in lingvo/core/attention.py [0:0]


  def __init__(self, params):
    """Constructs an LocationSensitiveAttention object."""
    super().__init__(params)
    p = self.params
    self._is_quantized = p.qdomain.default is not None
    assert not p.packed_input, ('Packed input is not supported yet for '
                                'LocationSensitiveAttention.')

    if p.atten_dropout_prob != 0:
      raise NotImplementedError('dropout is not supported')

    def AttenLogits(inputs):
      """Generates logits."""
      fns = self.fns

      def CollapseOutDim(x):
        return tf.reshape(x, [-1, tf.shape(x)[-1]])

      # => [sl, sb, hd]
      location_feats = tf.transpose(inputs.location_feats, [2, 0, 1])
      location_hidden = fns.qmatmul(
          CollapseOutDim(location_feats),
          inputs.location_var,
          qout_name='logits_mul')

      sl = py_utils.GetShape(location_feats)[0]
      tb = py_utils.GetShape(location_feats)[1]
      hd = py_utils.GetShape(inputs.location_var)[1]
      location_hidden = tf.reshape(location_hidden, [sl, tb, hd])
      sb = py_utils.GetShape(inputs.query_vec_reshaped)[2]
      bs_mult = py_utils.GetShape(inputs.query_vec_reshaped)[1]
      location_hidden = tf.reshape(location_hidden, [sl, bs_mult, sb, hd])

      # Shape of summed is [sl, tb/sb, sb, hidden_dim].
      summed = fns.qadd(
          inputs.concated_source_vecs,
          inputs.query_vec_reshaped,
          qout_name='logits_add')
      summed = fns.qadd(summed, location_hidden, qout_name='logits_bias')
      summed = fns.qtanh(summed)
      # logits is of shape [sl * tb/sb * sb, 1]. Computes dot product
      # between v with every rows in 'summed'. Then we reshape the
      # result to be of shape [sl, tb/sb, sb].
      logits = fns.qmatmul(
          tf.reshape(summed, [-1, p.hidden_dim]),
          tf.reshape(inputs.hidden_v, [p.hidden_dim, 1]),
          qout_name='logits')
      logits = tf.reshape(logits, py_utils.GetShape(summed)[:3])
      return logits

    def AttenLogitsSameBatchSize(inputs):
      """Generates logits.

      Optimized code path for when the target and the source have the same batch
      size.

      Args:
        inputs: a NestedMap containing:
          - concated_source_vecs: Tensor of shape [sl, batch, dim]
          - query_vec_transformed: Tensor of shape [batch, dim]
          - hidden_v: Tensor of shape [dim]
          - location_feats: Tensor of shape [batch, location_feature_dim, sl]
          - location_var: Tensor of shape [location_feature_dim, dim]

      Returns:
        logits in the shape [sl, batch_size].
      """

      def CollapseOutDim(x):
        return tf.reshape(x, [-1, tf.shape(x)[-1]])

      fns = self.fns
      # => [sl, sb, hd]
      location_feats = tf.transpose(inputs.location_feats, [2, 0, 1])
      location_hidden = fns.qmatmul(
          CollapseOutDim(location_feats),
          inputs.location_var,
          qout_name='logits_mul')
      sl = tf.shape(location_feats)[0]
      tb = tf.shape(location_feats)[1]
      hd = tf.shape(inputs.location_var)[1]
      location_hidden = tf.reshape(location_hidden, [sl, tb, hd])

      # Shape of summed is [sl, sb, hidden_dim].
      summed = fns.qadd(
          inputs.concated_source_vecs,
          tf.expand_dims(inputs.query_vec_transformed, 0),
          qout_name='logits_add')

      summed = fns.qadd(summed, location_hidden, qout_name='logits_bias')
      summed = fns.qtanh(summed)

      # logits is of shape [sl * sb, 1]. Computes dot product
      # between v with every rows in 'summed'. Then we reshape the
      # result to be of shape [sl, tb].
      logits = fns.qmatmul(
          tf.reshape(summed, [-1, p.hidden_dim]),
          tf.reshape(inputs.hidden_v, [p.hidden_dim, 1]),
          qout_name='logits')
      logits = tf.reshape(logits, py_utils.GetShape(summed)[:2])
      return logits

    def Atten(hidden_var, query_var, source_padding, concated_source_vecs,
              concated_source_contexts, query_vec, attention_state,
              location_filter_var, location_var, per_step_source_padding):
      """Computes the attention context vector."""
      p = self.params
      # attention_state shape [batch, len(p.location_features), slen]
      # it contains previous and accumulated attention probabilites.
      attention_state = py_utils.HasShape(attention_state,
                                          [-1, len(p.location_features), -1])

      fns = self.fns
      location_feats = self._ApplyConv(attention_state, location_filter_var)

      # concated_source_vecs is of shape [sl, sb, dims]
      # concated_source_contexts is of shape [sb, sl, context_dim]
      # query_vec is of shape [tb, dims]
      sb = py_utils.GetShape(concated_source_vecs)[1]
      tb = py_utils.GetShape(query_vec)[0]
      multiplier = tb // sb
      # concated_source_vecs is reshaped to [sl, 1, sb, hidden_dims]
      concated_source_vecs = tf.expand_dims(concated_source_vecs, 1)
      query_vec_transformed = fns.qmatmul(
          query_vec, query_var, qout_name='atten_matmul')
      # query_vec is reshaped to [1, tb/sb, sb, hidden_dims].
      query_vec_reshaped = tf.reshape(query_vec_transformed,
                                      [1, multiplier, sb, p.hidden_dim])
      # logits is of shape [sl, tb/sb, sb]
      logits = _ConditionalCallDefun(
          self._is_quantized, AttenLogits,
          py_utils.NestedMap(
              concated_source_vecs=concated_source_vecs,
              query_vec_reshaped=query_vec_reshaped,
              hidden_v=hidden_var,
              location_feats=location_feats,
              location_var=location_var))
      # Take out the padding states.
      # _source_padding is of shape [sl, sb].
      # reshaped to [sl, 1,  sb].
      source_padding = tf.expand_dims(source_padding, 1)
      per_step_source_padding = tf.reshape(
          tf.transpose(per_step_source_padding), [-1, multiplier, sb])
      source_padding = tf.add(source_padding, per_step_source_padding)
      source_padding = self.QRAct(source_padding,
                                  quant_utils.QDistribution.PADDING)

      # Reshape logits to a matrix of shape [tb, sl] and takes the
      # softmax to compute the probabilities.
      logits = tf.transpose(tf.reshape(logits, [-1, tb]))
      source_padding = tf.transpose(tf.reshape(source_padding, [-1, tb]))
      probs = self._PaddedSoftmax(logits, source_padding)
      # Reshape probs to be of shape [tb/sb, sb, sl].
      probs_reshaped = tf.reshape(probs, [multiplier, sb, -1])
      # Transpose probs to be of shape [sb, tb/sb, sl]
      probs_reshaped = tf.transpose(probs_reshaped, [1, 0, 2])
      # [sb, tb/sb, sl] * [sb, sl, context_dim] = [sb, tb/sb, context_dim]
      summed = fns.qbatchmatmul(
          tf.cast(probs_reshaped, concated_source_contexts.dtype),
          concated_source_contexts,
          qout_name='atten_context')
      # summed is of shape [tb/sb, sb, context_dim]
      summed = tf.transpose(summed, [1, 0, 2])
      return tf.reshape(summed, [tb, -1]), probs

    def AttenSameBatchSize(hidden_var, query_var, source_padding,
                           concated_source_vecs, concated_source_contexts,
                           query_vec, attention_state, location_filter_var,
                           location_var, per_step_source_padding):
      """Computes the attention context vector.

      Optimized code path for when source and target have the same batch size.
      """
      del per_step_source_padding
      p = self.params
      # attention_state shape [batch, len(p.location_features), slen]
      # it contains previous and accumulated attention probabilites.
      attention_state = py_utils.HasShape(attention_state,
                                          [-1, len(p.location_features), -1])

      fns = self.fns
      location_feats = self._ApplyConv(attention_state, location_filter_var)
      query_vec_transformed = fns.qmatmul(
          query_vec, query_var, qout_name='atten_matmul')
      # logits is of shape [sl, sb]
      logits = _ConditionalCallDefun(
          not self._is_quantized, AttenLogitsSameBatchSize,
          py_utils.NestedMap(
              concated_source_vecs=concated_source_vecs,
              query_vec_transformed=query_vec_transformed,
              hidden_v=hidden_var,
              location_feats=location_feats,
              location_var=location_var))
      # => [sl, tb]
      logits.set_shape(source_padding.shape)
      # Reshape logits to a matrix of shape [tb, sl] and takes the
      # softmax to compute the probabilities.
      logits = tf.transpose(logits)
      source_padding = tf.transpose(source_padding)
      probs = self._PaddedSoftmax(logits, source_padding)
      summed = fns.qbatchmatmul(
          tf.cast(tf.expand_dims(probs, 1), concated_source_contexts.dtype),
          concated_source_contexts,
          qout_name='atten_context')
      return tf.squeeze(summed, 1), probs

    if p.same_batch_size:
      self._ctx_vec = AttenSameBatchSize
    else:
      self._ctx_vec = Atten

    def EncodeSource(src_w, vecs, ctxs):
      fns = self.fns
      time, batch = py_utils.GetShape(vecs, 2)
      ctxs = py_utils.HasShape(ctxs, [time, batch, -1])
      transformed_vecs = tf.reshape(
          fns.qmatmul(
              tf.reshape(vecs, [-1, p.source_dim]),
              src_w,
              qout_name='encode_matmul'), [time, batch, -1])
      transposed_ctxs = tf.transpose(ctxs, [1, 0, 2])
      return transformed_vecs, transposed_ctxs

    self._encode_source = EncodeSource