in tensorflow_gnn/graph/keras/layers/gat_v2.py [0:0]
def __init__(self,
*,
num_heads: int,
per_head_channels: int,
receiver_tag: Optional[const.IncidentNodeOrContextTag] = None,
receiver_feature: const.FieldName = const.DEFAULT_STATE_NAME,
sender_node_feature: Optional[
const.FieldName] = const.DEFAULT_STATE_NAME,
sender_edge_feature: Optional[const.FieldName] = None,
use_bias: bool = True,
edge_dropout: float = 0.,
attention_activation: Union[str,
Callable[..., Any]] = "leaky_relu",
activation: Union[str, Callable[..., Any]] = "relu",
kernel_initializer: Union[
None, str, tf.keras.initializers.Initializer] = None,
**kwargs):
kwargs.setdefault("name", "gat_v2_convolution")
super().__init__(**kwargs)
if num_heads <= 0:
raise ValueError(f"Number of heads {num_heads} must be greater than 0.")
self._num_heads = num_heads
if per_head_channels <= 0:
raise ValueError(
f"Per-head channels {per_head_channels} must be greater than 0.")
self._per_head_channels = per_head_channels
self._receiver_tag = receiver_tag
self._receiver_feature = receiver_feature
self._sender_node_feature = sender_node_feature
self._sender_edge_feature = sender_edge_feature
self._use_bias = use_bias
if not 0 <= edge_dropout < 1:
raise ValueError(f"Edge dropout {edge_dropout} must be in [0, 1).")
self._edge_dropout = edge_dropout
self._attention_activation = tf.keras.activations.get(attention_activation)
self._activation = tf.keras.activations.get(activation)
self._kernel_initializer = kernel_initializer
# Create the transformations for the query input in all heads.
self._w_query = tf.keras.layers.Dense(
per_head_channels * num_heads,
kernel_initializer=kernel_initializer,
# This bias gets added to the attention features but not the outputs.
use_bias=use_bias,
name="query")
# Create the transformations for value input from sender nodes and edges.
if sender_node_feature is not None:
self._w_sender_node = tf.keras.layers.Dense(
per_head_channels * num_heads,
kernel_initializer=kernel_initializer,
# This bias gets added to the attention features and the outputs.
use_bias=use_bias,
name="value_node")
else:
self._w_sender_node = None
if sender_edge_feature is not None:
self._w_sender_edge = tf.keras.layers.Dense(
per_head_channels * num_heads,
kernel_initializer=kernel_initializer,
# This bias would be redundant with self._w_sender_node.
use_bias=use_bias and self._w_sender_node is None,
name="value_edge")
else:
self._w_sender_edge = None
if self._w_sender_node is None and self._w_sender_edge is None:
raise ValueError("GATv2Attention initialized with no inputs.")
# Create attention logits layers, one for each head. Note that we can't
# use a single Dense layer that outputs `num_heads` units because we need
# to apply a different attention function a_k to its corresponding
# W_k-transformed features.
self._attention_logits_fn = tf.keras.layers.experimental.EinsumDense(
"...ik,ki->...i",
output_shape=(None, num_heads, 1), # TODO(b/205825425): (num_heads,)
kernel_initializer=kernel_initializer,
name="attn_logits")