in tensorflow_recommenders/layers/feature_interaction/dot_interaction.py [0:0]
def call(self, inputs: List[tf.Tensor]) -> tf.Tensor:
"""Performs the interaction operation on the tensors in the list.
The tensors represent as transformed dense features and embedded categorical
features.
Pre-condition: The tensors should all have the same shape.
Args:
inputs: List of features with shapes [batch_size, feature_dim].
Returns:
activations: Tensor representing interacted features. It has a dimension
`num_features * num_features` if skip_gather is True, otherside
`num_features * (num_features + 1) / 2` if self_interaction is True and
`num_features * (num_features - 1) / 2` if self_interaction is False.
"""
num_features = len(inputs)
batch_size = tf.shape(inputs[0])[0]
feature_dim = tf.shape(inputs[0])[1]
# concat_features shape: batch_size, num_features, feature_dim
try:
concat_features = tf.concat(inputs, axis=-1)
concat_features = tf.reshape(concat_features,
[batch_size, -1, feature_dim])
except (ValueError, tf.errors.InvalidArgumentError) as e:
raise ValueError(f"Input tensors` dimensions must be equal, original"
f"error message: {e}")
# Interact features, select lower-triangular portion, and re-shape.
xactions = tf.matmul(concat_features, concat_features, transpose_b=True)
ones = tf.ones_like(xactions)
if self._self_interaction:
# Selecting lower-triangular portion including the diagonal.
lower_tri_mask = tf.linalg.band_part(ones, -1, 0)
upper_tri_mask = ones - lower_tri_mask
out_dim = num_features * (num_features + 1) // 2
else:
# Selecting lower-triangular portion not included the diagonal.
upper_tri_mask = tf.linalg.band_part(ones, 0, -1)
lower_tri_mask = ones - upper_tri_mask
out_dim = num_features * (num_features - 1) // 2
if self._skip_gather:
# Setting upper tiangle part of the interaction matrix to zeros.
activations = tf.where(condition=tf.cast(upper_tri_mask, tf.bool),
x=tf.zeros_like(xactions),
y=xactions)
out_dim = num_features * num_features
else:
activations = tf.boolean_mask(xactions, lower_tri_mask)
activations = tf.reshape(activations, (batch_size, out_dim))
return activations