text-semantic-search/index_builder/builder/task.py (75 lines of code) (raw):
#!/usr/bin/python
#
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import argparse
from datetime import datetime
import index
from httplib2 import Http
from googleapiclient.http import MediaFileUpload
from googleapiclient.discovery import build
from oauth2client.client import GoogleCredentials
LOCAL_INDEX_FILE = 'embeds.index'
CHUNKSIZE = 64 * 1024 * 1024
def _upload_to_gcs(gcs_services, local_file_name, bucket_name, gcs_location):
logging.info('Uploading file {} to {}...'.format(
local_file_name, "gs://{}/{}".format(bucket_name, gcs_location)))
media = MediaFileUpload(local_file_name,
mimetype='application/octet-stream',
chunksize=CHUNKSIZE, resumable=True)
request = gcs_services.objects().insert(
bucket=bucket_name, name=gcs_location, media_body=media)
response = None
while response is None:
progress, response = request.next_chunk()
logging.info('File {} uploaded to {}.'.format(
local_file_name, "gs://{}/{}".format(bucket_name, gcs_location)))
def upload_artefacts(gcs_index_file):
http = Http()
credentials = GoogleCredentials.get_application_default()
credentials.authorize(http)
gcs_services = build('storage', 'v1', http=http)
split_list = gcs_index_file[5:].split('/', 1)
bucket_name = split_list[0]
blob_path = split_list[1] if len(split_list) == 2 else None
_upload_to_gcs(gcs_services, LOCAL_INDEX_FILE, bucket_name, blob_path)
_upload_to_gcs(gcs_services,
LOCAL_INDEX_FILE+'.mapping', bucket_name, blob_path+'.mapping')
def get_args():
args_parser = argparse.ArgumentParser()
args_parser.add_argument(
'--embedding-files',
help='GCS or local paths to embedding files',
required=True
)
args_parser.add_argument(
'--index-file',
help='GCS or local paths to output index file',
required=True
)
args_parser.add_argument(
'--num-trees',
help='Number of trees to build in the index',
default=1000,
type=int
)
args_parser.add_argument(
'--job-dir',
help='GCS or local paths to job package'
)
return args_parser.parse_args()
def main():
args = get_args()
time_start = datetime.utcnow()
logging.info('Index building started...')
index.build_index(args.embedding_files, LOCAL_INDEX_FILE, args.num_trees)
time_end = datetime.utcnow()
logging.info('Index building finished.')
time_elapsed = time_end - time_start
logging.info('Index building elapsed time: {} seconds'.format(time_elapsed.total_seconds()))
time_start = datetime.utcnow()
logging.info('Uploading index artefacts started...')
upload_artefacts(args.index_file)
time_end = datetime.utcnow()
logging.info('Uploading index artefacts finished.')
time_elapsed = time_end - time_start
logging.info('Uploading index artefacts elapsed time: {} seconds'.format(time_elapsed.total_seconds()))
if __name__ == '__main__':
main()