backend/unified-cloud-search/services/unified_cloud_search_service.py (110 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import dataclasses
import functools
import logging
import random
import subprocess
from typing import Any, Dict, List, Optional
import requests
import tracer_helper
from services.search_service import CodeInfo, Item, SearchResult, SearchService
logger = logging.getLogger(__name__)
tracer = tracer_helper.get_tracer(__name__)
class UnifiedCloudSearchService(SearchService[str]):
project_id: str
location: str
datastore_id: str
deployed_index_id: str
is_staging: bool
@property
def id(self) -> str:
return self._id
@property
def name(self) -> str:
"""Name for this service that is shown on the frontend."""
return self._name
@property
def description(self) -> str:
"""Description for this service that is shown on the frontend."""
return self._description
@property
def allows_text_input(self) -> bool:
"""If true, this service allows text input."""
return False
@property
def code_info(self) -> Optional[CodeInfo]:
"""Info about code used to generate index."""
return self._code_info
@abc.abstractmethod
def convert_to_search_result(
self, results: List[Dict[str, Any]]
) -> List[Optional[SearchResult]]:
pass
def __init__(
self,
id: str,
name: str,
description: str,
words_file: str,
project_id: str,
location: str,
datastore_id: str,
is_staging: bool = False,
code_info: Optional[CodeInfo] = None,
) -> None:
self._id = id
self._name = name
self._description = description
self._code_info = code_info
with open(words_file, "r") as f:
words = f.readlines()
self.words = [word.strip() for word in words]
self.project_id = project_id
self.location = location
self.datastore_id = datastore_id
self.is_staging = is_staging
@tracer.start_as_current_span("get_all")
def get_suggestions(self, num_items: int = 60) -> List[Item]:
"""Get all existing ids and items."""
return random.sample(
[Item(id=word, text=word, image=None) for word in self.words],
min(num_items, len(self.words)),
)
@tracer.start_as_current_span("get_by_id")
def get_by_id(self, id: str) -> Optional[str]:
"""Get an item by id."""
return id
@tracer.start_as_current_span("match")
def search(self, query: str, num_neighbors: int) -> List[SearchResult]:
logger.info(f"index_endpoint.match completed")
# Retrieve latest access token
access_token = subprocess.run(
"gcloud auth print-access-token", shell=True, capture_output=True, text=True
).stdout.strip()
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
json_data = {
"query": query,
"page_size": num_neighbors,
"offset": 0,
}
url = (
f"https://discoveryengine.googleapis.com/v1alpha/projects/{self.project_id}/locations/{self.location}/collections/default_collection/dataStores/{self.datastore_id}/servingConfigs/default_search:search"
if not self.is_staging
else f"https://staging-discoveryengine.sandbox.googleapis.com/v1alpha/projects/{self.project_id}/locations/{self.location}/collections/default_collection/dataStores/{self.datastore_id}/servingConfigs/default_config:search"
)
response = requests.post(
url,
headers=headers,
json=json_data,
)
if response.status_code == 200:
matches_all = self.convert_to_search_result(
results=response.json()["results"]
)
logger.info(f"matches converted")
matches_all_nonoptional: List[SearchResult] = [
match for match in matches_all if match is not None
]
logger.info(f"matches none filtered")
return matches_all_nonoptional
else:
raise RuntimeError("Error retrieving search results")
@tracer.start_as_current_span("get_total_index_count")
def get_total_index_count(self) -> int:
return 1234