in tensorflow_ranking/python/keras/model.py [0:0]
def __call__(self) -> Tuple[TensorDict, TensorDict]:
"""See `InputCreator`."""
def get_keras_input(feature_spec, name, is_example=False):
if isinstance(feature_spec, tf.io.FixedLenFeature):
return tf.keras.Input(
shape=(None,) + tuple(feature_spec.shape)
if is_example else tuple(feature_spec.shape),
dtype=feature_spec.dtype,
name=name)
elif isinstance(feature_spec, tf.io.VarLenFeature):
return tf.keras.Input(
shape=(None, 1) if is_example else (1),
dtype=feature_spec.dtype,
name=name,
sparse=True)
elif isinstance(feature_spec, tf.io.RaggedFeature):
return tf.keras.Input(
shape=(None,) *
(len(feature_spec.partitions) + 2) if is_example else
(None,) * (len(feature_spec.partitions) + 1),
dtype=feature_spec.dtype,
name=name,
ragged=True)
else:
raise ValueError("{} is not supported.".format(feature_spec))
context_inputs = {
name: get_keras_input(spec, name)
for name, spec in self._context_feature_spec.items()
}
example_inputs = {
name: get_keras_input(spec, name, is_example=True)
for name, spec in self._example_feature_spec.items()
}
return context_inputs, example_inputs