def call()

in tensorflow_gnn/graph/keras/layers/next_state.py [0:0]


  def call(
      self, inputs: Tuple[
          const.FieldOrFields, const.FieldsNest, const.FieldsNest
      ]) -> const.FieldOrFields:
    # Extract the feature for a skip connection.
    self_input = inputs[0]
    if isinstance(self_input, (tf.Tensor, tf.RaggedTensor)):
      skip_connection_feature = self_input
      skip_connection_msg = "single input"
    else:
      try:
        skip_connection_feature = self_input[self._skip_connection_feature_name]
      except KeyError as e:
        raise KeyError(
            "ResidualNextState() could not find the "
            f"skip connection feature '{self._skip_connection_feature_name}' "
            f"in the features of the updated graph piece: {list(self_input)}"
        ) from e
      skip_connection_msg = (
          f"input feature '{self._skip_connection_feature_name}'")

    # Compute the state update.
    net = tf.nest.flatten(inputs)
    net = tf.concat(net, axis=-1)
    net = self._residual_block(net)
    if not skip_connection_feature.shape.is_compatible_with(net.shape):
      raise ValueError(
          "A ResidualNextState() requires an update_fn whose "
          "output has the same shape as the input state, but got "
          f"output shape {net.shape.as_list()} vs "
          f"input shape {skip_connection_feature.shape.as_list()} "
          f"from {skip_connection_msg}.")
    net = tf.add(net, skip_connection_feature)
    net = self._activation(net)
    return net