in tensorflow_gnn/graph/keras/layers/next_state.py [0:0]
def call(
self, inputs: Tuple[
const.FieldOrFields, const.FieldsNest, const.FieldsNest
]) -> const.FieldOrFields:
# Extract the feature for a skip connection.
self_input = inputs[0]
if isinstance(self_input, (tf.Tensor, tf.RaggedTensor)):
skip_connection_feature = self_input
skip_connection_msg = "single input"
else:
try:
skip_connection_feature = self_input[self._skip_connection_feature_name]
except KeyError as e:
raise KeyError(
"ResidualNextState() could not find the "
f"skip connection feature '{self._skip_connection_feature_name}' "
f"in the features of the updated graph piece: {list(self_input)}"
) from e
skip_connection_msg = (
f"input feature '{self._skip_connection_feature_name}'")
# Compute the state update.
net = tf.nest.flatten(inputs)
net = tf.concat(net, axis=-1)
net = self._residual_block(net)
if not skip_connection_feature.shape.is_compatible_with(net.shape):
raise ValueError(
"A ResidualNextState() requires an update_fn whose "
"output has the same shape as the input state, but got "
f"output shape {net.shape.as_list()} vs "
f"input shape {skip_connection_feature.shape.as_list()} "
f"from {skip_connection_msg}.")
net = tf.add(net, skip_connection_feature)
net = self._activation(net)
return net