def model_fn()

in recommended-item-search/softmax_model.py [0:0]


def model_fn(features, labels, mode, params):
  """A recommendation model for movielens dataset.
  
  Args:
    features: sequences of movie_ids (batch_size x sequence_size)
    labels: None
  """
  # fill a difference between training and prediction input.
  if not isinstance(features, dict):
    features = {'movie_ids': features}

  # create user_embeddings
  feature_columns = get_feature_columns(
      metadata_path=params.metadata_path,
      embeddings_dim=params.hidden_dims[-1])
  user_input = tf.feature_column.input_layer(
      features=features, feature_columns=feature_columns)
  user_embeddings = build_network(
      inputs=user_input, hidden_dims=params.hidden_dims,
      activation_fn=get_activation_fn(params.activation_name))
  
  # extract movie_embeddings
  with tf.variable_scope('input_layer', reuse=True):
    movie_embeddings = tf.get_variable('movie_ids_embedding/embedding_weights')

  # generate labels from features['movie_ids']
  labels = generate_labels(features)
  loss = softmax_loss(user_embeddings, movie_embeddings, labels)
  
  estimator_spec = None
  
  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'user_embeddings': user_embeddings
    }
    export_outputs = {
        'predictions': tf.estimator.export.PredictOutput(predictions)
    }
    estimator_spec = tf.estimator.EstimatorSpec(
        mode=mode, predictions=predictions, export_outputs=export_outputs)
    
  if mode == tf.estimator.ModeKeys.TRAIN:
    global_step = tf.train.get_global_step()
    learning_rate = tf.train.exponential_decay(
        learning_rate=params.learning_rate, global_step=global_step,
        decay_steps=params.lr_decay_steps, decay_rate=params.lr_decay_rate)
    optimizer = tf.train.AdagradOptimizer(learning_rate)
    train_op = optimizer.minimize(loss, global_step=global_step)
    estimator_spec = tf.estimator.EstimatorSpec(
        mode=mode, loss=loss, train_op=train_op)

  if mode == tf.estimator.ModeKeys.EVAL:
    predictions = tf.matmul(
        user_embeddings, movie_embeddings, transpose_b=True)
    eval_metric_ops = {
        'precision_at_10': tf.metrics.precision_at_k(
            labels=labels, predictions=predictions, k=10)
    }
    estimator_spec = tf.estimator.EstimatorSpec(
        mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
  
  return estimator_spec