def train_model()

in fairness_indicators/example_model.py [0:0]


def train_model(model_dir,
                train_tf_file,
                label,
                text_feature,
                feature_map,
                module_spec='https://tfhub.dev/google/nnlm-en-dim128/1'):
  """Train model using DNN Classifier.

  Args:
    model_dir: Directory path to save trained model.
    train_tf_file: File containing training TFRecordDataset.
    label: Groundtruth label.
    text_feature: Text feature to be evaluated.
    feature_map: Dict of feature names to their data type.
    module_spec: A module spec defining the module to instantiate or a path
      where to load a module spec.

  Returns:
    Trained DNNClassifier.
  """

  def train_input_fn():
    """Train Input function."""

    def parse_function(serialized):
      parsed_example = tf.io.parse_single_example(
          serialized=serialized, features=feature_map)
      # Adds a weight column to deal with unbalanced classes.
      parsed_example['weight'] = tf.add(parsed_example[label], 0.1)
      return (parsed_example, parsed_example[label])

    train_dataset = tf.data.TFRecordDataset(
        filenames=[train_tf_file]).map(parse_function).batch(512)
    return train_dataset

  text_embedding_column = hub.text_embedding_column(
      key=text_feature, module_spec=module_spec)

  classifier = tf.estimator.DNNClassifier(
      hidden_units=[500, 100],
      weight_column='weight',
      feature_columns=[text_embedding_column],
      n_classes=2,
      optimizer=tf.train.AdagradOptimizer(learning_rate=0.003),
      model_dir=model_dir)

  classifier.train(input_fn=train_input_fn, steps=1000)
  return classifier