text-semantic-search/embeddings_extraction/etl/pipeline.py (86 lines of code) (raw):

#!/usr/bin/python # # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import apache_beam as beam from apache_beam.io.gcp.datastore.v1.datastoreio import WriteToDatastore import tensorflow as tf from google.cloud.proto.datastore.v1 import entity_pb2 from googledatastore import helper as datastore_helper import tensorflow_transform.coders as tft_coders from tensorflow_transform.beam import impl encoder = None def get_source_query(limit=1000000): query = """ SELECT GENERATE_UUID() as id, text FROM ( SELECT DISTINCT LOWER(title) text FROM `bigquery-samples.wikipedia_benchmark.Wiki100B` WHERE ARRAY_LENGTH(split(title,' ')) >= 5 AND language = 'en' AND LENGTH(title) < 500 ) LIMIT {0} """.format(limit) return query def embed_text(text): import tensorflow_hub as hub global encoder if encoder is None: encoder = hub.Module( 'https://tfhub.dev/google/universal-sentence-encoder/2') embedding = encoder(text) return embedding def parse_articles(csv_line): return csv_line.split(',')[1], None def get_metadata(): from tensorflow_transform.tf_metadata import dataset_schema from tensorflow_transform.tf_metadata import dataset_metadata metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema({ 'id': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), 'text': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()) })) return metadata def preprocess_fn(input_features): import tensorflow_transform as tft embedding = tft.apply_function(embed_text, input_features['text']) output_features = { 'id': input_features['id'], 'embedding': embedding } return output_features def create_entity(input_features, kind): entity = entity_pb2.Entity() datastore_helper.add_key_path( entity.key, kind, input_features['id']) datastore_helper.add_properties( entity, { 'text': unicode(input_features['text']) }) return entity def run(pipeline_options, known_args): pipeline = beam.Pipeline(options=pipeline_options) gcp_project = pipeline_options.get_all_options()['project'] with impl.Context(known_args.transform_temp_dir): articles = ( pipeline | 'Read articles from BigQuery' >> beam.io.Read(beam.io.BigQuerySource( project=gcp_project, query=get_source_query(known_args.limit), use_standard_sql=True)) ) articles_dataset = (articles, get_metadata()) embeddings_dataset, _ = ( articles_dataset | 'Extract embeddings' >> impl.AnalyzeAndTransformDataset(preprocess_fn) ) embeddings, transformed_metadata = embeddings_dataset embeddings | 'Write embeddings to TFRecords' >> beam.io.tfrecordio.WriteToTFRecord( file_path_prefix='{0}'.format(known_args.output_dir), file_name_suffix='.tfrecords', coder=tft_coders.example_proto_coder.ExampleProtoCoder( transformed_metadata.schema), num_shards=int(known_args.limit/25000) ) ( articles | 'Convert to entity' >> beam.Map( lambda input_features: create_entity( input_features, known_args.kind)) | 'Write to Datastore' >> WriteToDatastore(project=gcp_project) ) if known_args.enable_debug: embeddings | 'Debug Output' >> beam.io.textio.WriteToText( file_path_prefix=known_args.debug_output_prefix, file_name_suffix='.txt') job = pipeline.run() if pipeline_options.get_all_options()['runner'] == 'DirectRunner': job.wait_until_finish()