def _convolve()

in tensorflow_gnn/graph/keras/layers/gat_v2.py [0:0]


  def _convolve(self, *,
                sender_node_input: Optional[tf.Tensor],
                sender_edge_input: Optional[tf.Tensor],
                receiver_input: tf.Tensor,
                broadcast_from_sender_node: Callable[[tf.Tensor], tf.Tensor],
                broadcast_from_receiver: Callable[[tf.Tensor], tf.Tensor],
                softmax_per_receiver: Callable[[tf.Tensor], tf.Tensor],
                pool_to_receiver: Callable[[str, tf.Tensor], tf.Tensor],
                training: bool) -> tf.Tensor:
    """Returns the convolution result.

    The Tensor inputs to this function still have their original shapes
    and need to be broadcast such that the leading dimension is indexed
    by the items in the graph that are attended to (usually edges; except
    when convolving from nodes to context). In the end, values have to be
    pooled from items into a Tensor with a leading dimension indexed by
    receivers, see `pool_to_receiver`.

    Args:
      sender_node_input: The input Tensor from the sender NodeSet, or None.
        See broadcast_from_sender_node.
      sender_edge_input: The input Tensor from the sender EdgeSet, or None.
        If present, this Tensor is already indexed by the items to attend to.
      receiver_input: The input Tensor from the receiver NodeSet or Context.
        See broadcast_from_receiver.
      broadcast_from_sender_node: A function that broadcasts a Tensor
        indexed like sender_node_input to a Tensor indexed by the items
        that are attended to.
      broadcast_from_receiver: A function that broadcasts a Tensor
        indexed like receiver_input to a Tensor indexed by the items
        that are attended to.
      softmax_per_receiver: A function accepts an item-indexed tensor,
        applies softmax normalization to values with a common receiver and
        same trailing indices, and returns the result with unchanged shape.
      pool_to_receiver: A function that pools an item-indexed Tensor to a
        receiver-indexed tensor by summation across items with the same
        receiver.
      training: A boolean. If true, compute the result of training rather
        than inference.

    Returns:
      A Tensor whose leading dimension is indexed by receivers, with the
      result of the convolution.
    """
    # Form the attention query for each head.
    # [num_items, *extra_dims, num_heads, channels_per_head]
    query = broadcast_from_receiver(self._split_heads(self._w_query(
        receiver_input)))
    # TODO(b/205960151): Optionally include a context feature.

    # Form the attention value by transforming the configured inputs
    # and adding up the transformed values.
    # [num_items, *extra_dims, num_heads, channels_per_head]
    value_terms = []
    if sender_node_input is not None:
      value_terms.append(broadcast_from_sender_node(
          self._split_heads(self._w_sender_node(sender_node_input))))
    if sender_edge_input is not None:
      value_terms.append(
          self._split_heads(self._w_sender_edge(sender_edge_input)))
    assert value_terms, "Internal error: no values, __init__ should catch this."
    value = tf.add_n(value_terms)

    # Compute the features from which attention logits are computed.
    # [num_items, *extra_dims, num_heads, channels_per_head]
    attention_features = self._attention_activation(query + value)

    # Compute the attention logits and softmax to get the coefficients.
    # [num_items, *extra_dims, num_heads, 1]
    logits = tf.expand_dims(self._attention_logits_fn(attention_features), -1)
    attention_coefficients = softmax_per_receiver(logits)

    if training:
      # Apply dropout to the normalized attention coefficients, as is done in
      # the original GAT paper. This should have the same effect as edge
      # dropout. Also, note that tf.nn.dropout upscales the remaining values,
      # which should maintain the sum-up-to-1 per node in expectation.
      attention_coefficients = tf.nn.dropout(attention_coefficients,
                                             self._edge_dropout)

    # Apply the attention coefficients to the transformed query.
    # [num_items, *extra_dims, num_heads, per_head_channels]
    messages = value * attention_coefficients
    # Take the sum of the weighted values, which equals the weighted average.
    # Receivers without incoming senders get the empty sum 0.
    # [num_receivers, *extra_dims, num_heads, per_head_channels]
    pooled_messages = pool_to_receiver("sum", messages)
    # Apply the nonlinearity.
    pooled_messages = self._activation(pooled_messages)
    pooled_messages = self._merge_heads(pooled_messages)

    return pooled_messages