in tensorflow_ranking/python/keras/feature.py [0:0]
def call(self, inputs, training=None):
"""Transforms the features into dense context features and example features.
This is the Keras equivalent of `tfr.feature.encode_listwise_features`.
Args:
inputs: (dict) Features with a mix of context (2D) and example features
(3D).
training: (bool) whether in train or inference mode.
Returns:
context_features: (dict) context feature names to dense 2D tensors of
shape [batch_size, feature_dims].
example_features: (dict) example feature names to dense 3D tensors of
shape [batch_size, list_size, feature_dims].
"""
features = inputs
context_features = {}
if self._context_feature_columns:
context_cols_to_tensors = {}
self._context_dense_layer(
features,
training=training,
cols_to_output_tensors=context_cols_to_tensors)
context_features = {
name: context_cols_to_tensors[col]
for name, col in six.iteritems(self.context_feature_columns)
}
example_features = {}
if self._example_feature_columns:
# Compute example_features. Note that the key in `example_feature_columns`
# dict can be different from the key in the `features` dict. We only need
# to reshape the per-example tensors in `features`. To obtain the keys for
# per-example features, we use the parsing feature specs.
example_specs = tf.feature_column.make_parse_example_spec(
list(six.itervalues(self._example_feature_columns)))
example_name = next(six.iterkeys(example_specs))
batch_size = tf.shape(input=features[example_name])[0]
list_size = tf.shape(input=features[example_name])[1]
reshaped_example_features = {}
for name in example_specs:
if name not in features:
continue
reshaped_example_features[name] = utils.reshape_first_ndims(
features[name], 2, [batch_size * list_size])
example_cols_to_tensors = {}
self._example_dense_layer(
reshaped_example_features,
training=training,
cols_to_output_tensors=example_cols_to_tensors)
example_features = {
name: utils.reshape_first_ndims(example_cols_to_tensors[col], 1,
[batch_size, list_size])
for name, col in six.iteritems(self._example_feature_columns)
}
return context_features, example_features