def softmax()

in tensorflow_gnn/graph/normalization_ops.py [0:0]


def softmax(
    graph_tensor: gt.GraphTensor,
    per_tag: const.IncidentNodeOrContextTag,
    *,
    edge_set_name: Optional[const.EdgeSetName] = None,
    node_set_name: Optional[const.NodeSetName] = None,
    feature_value: Optional[gt.Field] = None,
    feature_name: Optional[gt.FieldName] = None) -> gt.Field:
  """Computes softmax over a many-to-one relationship in a GraphTensor.

  This function can be used to compute a softmax normalization...
    * of edge values, across the edges with a common incident node at `per_tag`
      (e.g., SOURCE or TARGET);
    * of node values, across all the nodes in the same graph component;
    * of edge values, across all the edges in the same graph component.
   For non-scalar values, the softmax function is applied element-wise.

  Args:
    graph_tensor: A scalar GraphTensor.
    per_tag: tfgnn.CONTEXT for normalization per graph component, or an incident
      node tag (e.g., tfgnn.SOURCE or tfgnn.TARGET) for normalization per
      common incident node.
    edge_set_name: The name of the edge set on which values are normalized
      Exactly one of edge_set_name and node_set_name must be set.
    node_set_name: The name of the node set on which values are normalized,
      allowed only if per_tag is CONTEXT. See also edge_set_name.
    feature_value: A ragged or dense tensor with the value; cf. feature_name.
    feature_name: The name of the feature to be used as input value.
      Exactly one of feature_value or feature_name must be set.

  Raises:
    ValueError: if `graph_tensor` does not contain an edge set or node set
      of the given name.

  Returns:
    The softmaxed values. The dimensions do not change from the input.
  """
  # Set up the `value` to be softmaxed with generic `pool` and `broadcast`.
  if bool(edge_set_name is None) + bool(node_set_name is None) != 1:
    raise ValueError("Must pass exactly one of edge_set_name, node_set_name.")
  if edge_set_name:
    value = ops.resolve_value(
        graph_tensor.edge_sets[edge_set_name],
        feature_value=feature_value, feature_name=feature_name)
    pool = functools.partial(
        ops.pool, graph_tensor, per_tag, edge_set_name=edge_set_name)
    broadcast = functools.partial(
        ops.broadcast, graph_tensor, per_tag, edge_set_name=edge_set_name)
  else:
    value = ops.resolve_value(
        graph_tensor.node_sets[node_set_name],
        feature_value=feature_value, feature_name=feature_name)
    pool = functools.partial(
        ops.pool, graph_tensor, per_tag, node_set_name=node_set_name)
    broadcast = functools.partial(
        ops.broadcast, graph_tensor, per_tag, node_set_name=node_set_name)

  # Compute softmax. Subtract the maxes for numerical stability.
  # Some segment_maxes may be -inf, but that's broadcast nowhere.
  segment_maxes = pool(reduce_type="max", feature_value=value)
  maxes = broadcast(feature_value=segment_maxes)
  exp_edge_value = tf.exp(value - maxes)
  sum_exp_value = pool(reduce_type="sum", feature_value=exp_edge_value)
  return exp_edge_value / broadcast(feature_value=sum_exp_value)