text-semantic-search/index_builder/builder/task.py (75 lines of code) (raw):

#!/usr/bin/python # # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import argparse from datetime import datetime import index from httplib2 import Http from googleapiclient.http import MediaFileUpload from googleapiclient.discovery import build from oauth2client.client import GoogleCredentials LOCAL_INDEX_FILE = 'embeds.index' CHUNKSIZE = 64 * 1024 * 1024 def _upload_to_gcs(gcs_services, local_file_name, bucket_name, gcs_location): logging.info('Uploading file {} to {}...'.format( local_file_name, "gs://{}/{}".format(bucket_name, gcs_location))) media = MediaFileUpload(local_file_name, mimetype='application/octet-stream', chunksize=CHUNKSIZE, resumable=True) request = gcs_services.objects().insert( bucket=bucket_name, name=gcs_location, media_body=media) response = None while response is None: progress, response = request.next_chunk() logging.info('File {} uploaded to {}.'.format( local_file_name, "gs://{}/{}".format(bucket_name, gcs_location))) def upload_artefacts(gcs_index_file): http = Http() credentials = GoogleCredentials.get_application_default() credentials.authorize(http) gcs_services = build('storage', 'v1', http=http) split_list = gcs_index_file[5:].split('/', 1) bucket_name = split_list[0] blob_path = split_list[1] if len(split_list) == 2 else None _upload_to_gcs(gcs_services, LOCAL_INDEX_FILE, bucket_name, blob_path) _upload_to_gcs(gcs_services, LOCAL_INDEX_FILE+'.mapping', bucket_name, blob_path+'.mapping') def get_args(): args_parser = argparse.ArgumentParser() args_parser.add_argument( '--embedding-files', help='GCS or local paths to embedding files', required=True ) args_parser.add_argument( '--index-file', help='GCS or local paths to output index file', required=True ) args_parser.add_argument( '--num-trees', help='Number of trees to build in the index', default=1000, type=int ) args_parser.add_argument( '--job-dir', help='GCS or local paths to job package' ) return args_parser.parse_args() def main(): args = get_args() time_start = datetime.utcnow() logging.info('Index building started...') index.build_index(args.embedding_files, LOCAL_INDEX_FILE, args.num_trees) time_end = datetime.utcnow() logging.info('Index building finished.') time_elapsed = time_end - time_start logging.info('Index building elapsed time: {} seconds'.format(time_elapsed.total_seconds())) time_start = datetime.utcnow() logging.info('Uploading index artefacts started...') upload_artefacts(args.index_file) time_end = datetime.utcnow() logging.info('Uploading index artefacts finished.') time_elapsed = time_end - time_start logging.info('Uploading index artefacts elapsed time: {} seconds'.format(time_elapsed.total_seconds())) if __name__ == '__main__': main()