text-semantic-search/semantic_search/utils/search.py (59 lines of code) (raw):
#!/usr/bin/python
#
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specif5ic language governing permissions and
# limitations under the License.
import embedding
import matching
import lookup
import os
import logging
import googleapiclient
from httplib2 import Http
from oauth2client.client import GoogleCredentials
# Configurable parameters
GCS_BUCKET = ''
KIND = 'wikipedia'
GCS_INDEX_LOCATION = '{}/index/embeds.index'.format(KIND)
INDEX_FILE = 'embeds.index'
CHUNKSIZE = 16 * 1024 * 1024
def _download_from_gcs(gcs_services, bucket_name, gcs_location, local_file_name):
print('Downloading file {} to {}...'.format(
'gs://{}/{}'.format(bucket_name, gcs_location), local_file_name))
with open(local_file_name, 'wb') as file_writer:
request = gcs_services.objects().get_media(
bucket=bucket_name, object=gcs_location)
media = googleapiclient.http.MediaIoBaseDownload(
file_writer, request, chunksize=CHUNKSIZE)
download_complete = False
while not download_complete:
progress, download_complete = media.next_chunk()
print('File {} downloaded to {}.'.format(
'gs://{}/{}'.format(bucket_name, gcs_location), local_file_name))
print('File size: {} GB'.format(
round(os.path.getsize(local_file_name) / float(1024 ** 3), 2)))
def download_artefacts(index_file, bucket_name, gcs_index_location):
http = Http()
credentials = GoogleCredentials.get_application_default()
credentials.authorize(http)
gcs_services = googleapiclient.discovery.build('storage', 'v1', http=http)
_download_from_gcs(gcs_services, bucket_name, gcs_index_location, index_file)
_download_from_gcs(gcs_services, bucket_name,
gcs_index_location + '.mapping', index_file + '.mapping')
class SearchUtil:
def __init__(self):
print('Initialising search utility...')
dir_path = os.path.dirname(os.path.realpath(__file__))
index_file = os.path.join(dir_path, INDEX_FILE)
print('Downloading index artefacts...')
download_artefacts(index_file, GCS_BUCKET, GCS_INDEX_LOCATION)
print('Index artefacts downloaded.')
print('Initialising matching util...')
self.match_util = matching.MatchingUtil(index_file)
print('Matching util initialised.')
print('Initialising embedding util...')
self.embed_util = embedding.EmbedUtil()
print('Embedding util initialised.')
print('Initialising datastore util...')
self.datastore_util = lookup.DatastoreUtil(KIND)
print('Datastore util is initialised.')
print('Search utility is up and running.')
def search(self, query, num_matches=10):
query_embedding = self.embed_util.extract_embeddings(query)
item_ids = self.match_util.find_similar_items(query_embedding, num_matches)
items = self.datastore_util.get_items(item_ids)
return items