in fairness_indicators/example_model.py [0:0]
def train_model(model_dir,
train_tf_file,
label,
text_feature,
feature_map,
module_spec='https://tfhub.dev/google/nnlm-en-dim128/1'):
"""Train model using DNN Classifier.
Args:
model_dir: Directory path to save trained model.
train_tf_file: File containing training TFRecordDataset.
label: Groundtruth label.
text_feature: Text feature to be evaluated.
feature_map: Dict of feature names to their data type.
module_spec: A module spec defining the module to instantiate or a path
where to load a module spec.
Returns:
Trained DNNClassifier.
"""
def train_input_fn():
"""Train Input function."""
def parse_function(serialized):
parsed_example = tf.io.parse_single_example(
serialized=serialized, features=feature_map)
# Adds a weight column to deal with unbalanced classes.
parsed_example['weight'] = tf.add(parsed_example[label], 0.1)
return (parsed_example, parsed_example[label])
train_dataset = tf.data.TFRecordDataset(
filenames=[train_tf_file]).map(parse_function).batch(512)
return train_dataset
text_embedding_column = hub.text_embedding_column(
key=text_feature, module_spec=module_spec)
classifier = tf.estimator.DNNClassifier(
hidden_units=[500, 100],
weight_column='weight',
feature_columns=[text_embedding_column],
n_classes=2,
optimizer=tf.train.AdagradOptimizer(learning_rate=0.003),
model_dir=model_dir)
classifier.train(input_fn=train_input_fn, steps=1000)
return classifier