text-semantic-search/semantic_search/utils/search.py (59 lines of code) (raw):

#!/usr/bin/python # # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specif5ic language governing permissions and # limitations under the License. import embedding import matching import lookup import os import logging import googleapiclient from httplib2 import Http from oauth2client.client import GoogleCredentials # Configurable parameters GCS_BUCKET = '' KIND = 'wikipedia' GCS_INDEX_LOCATION = '{}/index/embeds.index'.format(KIND) INDEX_FILE = 'embeds.index' CHUNKSIZE = 16 * 1024 * 1024 def _download_from_gcs(gcs_services, bucket_name, gcs_location, local_file_name): print('Downloading file {} to {}...'.format( 'gs://{}/{}'.format(bucket_name, gcs_location), local_file_name)) with open(local_file_name, 'wb') as file_writer: request = gcs_services.objects().get_media( bucket=bucket_name, object=gcs_location) media = googleapiclient.http.MediaIoBaseDownload( file_writer, request, chunksize=CHUNKSIZE) download_complete = False while not download_complete: progress, download_complete = media.next_chunk() print('File {} downloaded to {}.'.format( 'gs://{}/{}'.format(bucket_name, gcs_location), local_file_name)) print('File size: {} GB'.format( round(os.path.getsize(local_file_name) / float(1024 ** 3), 2))) def download_artefacts(index_file, bucket_name, gcs_index_location): http = Http() credentials = GoogleCredentials.get_application_default() credentials.authorize(http) gcs_services = googleapiclient.discovery.build('storage', 'v1', http=http) _download_from_gcs(gcs_services, bucket_name, gcs_index_location, index_file) _download_from_gcs(gcs_services, bucket_name, gcs_index_location + '.mapping', index_file + '.mapping') class SearchUtil: def __init__(self): print('Initialising search utility...') dir_path = os.path.dirname(os.path.realpath(__file__)) index_file = os.path.join(dir_path, INDEX_FILE) print('Downloading index artefacts...') download_artefacts(index_file, GCS_BUCKET, GCS_INDEX_LOCATION) print('Index artefacts downloaded.') print('Initialising matching util...') self.match_util = matching.MatchingUtil(index_file) print('Matching util initialised.') print('Initialising embedding util...') self.embed_util = embedding.EmbedUtil() print('Embedding util initialised.') print('Initialising datastore util...') self.datastore_util = lookup.DatastoreUtil(KIND) print('Datastore util is initialised.') print('Search utility is up and running.') def search(self, query, num_matches=10): query_embedding = self.embed_util.extract_embeddings(query) item_ids = self.match_util.find_similar_items(query_embedding, num_matches) items = self.datastore_util.get_items(item_ids) return items