def _make_model_or_none()

in tensorflow_gnn/graph/keras/layers/map_features.py [0:0]


def _make_model_or_none(model_fn, graph_piece_spec, **kwargs):
  """Returns a Model to map this graph piece, or None to leave it alone."""
  if model_fn is None:
    return None  # This graph piece is to be left alone.

  graph_piece_input = tf.keras.layers.Input(type_spec=graph_piece_spec)
  raw_outputs = model_fn(graph_piece_input, **kwargs)
  if raw_outputs is None:
    return None  # As if model_fn were None to begin with.
  if isinstance(raw_outputs, Mapping):
    outputs = dict(raw_outputs)
  else:
    outputs = {const.DEFAULT_STATE_NAME: raw_outputs}

  non_keras_outputs = {k: v for k, v in outputs.items()
                       if not tf.keras.backend.is_keras_tensor(v)}
  if non_keras_outputs:
    raise ValueError(
        "MapFeatures(...=fn) requires the callback fn to return KerasTensors "
        "that depend on the input to fn. For values created from scratch, "
        "use tfgnn.keras.layers.TotalSize()(...) to get the (possibly static) "
        "output size with a proper dependency on the input.\n"
        f"The callback for {kwargs or 'context'} "
        f"returned the following non-KerasTensor outputs: {non_keras_outputs}")

  return tf.keras.Model(graph_piece_input, outputs)