def __call__()

in tensorflow_ranking/python/keras/model.py [0:0]


  def __call__(self) -> Tuple[TensorDict, TensorDict]:
    """See `InputCreator`."""

    def get_keras_input(feature_spec, name, is_example=False):
      if isinstance(feature_spec, tf.io.FixedLenFeature):
        return tf.keras.Input(
            shape=(None,) + tuple(feature_spec.shape)
            if is_example else tuple(feature_spec.shape),
            dtype=feature_spec.dtype,
            name=name)
      elif isinstance(feature_spec, tf.io.VarLenFeature):
        return tf.keras.Input(
            shape=(None, 1) if is_example else (1),
            dtype=feature_spec.dtype,
            name=name,
            sparse=True)
      elif isinstance(feature_spec, tf.io.RaggedFeature):
        return tf.keras.Input(
            shape=(None,) *
            (len(feature_spec.partitions) + 2) if is_example else
            (None,) * (len(feature_spec.partitions) + 1),
            dtype=feature_spec.dtype,
            name=name,
            ragged=True)
      else:
        raise ValueError("{} is not supported.".format(feature_spec))

    context_inputs = {
        name: get_keras_input(spec, name)
        for name, spec in self._context_feature_spec.items()
    }
    example_inputs = {
        name: get_keras_input(spec, name, is_example=True)
        for name, spec in self._example_feature_spec.items()
    }
    return context_inputs, example_inputs