def call()

in tensorflow_gnn/graph/keras/layers/gat_v2.py [0:0]


  def call(self, graph: gt.GraphTensor, *,
           edge_set_name: Optional[gt.EdgeSetName] = None,
           node_set_name: Optional[gt.NodeSetName] = None,
           receiver_tag: Optional[const.IncidentNodeOrContextTag] = None,
           training: bool = None) -> gt.GraphTensor:
    # pylint: disable=g-long-lambda

    # Normalize inputs.
    gt.check_scalar_graph_tensor(graph, "GATv2Convolution")
    # TODO(b/205960151): make a helper for this or align with graph_ops.py
    if receiver_tag is None:
      if self._receiver_tag is None:
        raise ValueError("GATv2Convolution requires receiver_tag to be set "
                         "at init or call time")
      receiver_tag = self._receiver_tag
    else:
      if self._receiver_tag not in [None, receiver_tag]:
        raise ValueError(
            f"GATv2Convolution(..., receiver_tag={self._receiver_tag})"
            f"was called with contradictory value receiver_tag={receiver_tag}")

    # Find the receiver graph piece (NodeSet or Context), the EdgeSet (if any)
    # and the sender NodeSet (if any) with its broadcasting function.
    if receiver_tag == const.CONTEXT:
      if (edge_set_name is None) + (node_set_name is None) != 1:
        raise ValueError(
            "Must pass exactly one of edge_set_name, node_set_name "
            "for receiver_tag CONTEXT.")
      if edge_set_name is not None:
        # Pooling from EdgeSet to Context; no node set involved.
        name_kwarg = dict(edge_set_name=edge_set_name)
        edge_set = graph.edge_sets[edge_set_name]
        sender_node_set = None
        broadcast_from_sender_node = None
      else:
        # Pooling from NodeSet to Context, no EdgeSet involved.
        name_kwarg = dict(node_set_name=node_set_name)
        edge_set = None
        sender_node_set = graph.node_sets[node_set_name]
        # Values are computed per sender node, no need to broadcast
        broadcast_from_sender_node = lambda x: x
      receiver_piece = graph.context
    else:
      # Convolving from nodes to nodes.
      if edge_set_name is None or node_set_name is not None:
        raise ValueError("Must pass edge_set_name, not node_set_name")
      name_kwarg = dict(edge_set_name=edge_set_name)
      edge_set = graph.edge_sets[edge_set_name]
      sender_node_tag = reverse_tag(receiver_tag)
      sender_node_set = graph.node_sets[
          edge_set.adjacency.node_set_name(sender_node_tag)]
      broadcast_from_sender_node = lambda x: ops.broadcast_node_to_edges(
          graph, edge_set_name, sender_node_tag, feature_value=x)
      receiver_piece = graph.node_sets[
          edge_set.adjacency.node_set_name(receiver_tag)]

    # Set up the broadcast/pool ops for the receiver. The tag and name arguments
    # conveniently encode the distinction between operating over edge/node,
    # node/context or edge/context.
    broadcast_from_receiver = lambda x: ops.broadcast(
        graph, receiver_tag, **name_kwarg, feature_value=x)
    # If the call/convolve split gets reused beyond this class, this shouldn't
    # be hardwired to softmax but support binding args for a custom (set of)
    # functions with this interface.
    softmax_per_receiver = lambda x: normalization_ops.softmax(
        graph, receiver_tag, **name_kwarg, feature_value=x)
    pool_to_receiver = lambda reduce_type, x: ops.pool(
        graph, receiver_tag, **name_kwarg, reduce_type=reduce_type,
        feature_value=x)

    # Set up the inputs.
    receiver_input = receiver_piece[self._receiver_feature]
    if None not in [sender_node_set, self._sender_node_feature]:
      sender_node_input = sender_node_set[self._sender_node_feature]
    else:
      sender_node_input = None
    if None not in [edge_set, self._sender_edge_feature]:
      sender_edge_input = edge_set[self._sender_edge_feature]
    else:
      sender_edge_input = None

    return self._convolve(
        sender_node_input=sender_node_input,
        sender_edge_input=sender_edge_input,
        receiver_input=receiver_input,
        broadcast_from_sender_node=broadcast_from_sender_node,
        broadcast_from_receiver=broadcast_from_receiver,
        softmax_per_receiver=softmax_per_receiver,
        pool_to_receiver=pool_to_receiver,
        training=training)