def run()

in text-semantic-search/embeddings_extraction/etl/pipeline.py [0:0]


def run(pipeline_options, known_args):

  pipeline = beam.Pipeline(options=pipeline_options)
  gcp_project = pipeline_options.get_all_options()['project']

  with impl.Context(known_args.transform_temp_dir):
    articles = (
        pipeline
        | 'Read articles from BigQuery' >> beam.io.Read(beam.io.BigQuerySource(
      project=gcp_project, query=get_source_query(known_args.limit),
      use_standard_sql=True))
    )

    articles_dataset = (articles, get_metadata())
    embeddings_dataset, _ = (
        articles_dataset
        | 'Extract embeddings' >> impl.AnalyzeAndTransformDataset(preprocess_fn)
    )

    embeddings, transformed_metadata = embeddings_dataset

    embeddings | 'Write embeddings to TFRecords' >> beam.io.tfrecordio.WriteToTFRecord(
      file_path_prefix='{0}'.format(known_args.output_dir),
      file_name_suffix='.tfrecords',
      coder=tft_coders.example_proto_coder.ExampleProtoCoder(
        transformed_metadata.schema),
      num_shards=int(known_args.limit/25000)
    )

    (
        articles
        | 'Convert to entity' >> beam.Map(
              lambda input_features: create_entity(
                input_features, known_args.kind))
        | 'Write to Datastore' >> WriteToDatastore(project=gcp_project)
    )

    if known_args.enable_debug:
      embeddings | 'Debug Output' >> beam.io.textio.WriteToText(
        file_path_prefix=known_args.debug_output_prefix,
        file_name_suffix='.txt')

  job = pipeline.run()

  if pipeline_options.get_all_options()['runner'] == 'DirectRunner':
    job.wait_until_finish()