def call()

in tensorflow_recommenders/layers/feature_interaction/dot_interaction.py [0:0]


  def call(self, inputs: List[tf.Tensor]) -> tf.Tensor:
    """Performs the interaction operation on the tensors in the list.

    The tensors represent as transformed dense features and embedded categorical
    features.
    Pre-condition: The tensors should all have the same shape.

    Args:
      inputs: List of features with shapes [batch_size, feature_dim].

    Returns:
      activations: Tensor representing interacted features. It has a dimension
      `num_features * num_features` if skip_gather is True, otherside
      `num_features * (num_features + 1) / 2` if self_interaction is True and
      `num_features * (num_features - 1) / 2` if self_interaction is False.
    """
    num_features = len(inputs)
    batch_size = tf.shape(inputs[0])[0]
    feature_dim = tf.shape(inputs[0])[1]
    # concat_features shape: batch_size, num_features, feature_dim
    try:
      concat_features = tf.concat(inputs, axis=-1)
      concat_features = tf.reshape(concat_features,
                                   [batch_size, -1, feature_dim])
    except (ValueError, tf.errors.InvalidArgumentError) as e:
      raise ValueError(f"Input tensors` dimensions must be equal, original"
                       f"error message: {e}")

    # Interact features, select lower-triangular portion, and re-shape.
    xactions = tf.matmul(concat_features, concat_features, transpose_b=True)
    ones = tf.ones_like(xactions)
    if self._self_interaction:
      # Selecting lower-triangular portion including the diagonal.
      lower_tri_mask = tf.linalg.band_part(ones, -1, 0)
      upper_tri_mask = ones - lower_tri_mask
      out_dim = num_features * (num_features + 1) // 2
    else:
      # Selecting lower-triangular portion not included the diagonal.
      upper_tri_mask = tf.linalg.band_part(ones, 0, -1)
      lower_tri_mask = ones - upper_tri_mask
      out_dim = num_features * (num_features - 1) // 2

    if self._skip_gather:
      # Setting upper tiangle part of the interaction matrix to zeros.
      activations = tf.where(condition=tf.cast(upper_tri_mask, tf.bool),
                             x=tf.zeros_like(xactions),
                             y=xactions)
      out_dim = num_features * num_features
    else:
      activations = tf.boolean_mask(xactions, lower_tri_mask)
    activations = tf.reshape(activations, (batch_size, out_dim))
    return activations