example-apps/relevance-workbench/app-api/app.py (176 lines of code) (raw):

from flask import Flask, request, jsonify import os from elasticsearch import Elasticsearch CLOUD_ID = os.environ["CLOUD_ID"] ES_USER = os.environ["ELASTICSEARCH_USERNAME"] ES_PASSWORD = os.environ["ELASTICSEARCH_PASSWORD"] datasets = { "movies": { "id": "movies", "label": "Movies", "index": "search-movies", "search_fields": ["title", "overview", "keywords"], "elser_search_fields": [ "ml.inference.overview_expanded.predicted_value", "ml.inference.title_expanded.predicted_value^0.5", ], "result_fields": ["title", "overview"], "mapping_fields": {"text": "overview", "title": "title"}, } } app = Flask(__name__) @app.route("/api/search/<index>") def route_api_search(index): """ Execute the search """ [query, rrf, type, k, datasetId] = [ request.args.get("q"), request.args.get("rrf", default=False, type=lambda v: v.lower() == "true"), request.args.get("type", default="bm25"), request.args.get("k", default=0), request.args.get("dataset", default="movies"), ] if type == "elser": search_result = run_semantic_search( query, index, **{"rrf": rrf, "k": k, "dataset": datasetId} ) elif type == "bm25": search_result = run_full_text_search(query, index, **{"dataset": datasetId}) transformed_search_result = transform_search_response( search_result, datasets[datasetId]["mapping_fields"] ) return jsonify(response=transformed_search_result) @app.route("/api/datasets", methods=["GET"]) def route_api_datasets(): """ Return the available datasets """ return datasets @app.errorhandler(404) def resource_not_found(e): """ Return a JSON response of the error and the URL that was requested """ return jsonify(error=str(e)), 404 def get_text_expansion_request_body(query, size=10, **options): """ Generates an ES text expansion search request. """ fields = datasets[options["dataset"]]["elser_search_fields"] result_fields = datasets[options["dataset"]]["result_fields"] text_expansions = [] boost = 1 for field in fields: split_field_descriptor = field.split("^") if len(split_field_descriptor) == 2: boost = split_field_descriptor[1] field = split_field_descriptor[0] te = {"text_expansion": {}} te["text_expansion"][field] = { "model_text": query, "model_id": ".elser_model_1", "boost": boost, } text_expansions.append(te) return { "_source": False, "fields": result_fields, "size": size, "query": {"bool": {"should": text_expansions}}, } def get_text_expansion_request_body(query, size=10, **options): """ Generates an ES text expansion search request. """ fields = datasets[options["dataset"]]["elser_search_fields"] result_fields = datasets[options["dataset"]]["result_fields"] text_expansions = [] boost = 1 for field in fields: split_field_descriptor = field.split("^") if len(split_field_descriptor) == 2: boost = split_field_descriptor[1] field = split_field_descriptor[0] te = {"text_expansion": {}} te["text_expansion"][field] = { "model_text": query, "model_id": ".elser_model_1", "boost": boost, } text_expansions.append(te) return { "_source": False, "fields": result_fields, "size": size, "query": {"bool": {"should": text_expansions}}, } def get_text_search_request_body(query, size=10, **options): """ Generates an ES full text search request. """ fields = datasets[options["dataset"]]["result_fields"] search_fields = datasets[options["dataset"]]["search_fields"] return { "_source": False, "fields": fields, "size": size, "query": {"multi_match": {"query": query, "fields": search_fields}}, } def get_hybrid_search_rrf_request_body(query, size=10, **options): """ Generates an ES hybrid search with RRF """ fields = datasets[options["dataset"]]["elser_search_fields"] result_fields = datasets[options["dataset"]]["result_fields"] search_fields = datasets[options["dataset"]]["search_fields"] text_expansions = [] boost = 1 for field in fields: split_field_descriptor = field.split("^") if len(split_field_descriptor) == 2: boost = split_field_descriptor[1] field = split_field_descriptor[0] te = {"text_expansion": {}} te["text_expansion"][field] = { "model_text": query, "model_id": ".elser_model_1", "boost": boost, } text_expansions.append(te) return { "_source": False, "fields": result_fields, "size": size, "rank": {"rrf": {"window_size": 10, "rank_constant": 2}}, "sub_searches": [ {"query": {"bool": {"should": text_expansions}}}, {"query": {"multi_match": {"query": query, "fields": search_fields}}}, ], } def execute_search_request(index, body): """ Executes an ES search request and returns the JSON response. """ es = Elasticsearch(cloud_id=CLOUD_ID, basic_auth=(ES_USER, ES_PASSWORD)) response = es.search( index=index, query=body["query"], fields=body["fields"], size=body["size"], source=body["_source"], ) return response def execute_search_request_using_raw_dsl(index, body): """ Executes an ES search request using the request library and returns the JSON response. """ es = Elasticsearch(cloud_id=CLOUD_ID, basic_auth=(ES_USER, ES_PASSWORD)) response = es.perform_request( "POST", f"/{index}/_search", headers={"content-type": "application/json", "accept": "application/json"}, body=body, ) return response def run_full_text_search(query, index, **options): """ Runs a full text search on the given index using the passed query. """ if query is None or query.strip() == "": raise Exception("Query cannot be empty") body = get_text_search_request_body(query, **options) response = execute_search_request(index, body) return response["hits"]["hits"] def run_semantic_search(query, index, **options): """ Runs a semantic search of the provided query on the target index, and reranks the KNN and BM25 results. """ if options.get("rrf") == True: body = get_hybrid_search_rrf_request_body(query, **options) # Execute the request using the raw DSL to avoid the ES Python client since sub_searches query are not supported yet response_json = execute_search_request_using_raw_dsl(index, body) else: body = get_text_expansion_request_body(query, **options) print(body) response_json = execute_search_request(index, body) return response_json["hits"]["hits"] def find_id_index(id: int, hits: list): """ Finds the index of an object in `hits` which has _id == `id`. """ for i, v in enumerate(hits): if v["_id"] == id: return i + 1 return 0 def transform_search_response(searchResults, mappingFields): for hit in searchResults: fields = hit["fields"] hit["fields"] = { "text": fields[mappingFields["text"]], "title": fields[mappingFields["title"]], } return searchResults