in tensorflow_gnn/graph/keras/layers/gat_v2.py [0:0]
def call(self, graph: gt.GraphTensor, *,
edge_set_name: Optional[gt.EdgeSetName] = None,
node_set_name: Optional[gt.NodeSetName] = None,
receiver_tag: Optional[const.IncidentNodeOrContextTag] = None,
training: bool = None) -> gt.GraphTensor:
# pylint: disable=g-long-lambda
# Normalize inputs.
gt.check_scalar_graph_tensor(graph, "GATv2Convolution")
# TODO(b/205960151): make a helper for this or align with graph_ops.py
if receiver_tag is None:
if self._receiver_tag is None:
raise ValueError("GATv2Convolution requires receiver_tag to be set "
"at init or call time")
receiver_tag = self._receiver_tag
else:
if self._receiver_tag not in [None, receiver_tag]:
raise ValueError(
f"GATv2Convolution(..., receiver_tag={self._receiver_tag})"
f"was called with contradictory value receiver_tag={receiver_tag}")
# Find the receiver graph piece (NodeSet or Context), the EdgeSet (if any)
# and the sender NodeSet (if any) with its broadcasting function.
if receiver_tag == const.CONTEXT:
if (edge_set_name is None) + (node_set_name is None) != 1:
raise ValueError(
"Must pass exactly one of edge_set_name, node_set_name "
"for receiver_tag CONTEXT.")
if edge_set_name is not None:
# Pooling from EdgeSet to Context; no node set involved.
name_kwarg = dict(edge_set_name=edge_set_name)
edge_set = graph.edge_sets[edge_set_name]
sender_node_set = None
broadcast_from_sender_node = None
else:
# Pooling from NodeSet to Context, no EdgeSet involved.
name_kwarg = dict(node_set_name=node_set_name)
edge_set = None
sender_node_set = graph.node_sets[node_set_name]
# Values are computed per sender node, no need to broadcast
broadcast_from_sender_node = lambda x: x
receiver_piece = graph.context
else:
# Convolving from nodes to nodes.
if edge_set_name is None or node_set_name is not None:
raise ValueError("Must pass edge_set_name, not node_set_name")
name_kwarg = dict(edge_set_name=edge_set_name)
edge_set = graph.edge_sets[edge_set_name]
sender_node_tag = reverse_tag(receiver_tag)
sender_node_set = graph.node_sets[
edge_set.adjacency.node_set_name(sender_node_tag)]
broadcast_from_sender_node = lambda x: ops.broadcast_node_to_edges(
graph, edge_set_name, sender_node_tag, feature_value=x)
receiver_piece = graph.node_sets[
edge_set.adjacency.node_set_name(receiver_tag)]
# Set up the broadcast/pool ops for the receiver. The tag and name arguments
# conveniently encode the distinction between operating over edge/node,
# node/context or edge/context.
broadcast_from_receiver = lambda x: ops.broadcast(
graph, receiver_tag, **name_kwarg, feature_value=x)
# If the call/convolve split gets reused beyond this class, this shouldn't
# be hardwired to softmax but support binding args for a custom (set of)
# functions with this interface.
softmax_per_receiver = lambda x: normalization_ops.softmax(
graph, receiver_tag, **name_kwarg, feature_value=x)
pool_to_receiver = lambda reduce_type, x: ops.pool(
graph, receiver_tag, **name_kwarg, reduce_type=reduce_type,
feature_value=x)
# Set up the inputs.
receiver_input = receiver_piece[self._receiver_feature]
if None not in [sender_node_set, self._sender_node_feature]:
sender_node_input = sender_node_set[self._sender_node_feature]
else:
sender_node_input = None
if None not in [edge_set, self._sender_edge_feature]:
sender_edge_input = edge_set[self._sender_edge_feature]
else:
sender_edge_input = None
return self._convolve(
sender_node_input=sender_node_input,
sender_edge_input=sender_edge_input,
receiver_input=receiver_input,
broadcast_from_sender_node=broadcast_from_sender_node,
broadcast_from_receiver=broadcast_from_receiver,
softmax_per_receiver=softmax_per_receiver,
pool_to_receiver=pool_to_receiver,
training=training)