in tfx_addons/sampling/example/sampler_pipeline_local.py [0:0]
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
module_file: Text, serving_model_dir: Text,
metadata_path: Text,
beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
"""Implements an example pipeline with the sampling component witin TFX."""
example_gen = CsvExampleGen(input_base=data_root)
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'])
# Sampler component, with input examples and class label.
sample = Sampler(input_data=example_gen.outputs['examples'],
splits=['train'],
label='Class',
shards=10)
transform = Transform(examples=sample.outputs['output_data'],
schema=schema_gen.outputs['schema'],
module_file=module_file)
latest_model_resolver = resolver.Resolver(
strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
latest_model=Channel(type=Model)).with_id('latest_model_resolver')
trainer = Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
transformed_examples=transform.outputs['transformed_examples'],
schema=schema_gen.outputs['schema'],
base_model=latest_model_resolver.outputs['latest_model'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
model_resolver = resolver.Resolver(
strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
model=Channel(type=Model),
model_blessing=Channel(
type=ModelBlessing)).with_id('latest_blessed_model_resolver')
# Uses TFMA to compute a evaluation statistics over features of a model and
# perform quality validation of a candidate model (compared to a baseline).
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(signature_name='eval')],
slicing_specs=[
tfma.SlicingSpec(),
tfma.SlicingSpec(feature_keys=['trip_start_hour'])
],
metrics_specs=[
tfma.MetricsSpec(
thresholds={
'accuracy':
tfma.config.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.6}),
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-10}))
})
])
evaluator = Evaluator(examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'],
eval_config=eval_config)
pusher = Pusher(model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=serving_model_dir)))
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen,
statistics_gen,
schema_gen,
example_validator,
sample,
transform,
latest_model_resolver,
trainer,
model_resolver,
evaluator,
pusher,
],
enable_cache=False,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
beam_pipeline_args=beam_pipeline_args)