def __init__()

in tensorflow_gnn/graph/keras/layers/gat_v2.py [0:0]


  def __init__(self,
               *,
               num_heads: int,
               per_head_channels: int,
               receiver_tag: Optional[const.IncidentNodeOrContextTag] = None,
               receiver_feature: const.FieldName = const.DEFAULT_STATE_NAME,
               sender_node_feature: Optional[
                   const.FieldName] = const.DEFAULT_STATE_NAME,
               sender_edge_feature: Optional[const.FieldName] = None,
               use_bias: bool = True,
               edge_dropout: float = 0.,
               attention_activation: Union[str,
                                           Callable[..., Any]] = "leaky_relu",
               activation: Union[str, Callable[..., Any]] = "relu",
               kernel_initializer: Union[
                   None, str, tf.keras.initializers.Initializer] = None,
               **kwargs):
    kwargs.setdefault("name", "gat_v2_convolution")
    super().__init__(**kwargs)

    if num_heads <= 0:
      raise ValueError(f"Number of heads {num_heads} must be greater than 0.")
    self._num_heads = num_heads

    if per_head_channels <= 0:
      raise ValueError(
          f"Per-head channels {per_head_channels} must be greater than 0.")
    self._per_head_channels = per_head_channels

    self._receiver_tag = receiver_tag
    self._receiver_feature = receiver_feature
    self._sender_node_feature = sender_node_feature
    self._sender_edge_feature = sender_edge_feature
    self._use_bias = use_bias

    if not 0 <= edge_dropout < 1:
      raise ValueError(f"Edge dropout {edge_dropout} must be in [0, 1).")
    self._edge_dropout = edge_dropout

    self._attention_activation = tf.keras.activations.get(attention_activation)
    self._activation = tf.keras.activations.get(activation)
    self._kernel_initializer = kernel_initializer

    # Create the transformations for the query input in all heads.
    self._w_query = tf.keras.layers.Dense(
        per_head_channels * num_heads,
        kernel_initializer=kernel_initializer,
        # This bias gets added to the attention features but not the outputs.
        use_bias=use_bias,
        name="query")

    # Create the transformations for value input from sender nodes and edges.
    if sender_node_feature is not None:
      self._w_sender_node = tf.keras.layers.Dense(
          per_head_channels * num_heads,
          kernel_initializer=kernel_initializer,
          # This bias gets added to the attention features and the outputs.
          use_bias=use_bias,
          name="value_node")
    else:
      self._w_sender_node = None
    if sender_edge_feature is not None:
      self._w_sender_edge = tf.keras.layers.Dense(
          per_head_channels * num_heads,
          kernel_initializer=kernel_initializer,
          # This bias would be redundant with self._w_sender_node.
          use_bias=use_bias and self._w_sender_node is None,
          name="value_edge")
    else:
      self._w_sender_edge = None
    if self._w_sender_node is None and self._w_sender_edge is None:
      raise ValueError("GATv2Attention initialized with no inputs.")

    # Create attention logits layers, one for each head. Note that we can't
    # use a single Dense layer that outputs `num_heads` units because we need
    # to apply a different attention function a_k to its corresponding
    # W_k-transformed features.
    self._attention_logits_fn = tf.keras.layers.experimental.EinsumDense(
        "...ik,ki->...i",
        output_shape=(None, num_heads, 1),  # TODO(b/205825425): (num_heads,)
        kernel_initializer=kernel_initializer,
        name="attn_logits")