in self-paced-labs/vertex-ai/vertex-pipelines/tfx/tfx_taxifare_tips/tfx_pipeline/pipeline.py [0:0]
def create_pipeline(pipeline_name: Text, pipeline_root: Text):
"""
Args:
pipeline_name:
pipeline_root:
num_epochs:
batch_size:
learning_rate:
hidden_units:
Returns:
pipeline:
"""
# Get train split BigQuery query.
train_sql_query = bq_datasource_utils.get_training_source_query(
config.GOOGLE_CLOUD_PROJECT_ID,
config.GOOGLE_CLOUD_REGION,
config.DATASET_DISPLAY_NAME,
ml_use="UNASSIGNED",
limit=int(config.TRAIN_LIMIT),
)
# Configure train and eval splits for model training and evaluation.
train_output_config = example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(
splits=[
example_gen_pb2.SplitConfig.Split(
name="train", hash_buckets=int(config.NUM_TRAIN_SPLITS)
),
example_gen_pb2.SplitConfig.Split(
name="eval", hash_buckets=int(config.NUM_EVAL_SPLITS)
),
]
)
)
# Generate train split examples.
train_example_gen = BigQueryExampleGen(
query=train_sql_query,
output_config=train_output_config,
).with_id("TrainDataGen")
# Get test source query.
test_sql_query = bq_datasource_utils.get_training_source_query(
config.GOOGLE_CLOUD_PROJECT_ID,
config.GOOGLE_CLOUD_REGION,
config.DATASET_DISPLAY_NAME,
ml_use="TEST",
limit=int(config.TEST_LIMIT),
)
# Configure test split for model evaluation.
test_output_config = example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(
splits=[
example_gen_pb2.SplitConfig.Split(name="test", hash_buckets=1),
]
)
)
# Test example generation.
test_example_gen = BigQueryExampleGen(
query=test_sql_query,
output_config=test_output_config,
).with_id("TestDataGen")
# Schema importer.
schema_importer = Importer(
source_uri=SCHEMA_DIR,
artifact_type=Schema,
).with_id("SchemaImporter")
# schema_importer = ImportSchemaGen(schema_file=SCHEMA_FILE).with_id("SchemaImporter")
# Generate dataset statistics.
statistics_gen = StatisticsGen(
examples=train_example_gen.outputs["examples"]
).with_id("StatisticsGen")
# Generate data schema file.
# schema_gen = SchemaGen(
# statistics=statistics_gen.outputs["statistics"], infer_feature_shape=True
# )
# Example validation.
example_validator = ExampleValidator(
statistics=statistics_gen.outputs["statistics"],
schema=schema_importer.outputs["result"],
).with_id("ExampleValidator")
# Data transformation.
transform = Transform(
examples=train_example_gen.outputs["examples"],
schema=schema_importer.outputs["result"],
module_file=TRANSFORM_MODULE_FILE,
# This is a temporary workaround to run on Dataflow.
force_tf_compat_v1=config.BEAM_RUNNER == "DataflowRunner",
splits_config=transform_pb2.SplitsConfig(
analyze=["train"], transform=["train", "eval"]
),
).with_id("Tranform")
# Add dependency from example_validator to transform.
transform.add_upstream_node(example_validator)
# Train model on Vertex AI.
trainer = VertexTrainer(
module_file=TRAIN_MODULE_FILE,
examples=transform.outputs["transformed_examples"],
transform_graph=transform.outputs["transform_graph"],
custom_config=config.VERTEX_TRAINING_CONFIG,
).with_id("ModelTrainer")
# Get the latest blessed model (baseline) for model validation.
baseline_model_resolver = Resolver(
strategy_class=LatestBlessedModelStrategy,
model=Channel(type=Model),
model_blessing=Channel(type=ModelBlessing),
).with_id("BaselineModelResolver")
# Prepare evaluation config.
eval_config = tfma.EvalConfig(
model_specs=[
tfma.ModelSpec(
signature_name="serving_tf_example",
label_key=features.TARGET_FEATURE_NAME,
prediction_key="probabilities",
)
],
slicing_specs=[
tfma.SlicingSpec(),
],
metrics_specs=[
tfma.MetricsSpec(
metrics=[
tfma.MetricConfig(class_name="ExampleCount"),
tfma.MetricConfig(
class_name="BinaryAccuracy",
threshold=tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={"value": float(config.ACCURACY_THRESHOLD)}
),
# Change threshold will be ignored if there is no
# baseline model resolved from MLMD (first run).
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={"value": -1e-10},
),
),
),
]
)
],
)
# Model evaluation.
evaluator = Evaluator(
examples=test_example_gen.outputs["examples"],
example_splits=["test"],
model=trainer.outputs["model"],
baseline_model=baseline_model_resolver.outputs["model"],
eval_config=eval_config,
schema=schema_importer.outputs["result"],
).with_id("ModelEvaluator")
exported_model_location = os.path.join(
config.MODEL_REGISTRY_URI, config.MODEL_DISPLAY_NAME
)
push_destination = pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=exported_model_location
)
)
# Push custom model to model registry.
pusher = Pusher(
model=trainer.outputs["model"],
model_blessing=evaluator.outputs["blessing"],
push_destination=push_destination,
).with_id("ModelPusher")
pipeline_components = [
train_example_gen,
test_example_gen,
schema_importer,
statistics_gen,
# schema_gen,
example_validator,
transform,
trainer,
baseline_model_resolver,
evaluator,
pusher,
]
logging.info(
"Pipeline components: %s",
", ".join([component.id for component in pipeline_components]),
)
beam_pipeline_args = config.BEAM_DIRECT_PIPELINE_ARGS
if config.BEAM_RUNNER == "DataflowRunner":
beam_pipeline_args = config.BEAM_DATAFLOW_PIPELINE_ARGS
logging.info("Beam pipeline args: %s", beam_pipeline_args)
return Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=pipeline_components,
beam_pipeline_args=beam_pipeline_args,
enable_cache=int(config.ENABLE_CACHE),
)