in tfx_addons/sampling/example/sampler_utils.py [0:0]
def trainer_fn(trainer_fn_args, schema):
"""Build the estimator using the high level API.
Args:
trainer_fn_args: Holds args used to train the model as name/value pairs.
schema: Holds the schema of the training examples.
Returns:
A dict of the following:
- estimator: The estimator that will be used for training and eval.
- train_spec: Spec for training.
- eval_spec: Spec for eval.
- eval_input_receiver_fn: Input function for eval.
"""
# Number of nodes in the first layer of the DNN
first_dnn_layer_size = 100
num_dnn_layers = 4
dnn_decay_factor = 0.7
train_batch_size = 40
eval_batch_size = 40
tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)
train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda
trainer_fn_args.train_files,
trainer_fn_args.data_accessor,
tf_transform_output,
batch_size=train_batch_size)
eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda
trainer_fn_args.eval_files,
trainer_fn_args.data_accessor,
tf_transform_output,
batch_size=eval_batch_size)
train_spec = tf.estimator.TrainSpec( # pylint: disable=g-long-lambda
train_input_fn,
max_steps=trainer_fn_args.train_steps)
serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda
tf_transform_output, schema)
exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
eval_spec = tf.estimator.EvalSpec(eval_input_fn,
steps=trainer_fn_args.eval_steps,
exporters=[exporter],
name='credit-fraud-eval')
# Keep multiple checkpoint files for distributed training, note that
# keep_max_checkpoint should be greater or equal to the number of replicas to
# avoid race condition.
run_config = tf.estimator.RunConfig(save_checkpoints_steps=999,
keep_checkpoint_max=5)
run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
warm_start_from = trainer_fn_args.base_model
estimator = _build_estimator(
# Construct layers sizes with exponetial decay
hidden_units=[
max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
for i in range(num_dnn_layers)
],
config=run_config,
warm_start_from=warm_start_from)
# Create an input receiver for TFMA processing
receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda
tf_transform_output, schema)
return {
'estimator': estimator,
'train_spec': train_spec,
'eval_spec': eval_spec,
'eval_input_receiver_fn': receiver_fn
}