in text-semantic-search/embeddings_extraction/etl/pipeline.py [0:0]
def run(pipeline_options, known_args):
pipeline = beam.Pipeline(options=pipeline_options)
gcp_project = pipeline_options.get_all_options()['project']
with impl.Context(known_args.transform_temp_dir):
articles = (
pipeline
| 'Read articles from BigQuery' >> beam.io.Read(beam.io.BigQuerySource(
project=gcp_project, query=get_source_query(known_args.limit),
use_standard_sql=True))
)
articles_dataset = (articles, get_metadata())
embeddings_dataset, _ = (
articles_dataset
| 'Extract embeddings' >> impl.AnalyzeAndTransformDataset(preprocess_fn)
)
embeddings, transformed_metadata = embeddings_dataset
embeddings | 'Write embeddings to TFRecords' >> beam.io.tfrecordio.WriteToTFRecord(
file_path_prefix='{0}'.format(known_args.output_dir),
file_name_suffix='.tfrecords',
coder=tft_coders.example_proto_coder.ExampleProtoCoder(
transformed_metadata.schema),
num_shards=int(known_args.limit/25000)
)
(
articles
| 'Convert to entity' >> beam.Map(
lambda input_features: create_entity(
input_features, known_args.kind))
| 'Write to Datastore' >> WriteToDatastore(project=gcp_project)
)
if known_args.enable_debug:
embeddings | 'Debug Output' >> beam.io.textio.WriteToText(
file_path_prefix=known_args.debug_output_prefix,
file_name_suffix='.txt')
job = pipeline.run()
if pipeline_options.get_all_options()['runner'] == 'DirectRunner':
job.wait_until_finish()