in tensorflow_gnn/graph/keras/layers/map_features.py [0:0]
def _make_model_or_none(model_fn, graph_piece_spec, **kwargs):
"""Returns a Model to map this graph piece, or None to leave it alone."""
if model_fn is None:
return None # This graph piece is to be left alone.
graph_piece_input = tf.keras.layers.Input(type_spec=graph_piece_spec)
raw_outputs = model_fn(graph_piece_input, **kwargs)
if raw_outputs is None:
return None # As if model_fn were None to begin with.
if isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
else:
outputs = {const.DEFAULT_STATE_NAME: raw_outputs}
non_keras_outputs = {k: v for k, v in outputs.items()
if not tf.keras.backend.is_keras_tensor(v)}
if non_keras_outputs:
raise ValueError(
"MapFeatures(...=fn) requires the callback fn to return KerasTensors "
"that depend on the input to fn. For values created from scratch, "
"use tfgnn.keras.layers.TotalSize()(...) to get the (possibly static) "
"output size with a proper dependency on the input.\n"
f"The callback for {kwargs or 'context'} "
f"returned the following non-KerasTensor outputs: {non_keras_outputs}")
return tf.keras.Model(graph_piece_input, outputs)